refactor(tvix/nix-compat/wire): mv Bytes{WriterState,PacketPosition}

This is perfectly fine to track the position inside a reader too, so
rename it to reflect that.
Also make the docstring a bit less write-specific.

Change-Id: I831b0a8fe44a2477d4af96fefc692b9aabc378f1
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11382
Reviewed-by: picnoir picnoir <picnoir@alternativebit.fr>
Tested-by: BuildkiteCI
Autosubmit: flokli <flokli@flokli.de>
This commit is contained in:
Florian Klink 2024-04-08 17:43:12 +03:00 committed by clbot
parent 35d70f94b7
commit 24fd4e963a

View file

@ -40,21 +40,22 @@ pin_project! {
#[pin]
inner: W,
payload_len: u64,
state: BytesWriterState,
state: BytesPacketPosition,
}
}
/// Models the state [BytesWriter] currently is in.
/// It can be in three stages, writing size, payload or padding fields.
/// The number tracks the number of bytes written in the current state.
/// Models the position inside a "bytes wire packet" that the reader or writer
/// is in.
/// It can be in three different stages, inside size, payload or padding fields.
/// The number tracks the number of bytes written inside the specific field.
/// There shall be no ambiguous states, at the end of a stage we immediately
/// move to the beginning of the next one:
/// - Size(LEN_SIZE) must be expressed as Payload(0)
/// - Payload(self.payload_len) must be expressed as Padding(0)
///
/// Padding(padding_len) means everything that needed to be written was written.
/// Padding(padding_len) means we're at the end of the bytes wire packet.
#[derive(Clone, Debug, PartialEq, Eq)]
enum BytesWriterState {
pub(crate) enum BytesPacketPosition {
Size(usize),
Payload(u64),
Padding(usize),
@ -69,7 +70,7 @@ where
Self {
inner: w,
payload_len,
state: BytesWriterState::Size(0),
state: BytesPacketPosition::Size(0),
}
}
}
@ -100,8 +101,8 @@ where
loop {
match *this.state {
BytesWriterState::Size(LEN_SIZE) => unreachable!(),
BytesWriterState::Size(pos) => {
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
BytesPacketPosition::Size(pos) => {
let size_field = &this.payload_len.to_le_bytes();
let bytes_written = ensure_nonzero_bytes_written(ready!(this
@ -111,12 +112,12 @@ where
let new_pos = pos + bytes_written;
if new_pos == LEN_SIZE {
*this.state = BytesWriterState::Payload(0);
*this.state = BytesPacketPosition::Payload(0);
} else {
*this.state = BytesWriterState::Size(new_pos);
*this.state = BytesPacketPosition::Size(new_pos);
}
}
BytesWriterState::Payload(pos) => {
BytesPacketPosition::Payload(pos) => {
// Ensure we still have space for more payload
if pos + (buf.len() as u64) > *this.payload_len {
return Poll::Ready(Err(std::io::Error::new(
@ -128,15 +129,15 @@ where
ensure_nonzero_bytes_written(bytes_written)?;
let new_pos = pos + (bytes_written as u64);
if new_pos == *this.payload_len {
*this.state = BytesWriterState::Padding(0)
*this.state = BytesPacketPosition::Padding(0)
} else {
*this.state = BytesWriterState::Payload(new_pos)
*this.state = BytesPacketPosition::Payload(new_pos)
}
return Poll::Ready(Ok(bytes_written));
}
// If we're already in padding state, there should be no more payload left to write!
BytesWriterState::Padding(_pos) => {
BytesPacketPosition::Padding(_pos) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"tried to write excess bytes",
@ -154,8 +155,8 @@ where
loop {
match *this.state {
BytesWriterState::Size(LEN_SIZE) => unreachable!(),
BytesWriterState::Size(pos) => {
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
BytesPacketPosition::Size(pos) => {
// More bytes to write in the size field
let size_field = &this.payload_len.to_le_bytes()[..];
let bytes_written = ensure_nonzero_bytes_written(ready!(this
@ -165,23 +166,23 @@ where
let new_pos = pos + bytes_written;
if new_pos == LEN_SIZE {
// Size field written, now ready to receive payload
*this.state = BytesWriterState::Payload(0);
*this.state = BytesPacketPosition::Payload(0);
} else {
*this.state = BytesWriterState::Size(new_pos);
*this.state = BytesPacketPosition::Size(new_pos);
}
}
BytesWriterState::Payload(_pos) => {
BytesPacketPosition::Payload(_pos) => {
// If we're at position 0 and want to write 0 bytes of payload
// in total, we can transition to padding.
// Otherwise, break, as we're expecting more payload to
// be written.
if *this.payload_len == 0 {
*this.state = BytesWriterState::Padding(0);
*this.state = BytesPacketPosition::Padding(0);
} else {
break;
}
}
BytesWriterState::Padding(pos) => {
BytesPacketPosition::Padding(pos) => {
// Write remaining padding, if there is padding to write.
let padding_len = super::bytes::padding_len(*this.payload_len) as usize;
@ -190,7 +191,7 @@ where
.inner
.as_mut()
.poll_write(cx, &EMPTY_BYTES[..padding_len]))?)?;
*this.state = BytesWriterState::Padding(pos + bytes_written);
*this.state = BytesPacketPosition::Padding(pos + bytes_written);
} else {
// everything written, break
break;
@ -213,7 +214,7 @@ where
// After a flush, being inside the padding state, and at the end of the padding
// is the only way to prevent a dirty shutdown.
if let BytesWriterState::Padding(pos) = *this.state {
if let BytesPacketPosition::Padding(pos) = *this.state {
let padding_len = super::bytes::padding_len(*this.payload_len) as usize;
if padding_len == pos {
// Shutdown the underlying writer