refactor(nix-compat/wire/bytes): style fixes

Change-Id: I65c3c43df83e0c364a4b7f1f3054c5b676bd07d5
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11605
Reviewed-by: flokli <flokli@flokli.de>
Tested-by: BuildkiteCI
This commit is contained in:
edef 2024-05-08 08:01:27 +00:00
parent ca10a8726f
commit 1eedb88939

View file

@ -2,7 +2,7 @@ use std::{
io::{Error, ErrorKind}, io::{Error, ErrorKind},
ops::RangeInclusive, ops::RangeInclusive,
}; };
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
pub(crate) mod reader; pub(crate) mod reader;
pub use reader::BytesReader; pub use reader::BytesReader;
@ -35,7 +35,7 @@ const LEN_SIZE: usize = 8;
pub async fn read_bytes<R: ?Sized>( pub async fn read_bytes<R: ?Sized>(
r: &mut R, r: &mut R,
allowed_size: RangeInclusive<usize>, allowed_size: RangeInclusive<usize>,
) -> std::io::Result<Vec<u8>> ) -> io::Result<Vec<u8>>
where where
R: AsyncReadExt + Unpin, R: AsyncReadExt + Unpin,
{ {
@ -46,8 +46,8 @@ where
.ok() .ok()
.filter(|len| allowed_size.contains(len)) .filter(|len| allowed_size.contains(len))
.ok_or_else(|| { .ok_or_else(|| {
std::io::Error::new( io::Error::new(
std::io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"signalled package size not in allowed range", "signalled package size not in allowed range",
) )
})?; })?;
@ -63,15 +63,15 @@ where
// make sure we got exactly the number of bytes, and not less. // make sure we got exactly the number of bytes, and not less.
if s as u64 != padded_len { if s as u64 != padded_len {
return Err(std::io::ErrorKind::UnexpectedEof.into()); return Err(io::ErrorKind::UnexpectedEof.into());
} }
let (_content, padding) = buf.split_at(len); let (_content, padding) = buf.split_at(len);
// ensure the padding is all zeroes. // ensure the padding is all zeroes.
if !padding.iter().all(|e| *e == b'\0') { if padding.iter().any(|&b| b != 0) {
return Err(std::io::Error::new( return Err(io::Error::new(
std::io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"padding is not all zeroes", "padding is not all zeroes",
)); ));
} }
@ -84,10 +84,7 @@ where
/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string. /// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
/// Internally uses [read_bytes]. /// Internally uses [read_bytes].
/// Rejects reading more than `allowed_size` bytes of payload. /// Rejects reading more than `allowed_size` bytes of payload.
pub async fn read_string<R>( pub async fn read_string<R>(r: &mut R, allowed_size: RangeInclusive<usize>) -> io::Result<String>
r: &mut R,
allowed_size: RangeInclusive<usize>,
) -> std::io::Result<String>
where where
R: AsyncReadExt + Unpin, R: AsyncReadExt + Unpin,
{ {
@ -107,7 +104,7 @@ where
pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>( pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
w: &mut W, w: &mut W,
b: B, b: B,
) -> std::io::Result<()> { ) -> io::Result<()> {
// write the size packet. // write the size packet.
w.write_u64_le(b.as_ref().len() as u64).await?; w.write_u64_le(b.as_ref().len() as u64).await?;