refactor(nix-compat/wire/bytes): drop pin_project, clean up
We already require R: Unpin in the constructor, so there's not much use to pin projection. Change-Id: Ia7bf734dc3aa86ffa6d1d5de778939baa9676bb9 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11516 Tested-by: BuildkiteCI Reviewed-by: flokli <flokli@flokli.de>
This commit is contained in:
parent
652413c97d
commit
859bfcb68b
1 changed files with 57 additions and 63 deletions
|
@ -1,45 +1,39 @@
|
|||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
io,
|
||||
ops::{Bound, RangeBounds, RangeInclusive},
|
||||
task::{ready, Poll},
|
||||
pin::Pin,
|
||||
task::{self, ready, Poll},
|
||||
};
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
use super::{padding_len, BytesPacketPosition, LEN_SIZE};
|
||||
|
||||
pin_project! {
|
||||
/// Reads a "bytes wire packet" from the underlying reader.
|
||||
/// The format is the same as in [crate::wire::bytes::read_bytes],
|
||||
/// however this structure provides a [AsyncRead] interface,
|
||||
/// allowing to not having to pass around the entire payload in memory.
|
||||
///
|
||||
/// After being constructed with the underlying reader and an allowed size,
|
||||
/// subsequent requests to poll_read will return payload data until the end
|
||||
/// of the packet is reached.
|
||||
///
|
||||
/// Internally, it will first read over the size packet, filling payload_size,
|
||||
/// ensuring it fits allowed_size, then return payload data.
|
||||
/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
|
||||
/// when all padding has been successfully consumed too.
|
||||
///
|
||||
/// This also means, it's important for a user to always read to the end,
|
||||
/// and not just call read_exact - otherwise it might not skip over the
|
||||
/// padding, and return garbage when reading the next packet.
|
||||
///
|
||||
/// In case of an error due to size constraints, or in case of not reading
|
||||
/// all the way to the end (and getting a EOF), the underlying reader is no
|
||||
/// longer usable and might return garbage.
|
||||
pub struct BytesReader<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
{
|
||||
#[pin]
|
||||
inner: R,
|
||||
|
||||
allowed_size: RangeInclusive<u64>,
|
||||
payload_size: [u8; 8],
|
||||
state: BytesPacketPosition,
|
||||
}
|
||||
/// Reads a "bytes wire packet" from the underlying reader.
|
||||
/// The format is the same as in [crate::wire::bytes::read_bytes],
|
||||
/// however this structure provides a [AsyncRead] interface,
|
||||
/// allowing to not having to pass around the entire payload in memory.
|
||||
///
|
||||
/// After being constructed with the underlying reader and an allowed size,
|
||||
/// subsequent requests to poll_read will return payload data until the end
|
||||
/// of the packet is reached.
|
||||
///
|
||||
/// Internally, it will first read over the size packet, filling payload_size,
|
||||
/// ensuring it fits allowed_size, then return payload data.
|
||||
/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
|
||||
/// when all padding has been successfully consumed too.
|
||||
///
|
||||
/// This also means, it's important for a user to always read to the end,
|
||||
/// and not just call read_exact - otherwise it might not skip over the
|
||||
/// padding, and return garbage when reading the next packet.
|
||||
///
|
||||
/// In case of an error due to size constraints, or in case of not reading
|
||||
/// all the way to the end (and getting a EOF), the underlying reader is no
|
||||
/// longer usable and might return garbage.
|
||||
pub struct BytesReader<R> {
|
||||
inner: R,
|
||||
allowed_size: RangeInclusive<u64>,
|
||||
payload_size: [u8; 8],
|
||||
state: BytesPacketPosition,
|
||||
}
|
||||
|
||||
impl<R> BytesReader<R>
|
||||
|
@ -70,10 +64,10 @@ where
|
|||
}
|
||||
/// Returns an error if the passed usize is 0.
|
||||
#[inline]
|
||||
fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> {
|
||||
fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, io::Error> {
|
||||
if bytes_read == 0 {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"underlying reader returned EOF",
|
||||
))
|
||||
} else {
|
||||
|
@ -83,60 +77,60 @@ fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error>
|
|||
|
||||
impl<R> AsyncRead for BytesReader<R>
|
||||
where
|
||||
R: AsyncRead,
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let mut this = self.project();
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context,
|
||||
buf: &mut tokio::io::ReadBuf,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Use a loop, so we can deal with (multiple) state transitions.
|
||||
loop {
|
||||
match *this.state {
|
||||
match this.state {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => {
|
||||
// used in case an invalid size was signalled.
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"signalled package size not in allowed range",
|
||||
))?
|
||||
}
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
// try to read more of the size field.
|
||||
// We wrap a ReadBuf around this.payload_size here, and set_filled.
|
||||
let mut read_buf = tokio::io::ReadBuf::new(this.payload_size);
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut this.payload_size);
|
||||
read_buf.advance(pos);
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
|
||||
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut read_buf))?;
|
||||
|
||||
ensure_nonzero_bytes_read(read_buf.filled().len() - pos)?;
|
||||
|
||||
let total_size_read = read_buf.filled().len();
|
||||
if total_size_read == LEN_SIZE {
|
||||
// If the entire payload size was read, parse it
|
||||
let payload_size = u64::from_le_bytes(*this.payload_size);
|
||||
let payload_size = u64::from_le_bytes(this.payload_size);
|
||||
|
||||
if !this.allowed_size.contains(&payload_size) {
|
||||
// If it's not in the allowed
|
||||
// range, transition to failure mode
|
||||
// `BytesPacketPosition::Size(LEN_SIZE)`, where only
|
||||
// an error is returned.
|
||||
*this.state = BytesPacketPosition::Size(LEN_SIZE)
|
||||
this.state = BytesPacketPosition::Size(LEN_SIZE)
|
||||
} else if payload_size == 0 {
|
||||
// If the payload size is 0, move on to reading padding directly.
|
||||
*this.state = BytesPacketPosition::Padding(0)
|
||||
this.state = BytesPacketPosition::Padding(0)
|
||||
} else {
|
||||
// Else, transition to reading the payload.
|
||||
*this.state = BytesPacketPosition::Payload(0)
|
||||
this.state = BytesPacketPosition::Payload(0)
|
||||
}
|
||||
} else {
|
||||
// If we still need to read more of payload size, update
|
||||
// our position in the state.
|
||||
*this.state = BytesPacketPosition::Size(total_size_read)
|
||||
this.state = BytesPacketPosition::Size(total_size_read)
|
||||
}
|
||||
}
|
||||
BytesPacketPosition::Payload(pos) => {
|
||||
let signalled_size = u64::from_le_bytes(*this.payload_size);
|
||||
let signalled_size = u64::from_le_bytes(this.payload_size);
|
||||
// We don't enter this match arm at all if we're expecting empty payload
|
||||
debug_assert!(signalled_size > 0, "signalled size must be larger than 0");
|
||||
|
||||
|
@ -147,7 +141,7 @@ where
|
|||
// Reducing these two u64 to usize on 32bits is fine - we
|
||||
// only care about not reading too much, not too less.
|
||||
let mut limited_buf = buf.take((signalled_size - pos) as usize);
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut limited_buf))?;
|
||||
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut limited_buf))?;
|
||||
limited_buf.filled().len()
|
||||
})?;
|
||||
|
||||
|
@ -158,11 +152,11 @@ where
|
|||
if pos + bytes_read as u64 == signalled_size {
|
||||
// If we now read all payload, transition to padding
|
||||
// state.
|
||||
*this.state = BytesPacketPosition::Padding(0);
|
||||
this.state = BytesPacketPosition::Padding(0);
|
||||
} else {
|
||||
// if we didn't read everything yet, update our position
|
||||
// in the state.
|
||||
*this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
|
||||
this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
|
||||
}
|
||||
|
||||
// We return from poll_read here.
|
||||
|
@ -181,7 +175,7 @@ where
|
|||
// bytes. Only return `Ready(Ok(()))` once we're past the
|
||||
// padding (or in cases where polling the inner reader
|
||||
// returns `Poll::Pending`).
|
||||
let signalled_size = u64::from_le_bytes(*this.payload_size);
|
||||
let signalled_size = u64::from_le_bytes(this.payload_size);
|
||||
let total_padding_len = padding_len(signalled_size) as usize;
|
||||
|
||||
let padding_len_remaining = total_padding_len - pos;
|
||||
|
@ -192,15 +186,15 @@ where
|
|||
let mut padding_buf = padding_buf.take(padding_len_remaining);
|
||||
|
||||
// read into padding_buf.
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut padding_buf))?;
|
||||
ready!(Pin::new(&mut this.inner).poll_read(cx, &mut padding_buf))?;
|
||||
let bytes_read = ensure_nonzero_bytes_read(padding_buf.filled().len())?;
|
||||
|
||||
*this.state = BytesPacketPosition::Padding(pos + bytes_read);
|
||||
this.state = BytesPacketPosition::Padding(pos + bytes_read);
|
||||
|
||||
// ensure the bytes are not null bytes
|
||||
if !padding_buf.filled().iter().all(|e| *e == b'\0') {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"padding is not all zeroes",
|
||||
))
|
||||
.into();
|
||||
|
|
Loading…
Reference in a new issue