netstack3_base/data_structures/
ref_counted_hash_map.rs1use alloc::collections::hash_map::{Entry, HashMap};
6use core::hash::Hash;
7use core::num::NonZeroUsize;
8
9#[derive(Debug, Eq, PartialEq)]
11pub enum InsertResult<O> {
12 Inserted(O),
14 AlreadyPresent,
17}
18
19#[derive(Debug, Eq, PartialEq)]
21pub enum RemoveResult<V> {
22 Removed(V),
24 StillPresent,
26 NotPresent,
28}
29
30#[derive(Debug)]
32pub struct RefCountedHashMap<K, V> {
33 inner: HashMap<K, (NonZeroUsize, V)>,
34}
35
36impl<K, V> Default for RefCountedHashMap<K, V> {
37 fn default() -> RefCountedHashMap<K, V> {
38 RefCountedHashMap { inner: HashMap::default() }
39 }
40}
41
42impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
43 pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
48 match self.inner.entry(key) {
49 Entry::Occupied(mut entry) => {
50 let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
51 *refcnt = refcnt.checked_add(1).unwrap();
52 InsertResult::AlreadyPresent
53 }
54 Entry::Vacant(entry) => {
55 let (value, output) = f();
56 let _: &mut (NonZeroUsize, V) =
57 entry.insert((NonZeroUsize::new(1).unwrap(), value));
58 InsertResult::Inserted(output)
59 }
60 }
61 }
62
63 pub fn remove(&mut self, key: K) -> RemoveResult<V> {
68 match self.inner.entry(key) {
69 Entry::Vacant(_) => RemoveResult::NotPresent,
70 Entry::Occupied(mut entry) => {
71 let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
72 match NonZeroUsize::new(refcnt.get() - 1) {
73 None => {
74 let (_, value): (NonZeroUsize, V) = entry.remove();
75 RemoveResult::Removed(value)
76 }
77 Some(new_refcnt) => {
78 *refcnt = new_refcnt;
79 RemoveResult::StillPresent
80 }
81 }
82 }
83 }
84 }
85
86 pub fn contains_key(&self, key: &K) -> bool {
88 self.inner.contains_key(key)
89 }
90
91 pub fn get(&self, key: &K) -> Option<&V> {
93 self.inner.get(key).map(|(_, value)| value)
94 }
95
96 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
98 self.inner.get_mut(key).map(|(_, value)| value)
99 }
100
101 pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
104 self.inner.iter_mut().map(|(key, (_, value))| (key, value))
105 }
106
107 pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> + Clone {
110 self.inner.iter().map(|(key, (_, value))| (key, value))
111 }
112
113 pub fn iter_ref_counts<'a>(
116 &'a self,
117 ) -> impl 'a + Iterator<Item = (&'a K, &'a NonZeroUsize)> + Clone {
118 self.inner.iter().map(|(key, (count, _))| (key, count))
119 }
120
121 pub fn is_empty(&self) -> bool {
123 self.inner.is_empty()
124 }
125}
126
127#[derive(Debug)]
129pub struct RefCountedHashSet<T> {
130 inner: RefCountedHashMap<T, ()>,
131}
132
133impl<T> Default for RefCountedHashSet<T> {
134 fn default() -> RefCountedHashSet<T> {
135 RefCountedHashSet { inner: RefCountedHashMap::default() }
136 }
137}
138
139impl<T: Eq + Hash> RefCountedHashSet<T> {
140 pub fn insert(&mut self, value: T) -> InsertResult<()> {
142 self.inner.insert_with(value, || ((), ()))
143 }
144
145 pub fn remove(&mut self, value: T) -> RemoveResult<()> {
150 self.inner.remove(value)
151 }
152
153 pub fn contains(&self, value: &T) -> bool {
155 self.inner.contains_key(value)
156 }
157
158 pub fn len(&self) -> usize {
160 self.inner.inner.len()
161 }
162
163 pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
165 self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
166 }
167}
168
169impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
170 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
171 iter.into_iter().fold(Self::default(), |mut set, t| {
172 let _: InsertResult<()> = set.insert(t);
173 set
174 })
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use super::*;
181
182 #[test]
183 fn test_ref_counted_hash_map() {
184 let mut map = RefCountedHashMap::<&str, ()>::default();
185 let key = "key";
186
187 for refcount in 1..=2 {
192 assert!(!map.contains_key(&key));
193
194 assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
197 assert!(map.contains_key(&key));
198 assert_refcount(&map, key, 1, "after initial insert");
199
200 for i in 1..refcount {
202 assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
205 assert!(map.contains_key(&key));
206 assert_refcount(&map, key, i + 1, "after subsequent insert");
207 }
208
209 for i in 1..refcount {
211 assert_eq!(map.remove(key), RemoveResult::StillPresent);
214 assert!(map.contains_key(&key));
215 assert_refcount(&map, key, refcount - i, "after decrement refcount");
216 }
217
218 assert_refcount(&map, key, 1, "before entry removed");
219 assert_eq!(map.remove(key), RemoveResult::Removed(()));
221 assert!(!map.contains_key(&key));
222
223 assert_eq!(map.remove(key), RemoveResult::NotPresent);
225 }
226 }
227
228 fn assert_refcount(
229 map: &RefCountedHashMap<&str, ()>,
230 key: &str,
231 expected_refcount: usize,
232 context: &str,
233 ) {
234 let (actual_refcount, _value) =
235 map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
236 assert_eq!(actual_refcount.get(), expected_refcount);
237 }
238}