refactor(nix-compat/wire): drop primitive functions
These may as well be inlined, and hardly need tests, since they just alias AsyncReadExt::read_u64_le / AsyncWriteExt::write_u64_le. Boolean reading is worth making explicit, since callers may differ on how they want to handle values other than 0 and 1. Boolean writing simplifies to `.write_u64_le(x as u64)`, which is also fine to inline. Change-Id: Ief9722fe886688693feb924ff0306b5bc68dd7a2 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11549 Reviewed-by: flokli <flokli@flokli.de> Tested-by: BuildkiteCI
This commit is contained in:
parent
b3305ea6e2
commit
095f715a80
6 changed files with 33 additions and 112 deletions
|
@ -131,27 +131,27 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
|
||||||
r: &mut R,
|
r: &mut R,
|
||||||
client_version: ProtocolVersion,
|
client_version: ProtocolVersion,
|
||||||
) -> std::io::Result<ClientSettings> {
|
) -> std::io::Result<ClientSettings> {
|
||||||
let keep_failed = wire::read_bool(r).await?;
|
let keep_failed = r.read_u64_le().await? != 0;
|
||||||
let keep_going = wire::read_bool(r).await?;
|
let keep_going = r.read_u64_le().await? != 0;
|
||||||
let try_fallback = wire::read_bool(r).await?;
|
let try_fallback = r.read_u64_le().await? != 0;
|
||||||
let verbosity_uint = wire::read_u64(r).await?;
|
let verbosity_uint = r.read_u64_le().await?;
|
||||||
let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| {
|
let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| {
|
||||||
Error::new(
|
Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("Can't convert integer {} to verbosity", verbosity_uint),
|
format!("Can't convert integer {} to verbosity", verbosity_uint),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
let max_build_jobs = wire::read_u64(r).await?;
|
let max_build_jobs = r.read_u64_le().await?;
|
||||||
let max_silent_time = wire::read_u64(r).await?;
|
let max_silent_time = r.read_u64_le().await?;
|
||||||
_ = wire::read_u64(r).await?; // obsolete useBuildHook
|
_ = r.read_u64_le().await?; // obsolete useBuildHook
|
||||||
let verbose_build = wire::read_bool(r).await?;
|
let verbose_build = r.read_u64_le().await? != 0;
|
||||||
_ = wire::read_u64(r).await?; // obsolete logType
|
_ = r.read_u64_le().await?; // obsolete logType
|
||||||
_ = wire::read_u64(r).await?; // obsolete printBuildTrace
|
_ = r.read_u64_le().await?; // obsolete printBuildTrace
|
||||||
let build_cores = wire::read_u64(r).await?;
|
let build_cores = r.read_u64_le().await?;
|
||||||
let use_substitutes = wire::read_bool(r).await?;
|
let use_substitutes = r.read_u64_le().await? != 0;
|
||||||
let mut overrides = HashMap::new();
|
let mut overrides = HashMap::new();
|
||||||
if client_version.minor() >= 12 {
|
if client_version.minor() >= 12 {
|
||||||
let num_overrides = wire::read_u64(r).await?;
|
let num_overrides = r.read_u64_le().await?;
|
||||||
for _ in 0..num_overrides {
|
for _ in 0..num_overrides {
|
||||||
let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
|
let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
|
||||||
let value = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
|
let value = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
|
||||||
|
@ -197,17 +197,17 @@ pub async fn server_handshake_client<'a, RW: 'a>(
|
||||||
where
|
where
|
||||||
&'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
|
&'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||||
{
|
{
|
||||||
let worker_magic_1 = wire::read_u64(&mut conn).await?;
|
let worker_magic_1 = conn.read_u64_le().await?;
|
||||||
if worker_magic_1 != WORKER_MAGIC_1 {
|
if worker_magic_1 != WORKER_MAGIC_1 {
|
||||||
Err(std::io::Error::new(
|
Err(std::io::Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("Incorrect worker magic number received: {}", worker_magic_1),
|
format!("Incorrect worker magic number received: {}", worker_magic_1),
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
wire::write_u64(&mut conn, WORKER_MAGIC_2).await?;
|
conn.write_u64_le(WORKER_MAGIC_2).await?;
|
||||||
wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?;
|
conn.write_u64_le(PROTOCOL_VERSION.into()).await?;
|
||||||
conn.flush().await?;
|
conn.flush().await?;
|
||||||
let client_version = wire::read_u64(&mut conn).await?;
|
let client_version = conn.read_u64_le().await?;
|
||||||
// Parse into ProtocolVersion.
|
// Parse into ProtocolVersion.
|
||||||
let client_version: ProtocolVersion = client_version
|
let client_version: ProtocolVersion = client_version
|
||||||
.try_into()
|
.try_into()
|
||||||
|
@ -220,14 +220,14 @@ where
|
||||||
}
|
}
|
||||||
if client_version.minor() >= 14 {
|
if client_version.minor() >= 14 {
|
||||||
// Obsolete CPU affinity.
|
// Obsolete CPU affinity.
|
||||||
let read_affinity = wire::read_u64(&mut conn).await?;
|
let read_affinity = conn.read_u64_le().await?;
|
||||||
if read_affinity != 0 {
|
if read_affinity != 0 {
|
||||||
let _cpu_affinity = wire::read_u64(&mut conn).await?;
|
let _cpu_affinity = conn.read_u64_le().await?;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if client_version.minor() >= 11 {
|
if client_version.minor() >= 11 {
|
||||||
// Obsolete reserveSpace
|
// Obsolete reserveSpace
|
||||||
let _reserve_space = wire::read_u64(&mut conn).await?;
|
let _reserve_space = conn.read_u64_le().await?;
|
||||||
}
|
}
|
||||||
if client_version.minor() >= 33 {
|
if client_version.minor() >= 33 {
|
||||||
// Nix version. We're plain lying, we're not Nix, but eh…
|
// Nix version. We're plain lying, we're not Nix, but eh…
|
||||||
|
@ -245,7 +245,7 @@ where
|
||||||
|
|
||||||
/// Read a worker [Operation] from the wire.
|
/// Read a worker [Operation] from the wire.
|
||||||
pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
|
pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
|
||||||
let op_number = wire::read_u64(r).await?;
|
let op_number = r.read_u64_le().await?;
|
||||||
Operation::from_u64(op_number).ok_or(Error::new(
|
Operation::from_u64(op_number).ok_or(Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("Invalid OP number {}", op_number),
|
format!("Invalid OP number {}", op_number),
|
||||||
|
@ -278,8 +278,8 @@ where
|
||||||
W: AsyncReadExt + AsyncWriteExt + Unpin,
|
W: AsyncReadExt + AsyncWriteExt + Unpin,
|
||||||
{
|
{
|
||||||
match t {
|
match t {
|
||||||
Trust::Trusted => wire::write_u64(conn, 1).await,
|
Trust::Trusted => conn.write_u64_le(1).await,
|
||||||
Trust::NotTrusted => wire::write_u64(conn, 2).await,
|
Trust::NotTrusted => conn.write_u64_le(2).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,6 @@ pub use reader::BytesReader;
|
||||||
mod writer;
|
mod writer;
|
||||||
pub use writer::BytesWriter;
|
pub use writer::BytesWriter;
|
||||||
|
|
||||||
use super::primitive;
|
|
||||||
|
|
||||||
/// 8 null bytes, used to write out padding.
|
/// 8 null bytes, used to write out padding.
|
||||||
const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
|
const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
|
||||||
|
|
||||||
|
@ -41,7 +39,7 @@ where
|
||||||
S: RangeBounds<u64>,
|
S: RangeBounds<u64>,
|
||||||
{
|
{
|
||||||
// read the length field
|
// read the length field
|
||||||
let len = primitive::read_u64(r).await?;
|
let len = r.read_u64_le().await?;
|
||||||
|
|
||||||
if !allowed_size.contains(&len) {
|
if !allowed_size.contains(&len) {
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
|
@ -52,7 +50,7 @@ where
|
||||||
|
|
||||||
// calculate the total length, including padding.
|
// calculate the total length, including padding.
|
||||||
// byte packets are padded to 8 byte blocks each.
|
// byte packets are padded to 8 byte blocks each.
|
||||||
let padded_len = padding_len(len) as u64 + (len as u64);
|
let padded_len = padding_len(len) as u64 + len;
|
||||||
let mut limited_reader = r.take(padded_len);
|
let mut limited_reader = r.take(padded_len);
|
||||||
|
|
||||||
let mut buf = Vec::new();
|
let mut buf = Vec::new();
|
||||||
|
@ -105,7 +103,7 @@ pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
|
||||||
b: B,
|
b: B,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
// write the size packet.
|
// write the size packet.
|
||||||
primitive::write_u64(w, b.as_ref().len() as u64).await?;
|
w.write_u64_le(b.as_ref().len() as u64).await?;
|
||||||
|
|
||||||
// write the payload
|
// write the payload
|
||||||
w.write_all(b.as_ref()).await?;
|
w.write_all(b.as_ref()).await?;
|
||||||
|
|
|
@ -5,9 +5,7 @@ use std::{
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
task::{self, ready, Poll},
|
task::{self, ready, Poll},
|
||||||
};
|
};
|
||||||
use tokio::io::{AsyncRead, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
|
||||||
|
|
||||||
use crate::wire::read_u64;
|
|
||||||
|
|
||||||
use trailer::{read_trailer, ReadTrailer, Trailer};
|
use trailer::{read_trailer, ReadTrailer, Trailer};
|
||||||
mod trailer;
|
mod trailer;
|
||||||
|
@ -52,7 +50,7 @@ where
|
||||||
{
|
{
|
||||||
/// Constructs a new BytesReader, using the underlying passed reader.
|
/// Constructs a new BytesReader, using the underlying passed reader.
|
||||||
pub async fn new<S: RangeBounds<u64>>(mut reader: R, allowed_size: S) -> io::Result<Self> {
|
pub async fn new<S: RangeBounds<u64>>(mut reader: R, allowed_size: S) -> io::Result<Self> {
|
||||||
let size = read_u64(&mut reader).await?;
|
let size = reader.read_u64_le().await?;
|
||||||
|
|
||||||
if !allowed_size.contains(&size) {
|
if !allowed_size.contains(&size) {
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"));
|
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"));
|
||||||
|
|
|
@ -3,6 +3,3 @@
|
||||||
|
|
||||||
mod bytes;
|
mod bytes;
|
||||||
pub use bytes::*;
|
pub use bytes::*;
|
||||||
|
|
||||||
mod primitive;
|
|
||||||
pub use primitive::*;
|
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
// SPDX-FileCopyrightText: 2023 embr <git@liclac.eu>
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: EUPL-1.2
|
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
/// Read a u64 from the AsyncRead (little endian).
|
|
||||||
pub async fn read_u64<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<u64> {
|
|
||||||
r.read_u64_le().await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Write a u64 to the AsyncWrite (little endian).
|
|
||||||
pub async fn write_u64<W: AsyncWrite + Unpin>(w: &mut W, v: u64) -> std::io::Result<()> {
|
|
||||||
w.write_u64_le(v).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
/// Read a boolean from the AsyncRead, encoded as u64 (>0 is true).
|
|
||||||
pub async fn read_bool<R: AsyncRead + Unpin>(r: &mut R) -> std::io::Result<bool> {
|
|
||||||
Ok(read_u64(r).await? > 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
/// Write a boolean to the AsyncWrite, encoded as u64 (>0 is true).
|
|
||||||
pub async fn write_bool<W: AsyncWrite + Unpin>(w: &mut W, v: bool) -> std::io::Result<()> {
|
|
||||||
write_u64(w, if v { 1u64 } else { 0u64 }).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use tokio_test::io::Builder;
|
|
||||||
|
|
||||||
// Integers.
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_read_u64() {
|
|
||||||
let mut mock = Builder::new().read(&1234567890u64.to_le_bytes()).build();
|
|
||||||
assert_eq!(1234567890u64, read_u64(&mut mock).await.unwrap());
|
|
||||||
}
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_write_u64() {
|
|
||||||
let mut mock = Builder::new().write(&1234567890u64.to_le_bytes()).build();
|
|
||||||
write_u64(&mut mock, 1234567890).await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Booleans.
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_read_bool_0() {
|
|
||||||
let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
|
|
||||||
assert!(!read_bool(&mut mock).await.unwrap());
|
|
||||||
}
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_read_bool_1() {
|
|
||||||
let mut mock = Builder::new().read(&1u64.to_le_bytes()).build();
|
|
||||||
assert!(read_bool(&mut mock).await.unwrap());
|
|
||||||
}
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_read_bool_2() {
|
|
||||||
let mut mock = Builder::new().read(&2u64.to_le_bytes()).build();
|
|
||||||
assert!(read_bool(&mut mock).await.unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_write_bool_false() {
|
|
||||||
let mut mock = Builder::new().write(&0u64.to_le_bytes()).build();
|
|
||||||
write_bool(&mut mock, false).await.unwrap();
|
|
||||||
}
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_write_bool_true() {
|
|
||||||
let mut mock = Builder::new().write(&1u64.to_le_bytes()).build();
|
|
||||||
write_bool(&mut mock, true).await.unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -4,7 +4,7 @@ use tokio_listener::{self, SystemOptions, UserOptions};
|
||||||
use tracing::{debug, error, info, instrument, Level};
|
use tracing::{debug, error, info, instrument, Level};
|
||||||
|
|
||||||
use nix_compat::worker_protocol::{self, server_handshake_client, ClientSettings, Trust};
|
use nix_compat::worker_protocol::{self, server_handshake_client, ClientSettings, Trust};
|
||||||
use nix_compat::{wire, ProtocolVersion};
|
use nix_compat::ProtocolVersion;
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
|
@ -78,7 +78,9 @@ where
|
||||||
// TODO: implement logging. For now, we'll just send
|
// TODO: implement logging. For now, we'll just send
|
||||||
// STDERR_LAST, which is good enough to get Nix respond to
|
// STDERR_LAST, which is good enough to get Nix respond to
|
||||||
// us.
|
// us.
|
||||||
wire::write_u64(&mut client_connection.conn, worker_protocol::STDERR_LAST)
|
client_connection
|
||||||
|
.conn
|
||||||
|
.write_u64_le(worker_protocol::STDERR_LAST)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
loop {
|
loop {
|
||||||
|
@ -109,6 +111,6 @@ where
|
||||||
let settings = worker_protocol::read_client_settings(&mut conn.conn, conn.version).await?;
|
let settings = worker_protocol::read_client_settings(&mut conn.conn, conn.version).await?;
|
||||||
// The client expects us to send some logs when we're processing
|
// The client expects us to send some logs when we're processing
|
||||||
// the settings. Sending STDERR_LAST signal we're done processing.
|
// the settings. Sending STDERR_LAST signal we're done processing.
|
||||||
wire::write_u64(&mut conn.conn, worker_protocol::STDERR_LAST).await?;
|
conn.conn.write_u64_le(worker_protocol::STDERR_LAST).await?;
|
||||||
Ok(settings)
|
Ok(settings)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue