diff --git a/src/nix/host/local.rs b/src/nix/host/local.rs index 614bbf1..1a2e8ed 100644 --- a/src/nix/host/local.rs +++ b/src/nix/host/local.rs @@ -1,9 +1,9 @@ use std::convert::TryInto; use std::collections::HashMap; use std::fs; -use std::io::Write; use async_trait::async_trait; +use tokio::fs::OpenOptions; use tokio::process::Command; use tempfile::NamedTempFile; @@ -109,10 +109,11 @@ impl Local { let dest_path = key.dest_dir.join(name); - let mut temp = NamedTempFile::new()?; - temp.write_all(key.text.as_bytes())?; - + let temp = NamedTempFile::new()?; let (_, temp_path) = temp.keep().map_err(|pe| pe.error)?; + let mut reader = key.reader().await?; + let mut writer = OpenOptions::new().write(true).open(&temp_path).await?; + tokio::io::copy(reader.as_mut(), &mut writer).await?; // Well, we need the userspace chmod program to parse the // permission, for NixOps compatibility @@ -149,7 +150,13 @@ impl Local { let parent_dir = dest_path.parent().unwrap(); fs::create_dir_all(parent_dir)?; - fs::rename(temp_path, dest_path)?; + + if fs::rename(&temp_path, &dest_path).is_err() { + // Linux can not rename across different filesystems, try copy-then-remove + let copy_result = fs::copy(&temp_path, &dest_path); + fs::remove_file(&temp_path)?; + copy_result?; + } Ok(()) } diff --git a/src/nix/host/ssh.rs b/src/nix/host/ssh.rs index 2ced2b3..bb04310 100644 --- a/src/nix/host/ssh.rs +++ b/src/nix/host/ssh.rs @@ -214,7 +214,8 @@ impl Ssh { let mut child = command.spawn()?; let mut stdin = child.stdin.take().unwrap(); - stdin.write_all(key.text.as_bytes()).await?; + let mut reader = key.reader().await?; + tokio::io::copy(reader.as_mut(), &mut stdin).await?; stdin.flush().await?; drop(stdin); diff --git a/src/nix/key.rs b/src/nix/key.rs index 58ee2a1..eb8021e 100644 --- a/src/nix/key.rs +++ b/src/nix/key.rs @@ -1,12 +1,18 @@ -use std::path::PathBuf; +use std::{ + io::{self, Cursor}, + path::PathBuf, +}; use regex::Regex; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +use tokio::{fs::File, io::AsyncRead}; use validator::{Validate, ValidationError}; #[derive(Debug, Clone, Validate, Serialize, Deserialize)] pub struct Key { - pub(crate) text: String, + pub(crate) text: Option, + #[serde(rename = "keyFile")] + pub(crate) key_file: Option, #[validate(custom = "validate_dest_dir")] #[serde(rename = "destDir")] pub(super) dest_dir: PathBuf, @@ -17,6 +23,18 @@ pub struct Key { pub(super) permissions: String, } +impl Key { + pub(crate) async fn reader(&'_ self,) -> Result, io::Error> { + if let Some(ref t) = self.text { + Ok(Box::new(Cursor::new(t))) + } else if let Some(ref p) = self.key_file { + Ok(Box::new(File::open(p).await?)) + } else { + unreachable!("Neither `text` nor `keyFile` set. This should have been validated by Nix assertions."); + } + } +} + fn validate_unix_name(name: &str) -> Result<(), ValidationError> { let re = Regex::new(r"^[a-z][-a-z0-9]*$").unwrap(); if re.is_match(name) {