colmena/src/util.rs

370 lines
10 KiB
Rust
Raw Normal View History

use std::convert::TryFrom;
use std::path::PathBuf;
use std::process::Stdio;
2020-12-18 10:27:44 +01:00
use async_trait::async_trait;
use clap::{parser::ValueSource as ClapValueSource, Arg, ArgMatches, Command as ClapCommand};
use futures::future::join3;
use serde::de::DeserializeOwned;
2022-07-30 07:13:09 +02:00
use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader};
use tokio::process::Command;
2020-12-16 05:21:26 +01:00
2022-07-30 07:13:09 +02:00
use super::error::{ColmenaError, ColmenaResult};
use super::job::JobHandle;
2022-07-30 07:13:09 +02:00
use super::nix::deployment::TargetNodeMap;
use super::nix::{Flake, Hive, HivePath, StorePath};
2020-12-18 10:27:44 +01:00
const NEWLINE: u8 = 0xa;
/// Non-interactive execution of an arbitrary command.
pub struct CommandExecution {
command: Command,
job: Option<JobHandle>,
hide_stdout: bool,
stdout: Option<String>,
stderr: Option<String>,
}
/// Helper extensions for Commands.
#[async_trait]
pub trait CommandExt {
/// Runs the command with stdout and stderr passed through to the user.
async fn passthrough(&mut self) -> ColmenaResult<()>;
/// Runs the command, capturing the output as a String.
async fn capture_output(&mut self) -> ColmenaResult<String>;
/// Runs the command, capturing deserialized output from JSON.
2022-07-30 07:13:09 +02:00
async fn capture_json<T>(&mut self) -> ColmenaResult<T>
where
T: DeserializeOwned;
/// Runs the command, capturing a single store path.
async fn capture_store_path(&mut self) -> ColmenaResult<StorePath>;
}
impl CommandExecution {
2021-02-10 04:28:45 +01:00
pub fn new(command: Command) -> Self {
Self {
command,
job: None,
hide_stdout: false,
stdout: None,
stderr: None,
}
}
/// Sets the job associated with this execution.
pub fn set_job(&mut self, job: Option<JobHandle>) {
self.job = job;
}
/// Sets whether to hide stdout.
pub fn set_hide_stdout(&mut self, hide_stdout: bool) {
self.hide_stdout = hide_stdout;
}
/// Returns logs from the last invocation.
pub fn get_logs(&self) -> (Option<&String>, Option<&String>) {
(self.stdout.as_ref(), self.stderr.as_ref())
}
/// Runs the command.
pub async fn run(&mut self) -> ColmenaResult<()> {
self.command.stdin(Stdio::null());
self.command.stdout(Stdio::piped());
self.command.stderr(Stdio::piped());
self.stdout = Some(String::new());
self.stderr = Some(String::new());
let mut child = self.command.spawn()?;
let stdout = BufReader::new(child.stdout.take().unwrap());
let stderr = BufReader::new(child.stderr.take().unwrap());
2022-07-30 07:13:09 +02:00
let stdout_job = if self.hide_stdout {
None
} else {
self.job.clone()
};
let futures = join3(
capture_stream(stdout, stdout_job, false),
capture_stream(stderr, self.job.clone(), true),
child.wait(),
);
2021-12-05 10:14:12 +01:00
let (stdout, stderr, wait) = futures.await;
self.stdout = Some(stdout?);
self.stderr = Some(stderr?);
let exit = wait?;
if exit.success() {
Ok(())
} else {
2021-04-29 00:09:40 +02:00
Err(exit.into())
}
}
}
#[async_trait]
impl CommandExt for Command {
/// Runs the command with stdout and stderr passed through to the user.
async fn passthrough(&mut self) -> ColmenaResult<()> {
2022-07-30 07:13:09 +02:00
let exit = self.spawn()?.wait().await?;
if exit.success() {
Ok(())
} else {
Err(exit.into())
}
}
/// Captures output as a String.
async fn capture_output(&mut self) -> ColmenaResult<String> {
// We want the user to see the raw errors
let output = self
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?
.wait_with_output()
.await?;
if output.status.success() {
// FIXME: unwrap
Ok(String::from_utf8(output.stdout).unwrap())
} else {
Err(output.status.into())
}
}
/// Captures deserialized output from JSON.
2022-07-30 07:13:09 +02:00
async fn capture_json<T>(&mut self) -> ColmenaResult<T>
where
T: DeserializeOwned,
{
let output = self.capture_output().await?;
serde_json::from_str(&output).map_err(|_| ColmenaError::BadOutput {
2022-07-30 07:13:09 +02:00
output: output.clone(),
})
}
/// Captures a single store path.
async fn capture_store_path(&mut self) -> ColmenaResult<StorePath> {
let output = self.capture_output().await?;
let path = output.trim_end().to_owned();
StorePath::try_from(path)
}
}
#[async_trait]
impl CommandExt for CommandExecution {
async fn passthrough(&mut self) -> ColmenaResult<()> {
self.run().await
}
/// Captures output as a String.
async fn capture_output(&mut self) -> ColmenaResult<String> {
self.run().await?;
let (stdout, _) = self.get_logs();
Ok(stdout.unwrap().to_owned())
}
/// Captures deserialized output from JSON.
2022-07-30 07:13:09 +02:00
async fn capture_json<T>(&mut self) -> ColmenaResult<T>
where
T: DeserializeOwned,
{
let output = self.capture_output().await?;
serde_json::from_str(&output).map_err(|_| ColmenaError::BadOutput {
2022-07-30 07:13:09 +02:00
output: output.clone(),
})
}
/// Captures a single store path.
async fn capture_store_path(&mut self) -> ColmenaResult<StorePath> {
let output = self.capture_output().await?;
let path = output.trim_end().to_owned();
StorePath::try_from(path)
}
}
pub async fn hive_from_args(args: &ArgMatches) -> ColmenaResult<Hive> {
let path = match args.value_source("config").unwrap() {
ClapValueSource::DefaultValue => {
// traverse upwards until we find hive.nix
let mut cur = std::env::current_dir()?;
let mut file_path = None;
loop {
2021-10-28 23:09:35 +02:00
let flake = cur.join("flake.nix");
if flake.is_file() {
file_path = Some(flake);
break;
}
2021-10-28 23:09:35 +02:00
let legacy = cur.join("hive.nix");
if legacy.is_file() {
file_path = Some(legacy);
break;
}
match cur.parent() {
Some(parent) => {
cur = parent.to_owned();
}
None => {
break;
}
}
}
if file_path.is_none() {
2022-07-30 07:13:09 +02:00
log::error!(
"Could not find `hive.nix` or `flake.nix` in {:?} or any parent directory",
std::env::current_dir()?
);
}
file_path.unwrap()
}
ClapValueSource::CommandLine => {
2022-07-30 07:13:09 +02:00
let path = args
.get_one::<String>("config")
2022-07-30 07:13:09 +02:00
.expect("The config arg should exist")
.to_owned();
let fpath = PathBuf::from(&path);
2021-11-23 22:33:23 +01:00
if !fpath.exists() && path.contains(':') {
// Treat as flake URI
let flake = Flake::from_uri(path).await?;
log::info!("Using flake: {}", flake.uri());
let hive_path = HivePath::Flake(flake);
let mut hive = Hive::new(hive_path).await?;
if args.get_flag("show-trace") {
hive.set_show_trace(true);
}
if args.get_flag("impure") {
2022-08-17 04:15:43 +02:00
hive.set_impure(true);
}
return Ok(hive);
}
fpath
}
x => panic!("Unexpected value source for config: {:?}", x),
};
let hive_path = HivePath::from_path(path).await?;
2021-12-04 10:03:26 +01:00
match &hive_path {
HivePath::Legacy(p) => {
log::info!("Using configuration: {}", p.to_string_lossy());
}
HivePath::Flake(flake) => {
log::info!("Using flake: {}", flake.uri());
}
}
let mut hive = Hive::new(hive_path).await?;
if args.get_flag("show-trace") {
hive.set_show_trace(true);
}
if args.get_flag("impure") {
2022-08-17 04:15:43 +02:00
hive.set_impure(true);
}
Ok(hive)
}
2022-03-08 07:02:04 +01:00
pub fn register_selector_args(command: ClapCommand) -> ClapCommand {
2020-12-29 06:35:43 +01:00
command
2022-01-03 19:37:03 +01:00
.arg(Arg::new("on")
2020-12-16 05:21:26 +01:00
.long("on")
2021-01-02 05:45:41 +01:00
.value_name("NODES")
.help("Node selector")
.long_help(r#"Select a list of nodes to deploy to.
The list is comma-separated and globs are supported. To match tags, prepend the filter by @. Valid examples:
2020-12-16 05:21:26 +01:00
- host1,host2,host3
- edge-*
2020-12-18 10:27:44 +01:00
- edge-*,core-*
- @a-tag,@tags-can-have-*"#)
.num_args(1))
2020-12-16 05:21:26 +01:00
}
2022-07-30 07:13:09 +02:00
pub async fn capture_stream<R>(
mut stream: BufReader<R>,
job: Option<JobHandle>,
stderr: bool,
) -> ColmenaResult<String>
where
R: AsyncRead + Unpin,
2021-12-05 10:14:12 +01:00
{
let mut log = String::new();
loop {
let mut line = Vec::new();
let len = stream.read_until(NEWLINE, &mut line).await?;
let line = String::from_utf8_lossy(&line);
if len == 0 {
break;
}
let trimmed = line.trim_end();
if let Some(job) = &job {
if stderr {
2021-12-05 10:14:12 +01:00
job.stderr(trimmed.to_string())?;
} else {
2021-12-05 10:14:12 +01:00
job.stdout(trimmed.to_string())?;
}
}
log += trimmed;
log += "\n";
}
2021-12-05 10:14:12 +01:00
Ok(log)
}
pub fn get_label_width(targets: &TargetNodeMap) -> Option<usize> {
targets.keys().map(|n| n.len()).max()
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::BufReader;
use tokio_test::block_on;
#[test]
fn test_capture_stream() {
let expected = "Hello\nWorld\n";
let stream = BufReader::new(expected.as_bytes());
2022-07-30 07:13:09 +02:00
let captured = block_on(async { capture_stream(stream, None, false).await.unwrap() });
assert_eq!(expected, captured);
}
#[test]
fn test_capture_stream_with_invalid_utf8() {
let stream = BufReader::new([0x80, 0xa].as_slice());
2022-07-30 07:13:09 +02:00
let captured = block_on(async { capture_stream(stream, None, false).await.unwrap() });
assert_eq!("\u{fffd}\n", captured);
}
}