diff --git a/tvix/castore/src/import/archive.rs b/tvix/castore/src/import/archive.rs index eab695d65..84e280faf 100644 --- a/tvix/castore/src/import/archive.rs +++ b/tvix/castore/src/import/archive.rs @@ -1,28 +1,45 @@ +use std::io::{Cursor, Write}; +use std::sync::Arc; use std::{collections::HashMap, path::PathBuf}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{DfsPostOrder, EdgeRef}; use petgraph::Direction; use tokio::io::AsyncRead; +use tokio::sync::Semaphore; +use tokio::task::JoinSet; use tokio_stream::StreamExt; use tokio_tar::Archive; +use tokio_util::io::InspectReader; use tracing::{instrument, warn, Level}; use crate::blobservice::BlobService; use crate::directoryservice::DirectoryService; use crate::import::{ingest_entries, Error, IngestionEntry}; use crate::proto::node::Node; +use crate::B3Digest; + +/// Files smaller than this threshold, in bytes, are uploaded to the [BlobService] in the +/// background. +/// +/// This is a u32 since we acquire a weighted semaphore using the size of the blob. +/// [Semaphore::acquire_many_owned] takes a u32, so we need to ensure the size of +/// the blob can be represented using a u32 and will not cause an overflow. +const CONCURRENT_BLOB_UPLOAD_THRESHOLD: u32 = 1024 * 1024; + +/// The maximum amount of bytes allowed to be buffered in memory to perform async blob uploads. +const MAX_TARBALL_BUFFER_SIZE: usize = 128 * 1024 * 1024; /// Ingests elements from the given tar [`Archive`] into a the passed [`BlobService`] and /// [`DirectoryService`]. #[instrument(skip_all, ret(level = Level::TRACE), err)] -pub async fn ingest_archive<'a, BS, DS, R>( +pub async fn ingest_archive( blob_service: BS, directory_service: DS, mut archive: Archive, ) -> Result where - BS: AsRef + Clone, + BS: BlobService + Clone + 'static, DS: AsRef, R: AsyncRead + Unpin, { @@ -33,21 +50,80 @@ where // In the first phase, collect up all the regular files and symlinks. let mut nodes = IngestionEntryGraph::new(); + let semaphore = Arc::new(Semaphore::new(MAX_TARBALL_BUFFER_SIZE)); + let mut async_blob_uploads: JoinSet> = JoinSet::new(); + let mut entries_iter = archive.entries().map_err(Error::Archive)?; while let Some(mut entry) = entries_iter.try_next().await.map_err(Error::Archive)? { let path: PathBuf = entry.path().map_err(Error::Archive)?.into(); - let entry = match entry.header().entry_type() { + let header = entry.header(); + let entry = match header.entry_type() { tokio_tar::EntryType::Regular | tokio_tar::EntryType::GNUSparse | tokio_tar::EntryType::Continuous => { - // TODO: If the same path is overwritten in the tarball, we may leave - // an unreferenced blob after uploading. - let mut writer = blob_service.as_ref().open_write().await; - let size = tokio::io::copy(&mut entry, &mut writer) - .await - .map_err(Error::Archive)?; - let digest = writer.close().await.map_err(Error::Archive)?; + let header_size = header.size().map_err(Error::Archive)?; + + // If the blob is small enough, read it off the wire, compute the digest, + // and upload it to the [BlobService] in the background. + let (size, digest) = if header_size <= CONCURRENT_BLOB_UPLOAD_THRESHOLD as u64 { + let mut buffer = Vec::with_capacity(header_size as usize); + let mut hasher = blake3::Hasher::new(); + let mut reader = InspectReader::new(&mut entry, |bytes| { + hasher.write_all(bytes).unwrap(); + }); + + // Ensure that we don't buffer into memory until we've acquired a permit. + // This prevents consuming too much memory when performing concurrent + // blob uploads. + let permit = semaphore + .clone() + // This cast is safe because ensure the header_size is less than + // CONCURRENT_BLOB_UPLOAD_THRESHOLD which is a u32. + .acquire_many_owned(header_size as u32) + .await + .unwrap(); + let size = tokio::io::copy(&mut reader, &mut buffer) + .await + .map_err(Error::Archive)?; + + let digest: B3Digest = hasher.finalize().as_bytes().into(); + + { + let blob_service = blob_service.clone(); + let digest = digest.clone(); + async_blob_uploads.spawn({ + async move { + let mut writer = blob_service.open_write().await; + + tokio::io::copy(&mut Cursor::new(buffer), &mut writer) + .await + .map_err(Error::Archive)?; + + let blob_digest = writer.close().await.map_err(Error::Archive)?; + + assert_eq!(digest, blob_digest, "Tvix bug: blob digest mismatch"); + + // Make sure we hold the permit until we finish writing the blob + // to the [BlobService]. + drop(permit); + Ok(()) + } + }); + } + + (size, digest) + } else { + let mut writer = blob_service.open_write().await; + + let size = tokio::io::copy(&mut entry, &mut writer) + .await + .map_err(Error::Archive)?; + + let digest = writer.close().await.map_err(Error::Archive)?; + + (size, digest) + }; IngestionEntry::Regular { path, @@ -77,6 +153,10 @@ where nodes.add(entry)?; } + while let Some(result) = async_blob_uploads.join_next().await { + result.expect("task panicked")?; + } + ingest_entries( directory_service, futures::stream::iter(nodes.finalize()?.into_iter().map(Ok)), diff --git a/tvix/glue/src/fetchers.rs b/tvix/glue/src/fetchers.rs index 7981770eb..7560c447d 100644 --- a/tvix/glue/src/fetchers.rs +++ b/tvix/glue/src/fetchers.rs @@ -168,7 +168,7 @@ async fn hash( impl Fetcher where - BS: AsRef<(dyn BlobService + 'static)> + Send + Sync, + BS: AsRef<(dyn BlobService + 'static)> + Clone + Send + Sync + 'static, DS: AsRef<(dyn DirectoryService + 'static)>, PS: PathInfoService, { @@ -242,7 +242,7 @@ where // Ingest the archive, get the root node let node = tvix_castore::import::archive::ingest_archive( - &self.blob_service, + self.blob_service.clone(), &self.directory_service, archive, )