refactor(nix-compat/wire/bytes): fold TrailerReader into BytesReader

The TrailerReader has no purpose separate from BytesReader, and the
code gets a fair bit simpler this way.

EOF handling is simplified, since we just rely on the implicit
behaviour of the existing case.

Change-Id: Id9b9f022c7c89fbc47968a96032fc43553af8290
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11539
Reviewed-by: Brian Olsen <me@griff.name>
Tested-by: BuildkiteCI
Reviewed-by: flokli <flokli@flokli.de>
This commit is contained in:
edef 2024-04-29 14:52:34 +00:00
parent 44bd9543a6
commit fdecf52a52
2 changed files with 41 additions and 91 deletions

View file

@ -1,4 +1,5 @@
use std::{ use std::{
future::Future,
io, io,
ops::{Bound, RangeBounds}, ops::{Bound, RangeBounds},
pin::Pin, pin::Pin,
@ -6,7 +7,7 @@ use std::{
}; };
use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{AsyncRead, ReadBuf};
use trailer::TrailerReader; use trailer::{read_trailer, ReadTrailer, Trailer};
mod trailer; mod trailer;
/// Reads a "bytes wire packet" from the underlying reader. /// Reads a "bytes wire packet" from the underlying reader.
@ -33,6 +34,7 @@ pub struct BytesReader<R> {
#[derive(Debug)] #[derive(Debug)]
enum State<R> { enum State<R> {
/// The data size is being read.
Size { Size {
reader: Option<R>, reader: Option<R>,
/// Minimum length (inclusive) /// Minimum length (inclusive)
@ -42,12 +44,18 @@ enum State<R> {
filled: u8, filled: u8,
buf: [u8; 8], buf: [u8; 8],
}, },
/// Full 8-byte blocks are being read and released to the caller.
Body { Body {
reader: Option<R>, reader: Option<R>,
consumed: u64, consumed: u64,
/// The total length of all user data contained in both the body and trailer.
user_len: u64, user_len: u64,
}, },
Trailer(TrailerReader<R>), /// The trailer is in the process of being read.
ReadTrailer(ReadTrailer<R>),
/// The trailer has been fully read and validated,
/// and data can now be released to the caller.
ReleaseTrailer { consumed: u8, data: Trailer },
} }
impl<R> BytesReader<R> impl<R> BytesReader<R>
@ -100,7 +108,10 @@ where
State::Body { State::Body {
consumed, user_len, .. consumed, user_len, ..
} => Some(user_len - consumed), } => Some(user_len - consumed),
State::Trailer(ref r) => Some(r.len() as u64), State::ReadTrailer(ref fut) => Some(fut.len() as u64),
State::ReleaseTrailer { consumed, ref data } => {
Some(data.len() as u64 - consumed as u64)
}
} }
} }
} }
@ -166,7 +177,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
let reader = if remaining == 0 { let reader = if remaining == 0 {
let reader = reader.take().unwrap(); let reader = reader.take().unwrap();
let user_len = (*user_len & 7) as u8; let user_len = (*user_len & 7) as u8;
*this = State::Trailer(TrailerReader::new(reader, user_len)); *this = State::ReadTrailer(read_trailer(reader, user_len));
continue; continue;
} else { } else {
reader.as_mut().unwrap() reader.as_mut().unwrap()
@ -188,8 +199,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for BytesReader<R> {
} }
.into(); .into();
} }
State::Trailer(reader) => { State::ReadTrailer(fut) => {
return Pin::new(reader).poll_read(cx, buf); *this = State::ReleaseTrailer {
consumed: 0,
data: ready!(Pin::new(fut).poll(cx))?,
};
}
State::ReleaseTrailer { consumed, data } => {
let data = &data[*consumed as usize..];
let data = &data[..usize::min(data.len(), buf.remaining())];
buf.put_slice(data);
*consumed += data.len() as u8;
return Ok(()).into();
} }
} }
} }

View file

@ -53,7 +53,7 @@ impl Tag for Pad {
} }
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ReadTrailer<R, T: Tag> { pub(crate) struct ReadTrailer<R, T: Tag = Pad> {
reader: R, reader: R,
data_len: u8, data_len: u8,
filled: u8, filled: u8,
@ -90,7 +90,7 @@ impl<R, T: Tag> ReadTrailer<R, T> {
impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> { impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
type Output = io::Result<Trailer>; type Output = io::Result<Trailer>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
let this = &mut *self; let this = &mut *self;
loop { loop {
@ -136,72 +136,9 @@ impl<R: AsyncRead + Unpin, T: Tag> Future for ReadTrailer<R, T> {
} }
} }
#[derive(Debug)]
pub(crate) enum TrailerReader<R> {
Reading(ReadTrailer<R, Pad>),
Releasing { off: u8, data: Trailer },
Done,
}
impl<R: AsyncRead + Unpin> TrailerReader<R> {
pub fn new(reader: R, data_len: u8) -> Self {
Self::Reading(read_trailer(reader, data_len))
}
pub fn len(&self) -> u8 {
match self {
TrailerReader::Reading(fut) => fut.len(),
&TrailerReader::Releasing {
off,
data: Trailer { data_len, .. },
} => data_len - off,
TrailerReader::Done => 0,
}
}
}
impl<R: AsyncRead + Unpin> AsyncRead for TrailerReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
user_buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let this = &mut *self;
loop {
match this {
Self::Reading(fut) => {
*this = Self::Releasing {
off: 0,
data: ready!(Pin::new(fut).poll(cx))?,
};
}
Self::Releasing { off: 8, .. } => {
*this = Self::Done;
}
Self::Releasing { off, data } => {
assert_ne!(user_buf.remaining(), 0);
let buf = &data[*off as usize..];
let buf = &buf[..usize::min(buf.len(), user_buf.remaining())];
user_buf.put_slice(buf);
*off += buf.len() as u8;
break;
}
Self::Done => break,
}
}
Ok(()).into()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration; use std::time::Duration;
use tokio::io::AsyncReadExt;
use super::*; use super::*;
@ -213,11 +150,8 @@ mod tests {
.read(&[0xef, 0x00]) .read(&[0xef, 0x00])
.build(); .build();
let mut reader = TrailerReader::new(reader, 2);
let mut buf = vec![];
assert_eq!( assert_eq!(
reader.read_to_end(&mut buf).await.unwrap_err().kind(), read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
io::ErrorKind::UnexpectedEof io::ErrorKind::UnexpectedEof
); );
} }
@ -231,11 +165,8 @@ mod tests {
.wait(Duration::ZERO) .wait(Duration::ZERO)
.build(); .build();
let mut reader = TrailerReader::new(reader, 2);
let mut buf = vec![];
assert_eq!( assert_eq!(
reader.read_to_end(&mut buf).await.unwrap_err().kind(), read_trailer::<_, Pad>(reader, 2).await.unwrap_err().kind(),
io::ErrorKind::InvalidData io::ErrorKind::InvalidData
); );
} }
@ -250,21 +181,17 @@ mod tests {
.read(&[0x00, 0x00, 0x00, 0x00, 0x00]) .read(&[0x00, 0x00, 0x00, 0x00, 0x00])
.build(); .build();
let mut reader = TrailerReader::new(reader, 2); assert_eq!(
&*read_trailer::<_, Pad>(reader, 2).await.unwrap(),
let mut buf = vec![]; &[0xed, 0xef]
reader.read_to_end(&mut buf).await.unwrap(); );
assert_eq!(buf, &[0xed, 0xef]);
} }
#[tokio::test] #[tokio::test]
async fn no_padding() { async fn no_padding() {
let reader = tokio_test::io::Builder::new().build(); assert!(read_trailer::<_, Pad>(io::empty(), 0)
let mut reader = TrailerReader::new(reader, 0); .await
.unwrap()
let mut buf = vec![]; .is_empty());
reader.read_to_end(&mut buf).await.unwrap();
assert!(buf.is_empty());
} }
} }