feat(nix-compat/wire/bytes): allow specifying a pre-read size

Change-Id: I9c94239c308cfbc2e6dae871ba77fb33507433c9
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11517
Tested-by: BuildkiteCI
Reviewed-by: flokli <flokli@flokli.de>
This commit is contained in:
edef 2024-04-25 22:43:59 +00:00
parent 859bfcb68b
commit 70c679eac4

View file

@ -61,6 +61,20 @@ where
state: BytesPacketPosition::Size(0),
}
}
/// Construct a new BytesReader with a known, and already-read size.
pub fn with_size(r: R, size: u64) -> Self {
Self {
inner: r,
allowed_size: size..=size,
payload_size: u64::to_le_bytes(size),
state: if size != 0 {
BytesPacketPosition::Payload(0)
} else {
BytesPacketPosition::Padding(0)
},
}
}
}
/// Returns an error if the passed usize is 0.
#[inline]
@ -261,6 +275,33 @@ mod tests {
assert_eq!(payload, &buf[..]);
}
/// Read bytes packets of various length, and ensure read_to_end returns the
/// expected payload.
#[rstest]
#[case::empty(&[])] // empty bytes packet
#[case::size_1b(&[0xff])] // 1 bytes payload
#[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding)
#[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding)
#[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet
#[tokio::test]
async fn read_payload_correct_known(#[case] payload: &[u8]) {
let packet = produce_packet_bytes(payload).await;
let size = u64::from_le_bytes({
let mut buf = [0; 8];
buf.copy_from_slice(&packet[..8]);
buf
});
let mut mock = Builder::new().read(&packet[8..]).build();
let mut r = BytesReader::with_size(&mut mock, size);
let mut buf = Vec::new();
r.read_to_end(&mut buf).await.expect("must succeed");
assert_eq!(payload, &buf[..]);
}
/// Fail if the bytes packet is larger than allowed
#[tokio::test]
async fn read_bigger_than_allowed_fail() {