use std::convert::TryFrom; use std::error::Error; use std::process::Stdio; use async_trait::async_trait; use clap::{Arg, Command as ClapCommand}; use futures::future::join3; use serde::de::DeserializeOwned; use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader}; use tokio::process::Command; use super::error::{ColmenaError, ColmenaResult}; use super::job::JobHandle; use super::nix::deployment::TargetNodeMap; use super::nix::StorePath; const NEWLINE: u8 = 0xa; /// Non-interactive execution of an arbitrary command. pub struct CommandExecution { command: Command, job: Option, hide_stdout: bool, stdout: Option, stderr: Option, } /// Helper extensions for Commands. #[async_trait] pub trait CommandExt { /// Runs the command with stdout and stderr passed through to the user. async fn passthrough(&mut self) -> ColmenaResult<()>; /// Runs the command, capturing the output as a String. async fn capture_output(&mut self) -> ColmenaResult; /// Runs the command, capturing deserialized output from JSON. async fn capture_json(&mut self) -> ColmenaResult where T: DeserializeOwned; /// Runs the command, capturing a single store path. async fn capture_store_path(&mut self) -> ColmenaResult; } impl CommandExecution { pub fn new(command: Command) -> Self { Self { command, job: None, hide_stdout: false, stdout: None, stderr: None, } } /// Sets the job associated with this execution. pub fn set_job(&mut self, job: Option) { self.job = job; } /// Sets whether to hide stdout. pub fn set_hide_stdout(&mut self, hide_stdout: bool) { self.hide_stdout = hide_stdout; } /// Returns logs from the last invocation. pub fn get_logs(&self) -> (Option<&String>, Option<&String>) { (self.stdout.as_ref(), self.stderr.as_ref()) } /// Runs the command. pub async fn run(&mut self) -> ColmenaResult<()> { self.command.stdin(Stdio::null()); self.command.stdout(Stdio::piped()); self.command.stderr(Stdio::piped()); self.stdout = Some(String::new()); self.stderr = Some(String::new()); let mut child = self.command.spawn()?; let stdout = BufReader::new(child.stdout.take().unwrap()); let stderr = BufReader::new(child.stderr.take().unwrap()); let stdout_job = if self.hide_stdout { None } else { self.job.clone() }; let futures = join3( capture_stream(stdout, stdout_job, false), capture_stream(stderr, self.job.clone(), true), child.wait(), ); let (stdout, stderr, wait) = futures.await; self.stdout = Some(stdout?); self.stderr = Some(stderr?); let exit = wait?; if exit.success() { Ok(()) } else { Err(exit.into()) } } } #[async_trait] impl CommandExt for Command { /// Runs the command with stdout and stderr passed through to the user. async fn passthrough(&mut self) -> ColmenaResult<()> { let exit = self.spawn()?.wait().await?; if exit.success() { Ok(()) } else { Err(exit.into()) } } /// Captures output as a String. async fn capture_output(&mut self) -> ColmenaResult { // We want the user to see the raw errors let output = self .stdout(Stdio::piped()) .stderr(Stdio::inherit()) .spawn()? .wait_with_output() .await?; if output.status.success() { // FIXME: unwrap Ok(String::from_utf8(output.stdout).unwrap()) } else { Err(output.status.into()) } } /// Captures deserialized output from JSON. async fn capture_json(&mut self) -> ColmenaResult where T: DeserializeOwned, { let output = self.capture_output().await?; serde_json::from_str(&output).map_err(|_| ColmenaError::BadOutput { output: output.clone(), }) } /// Captures a single store path. async fn capture_store_path(&mut self) -> ColmenaResult { let output = self.capture_output().await?; let path = output.trim_end().to_owned(); StorePath::try_from(path) } } #[async_trait] impl CommandExt for CommandExecution { async fn passthrough(&mut self) -> ColmenaResult<()> { self.run().await } /// Captures output as a String. async fn capture_output(&mut self) -> ColmenaResult { self.run().await?; let (stdout, _) = self.get_logs(); Ok(stdout.unwrap().to_owned()) } /// Captures deserialized output from JSON. async fn capture_json(&mut self) -> ColmenaResult where T: DeserializeOwned, { let output = self.capture_output().await?; serde_json::from_str(&output).map_err(|_| ColmenaError::BadOutput { output: output.clone(), }) } /// Captures a single store path. async fn capture_store_path(&mut self) -> ColmenaResult { let output = self.capture_output().await?; let path = output.trim_end().to_owned(); StorePath::try_from(path) } } /// Parse a single key-value pair pub fn parse_key_val(s: &str) -> Result<(T, U), Box> where T: std::str::FromStr, T::Err: Error + Send + Sync + 'static, U: std::str::FromStr, U::Err: Error + Send + Sync + 'static, { let pos = s .find('=') .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) } pub fn register_selector_args(command: ClapCommand) -> ClapCommand { command .arg(Arg::new("on") .long("on") .value_name("NODES") .help("Node selector") .long_help(r#"Select a list of nodes to deploy to. The list is comma-separated and globs are supported. To match tags, prepend the filter by @. Valid examples: - host1,host2,host3 - edge-* - edge-*,core-* - @a-tag,@tags-can-have-*"#) .num_args(1)) } pub async fn capture_stream( mut stream: BufReader, job: Option, stderr: bool, ) -> ColmenaResult where R: AsyncRead + Unpin, { let mut log = String::new(); loop { let mut line = Vec::new(); let len = stream.read_until(NEWLINE, &mut line).await?; let line = String::from_utf8_lossy(&line); if len == 0 { break; } let trimmed = line.trim_end(); if let Some(job) = &job { if stderr { job.stderr(trimmed.to_string())?; } else { job.stdout(trimmed.to_string())?; } } log += trimmed; log += "\n"; } Ok(log) } pub fn get_label_width(targets: &TargetNodeMap) -> Option { targets.keys().map(|n| n.len()).max() } #[cfg(test)] mod tests { use super::*; use tokio::io::BufReader; use tokio_test::block_on; #[test] fn test_capture_stream() { let expected = "Hello\nWorld\n"; let stream = BufReader::new(expected.as_bytes()); let captured = block_on(async { capture_stream(stream, None, false).await.unwrap() }); assert_eq!(expected, captured); } #[test] fn test_capture_stream_with_invalid_utf8() { let stream = BufReader::new([0x80, 0xa].as_slice()); let captured = block_on(async { capture_stream(stream, None, false).await.unwrap() }); assert_eq!("\u{fffd}\n", captured); } }