2024-06-17 01:10:55 +02:00
|
|
|
//! The composition module allows composing different kinds of services based on a set of service
|
|
|
|
//! configurations _at runtime_.
|
|
|
|
//!
|
|
|
|
//! Store configs are deserialized with serde. The registry provides a stateful mapping from the
|
|
|
|
//! `type` tag of an internally tagged enum on the serde side to a Config struct which is
|
|
|
|
//! deserialized and then returned as a `Box<dyn ServiceBuilder<Output = dyn BlobService>>`
|
|
|
|
//! (the same for DirectoryService instead of BlobService etc).
|
|
|
|
//!
|
|
|
|
//! ### Example 1.: Implementing a new BlobService
|
|
|
|
//!
|
|
|
|
//! You need a Config struct which implements `DeserializeOwned` and
|
|
|
|
//! `ServiceBuilder<Output = dyn BlobService>`.
|
|
|
|
//! Provide the user with a function to call with
|
|
|
|
//! their registry. You register your new type as:
|
|
|
|
//!
|
|
|
|
//! ```
|
|
|
|
//! use std::sync::Arc;
|
|
|
|
//!
|
|
|
|
//! use tvix_castore::composition::*;
|
|
|
|
//! use tvix_castore::blobservice::BlobService;
|
|
|
|
//!
|
|
|
|
//! #[derive(serde::Deserialize)]
|
|
|
|
//! struct MyBlobServiceConfig {
|
|
|
|
//! }
|
|
|
|
//!
|
|
|
|
//! #[tonic::async_trait]
|
|
|
|
//! impl ServiceBuilder for MyBlobServiceConfig {
|
|
|
|
//! type Output = dyn BlobService;
|
2024-07-19 10:51:05 +02:00
|
|
|
//! async fn build(&self, _: &str, _: &CompositionContext) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>> {
|
2024-06-17 01:10:55 +02:00
|
|
|
//! todo!()
|
|
|
|
//! }
|
|
|
|
//! }
|
|
|
|
//!
|
2024-07-18 19:09:07 +02:00
|
|
|
//! impl TryFrom<url::Url> for MyBlobServiceConfig {
|
|
|
|
//! type Error = Box<dyn std::error::Error + Send + Sync>;
|
|
|
|
//! fn try_from(url: url::Url) -> Result<Self, Self::Error> {
|
|
|
|
//! todo!()
|
|
|
|
//! }
|
|
|
|
//! }
|
|
|
|
//!
|
2024-06-17 01:10:55 +02:00
|
|
|
//! pub fn add_my_service(reg: &mut Registry) {
|
|
|
|
//! reg.register::<Box<dyn ServiceBuilder<Output = dyn BlobService>>, MyBlobServiceConfig>("myblobservicetype");
|
|
|
|
//! }
|
|
|
|
//! ```
|
|
|
|
//!
|
|
|
|
//! Now, when a user deserializes a store config with the type tag "myblobservicetype" into a
|
|
|
|
//! `Box<dyn ServiceBuilder<Output = Arc<dyn BlobService>>>`, it will be done via `MyBlobServiceConfig`.
|
|
|
|
//!
|
|
|
|
//! ### Example 2.: Composing stores to get one store
|
|
|
|
//!
|
|
|
|
//! ```
|
2024-07-19 10:51:05 +02:00
|
|
|
//! use std::sync::Arc;
|
2024-06-17 01:10:55 +02:00
|
|
|
//! use tvix_castore::composition::*;
|
|
|
|
//! use tvix_castore::blobservice::BlobService;
|
|
|
|
//!
|
|
|
|
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
//! # tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap().block_on(async move {
|
|
|
|
//! let blob_services_configs_json = serde_json::json!({
|
|
|
|
//! "blobstore1": {
|
|
|
|
//! "type": "memory"
|
|
|
|
//! },
|
|
|
|
//! "blobstore2": {
|
|
|
|
//! "type": "memory"
|
|
|
|
//! },
|
|
|
|
//! "default": {
|
|
|
|
//! "type": "combined",
|
|
|
|
//! "local": "blobstore1",
|
|
|
|
//! "remote": "blobstore2"
|
|
|
|
//! }
|
|
|
|
//! });
|
|
|
|
//!
|
|
|
|
//! let blob_services_configs = with_registry(®, || serde_json::from_value(blob_services_configs_json))?;
|
2024-07-19 10:51:05 +02:00
|
|
|
//! let mut blob_service_composition = Composition::default();
|
|
|
|
//! blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
|
|
|
|
//! let blob_service: Arc<dyn BlobService> = blob_service_composition.build("default").await?;
|
2024-06-17 01:10:55 +02:00
|
|
|
//! # Ok(())
|
|
|
|
//! # })
|
|
|
|
//! # }
|
|
|
|
//! ```
|
|
|
|
//!
|
|
|
|
//! ### Example 3.: Creating another registry extending the default registry with third-party types
|
|
|
|
//!
|
|
|
|
//! ```
|
|
|
|
//! # pub fn add_my_service(reg: &mut tvix_castore::composition::Registry) {}
|
|
|
|
//! let mut my_registry = tvix_castore::composition::Registry::default();
|
|
|
|
//! tvix_castore::composition::add_default_services(&mut my_registry);
|
|
|
|
//! add_my_service(&mut my_registry);
|
|
|
|
//! ```
|
|
|
|
//!
|
|
|
|
//! Continue with Example 2, with my_registry instead of REG
|
|
|
|
|
|
|
|
use erased_serde::deserialize;
|
|
|
|
use futures::future::BoxFuture;
|
|
|
|
use futures::FutureExt;
|
|
|
|
use lazy_static::lazy_static;
|
|
|
|
use serde::de::DeserializeOwned;
|
|
|
|
use serde_tagged::de::{BoxFnSeed, SeedFactory};
|
|
|
|
use serde_tagged::util::TagString;
|
|
|
|
use std::any::{Any, TypeId};
|
|
|
|
use std::cell::Cell;
|
|
|
|
use std::collections::BTreeMap;
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use std::marker::PhantomData;
|
|
|
|
use std::sync::Arc;
|
|
|
|
use tonic::async_trait;
|
|
|
|
|
|
|
|
/// Resolves tag names to the corresponding Config type.
|
|
|
|
// Registry implementation details:
|
|
|
|
// This is really ugly. Really we would want to store this as a generic static field:
|
|
|
|
//
|
|
|
|
// ```
|
2024-07-18 19:09:07 +02:00
|
|
|
// struct Registry<T>(BTreeMap<(&'static str), RegistryEntry<T>);
|
2024-06-17 01:10:55 +02:00
|
|
|
// static REG<T>: Registry<T>;
|
|
|
|
// ```
|
|
|
|
//
|
|
|
|
// so that one version of the static is generated for each Type that the registry is accessed for.
|
|
|
|
// However, this is not possible, because generics are only a thing in functions, and even there
|
|
|
|
// they will not interact with static items:
|
|
|
|
// https://doc.rust-lang.org/reference/items/static-items.html#statics--generics
|
|
|
|
//
|
|
|
|
// So instead, we make this lookup at runtime by putting the TypeId into the key.
|
|
|
|
// But now we can no longer store the `BoxFnSeed<T>` because we are lacking the generic parameter
|
|
|
|
// T, so instead store it as `Box<dyn Any>` and downcast to `&BoxFnSeed<T>` when performing the
|
|
|
|
// lookup.
|
|
|
|
// I said it was ugly...
|
|
|
|
#[derive(Default)]
|
|
|
|
pub struct Registry(BTreeMap<(TypeId, &'static str), Box<dyn Any + Sync>>);
|
2024-07-18 19:09:07 +02:00
|
|
|
pub type FromUrlSeed<T> =
|
|
|
|
Box<dyn Fn(url::Url) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + Sync>;
|
|
|
|
pub struct RegistryEntry<T> {
|
|
|
|
serde_deserialize_seed: BoxFnSeed<DeserializeWithRegistry<T>>,
|
|
|
|
from_url_seed: FromUrlSeed<DeserializeWithRegistry<T>>,
|
|
|
|
}
|
2024-06-17 01:10:55 +02:00
|
|
|
|
|
|
|
struct RegistryWithFakeType<'r, T>(&'r Registry, PhantomData<T>);
|
|
|
|
|
|
|
|
impl<'r, 'de: 'r, T: 'static> SeedFactory<'de, TagString<'de>> for RegistryWithFakeType<'r, T> {
|
|
|
|
type Value = DeserializeWithRegistry<T>;
|
|
|
|
type Seed = &'r BoxFnSeed<Self::Value>;
|
|
|
|
|
|
|
|
// Required method
|
|
|
|
fn seed<E>(self, tag: TagString<'de>) -> Result<Self::Seed, E>
|
|
|
|
where
|
|
|
|
E: serde::de::Error,
|
|
|
|
{
|
|
|
|
// using find() and not get() because of https://github.com/rust-lang/rust/issues/80389
|
|
|
|
let seed: &Box<dyn Any + Sync> = self
|
|
|
|
.0
|
|
|
|
.0
|
|
|
|
.iter()
|
|
|
|
.find(|(k, _)| *k == &(TypeId::of::<T>(), tag.as_ref()))
|
2024-07-22 13:14:21 +02:00
|
|
|
.ok_or_else(|| serde::de::Error::custom(format!("Unknown type: {}", tag)))?
|
2024-06-17 01:10:55 +02:00
|
|
|
.1;
|
|
|
|
|
2024-07-18 19:09:07 +02:00
|
|
|
let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
|
|
|
|
|
|
|
|
Ok(&entry.serde_deserialize_seed)
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Wrapper type which implements Deserialize using the registry
|
|
|
|
///
|
|
|
|
/// Wrap your type in this in order to deserialize it using a registry, e.g.
|
|
|
|
/// `RegistryWithFakeType<Box<dyn MyTrait>>`, then the types registered for `Box<dyn MyTrait>`
|
|
|
|
/// will be used.
|
2024-07-18 19:09:07 +02:00
|
|
|
pub struct DeserializeWithRegistry<T>(pub T);
|
2024-06-17 01:10:55 +02:00
|
|
|
|
|
|
|
impl Registry {
|
|
|
|
/// Registers a mapping from type tag to a concrete type into the registry.
|
|
|
|
///
|
|
|
|
/// The type parameters are very important:
|
|
|
|
/// After calling `register::<Box<dyn FooTrait>, FooStruct>("footype")`, when a user
|
|
|
|
/// deserializes into an input with the type tag "myblobservicetype" into a
|
|
|
|
/// `Box<dyn FooTrait>`, it will first call the Deserialize imple of `FooStruct` and
|
|
|
|
/// then convert it into a `Box<dyn FooTrait>` using From::from.
|
2024-07-18 19:09:07 +02:00
|
|
|
pub fn register<
|
|
|
|
T: 'static,
|
|
|
|
C: DeserializeOwned
|
|
|
|
+ TryFrom<url::Url, Error = Box<dyn std::error::Error + Send + Sync>>
|
|
|
|
+ Into<T>,
|
|
|
|
>(
|
|
|
|
&mut self,
|
|
|
|
type_name: &'static str,
|
|
|
|
) {
|
|
|
|
self.0.insert(
|
|
|
|
(TypeId::of::<T>(), type_name),
|
|
|
|
Box::new(RegistryEntry {
|
|
|
|
serde_deserialize_seed: BoxFnSeed::new(|x| {
|
|
|
|
deserialize::<C>(x)
|
|
|
|
.map(Into::into)
|
|
|
|
.map(DeserializeWithRegistry)
|
|
|
|
}),
|
|
|
|
from_url_seed: Box::new(|url| {
|
|
|
|
C::try_from(url)
|
|
|
|
.map(Into::into)
|
|
|
|
.map(DeserializeWithRegistry)
|
|
|
|
}),
|
|
|
|
}),
|
|
|
|
);
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<'de, T: 'static> serde::Deserialize<'de> for DeserializeWithRegistry<T> {
|
|
|
|
fn deserialize<D>(de: D) -> std::result::Result<Self, D::Error>
|
|
|
|
where
|
|
|
|
D: serde::Deserializer<'de>,
|
|
|
|
{
|
|
|
|
serde_tagged::de::internal::deserialize(
|
|
|
|
de,
|
|
|
|
"type",
|
|
|
|
RegistryWithFakeType(ACTIVE_REG.get().unwrap(), PhantomData::<T>),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-18 19:09:07 +02:00
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
|
|
enum TryFromUrlError {
|
2024-07-22 13:14:21 +02:00
|
|
|
#[error("Unknown type: {0}")]
|
2024-07-18 19:09:07 +02:00
|
|
|
UnknownTag(String),
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T: 'static> TryFrom<url::Url> for DeserializeWithRegistry<T> {
|
|
|
|
type Error = Box<dyn std::error::Error + Send + Sync>;
|
|
|
|
fn try_from(url: url::Url) -> Result<Self, Self::Error> {
|
|
|
|
let tag = url.scheme().split('+').next().unwrap();
|
|
|
|
// same as in the SeedFactory impl: using find() and not get() because of https://github.com/rust-lang/rust/issues/80389
|
|
|
|
let seed = ACTIVE_REG
|
|
|
|
.get()
|
|
|
|
.unwrap()
|
|
|
|
.0
|
|
|
|
.iter()
|
|
|
|
.find(|(k, _)| *k == &(TypeId::of::<T>(), tag))
|
|
|
|
.ok_or_else(|| Box::new(TryFromUrlError::UnknownTag(tag.into())))?
|
|
|
|
.1;
|
|
|
|
let entry: &RegistryEntry<T> = <dyn Any>::downcast_ref(&**seed).unwrap();
|
|
|
|
(entry.from_url_seed)(url)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-17 01:10:55 +02:00
|
|
|
thread_local! {
|
|
|
|
/// The active Registry is global state, because there is no convenient and universal way to pass state
|
|
|
|
/// into the functions usually used for deserialization, e.g. `serde_json::from_str`, `toml::from_str`,
|
|
|
|
/// `serde_qs::from_str`.
|
|
|
|
static ACTIVE_REG: Cell<Option<&'static Registry>> = panic!("reg was accessed before initialization");
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Run the provided closure with a registry context.
|
|
|
|
/// Any serde deserialize calls within the closure will use the registry to resolve tag names to
|
|
|
|
/// the corresponding Config type.
|
|
|
|
pub fn with_registry<R>(reg: &'static Registry, f: impl FnOnce() -> R) -> R {
|
|
|
|
ACTIVE_REG.set(Some(reg));
|
|
|
|
let result = f();
|
|
|
|
ACTIVE_REG.set(None);
|
|
|
|
result
|
|
|
|
}
|
|
|
|
|
|
|
|
lazy_static! {
|
|
|
|
/// The provided registry of tvix_castore, with all builtin BlobStore/DirectoryStore implementations
|
|
|
|
pub static ref REG: Registry = {
|
2024-07-18 19:09:07 +02:00
|
|
|
let mut reg = Default::default();
|
2024-06-17 01:10:55 +02:00
|
|
|
add_default_services(&mut reg);
|
|
|
|
reg
|
|
|
|
};
|
|
|
|
}
|
|
|
|
|
|
|
|
// ---------- End of generic registry code --------- //
|
|
|
|
|
|
|
|
/// Register the builtin services of tvix_castore with the given registry.
|
|
|
|
/// This is useful for creating your own registry with the builtin types _and_
|
|
|
|
/// extra third party types.
|
|
|
|
pub fn add_default_services(reg: &mut Registry) {
|
|
|
|
crate::blobservice::register_blob_services(reg);
|
|
|
|
crate::directoryservice::register_directory_services(reg);
|
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
pub struct CompositionContext<'a> {
|
2024-07-21 15:28:48 +02:00
|
|
|
// The stack used to detect recursive instantiations and prevent deadlocks
|
|
|
|
// The TypeId of the trait object is included to distinguish e.g. the
|
|
|
|
// BlobService "default" and the DirectoryService "default".
|
2024-07-21 13:23:37 +02:00
|
|
|
stack: Vec<(TypeId, String)>,
|
2024-07-19 10:51:05 +02:00
|
|
|
composition: Option<&'a Composition>,
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
impl<'a> CompositionContext<'a> {
|
2024-07-18 17:27:08 +02:00
|
|
|
pub fn blank() -> Self {
|
|
|
|
Self {
|
|
|
|
stack: Default::default(),
|
|
|
|
composition: None,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
pub async fn resolve<T: ?Sized + Send + Sync + 'static>(
|
2024-06-17 01:10:55 +02:00
|
|
|
&self,
|
|
|
|
entrypoint: String,
|
|
|
|
) -> Result<Arc<T>, Box<dyn std::error::Error + Send + Sync + 'static>> {
|
|
|
|
// disallow recursion
|
2024-07-21 13:23:37 +02:00
|
|
|
if self
|
|
|
|
.stack
|
|
|
|
.contains(&(TypeId::of::<T>(), entrypoint.clone()))
|
|
|
|
{
|
|
|
|
return Err(CompositionError::Recursion(
|
|
|
|
self.stack.iter().map(|(_, n)| n.clone()).collect(),
|
|
|
|
)
|
|
|
|
.into());
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
2024-07-18 17:27:08 +02:00
|
|
|
match self.composition {
|
|
|
|
Some(comp) => Ok(comp.build_internal(self.stack.clone(), entrypoint).await?),
|
|
|
|
None => Err(CompositionError::NotFound(entrypoint).into()),
|
|
|
|
}
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
/// This is the trait usually implemented on a per-store-type Config struct and
|
|
|
|
/// used to instantiate it.
|
|
|
|
pub trait ServiceBuilder: Send + Sync {
|
|
|
|
type Output: ?Sized;
|
|
|
|
async fn build(
|
|
|
|
&self,
|
|
|
|
instance_name: &str,
|
2024-07-19 10:51:05 +02:00
|
|
|
context: &CompositionContext,
|
2024-06-17 01:10:55 +02:00
|
|
|
) -> Result<Arc<Self::Output>, Box<dyn std::error::Error + Send + Sync + 'static>>;
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<T: ?Sized, S: ServiceBuilder<Output = T> + 'static> From<S>
|
|
|
|
for Box<dyn ServiceBuilder<Output = T>>
|
|
|
|
{
|
|
|
|
fn from(t: S) -> Self {
|
|
|
|
Box::new(t)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
enum InstantiationState<T: ?Sized> {
|
|
|
|
Config(Box<dyn ServiceBuilder<Output = T>>),
|
|
|
|
InProgress(tokio::sync::watch::Receiver<Option<Result<Arc<T>, CompositionError>>>),
|
|
|
|
Done(Result<Arc<T>, CompositionError>),
|
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
#[derive(Default)]
|
|
|
|
pub struct Composition {
|
|
|
|
stores: std::sync::Mutex<HashMap<(TypeId, String), Box<dyn Any + Send + Sync>>>,
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(thiserror::Error, Clone, Debug)]
|
|
|
|
pub enum CompositionError {
|
|
|
|
#[error("store not found: {0}")]
|
|
|
|
NotFound(String),
|
|
|
|
#[error("recursion not allowed {0:?}")]
|
|
|
|
Recursion(Vec<String>),
|
|
|
|
#[error("store construction panicked {0}")]
|
|
|
|
Poisoned(String),
|
|
|
|
#[error("instantiation of service {0} failed: {1}")]
|
|
|
|
Failed(String, Arc<dyn std::error::Error + Send + Sync>),
|
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
impl<T: ?Sized + Send + Sync + 'static>
|
|
|
|
Extend<(
|
|
|
|
String,
|
|
|
|
DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
|
|
|
|
)> for Composition
|
|
|
|
{
|
|
|
|
fn extend<I>(&mut self, configs: I)
|
|
|
|
where
|
|
|
|
I: IntoIterator<
|
2024-06-17 01:10:55 +02:00
|
|
|
Item = (
|
|
|
|
String,
|
|
|
|
DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>,
|
|
|
|
),
|
|
|
|
>,
|
2024-07-19 10:51:05 +02:00
|
|
|
{
|
|
|
|
self.stores
|
|
|
|
.lock()
|
|
|
|
.unwrap()
|
|
|
|
.extend(configs.into_iter().map(|(k, v)| {
|
|
|
|
(
|
|
|
|
(TypeId::of::<T>(), k),
|
|
|
|
Box::new(InstantiationState::Config(v.0)) as Box<dyn Any + Send + Sync>,
|
|
|
|
)
|
|
|
|
}))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl Composition {
|
|
|
|
pub fn extend_with_configs<T: ?Sized + Send + Sync + 'static>(
|
|
|
|
&mut self,
|
|
|
|
// Keep the concrete `HashMap` type here since it allows for type
|
|
|
|
// inference of what type is previously being deserialized.
|
|
|
|
configs: HashMap<String, DeserializeWithRegistry<Box<dyn ServiceBuilder<Output = T>>>>,
|
|
|
|
) {
|
|
|
|
self.extend(configs);
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
pub async fn build<T: ?Sized + Send + Sync + 'static>(
|
|
|
|
&self,
|
|
|
|
entrypoint: &str,
|
|
|
|
) -> Result<Arc<T>, CompositionError> {
|
2024-06-17 01:10:55 +02:00
|
|
|
self.build_internal(vec![], entrypoint.to_string()).await
|
|
|
|
}
|
|
|
|
|
2024-07-19 10:51:05 +02:00
|
|
|
fn build_internal<T: ?Sized + Send + Sync + 'static>(
|
2024-06-17 01:10:55 +02:00
|
|
|
&self,
|
2024-07-21 13:23:37 +02:00
|
|
|
stack: Vec<(TypeId, String)>,
|
2024-06-17 01:10:55 +02:00
|
|
|
entrypoint: String,
|
|
|
|
) -> BoxFuture<'_, Result<Arc<T>, CompositionError>> {
|
|
|
|
let mut stores = self.stores.lock().unwrap();
|
2024-07-19 10:51:05 +02:00
|
|
|
let entry = match stores.get_mut(&(TypeId::of::<T>(), entrypoint.clone())) {
|
2024-06-17 01:10:55 +02:00
|
|
|
Some(v) => v,
|
|
|
|
None => return Box::pin(futures::future::err(CompositionError::NotFound(entrypoint))),
|
|
|
|
};
|
|
|
|
// for lifetime reasons, we put a placeholder value in the hashmap while we figure out what
|
|
|
|
// the new value should be. the Mutex stays locked the entire time, so nobody will ever see
|
|
|
|
// this temporary value.
|
|
|
|
let prev_val = std::mem::replace(
|
|
|
|
entry,
|
2024-07-19 10:51:05 +02:00
|
|
|
Box::new(InstantiationState::<T>::Done(Err(
|
|
|
|
CompositionError::Poisoned(entrypoint.clone()),
|
|
|
|
))),
|
2024-06-17 01:10:55 +02:00
|
|
|
);
|
2024-07-19 10:51:05 +02:00
|
|
|
let (new_val, ret) = match *prev_val.downcast::<InstantiationState<T>>().unwrap() {
|
2024-06-17 01:10:55 +02:00
|
|
|
InstantiationState::Done(service) => (
|
|
|
|
InstantiationState::Done(service.clone()),
|
|
|
|
futures::future::ready(service).boxed(),
|
|
|
|
),
|
|
|
|
// the construction of the store has not started yet.
|
|
|
|
InstantiationState::Config(config) => {
|
|
|
|
let (tx, rx) = tokio::sync::watch::channel(None);
|
|
|
|
(
|
|
|
|
InstantiationState::InProgress(rx),
|
|
|
|
(async move {
|
|
|
|
let mut new_context = CompositionContext {
|
|
|
|
stack: stack.clone(),
|
2024-07-18 17:27:08 +02:00
|
|
|
composition: Some(self),
|
2024-06-17 01:10:55 +02:00
|
|
|
};
|
2024-07-21 13:23:37 +02:00
|
|
|
new_context
|
|
|
|
.stack
|
|
|
|
.push((TypeId::of::<T>(), entrypoint.clone()));
|
2024-07-21 15:19:16 +02:00
|
|
|
let res =
|
|
|
|
config.build(&entrypoint, &new_context).await.map_err(|e| {
|
|
|
|
match e.downcast() {
|
|
|
|
Ok(e) => *e,
|
|
|
|
Err(e) => CompositionError::Failed(entrypoint, e.into()),
|
|
|
|
}
|
|
|
|
});
|
2024-06-17 01:10:55 +02:00
|
|
|
tx.send(Some(res.clone())).unwrap();
|
|
|
|
res
|
|
|
|
})
|
|
|
|
.boxed(),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
// there is already a task driving forward the construction of this store, wait for it
|
|
|
|
// to notify us via the provided channel
|
|
|
|
InstantiationState::InProgress(mut recv) => {
|
|
|
|
(InstantiationState::InProgress(recv.clone()), {
|
|
|
|
(async move {
|
|
|
|
loop {
|
|
|
|
if let Some(v) =
|
|
|
|
recv.borrow_and_update().as_ref().map(|res| res.clone())
|
|
|
|
{
|
|
|
|
break v;
|
|
|
|
}
|
|
|
|
recv.changed().await.unwrap();
|
|
|
|
}
|
|
|
|
})
|
|
|
|
.boxed()
|
|
|
|
})
|
|
|
|
}
|
|
|
|
};
|
2024-07-19 10:51:05 +02:00
|
|
|
*entry = Box::new(new_val);
|
2024-06-17 01:10:55 +02:00
|
|
|
ret
|
|
|
|
}
|
2024-07-19 15:22:48 +02:00
|
|
|
|
|
|
|
pub fn context(&self) -> CompositionContext {
|
|
|
|
CompositionContext {
|
|
|
|
stack: vec![],
|
|
|
|
composition: Some(self),
|
|
|
|
}
|
|
|
|
}
|
2024-06-17 01:10:55 +02:00
|
|
|
}
|
2024-07-21 15:19:46 +02:00
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod test {
|
|
|
|
use super::*;
|
|
|
|
use crate::blobservice::BlobService;
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
/// Test that we return a reference to the same instance of MemoryBlobService (via ptr_eq)
|
|
|
|
/// when instantiating the same entrypoint twice. By instantiating concurrently, we also
|
|
|
|
/// test the channels notifying the second consumer when the store has been instantiated.
|
|
|
|
#[tokio::test]
|
|
|
|
async fn concurrent() {
|
|
|
|
let blob_services_configs_json = serde_json::json!({
|
|
|
|
"default": {
|
|
|
|
"type": "memory",
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
let blob_services_configs =
|
|
|
|
with_registry(®, || serde_json::from_value(blob_services_configs_json)).unwrap();
|
|
|
|
let mut blob_service_composition = Composition::default();
|
|
|
|
blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
|
|
|
|
let (blob_service1, blob_service2) = tokio::join!(
|
|
|
|
blob_service_composition.build::<dyn BlobService>("default"),
|
|
|
|
blob_service_composition.build::<dyn BlobService>("default")
|
|
|
|
);
|
|
|
|
assert!(Arc::ptr_eq(
|
|
|
|
&blob_service1.unwrap(),
|
|
|
|
&blob_service2.unwrap()
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Test that we throw the correct error when an instantiation would recurse (deadlock)
|
|
|
|
#[tokio::test]
|
|
|
|
async fn reject_recursion() {
|
|
|
|
let blob_services_configs_json = serde_json::json!({
|
|
|
|
"default": {
|
|
|
|
"type": "combined",
|
|
|
|
"local": "other",
|
|
|
|
"remote": "other"
|
|
|
|
},
|
|
|
|
"other": {
|
|
|
|
"type": "combined",
|
|
|
|
"local": "default",
|
|
|
|
"remote": "default"
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
let blob_services_configs =
|
|
|
|
with_registry(®, || serde_json::from_value(blob_services_configs_json)).unwrap();
|
|
|
|
let mut blob_service_composition = Composition::default();
|
|
|
|
blob_service_composition.extend_with_configs::<dyn BlobService>(blob_services_configs);
|
|
|
|
match blob_service_composition
|
|
|
|
.build::<dyn BlobService>("default")
|
|
|
|
.await
|
|
|
|
{
|
|
|
|
Err(CompositionError::Recursion(stack)) => {
|
|
|
|
assert_eq!(stack, vec!["default".to_string(), "other".to_string()])
|
|
|
|
}
|
|
|
|
other => panic!("should have returned an error, returned: {:?}", other.err()),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|