refactor(tvix/nix-compat): move worker_protocol into nix_daemon mod

This doesn't have much to do with the plain "wire" format, it's merely
one user of it.

Also, use the more "public" `wire::` API to read/write bytes, strings,
bools and u64s.

Change-Id: I98dddcc3004dfde7a0c009958fe84a840f77b188
Reviewed-on: https://cl.tvl.fyi/c/depot/+/11390
Tested-by: BuildkiteCI
Autosubmit: flokli <flokli@flokli.de>
Reviewed-by: raitobezarius <tvl@lahfa.xyz>
Reviewed-by: Brian Olsen <me@griff.name>
This commit is contained in:
Florian Klink 2024-04-10 15:43:15 +03:00 committed by flokli
parent 36b296609b
commit 742937d55c
5 changed files with 34 additions and 34 deletions

View file

@ -9,3 +9,8 @@ pub mod store_path;
#[cfg(feature = "wire")] #[cfg(feature = "wire")]
pub mod wire; pub mod wire;
#[cfg(feature = "wire")]
mod nix_daemon;
#[cfg(feature = "wire")]
pub use nix_daemon::worker_protocol;

View file

@ -0,0 +1 @@
pub mod worker_protocol;

View file

@ -7,9 +7,7 @@ use enum_primitive_derive::Primitive;
use num_traits::{FromPrimitive, ToPrimitive}; use num_traits::{FromPrimitive, ToPrimitive};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::wire::{bytes, primitive}; use crate::wire;
use super::bytes::read_string;
static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc" static WORKER_MAGIC_1: u64 = 0x6e697863; // "nixc"
static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio" static WORKER_MAGIC_2: u64 = 0x6478696f; // "dxio"
@ -131,30 +129,30 @@ pub async fn read_client_settings<R: AsyncReadExt + Unpin>(
r: &mut R, r: &mut R,
client_version: u64, client_version: u64,
) -> std::io::Result<ClientSettings> { ) -> std::io::Result<ClientSettings> {
let keep_failed = primitive::read_bool(r).await?; let keep_failed = wire::read_bool(r).await?;
let keep_going = primitive::read_bool(r).await?; let keep_going = wire::read_bool(r).await?;
let try_fallback = primitive::read_bool(r).await?; let try_fallback = wire::read_bool(r).await?;
let verbosity_uint = primitive::read_u64(r).await?; let verbosity_uint = wire::read_u64(r).await?;
let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| { let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| {
Error::new( Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("Can't convert integer {} to verbosity", verbosity_uint), format!("Can't convert integer {} to verbosity", verbosity_uint),
) )
})?; })?;
let max_build_jobs = primitive::read_u64(r).await?; let max_build_jobs = wire::read_u64(r).await?;
let max_silent_time = primitive::read_u64(r).await?; let max_silent_time = wire::read_u64(r).await?;
_ = primitive::read_u64(r).await?; // obsolete useBuildHook _ = wire::read_u64(r).await?; // obsolete useBuildHook
let verbose_build = primitive::read_bool(r).await?; let verbose_build = wire::read_bool(r).await?;
_ = primitive::read_u64(r).await?; // obsolete logType _ = wire::read_u64(r).await?; // obsolete logType
_ = primitive::read_u64(r).await?; // obsolete printBuildTrace _ = wire::read_u64(r).await?; // obsolete printBuildTrace
let build_cores = primitive::read_u64(r).await?; let build_cores = wire::read_u64(r).await?;
let use_substitutes = primitive::read_bool(r).await?; let use_substitutes = wire::read_bool(r).await?;
let mut overrides = HashMap::new(); let mut overrides = HashMap::new();
if client_version >= 12 { if client_version >= 12 {
let num_overrides = primitive::read_u64(r).await?; let num_overrides = wire::read_u64(r).await?;
for _ in 0..num_overrides { for _ in 0..num_overrides {
let name = read_string(r, 0..MAX_SETTING_SIZE).await?; let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
let value = read_string(r, 0..MAX_SETTING_SIZE).await?; let value = wire::read_string(r, 0..MAX_SETTING_SIZE).await?;
overrides.insert(name, value); overrides.insert(name, value);
} }
} }
@ -197,17 +195,17 @@ pub async fn server_handshake_client<'a, RW: 'a>(
where where
&'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin, &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin,
{ {
let worker_magic_1 = primitive::read_u64(&mut conn).await?; let worker_magic_1 = wire::read_u64(&mut conn).await?;
if worker_magic_1 != WORKER_MAGIC_1 { if worker_magic_1 != WORKER_MAGIC_1 {
Err(std::io::Error::new( Err(std::io::Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("Incorrect worker magic number received: {}", worker_magic_1), format!("Incorrect worker magic number received: {}", worker_magic_1),
)) ))
} else { } else {
primitive::write_u64(&mut conn, WORKER_MAGIC_2).await?; wire::write_u64(&mut conn, WORKER_MAGIC_2).await?;
conn.write_all(&PROTOCOL_VERSION).await?; conn.write_all(&PROTOCOL_VERSION).await?;
conn.flush().await?; conn.flush().await?;
let client_version = primitive::read_u64(&mut conn).await?; let client_version = wire::read_u64(&mut conn).await?;
if client_version < 0x10a { if client_version < 0x10a {
return Err(Error::new( return Err(Error::new(
ErrorKind::Unsupported, ErrorKind::Unsupported,
@ -218,20 +216,20 @@ where
let _protocol_major = client_version & 0xff00; let _protocol_major = client_version & 0xff00;
if protocol_minor >= 14 { if protocol_minor >= 14 {
// Obsolete CPU affinity. // Obsolete CPU affinity.
let read_affinity = primitive::read_u64(&mut conn).await?; let read_affinity = wire::read_u64(&mut conn).await?;
if read_affinity != 0 { if read_affinity != 0 {
let _cpu_affinity = primitive::read_u64(&mut conn).await?; let _cpu_affinity = wire::read_u64(&mut conn).await?;
}; };
} }
if protocol_minor >= 11 { if protocol_minor >= 11 {
// Obsolete reserveSpace // Obsolete reserveSpace
let _reserve_space = primitive::read_u64(&mut conn).await?; let _reserve_space = wire::read_u64(&mut conn).await?;
} }
if protocol_minor >= 33 { if protocol_minor >= 33 {
// Nix version. We're plain lying, we're not Nix, but eh… // Nix version. We're plain lying, we're not Nix, but eh…
// Setting it to the 2.3 lineage. Not 100% sure this is a // Setting it to the 2.3 lineage. Not 100% sure this is a
// good idea. // good idea.
bytes::write_bytes(&mut conn, nix_version).await?; wire::write_bytes(&mut conn, nix_version).await?;
conn.flush().await?; conn.flush().await?;
} }
if protocol_minor >= 35 { if protocol_minor >= 35 {
@ -243,7 +241,7 @@ where
/// Read a worker [Operation] from the wire. /// Read a worker [Operation] from the wire.
pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> { pub async fn read_op<R: AsyncReadExt + Unpin>(r: &mut R) -> std::io::Result<Operation> {
let op_number = primitive::read_u64(r).await?; let op_number = wire::read_u64(r).await?;
Operation::from_u64(op_number).ok_or(Error::new( Operation::from_u64(op_number).ok_or(Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("Invalid OP number {}", op_number), format!("Invalid OP number {}", op_number),
@ -276,8 +274,8 @@ where
W: AsyncReadExt + AsyncWriteExt + Unpin, W: AsyncReadExt + AsyncWriteExt + Unpin,
{ {
match t { match t {
Trust::Trusted => primitive::write_u64(conn, 1).await, Trust::Trusted => wire::write_u64(conn, 1).await,
Trust::NotTrusted => primitive::write_u64(conn, 2).await, Trust::NotTrusted => wire::write_u64(conn, 2).await,
} }
} }

View file

@ -6,5 +6,3 @@ pub use bytes::*;
mod primitive; mod primitive;
pub use primitive::*; pub use primitive::*;
pub mod worker_protocol;

View file

@ -3,10 +3,8 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_listener::{self, SystemOptions, UserOptions}; use tokio_listener::{self, SystemOptions, UserOptions};
use tracing::{debug, error, info, instrument, Level}; use tracing::{debug, error, info, instrument, Level};
use nix_compat::wire::{ use nix_compat::wire;
self, use nix_compat::worker_protocol::{self, server_handshake_client, ClientSettings, Trust};
worker_protocol::{self, server_handshake_client, ClientSettings, Trust},
};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
struct Cli { struct Cli {