refactor(tvix/nix-compat/wire): mv Bytes{WriterState,PacketPosition}
This is perfectly fine to track the position inside a reader too, so rename it to reflect that. Also make the docstring a bit less write-specific. Change-Id: I831b0a8fe44a2477d4af96fefc692b9aabc378f1 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11382 Reviewed-by: picnoir picnoir <picnoir@alternativebit.fr> Tested-by: BuildkiteCI Autosubmit: flokli <flokli@flokli.de>
This commit is contained in:
parent
35d70f94b7
commit
24fd4e963a
1 changed files with 25 additions and 24 deletions
|
@ -40,21 +40,22 @@ pin_project! {
|
|||
#[pin]
|
||||
inner: W,
|
||||
payload_len: u64,
|
||||
state: BytesWriterState,
|
||||
state: BytesPacketPosition,
|
||||
}
|
||||
}
|
||||
|
||||
/// Models the state [BytesWriter] currently is in.
|
||||
/// It can be in three stages, writing size, payload or padding fields.
|
||||
/// The number tracks the number of bytes written in the current state.
|
||||
/// Models the position inside a "bytes wire packet" that the reader or writer
|
||||
/// is in.
|
||||
/// It can be in three different stages, inside size, payload or padding fields.
|
||||
/// The number tracks the number of bytes written inside the specific field.
|
||||
/// There shall be no ambiguous states, at the end of a stage we immediately
|
||||
/// move to the beginning of the next one:
|
||||
/// - Size(LEN_SIZE) must be expressed as Payload(0)
|
||||
/// - Payload(self.payload_len) must be expressed as Padding(0)
|
||||
///
|
||||
/// Padding(padding_len) means everything that needed to be written was written.
|
||||
/// Padding(padding_len) means we're at the end of the bytes wire packet.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
enum BytesWriterState {
|
||||
pub(crate) enum BytesPacketPosition {
|
||||
Size(usize),
|
||||
Payload(u64),
|
||||
Padding(usize),
|
||||
|
@ -69,7 +70,7 @@ where
|
|||
Self {
|
||||
inner: w,
|
||||
payload_len,
|
||||
state: BytesWriterState::Size(0),
|
||||
state: BytesPacketPosition::Size(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,8 +101,8 @@ where
|
|||
|
||||
loop {
|
||||
match *this.state {
|
||||
BytesWriterState::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesWriterState::Size(pos) => {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
let size_field = &this.payload_len.to_le_bytes();
|
||||
|
||||
let bytes_written = ensure_nonzero_bytes_written(ready!(this
|
||||
|
@ -111,12 +112,12 @@ where
|
|||
|
||||
let new_pos = pos + bytes_written;
|
||||
if new_pos == LEN_SIZE {
|
||||
*this.state = BytesWriterState::Payload(0);
|
||||
*this.state = BytesPacketPosition::Payload(0);
|
||||
} else {
|
||||
*this.state = BytesWriterState::Size(new_pos);
|
||||
*this.state = BytesPacketPosition::Size(new_pos);
|
||||
}
|
||||
}
|
||||
BytesWriterState::Payload(pos) => {
|
||||
BytesPacketPosition::Payload(pos) => {
|
||||
// Ensure we still have space for more payload
|
||||
if pos + (buf.len() as u64) > *this.payload_len {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
|
@ -128,15 +129,15 @@ where
|
|||
ensure_nonzero_bytes_written(bytes_written)?;
|
||||
let new_pos = pos + (bytes_written as u64);
|
||||
if new_pos == *this.payload_len {
|
||||
*this.state = BytesWriterState::Padding(0)
|
||||
*this.state = BytesPacketPosition::Padding(0)
|
||||
} else {
|
||||
*this.state = BytesWriterState::Payload(new_pos)
|
||||
*this.state = BytesPacketPosition::Payload(new_pos)
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(bytes_written));
|
||||
}
|
||||
// If we're already in padding state, there should be no more payload left to write!
|
||||
BytesWriterState::Padding(_pos) => {
|
||||
BytesPacketPosition::Padding(_pos) => {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"tried to write excess bytes",
|
||||
|
@ -154,8 +155,8 @@ where
|
|||
|
||||
loop {
|
||||
match *this.state {
|
||||
BytesWriterState::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesWriterState::Size(pos) => {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
// More bytes to write in the size field
|
||||
let size_field = &this.payload_len.to_le_bytes()[..];
|
||||
let bytes_written = ensure_nonzero_bytes_written(ready!(this
|
||||
|
@ -165,23 +166,23 @@ where
|
|||
let new_pos = pos + bytes_written;
|
||||
if new_pos == LEN_SIZE {
|
||||
// Size field written, now ready to receive payload
|
||||
*this.state = BytesWriterState::Payload(0);
|
||||
*this.state = BytesPacketPosition::Payload(0);
|
||||
} else {
|
||||
*this.state = BytesWriterState::Size(new_pos);
|
||||
*this.state = BytesPacketPosition::Size(new_pos);
|
||||
}
|
||||
}
|
||||
BytesWriterState::Payload(_pos) => {
|
||||
BytesPacketPosition::Payload(_pos) => {
|
||||
// If we're at position 0 and want to write 0 bytes of payload
|
||||
// in total, we can transition to padding.
|
||||
// Otherwise, break, as we're expecting more payload to
|
||||
// be written.
|
||||
if *this.payload_len == 0 {
|
||||
*this.state = BytesWriterState::Padding(0);
|
||||
*this.state = BytesPacketPosition::Padding(0);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
BytesWriterState::Padding(pos) => {
|
||||
BytesPacketPosition::Padding(pos) => {
|
||||
// Write remaining padding, if there is padding to write.
|
||||
let padding_len = super::bytes::padding_len(*this.payload_len) as usize;
|
||||
|
||||
|
@ -190,7 +191,7 @@ where
|
|||
.inner
|
||||
.as_mut()
|
||||
.poll_write(cx, &EMPTY_BYTES[..padding_len]))?)?;
|
||||
*this.state = BytesWriterState::Padding(pos + bytes_written);
|
||||
*this.state = BytesPacketPosition::Padding(pos + bytes_written);
|
||||
} else {
|
||||
// everything written, break
|
||||
break;
|
||||
|
@ -213,7 +214,7 @@ where
|
|||
|
||||
// After a flush, being inside the padding state, and at the end of the padding
|
||||
// is the only way to prevent a dirty shutdown.
|
||||
if let BytesWriterState::Padding(pos) = *this.state {
|
||||
if let BytesPacketPosition::Padding(pos) = *this.state {
|
||||
let padding_len = super::bytes::padding_len(*this.payload_len) as usize;
|
||||
if padding_len == pos {
|
||||
// Shutdown the underlying writer
|
||||
|
|
Loading…
Reference in a new issue