chore(tvix/nix-daemon): Implement framed protocol
When sending nars over the wire to the nix-daemon, nix protocol versions >= 1.23 use this framing protocol. This change implements an AsyncRead for this protocol, to be used in AddToStoreNar and any other operations when necessary. Change-Id: I5f7972fe1c9ea145780bf449321bd3efeb833d18 Reviewed-on: https://cl.tvl.fyi/c/depot/+/12814 Tested-by: BuildkiteCI Reviewed-by: flokli <flokli@flokli.de>
This commit is contained in:
parent
db13b6c092
commit
654cc3e43a
2 changed files with 191 additions and 0 deletions
189
tvix/nix-compat/src/nix_daemon/framing/framed_read.rs
Normal file
189
tvix/nix-compat/src/nix_daemon/framing/framed_read.rs
Normal file
|
@ -0,0 +1,189 @@
|
|||
use std::{
|
||||
io::Result,
|
||||
pin::Pin,
|
||||
task::{ready, Poll},
|
||||
};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
/// State machine for [`NixFramedReader`].
|
||||
///
|
||||
/// As the reader progresses it linearly cycles through the states.
|
||||
#[derive(Debug)]
|
||||
enum NixFramedReaderState {
|
||||
/// The reader always starts in this state.
|
||||
///
|
||||
/// Before the payload, the client first sends its size.
|
||||
/// The size is a u64 which is 8 bytes long, while it's likely that we will receive
|
||||
/// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
|
||||
/// So in this state we read up to 8 bytes and transition to
|
||||
/// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
|
||||
/// otherwise we reset filled to 0, and read the next size value.
|
||||
ReadingSize { buf: [u8; 8], filled: usize },
|
||||
/// This is where we read the actual payload that is sent to us.
|
||||
///
|
||||
/// Once we've read the expected number of bytes, we go back to the
|
||||
/// [`NixFramedReaderState::ReadingSize`] state.
|
||||
ReadingPayload {
|
||||
/// Represents the remaining number of bytes we expect to read based on the value
|
||||
/// read in the previous state.
|
||||
remaining: u64,
|
||||
},
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
|
||||
///
|
||||
/// See serialization.md#framed and [`NixFramedReaderState`] for details.
|
||||
pub struct NixFramedReader<R> {
|
||||
#[pin]
|
||||
reader: R,
|
||||
state: NixFramedReaderState,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> NixFramedReader<R> {
|
||||
pub fn new(reader: R) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
state: NixFramedReaderState::ReadingSize {
|
||||
buf: [0; 8],
|
||||
filled: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
read_buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
let mut this = self.as_mut().project();
|
||||
match this.state {
|
||||
NixFramedReaderState::ReadingSize { buf, filled } => {
|
||||
if *filled < buf.len() {
|
||||
let mut size_buf = ReadBuf::new(buf);
|
||||
size_buf.advance(*filled);
|
||||
|
||||
ready!(this.reader.poll_read(cx, &mut size_buf))?;
|
||||
let bytes_read = size_buf.filled().len() - *filled;
|
||||
if bytes_read == 0 {
|
||||
// oef
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
*filled += bytes_read;
|
||||
// Schedule ourselves to run again.
|
||||
return self.poll_read(cx, read_buf);
|
||||
}
|
||||
let size = u64::from_le_bytes(*buf);
|
||||
if size == 0 {
|
||||
// eof
|
||||
*filled = 0;
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
*this.state = NixFramedReaderState::ReadingPayload { remaining: size };
|
||||
self.poll_read(cx, read_buf)
|
||||
}
|
||||
NixFramedReaderState::ReadingPayload { remaining } => {
|
||||
// Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
|
||||
let safe_remaining = if *remaining <= usize::MAX as u64 {
|
||||
*remaining as usize
|
||||
} else {
|
||||
usize::MAX
|
||||
};
|
||||
if safe_remaining > 0 {
|
||||
// The buffer is no larger than the amount of data that we expect.
|
||||
// Otherwise we will trim the buffer below and come back here.
|
||||
if read_buf.remaining() <= safe_remaining {
|
||||
let filled_before = read_buf.filled().len();
|
||||
|
||||
ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
|
||||
let bytes_read = read_buf.filled().len() - filled_before;
|
||||
|
||||
*remaining -= bytes_read as u64;
|
||||
if *remaining == 0 {
|
||||
*this.state = NixFramedReaderState::ReadingSize {
|
||||
buf: [0; 8],
|
||||
filled: 0,
|
||||
};
|
||||
}
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
// Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
|
||||
// internal bookkeeping simpler.
|
||||
let mut smaller_buf = read_buf.take(safe_remaining);
|
||||
ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
|
||||
|
||||
let bytes_read = smaller_buf.filled().len();
|
||||
|
||||
// SAFETY: we just read this number of bytes into read_buf's backing slice above.
|
||||
unsafe { read_buf.assume_init(bytes_read) };
|
||||
read_buf.advance(bytes_read);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
*this.state = NixFramedReaderState::ReadingSize {
|
||||
buf: [0; 8],
|
||||
filled: 0,
|
||||
};
|
||||
self.poll_read(cx, read_buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod nix_framed_tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_test::io::Builder;
|
||||
|
||||
use crate::nix_daemon::framing::NixFramedReader;
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_hello_world_in_two_frames() {
|
||||
let mut mock = Builder::new()
|
||||
// The client sends len
|
||||
.read(&5u64.to_le_bytes())
|
||||
// Immediately followed by the bytes
|
||||
.read("hello".as_bytes())
|
||||
.wait(Duration::ZERO)
|
||||
// Send more data separately
|
||||
.read(&6u64.to_le_bytes())
|
||||
.read(" world".as_bytes())
|
||||
.build();
|
||||
|
||||
let mut reader = NixFramedReader::new(&mut mock);
|
||||
let mut result = String::new();
|
||||
reader
|
||||
.read_to_string(&mut result)
|
||||
.await
|
||||
.expect("Could not read into result");
|
||||
assert_eq!("hello world", result);
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
|
||||
let mut mock = Builder::new()
|
||||
// The client sends len
|
||||
.read(&5u64.to_le_bytes())
|
||||
// Immediately followed by the bytes
|
||||
.read("hello".as_bytes())
|
||||
.wait(Duration::ZERO)
|
||||
// Send more data separately
|
||||
.read(&6u64.to_le_bytes())
|
||||
.read(" world".as_bytes())
|
||||
.read(&0u64.to_le_bytes())
|
||||
.build();
|
||||
|
||||
let mut reader = NixFramedReader::new(&mut mock);
|
||||
let mut result = String::new();
|
||||
reader
|
||||
.read_to_string(&mut result)
|
||||
.await
|
||||
.expect("Could not read into result");
|
||||
assert_eq!("hello world", result);
|
||||
}
|
||||
}
|
|
@ -1,2 +1,4 @@
|
|||
mod framed_read;
|
||||
pub use framed_read::NixFramedReader;
|
||||
mod stderr_read;
|
||||
pub use stderr_read::StderrReadFramedReader;
|
||||
|
|
Loading…
Reference in a new issue