expando/
lib.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use starnix_sync::Mutex;
6use std::any::{Any, TypeId};
7use std::collections::BTreeMap;
8use std::marker::{Send, Sync};
9use std::ops::Deref;
10use std::sync::Arc;
11
12/// A spot in an `Expando`.
13///
14/// Holds a value of type `Arc<T>`.
15#[derive(Debug)]
16struct ExpandoSlot {
17    value: Arc<dyn Any + Send + Sync>,
18}
19
20impl ExpandoSlot {
21    fn new(value: Arc<dyn Any + Send + Sync>) -> Self {
22        ExpandoSlot { value }
23    }
24
25    fn downcast<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
26        self.value.clone().downcast::<T>().ok()
27    }
28}
29
30/// A lazy collection of values of every type.
31///
32/// An Expando contains a single instance of every type. The values are instantiated lazily
33/// when accessed. Useful for letting modules add their own state to context objects without
34/// requiring the context object itself to know about the types in every module.
35///
36/// Typically the type a module uses in the Expando will be private to that module, which lets
37/// the module know that no other code is accessing its slot on the expando.
38#[derive(Debug, Default)]
39pub struct Expando {
40    properties: Mutex<BTreeMap<TypeId, ExpandoSlot>>,
41}
42
43impl Expando {
44    /// Get the slot in the expando associated with the given type.
45    ///
46    /// The slot is added to the expando lazily but the same instance is returned every time the
47    /// expando is queried for the same type.
48    pub fn get<T: Any + Send + Sync + Default + 'static>(&self) -> Arc<T> {
49        let mut properties = self.properties.lock();
50        let type_id = TypeId::of::<T>();
51        let slot =
52            properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(Arc::new(T::default())));
53        assert_eq!(type_id, slot.value.deref().type_id());
54        slot.downcast().expect("downcast of expando slot was successful")
55    }
56
57    /// Get the slot in the expando associated with the given type, running `init` to initialize
58    /// the slot if needed.
59    ///
60    /// The slot is added to the expando lazily but the same instance is returned every time the
61    /// expando is queried for the same type.
62    pub fn get_or_init<T: Any + Send + Sync + 'static>(&self, init: impl FnOnce() -> T) -> Arc<T> {
63        self.get_or_try_init::<T, ()>(|| Ok(init())).expect("infallible initializer")
64    }
65
66    /// Get the slot in the expando associated with the given type, running `try_init` to initialize
67    /// the slot if needed. Returns an error only if `try_init` returns an error.
68    ///
69    /// The slot is added to the expando lazily but the same instance is returned every time the
70    /// expando is queried for the same type.
71    pub fn get_or_try_init<T: Any + Send + Sync + 'static, E>(
72        &self,
73        try_init: impl FnOnce() -> Result<T, E>,
74    ) -> Result<Arc<T>, E> {
75        let type_id = TypeId::of::<T>();
76
77        // Acquire the lock each time we want to look at the map so that user-provided initializer
78        // can use the expando too.
79        if let Some(slot) = self.properties.lock().get(&type_id) {
80            assert_eq!(type_id, slot.value.deref().type_id());
81            return Ok(slot.downcast().expect("downcast of expando slot was successful"));
82        }
83
84        // Initialize the new value without holding the lock.
85        let newly_init = Arc::new(try_init()?);
86
87        // Only insert the newly-initialized value if no other threads got there first.
88        let mut properties = self.properties.lock();
89        let slot = properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(newly_init));
90        assert_eq!(type_id, slot.value.deref().type_id());
91        Ok(slot.downcast().expect("downcast of expando slot was successful"))
92    }
93
94    /// Get the slot in the expando associated with the given type if it has previously been
95    /// initialized.
96    pub fn peek<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
97        let properties = self.properties.lock();
98        let type_id = TypeId::of::<T>();
99        let slot = properties.get(&type_id)?;
100        assert_eq!(type_id, slot.value.deref().type_id());
101        Some(slot.downcast().expect("downcast of expando slot was successful"))
102    }
103
104    /// Remove the provided type from the expando if it is present.
105    pub fn remove<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
106        let mut properties = self.properties.lock();
107        let type_id = TypeId::of::<T>();
108        let slot = properties.remove(&type_id)?;
109        assert_eq!(type_id, slot.value.deref().type_id());
110        Some(slot.downcast().expect("downcast of expando slot was successful"))
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[derive(Debug, Default)]
119    struct MyStruct {
120        counter: Mutex<i32>,
121    }
122
123    #[test]
124    fn basic_test() {
125        let expando = Expando::default();
126        let first = expando.get::<MyStruct>();
127        assert_eq!(*first.counter.lock(), 0);
128        *first.counter.lock() += 1;
129        let second = expando.get::<MyStruct>();
130        assert_eq!(*second.counter.lock(), 1);
131    }
132
133    #[test]
134    fn user_initializer() {
135        let expando = Expando::default();
136        let first = expando.get_or_init(|| String::from("hello"));
137        assert_eq!(first.as_str(), "hello");
138        let second = expando.get_or_init(|| String::from("world"));
139        assert_eq!(
140            second.as_str(),
141            "hello",
142            "expando must have preserved value from original initializer"
143        );
144        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
145    }
146
147    #[test]
148    fn nested_user_initializer() {
149        let expando = Expando::default();
150        let first = expando.get_or_init(|| expando.get::<u32>().to_string());
151        assert_eq!(first.as_str(), "0");
152        let second = expando.get_or_init(|| expando.get::<u32>().to_string());
153        assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
154    }
155
156    #[test]
157    fn failed_init_can_be_retried() {
158        let expando = Expando::default();
159        let failed = expando.get_or_try_init::<String, String>(|| Err(String::from("oops")));
160        assert_eq!(failed.unwrap_err().as_str(), "oops");
161
162        let succeeded = expando.get_or_try_init::<String, String>(|| Ok(String::from("hurray")));
163        assert_eq!(succeeded.unwrap().as_str(), "hurray");
164    }
165
166    #[test]
167    fn peek_works() {
168        let expando = Expando::default();
169        assert_eq!(expando.peek::<String>(), None);
170        let from_init = expando.get_or_init(|| String::from("hello"));
171        let from_peek = expando.peek::<String>().unwrap();
172        assert_eq!(from_peek.as_str(), "hello");
173        assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&from_peek));
174    }
175
176    #[test]
177    fn remove_works() {
178        let expando = Expando::default();
179        assert_eq!(expando.peek::<String>(), None);
180        let from_init = expando.get_or_init(|| String::from("hello"));
181        let removed = expando.remove::<String>().unwrap();
182        assert_eq!(removed.as_str(), "hello");
183        assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&removed));
184        assert_eq!(expando.peek::<String>(), None);
185    }
186}