1use starnix_sync::Mutex;
6use std::any::{Any, TypeId};
7use std::collections::BTreeMap;
8use std::marker::{Send, Sync};
9use std::ops::Deref;
10use std::sync::Arc;
11
12#[derive(Debug)]
16struct ExpandoSlot {
17 value: Arc<dyn Any + Send + Sync>,
18}
19
20impl ExpandoSlot {
21 fn new(value: Arc<dyn Any + Send + Sync>) -> Self {
22 ExpandoSlot { value }
23 }
24
25 fn downcast<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
26 self.value.clone().downcast::<T>().ok()
27 }
28}
29
30#[derive(Debug, Default)]
39pub struct Expando {
40 properties: Mutex<BTreeMap<TypeId, ExpandoSlot>>,
41}
42
43impl Expando {
44 pub fn get<T: Any + Send + Sync + Default + 'static>(&self) -> Arc<T> {
49 let mut properties = self.properties.lock();
50 let type_id = TypeId::of::<T>();
51 let slot =
52 properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(Arc::new(T::default())));
53 assert_eq!(type_id, slot.value.deref().type_id());
54 slot.downcast().expect("downcast of expando slot was successful")
55 }
56
57 pub fn get_or_init<T: Any + Send + Sync + 'static>(&self, init: impl FnOnce() -> T) -> Arc<T> {
63 self.get_or_try_init::<T, ()>(|| Ok(init())).expect("infallible initializer")
64 }
65
66 pub fn get_or_try_init<T: Any + Send + Sync + 'static, E>(
72 &self,
73 try_init: impl FnOnce() -> Result<T, E>,
74 ) -> Result<Arc<T>, E> {
75 let type_id = TypeId::of::<T>();
76
77 if let Some(slot) = self.properties.lock().get(&type_id) {
80 assert_eq!(type_id, slot.value.deref().type_id());
81 return Ok(slot.downcast().expect("downcast of expando slot was successful"));
82 }
83
84 let newly_init = Arc::new(try_init()?);
86
87 let mut properties = self.properties.lock();
89 let slot = properties.entry(type_id).or_insert_with(|| ExpandoSlot::new(newly_init));
90 assert_eq!(type_id, slot.value.deref().type_id());
91 Ok(slot.downcast().expect("downcast of expando slot was successful"))
92 }
93
94 pub fn peek<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
97 let properties = self.properties.lock();
98 let type_id = TypeId::of::<T>();
99 let slot = properties.get(&type_id)?;
100 assert_eq!(type_id, slot.value.deref().type_id());
101 Some(slot.downcast().expect("downcast of expando slot was successful"))
102 }
103
104 pub fn remove<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
106 let mut properties = self.properties.lock();
107 let type_id = TypeId::of::<T>();
108 let slot = properties.remove(&type_id)?;
109 assert_eq!(type_id, slot.value.deref().type_id());
110 Some(slot.downcast().expect("downcast of expando slot was successful"))
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[derive(Debug, Default)]
119 struct MyStruct {
120 counter: Mutex<i32>,
121 }
122
123 #[test]
124 fn basic_test() {
125 let expando = Expando::default();
126 let first = expando.get::<MyStruct>();
127 assert_eq!(*first.counter.lock(), 0);
128 *first.counter.lock() += 1;
129 let second = expando.get::<MyStruct>();
130 assert_eq!(*second.counter.lock(), 1);
131 }
132
133 #[test]
134 fn user_initializer() {
135 let expando = Expando::default();
136 let first = expando.get_or_init(|| String::from("hello"));
137 assert_eq!(first.as_str(), "hello");
138 let second = expando.get_or_init(|| String::from("world"));
139 assert_eq!(
140 second.as_str(),
141 "hello",
142 "expando must have preserved value from original initializer"
143 );
144 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
145 }
146
147 #[test]
148 fn nested_user_initializer() {
149 let expando = Expando::default();
150 let first = expando.get_or_init(|| expando.get::<u32>().to_string());
151 assert_eq!(first.as_str(), "0");
152 let second = expando.get_or_init(|| expando.get::<u32>().to_string());
153 assert_eq!(Arc::as_ptr(&first), Arc::as_ptr(&second));
154 }
155
156 #[test]
157 fn failed_init_can_be_retried() {
158 let expando = Expando::default();
159 let failed = expando.get_or_try_init::<String, String>(|| Err(String::from("oops")));
160 assert_eq!(failed.unwrap_err().as_str(), "oops");
161
162 let succeeded = expando.get_or_try_init::<String, String>(|| Ok(String::from("hurray")));
163 assert_eq!(succeeded.unwrap().as_str(), "hurray");
164 }
165
166 #[test]
167 fn peek_works() {
168 let expando = Expando::default();
169 assert_eq!(expando.peek::<String>(), None);
170 let from_init = expando.get_or_init(|| String::from("hello"));
171 let from_peek = expando.peek::<String>().unwrap();
172 assert_eq!(from_peek.as_str(), "hello");
173 assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&from_peek));
174 }
175
176 #[test]
177 fn remove_works() {
178 let expando = Expando::default();
179 assert_eq!(expando.peek::<String>(), None);
180 let from_init = expando.get_or_init(|| String::from("hello"));
181 let removed = expando.remove::<String>().unwrap();
182 assert_eq!(removed.as_str(), "hello");
183 assert_eq!(Arc::as_ptr(&from_init), Arc::as_ptr(&removed));
184 assert_eq!(expando.peek::<String>(), None);
185 }
186}