refactor(nix-compat/wire/bytes): drop pin_project, clean up

We already require R: Unpin in the constructor, so there's not much use
to pin projection.

Change-Id: Ia7bf734dc3aa86ffa6d1d5de778939baa9676bb9
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11516
Tested-by: BuildkiteCI
Reviewed-by: flokli <flokli@flokli.de>
This commit is contained in:
edef 2024-04-25 22:39:57 +00:00
parent 652413c97d
commit 859bfcb68b

View file

@ -1,45 +1,39 @@
use pin_project_lite::pin_project;
use std::{
io,
ops::{Bound, RangeBounds, RangeInclusive},
task::{ready, Poll},
pin::Pin,
task::{self, ready, Poll},
};
use tokio::io::AsyncRead;
use super::{padding_len, BytesPacketPosition, LEN_SIZE};
pin_project! {
/// Reads a "bytes wire packet" from the underlying reader.
/// The format is the same as in [crate::wire::bytes::read_bytes],
/// however this structure provides a [AsyncRead] interface,
/// allowing to not having to pass around the entire payload in memory.
///
/// After being constructed with the underlying reader and an allowed size,
/// subsequent requests to poll_read will return payload data until the end
/// of the packet is reached.
///
/// Internally, it will first read over the size packet, filling payload_size,
/// ensuring it fits allowed_size, then return payload data.
/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
/// when all padding has been successfully consumed too.
///
/// This also means, it's important for a user to always read to the end,
/// and not just call read_exact - otherwise it might not skip over the
/// padding, and return garbage when reading the next packet.
///
/// In case of an error due to size constraints, or in case of not reading
/// all the way to the end (and getting a EOF), the underlying reader is no
/// longer usable and might return garbage.
pub struct BytesReader<R>
where
R: AsyncRead
{
#[pin]
inner: R,
allowed_size: RangeInclusive<u64>,
payload_size: [u8; 8],
state: BytesPacketPosition,
}
/// Reads a "bytes wire packet" from the underlying reader.
/// The format is the same as in [crate::wire::bytes::read_bytes],
/// however this structure provides a [AsyncRead] interface,
/// allowing to not having to pass around the entire payload in memory.
///
/// After being constructed with the underlying reader and an allowed size,
/// subsequent requests to poll_read will return payload data until the end
/// of the packet is reached.
///
/// Internally, it will first read over the size packet, filling payload_size,
/// ensuring it fits allowed_size, then return payload data.
/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
/// when all padding has been successfully consumed too.
///
/// This also means, it's important for a user to always read to the end,
/// and not just call read_exact - otherwise it might not skip over the
/// padding, and return garbage when reading the next packet.
///
/// In case of an error due to size constraints, or in case of not reading
/// all the way to the end (and getting a EOF), the underlying reader is no
/// longer usable and might return garbage.
pub struct BytesReader<R> {
inner: R,
allowed_size: RangeInclusive<u64>,
payload_size: [u8; 8],
state: BytesPacketPosition,
}
impl<R> BytesReader<R>
@ -70,10 +64,10 @@ where
}
/// Returns an error if the passed usize is 0.
#[inline]
fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> {
fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, io::Error> {
if bytes_read == 0 {
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"underlying reader returned EOF",
))
} else {
@ -83,60 +77,60 @@ fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error>
impl<R> AsyncRead for BytesReader<R>
where
R: AsyncRead,
R: AsyncRead + Unpin,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut this = self.project();
self: Pin<&mut Self>,
cx: &mut task::Context,
buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
// Use a loop, so we can deal with (multiple) state transitions.
loop {
match *this.state {
match this.state {
BytesPacketPosition::Size(LEN_SIZE) => {
// used in case an invalid size was signalled.
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
Err(io::Error::new(
io::ErrorKind::InvalidData,
"signalled package size not in allowed range",
))?
}
BytesPacketPosition::Size(pos) => {
// try to read more of the size field.
// We wrap a ReadBuf around this.payload_size here, and set_filled.
let mut read_buf = tokio::io::ReadBuf::new(this.payload_size);
let mut read_buf = tokio::io::ReadBuf::new(&mut this.payload_size);
read_buf.advance(pos);
ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut read_buf))?;
ensure_nonzero_bytes_read(read_buf.filled().len() - pos)?;
let total_size_read = read_buf.filled().len();
if total_size_read == LEN_SIZE {
// If the entire payload size was read, parse it
let payload_size = u64::from_le_bytes(*this.payload_size);
let payload_size = u64::from_le_bytes(this.payload_size);
if !this.allowed_size.contains(&payload_size) {
// If it's not in the allowed
// range, transition to failure mode
// `BytesPacketPosition::Size(LEN_SIZE)`, where only
// an error is returned.
*this.state = BytesPacketPosition::Size(LEN_SIZE)
this.state = BytesPacketPosition::Size(LEN_SIZE)
} else if payload_size == 0 {
// If the payload size is 0, move on to reading padding directly.
*this.state = BytesPacketPosition::Padding(0)
this.state = BytesPacketPosition::Padding(0)
} else {
// Else, transition to reading the payload.
*this.state = BytesPacketPosition::Payload(0)
this.state = BytesPacketPosition::Payload(0)
}
} else {
// If we still need to read more of payload size, update
// our position in the state.
*this.state = BytesPacketPosition::Size(total_size_read)
this.state = BytesPacketPosition::Size(total_size_read)
}
}
BytesPacketPosition::Payload(pos) => {
let signalled_size = u64::from_le_bytes(*this.payload_size);
let signalled_size = u64::from_le_bytes(this.payload_size);
// We don't enter this match arm at all if we're expecting empty payload
debug_assert!(signalled_size > 0, "signalled size must be larger than 0");
@ -147,7 +141,7 @@ where
// Reducing these two u64 to usize on 32bits is fine - we
// only care about not reading too much, not too less.
let mut limited_buf = buf.take((signalled_size - pos) as usize);
ready!(this.inner.as_mut().poll_read(cx, &mut limited_buf))?;
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf))?;
limited_buf.filled().len()
})?;
@ -158,11 +152,11 @@ where
if pos + bytes_read as u64 == signalled_size {
// If we now read all payload, transition to padding
// state.
*this.state = BytesPacketPosition::Padding(0);
this.state = BytesPacketPosition::Padding(0);
} else {
// if we didn't read everything yet, update our position
// in the state.
*this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
}
// We return from poll_read here.
@ -181,7 +175,7 @@ where
// bytes. Only return `Ready(Ok(()))` once we're past the
// padding (or in cases where polling the inner reader
// returns `Poll::Pending`).
let signalled_size = u64::from_le_bytes(*this.payload_size);
let signalled_size = u64::from_le_bytes(this.payload_size);
let total_padding_len = padding_len(signalled_size) as usize;
let padding_len_remaining = total_padding_len - pos;
@ -192,15 +186,15 @@ where
let mut padding_buf = padding_buf.take(padding_len_remaining);
// read into padding_buf.
ready!(this.inner.as_mut().poll_read(cx, &mut padding_buf))?;
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut padding_buf))?;
let bytes_read = ensure_nonzero_bytes_read(padding_buf.filled().len())?;
*this.state = BytesPacketPosition::Padding(pos + bytes_read);
this.state = BytesPacketPosition::Padding(pos + bytes_read);
// ensure the bytes are not null bytes
if !padding_buf.filled().iter().all(|e| *e == b'\0') {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"padding is not all zeroes",
))
.into();