diff --git a/Cargo.lock b/Cargo.lock index 49936d5..e20f5c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,6 +171,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34d21f9bf1b425d2968943631ec91202fe5e837264063503708b83013f8fc938" dependencies = [ "clap_builder", + "clap_derive", + "once_cell", ] [[package]] @@ -195,6 +197,18 @@ dependencies = [ "clap", ] +[[package]] +name = "clap_derive" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "clap_lex" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index d358f2e..7699ca2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ edition = "2021" async-stream = "0.3.5" async-trait = "0.1.68" atty = "0.2" -clap = "4.2.7" +clap = { version = "4.2.7", features = ["derive"] } clap_complete = "4.2.3" clicolors-control = "1" console = "0.15.5" diff --git a/src/command/apply.rs b/src/command/apply.rs index 626022a..ba58e2b 100644 --- a/src/command/apply.rs +++ b/src/command/apply.rs @@ -4,16 +4,16 @@ use std::str::FromStr; use clap::{ builder::{ArgPredicate, PossibleValuesParser, ValueParser}, - value_parser, Arg, ArgMatches, Command as ClapCommand, + value_parser, Arg, ArgMatches, Command as ClapCommand, FromArgMatches, }; -use crate::error::ColmenaError; use crate::nix::deployment::{ Deployment, EvaluationNodeLimit, EvaluatorType, Goal, Options, ParallelismLimit, }; use crate::nix::NodeFilter; use crate::progress::SimpleProgressOutput; use crate::util; +use crate::{error::ColmenaError, nix::hive::HiveArgs}; pub fn register_deploy_args(command: ClapCommand) -> ClapCommand { command @@ -160,7 +160,11 @@ Same as the targets for switch-to-configuration, with the following extra pseudo } pub async fn run(_global_args: &ArgMatches, local_args: &ArgMatches) -> Result<(), ColmenaError> { - let hive = util::hive_from_args(local_args).await?; + let hive = HiveArgs::from_arg_matches(local_args) + .unwrap() + .into_hive() + .await + .unwrap(); let ssh_config = env::var("SSH_CONFIG_FILE").ok().map(PathBuf::from); diff --git a/src/command/apply_local.rs b/src/command/apply_local.rs index 6111e98..5b01ebd 100644 --- a/src/command/apply_local.rs +++ b/src/command/apply_local.rs @@ -2,14 +2,17 @@ use regex::Regex; use std::collections::HashMap; use std::str::FromStr; -use clap::{builder::PossibleValuesParser, Arg, ArgMatches, Command as ClapCommand}; +use clap::{ + builder::PossibleValuesParser, Arg, ArgMatches, Command as ClapCommand, FromArgMatches, +}; use tokio::fs; use crate::error::ColmenaError; + use crate::nix::deployment::{Deployment, Goal, Options, TargetNode}; +use crate::nix::hive::HiveArgs; use crate::nix::{host::Local as LocalHost, NodeName}; use crate::progress::SimpleProgressOutput; -use crate::util; pub fn subcommand() -> ClapCommand { ClapCommand::new("apply-local") @@ -89,7 +92,11 @@ pub async fn run(_global_args: &ArgMatches, local_args: &ArgMatches) -> Result<( } } - let hive = util::hive_from_args(local_args).await.unwrap(); + let hive = HiveArgs::from_arg_matches(local_args) + .unwrap() + .into_hive() + .await + .unwrap(); let hostname = { let s = if local_args.contains_id("node") { local_args.get_one::("node").unwrap().to_owned() diff --git a/src/command/eval.rs b/src/command/eval.rs index 0fc271a..0667444 100644 --- a/src/command/eval.rs +++ b/src/command/eval.rs @@ -1,9 +1,9 @@ use std::path::PathBuf; -use clap::{value_parser, Arg, ArgMatches, Command as ClapCommand}; +use clap::{value_parser, Arg, ArgMatches, Command as ClapCommand, FromArgMatches}; use crate::error::ColmenaError; -use crate::util; +use crate::nix::hive::HiveArgs; pub fn subcommand() -> ClapCommand { subcommand_gen("eval") @@ -48,7 +48,11 @@ pub async fn run(global_args: &ArgMatches, local_args: &ArgMatches) -> Result<() ); } - let hive = util::hive_from_args(local_args).await?; + let hive = HiveArgs::from_arg_matches(local_args) + .unwrap() + .into_hive() + .await + .unwrap(); if !(local_args.contains_id("expression") ^ local_args.contains_id("expression_file")) { log::error!("Either an expression (-E) or a .nix file containing an expression should be specified, not both."); diff --git a/src/command/exec.rs b/src/command/exec.rs index b723f6b..05131c5 100644 --- a/src/command/exec.rs +++ b/src/command/exec.rs @@ -2,12 +2,13 @@ use std::env; use std::path::PathBuf; use std::sync::Arc; -use clap::{value_parser, Arg, ArgMatches, Command as ClapCommand}; +use clap::{value_parser, Arg, ArgMatches, Command as ClapCommand, FromArgMatches}; use futures::future::join_all; use tokio::sync::Semaphore; use crate::error::ColmenaError; use crate::job::{JobMonitor, JobState, JobType}; +use crate::nix::hive::HiveArgs; use crate::nix::NodeFilter; use crate::progress::SimpleProgressOutput; use crate::util; @@ -60,7 +61,11 @@ It's recommended to use -- to separate Colmena options from the command to run. } pub async fn run(_global_args: &ArgMatches, local_args: &ArgMatches) -> Result<(), ColmenaError> { - let hive = util::hive_from_args(local_args).await?; + let hive = HiveArgs::from_arg_matches(local_args) + .unwrap() + .into_hive() + .await + .unwrap(); let ssh_config = env::var("SSH_CONFIG_FILE").ok().map(PathBuf::from); // FIXME: Just get_one:: diff --git a/src/command/repl.rs b/src/command/repl.rs index 5e1e93e..e9c58ad 100644 --- a/src/command/repl.rs +++ b/src/command/repl.rs @@ -1,12 +1,12 @@ use std::io::Write; -use clap::{ArgMatches, Command as ClapCommand}; +use clap::{ArgMatches, Command as ClapCommand, FromArgMatches}; use tempfile::Builder as TempFileBuilder; use tokio::process::Command; use crate::error::ColmenaError; +use crate::nix::hive::HiveArgs; use crate::nix::info::NixCheck; -use crate::util; pub fn subcommand() -> ClapCommand { ClapCommand::new("repl") @@ -24,7 +24,11 @@ pub async fn run(_global_args: &ArgMatches, local_args: &ArgMatches) -> Result<( let nix_check = NixCheck::detect().await; let nix_version = nix_check.version().expect("Could not detect Nix version"); - let hive = util::hive_from_args(local_args).await?; + let hive = HiveArgs::from_arg_matches(local_args) + .unwrap() + .into_hive() + .await + .unwrap(); let expr = hive.get_repl_expression(); diff --git a/src/nix/flake.rs b/src/nix/flake.rs index 4e7f72d..5f410ea 100644 --- a/src/nix/flake.rs +++ b/src/nix/flake.rs @@ -53,10 +53,10 @@ impl Flake { } /// Creates a flake from a Flake URI. - pub async fn from_uri(uri: String) -> ColmenaResult { + pub async fn from_uri(uri: impl AsRef) -> ColmenaResult { NixCheck::require_flake_support().await?; - let metadata = FlakeMetadata::resolve(&uri).await?; + let metadata = FlakeMetadata::resolve(uri.as_ref()).await?; Ok(Self { metadata, diff --git a/src/nix/hive/mod.rs b/src/nix/hive/mod.rs index 8831cc8..d8a897b 100644 --- a/src/nix/hive/mod.rs +++ b/src/nix/hive/mod.rs @@ -6,7 +6,9 @@ mod tests; use std::collections::HashMap; use std::convert::AsRef; use std::path::{Path, PathBuf}; +use std::str::FromStr; +use clap::Args; use tokio::process::Command; use tokio::sync::OnceCell; use validator::Validate; @@ -16,11 +18,93 @@ use super::{ Flake, MetaConfig, NixExpression, NixFlags, NodeConfig, NodeFilter, NodeName, ProfileDerivation, SerializedNixExpression, StorePath, }; -use crate::error::ColmenaResult; +use crate::error::{ColmenaError, ColmenaResult}; use crate::job::JobHandle; use crate::util::{CommandExecution, CommandExt}; use assets::Assets; +#[derive(Debug, Args)] +pub struct HiveArgs { + #[arg(short = 'f', long, value_name = "CONFIG")] + config: Option, + #[arg(long)] + show_trace: bool, + #[arg(long)] + impure: bool, + #[arg(long, value_parser = crate::util::parse_key_val::)] + nix_option: Vec<(String, String)>, +} + +impl HiveArgs { + pub async fn into_hive(self) -> ColmenaResult { + let path = match self.config { + Some(path) => path, + None => { + // traverse upwards until we find hive.nix + let mut cur = std::env::current_dir()?; + let mut file_path = None; + + loop { + let flake = cur.join("flake.nix"); + if flake.is_file() { + file_path = Some(flake); + break; + } + + let legacy = cur.join("hive.nix"); + if legacy.is_file() { + file_path = Some(legacy); + break; + } + + match cur.parent() { + Some(parent) => { + cur = parent.to_owned(); + } + None => { + break; + } + } + } + + if file_path.is_none() { + log::error!( + "Could not find `hive.nix` or `flake.nix` in {:?} or any parent directory", + std::env::current_dir()? + ); + } + + HivePath::from_path(file_path.unwrap()).await? + } + }; + + match &path { + HivePath::Legacy(p) => { + log::info!("Using configuration: {}", p.to_string_lossy()); + } + HivePath::Flake(flake) => { + log::info!("Using flake: {}", flake.uri()); + } + } + + let mut hive = Hive::new(path).await?; + + if self.show_trace { + hive.set_show_trace(true); + } + + if self.impure { + hive.set_impure(true); + } + + for (name, value) in self.nix_option { + hive.add_nix_option(name, value); + } + + Ok(hive) + } +} + #[derive(Debug, Clone)] pub enum HivePath { /// A Nix Flake. @@ -32,6 +116,29 @@ pub enum HivePath { Legacy(PathBuf), } +impl FromStr for HivePath { + type Err = ColmenaError; + + fn from_str(s: &str) -> Result { + // TODO: check for escaped colon maybe? + + let path = std::path::Path::new(s); + let handle = tokio::runtime::Handle::try_current() + .expect("We should always be executed after we have a runtime"); + + if !path.exists() && s.contains(':') { + // Treat as flake URI + let flake = handle.block_on(Flake::from_uri(s))?; + + log::info!("Using flake: {}", flake.uri()); + + Ok(Self::Flake(flake)) + } else { + handle.block_on(HivePath::from_path(path)) + } + } +} + #[derive(Debug)] pub struct Hive { /// Path to the hive. diff --git a/src/util.rs b/src/util.rs index e131788..93fee77 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,9 +1,10 @@ use std::convert::TryFrom; -use std::path::PathBuf; +use std::error::Error; + use std::process::Stdio; use async_trait::async_trait; -use clap::{parser::ValueSource as ClapValueSource, Arg, ArgMatches, Command as ClapCommand}; +use clap::{Arg, Command as ClapCommand}; use futures::future::join3; use serde::de::DeserializeOwned; use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader}; @@ -12,7 +13,7 @@ use tokio::process::Command; use super::error::{ColmenaError, ColmenaResult}; use super::job::JobHandle; use super::nix::deployment::TargetNodeMap; -use super::nix::{Flake, Hive, HivePath, StorePath}; +use super::nix::StorePath; const NEWLINE: u8 = 0xa; @@ -192,104 +193,18 @@ impl CommandExt for CommandExecution { } } -pub async fn hive_from_args(args: &ArgMatches) -> ColmenaResult { - let path = match args.value_source("config").unwrap() { - ClapValueSource::DefaultValue => { - // traverse upwards until we find hive.nix - let mut cur = std::env::current_dir()?; - let mut file_path = None; - - loop { - let flake = cur.join("flake.nix"); - if flake.is_file() { - file_path = Some(flake); - break; - } - - let legacy = cur.join("hive.nix"); - if legacy.is_file() { - file_path = Some(legacy); - break; - } - - match cur.parent() { - Some(parent) => { - cur = parent.to_owned(); - } - None => { - break; - } - } - } - - if file_path.is_none() { - log::error!( - "Could not find `hive.nix` or `flake.nix` in {:?} or any parent directory", - std::env::current_dir()? - ); - } - - file_path.unwrap() - } - ClapValueSource::CommandLine => { - let path = args - .get_one::("config") - .expect("The config arg should exist") - .to_owned(); - let fpath = PathBuf::from(&path); - - if !fpath.exists() && path.contains(':') { - // Treat as flake URI - let flake = Flake::from_uri(path).await?; - log::info!("Using flake: {}", flake.uri()); - - let hive_path = HivePath::Flake(flake); - - return hive_from_path(hive_path, args).await; - } - - fpath - } - x => panic!("Unexpected value source for config: {:?}", x), - }; - - let hive_path = HivePath::from_path(path).await?; - - hive_from_path(hive_path, args).await -} - -pub async fn hive_from_path(hive_path: HivePath, args: &ArgMatches) -> ColmenaResult { - match &hive_path { - HivePath::Legacy(p) => { - log::info!("Using configuration: {}", p.to_string_lossy()); - } - HivePath::Flake(flake) => { - log::info!("Using flake: {}", flake.uri()); - } - } - - let mut hive = Hive::new(hive_path).await?; - - if args.get_flag("show-trace") { - hive.set_show_trace(true); - } - - if args.get_flag("impure") { - hive.set_impure(true); - } - - if let Some(opts) = args.get_many::("nix-option") { - let iter = opts; - - let names = iter.clone().step_by(2); - let values = iter.clone().skip(1).step_by(2); - - for (name, value) in names.zip(values) { - hive.add_nix_option(name.to_owned(), value.to_owned()); - } - } - - Ok(hive) +/// Parse a single key-value pair +pub fn parse_key_val(s: &str) -> Result<(T, U), Box> +where + T: std::str::FromStr, + T::Err: Error + Send + Sync + 'static, + U: std::str::FromStr, + U::Err: Error + Send + Sync + 'static, +{ + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) } pub fn register_selector_args(command: ClapCommand) -> ClapCommand {