feat(nix-compat/wire): add read_bytes[_unchecked]
This introduces a version reading sized byte packets. Both read_bytes, accepting a range of allowed sizes, as well as read_bytes_unchecked, which doesn't care, are added, including tests. Co-Authored-By: picnoir <picnoir@alternativebit.fr> Change-Id: I9fc1c61eb561105e649eecca832af28badfdaaa8 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11150 Autosubmit: flokli <flokli@flokli.de> Reviewed-by: picnoir picnoir <picnoir@alternativebit.fr> Tested-by: BuildkiteCI
This commit is contained in:
parent
c364c0b4de
commit
5fccbe5939
2 changed files with 133 additions and 0 deletions
130
tvix/nix-compat/src/wire/bytes.rs
Normal file
130
tvix/nix-compat/src/wire/bytes.rs
Normal file
|
@ -0,0 +1,130 @@
|
|||
use std::ops::RangeBounds;
|
||||
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use super::primitive;
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Read a limited number of bytes from the AsyncRead.
|
||||
/// Rejects reading more than `allowed_size` bytes of payload.
|
||||
/// Internally takes care of dealing with the padding, so the returned Vec<u8>
|
||||
/// only contains the payload.
|
||||
/// This always buffers the entire contents into memory, we'll add a streaming
|
||||
/// version later.
|
||||
pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
|
||||
where
|
||||
R: AsyncReadExt + Unpin,
|
||||
S: RangeBounds<u64>,
|
||||
{
|
||||
// read the length field
|
||||
let len = primitive::read_u64(r).await?;
|
||||
|
||||
if !allowed_size.contains(&len) {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"signalled package size not in allowed range",
|
||||
));
|
||||
}
|
||||
|
||||
// calculate the total length, including padding.
|
||||
// byte packets are padded to 8 byte blocks each.
|
||||
let padded_len = if len % 8 == 0 {
|
||||
len
|
||||
} else {
|
||||
len + (8 - len % 8)
|
||||
};
|
||||
|
||||
let mut limited_reader = r.take(padded_len);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let s = limited_reader.read_to_end(&mut buf).await?;
|
||||
|
||||
// make sure we got exactly the number of bytes, and not less.
|
||||
if s as u64 != padded_len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"got less bytes than expected",
|
||||
));
|
||||
}
|
||||
|
||||
let (_content, padding) = buf.split_at(len as usize);
|
||||
|
||||
// ensure the padding is all zeroes.
|
||||
if !padding.iter().all(|e| *e == b'\0') {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"padding is not all zeroes",
|
||||
));
|
||||
}
|
||||
|
||||
// return the data without the padding
|
||||
buf.truncate(len as usize);
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Read an unlimited number of bytes from the AsyncRead.
|
||||
/// Note this can exhaust memory.
|
||||
/// Internally uses [read_bytes], which takes care of dealing with the padding,
|
||||
/// so the returned Vec<u8> only contains the payload.
|
||||
pub async fn read_bytes_unchecked<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Vec<u8>> {
|
||||
read_bytes(r, 0u64..).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio_test::io::Builder;
|
||||
|
||||
use super::*;
|
||||
use hex_literal::hex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_8_bytes_unchecked() {
|
||||
let mut mock = Builder::new()
|
||||
.read(&8u64.to_le_bytes())
|
||||
.read(&12345678u64.to_le_bytes())
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
&12345678u64.to_le_bytes(),
|
||||
read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_9_bytes_unchecked() {
|
||||
let mut mock = Builder::new()
|
||||
.read(&9u64.to_le_bytes())
|
||||
.read(&hex!("01020304050607080900000000000000"))
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
hex!("010203040506070809"),
|
||||
read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_0_bytes_unchecked() {
|
||||
// A empty byte packet is essentially just the 0 length field.
|
||||
// No data is read, and there's zero padding.
|
||||
let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
|
||||
|
||||
assert_eq!(
|
||||
hex!(""),
|
||||
read_bytes_unchecked(&mut mock).await.unwrap().as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
/// Ensure we don't read any further than the size field if the length
|
||||
/// doesn't match the range we want to accept.
|
||||
async fn test_reject_too_large() {
|
||||
let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
|
||||
|
||||
read_bytes(&mut mock, 10..10)
|
||||
.await
|
||||
.expect_err("expect this to fail");
|
||||
}
|
||||
}
|
|
@ -1,5 +1,8 @@
|
|||
//! Module parsing and emitting the wire format used by Nix, both in the
|
||||
//! nix-daemon protocol as well as in the NAR format.
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub mod bytes;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub mod primitive;
|
||||
|
|
Loading…
Reference in a new issue