netstack3_base/data_structures/
ref_counted_hash_map.rs1use core::hash::Hash;
6use core::num::NonZeroUsize;
7
8use netstack3_hashmap::hash_map::{Entry, HashMap};
9
10#[derive(Debug, Eq, PartialEq)]
12pub enum InsertResult<O> {
13 Inserted(O),
15 AlreadyPresent,
18}
19
20#[derive(Debug, Eq, PartialEq)]
22pub enum RemoveResult<V> {
23 Removed(V),
25 StillPresent,
27 NotPresent,
29}
30
31#[derive(Debug)]
33pub struct RefCountedHashMap<K, V> {
34 inner: HashMap<K, (NonZeroUsize, V)>,
35}
36
37impl<K, V> Default for RefCountedHashMap<K, V> {
38 fn default() -> RefCountedHashMap<K, V> {
39 RefCountedHashMap { inner: HashMap::default() }
40 }
41}
42
43impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
44 pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
49 match self.inner.entry(key) {
50 Entry::Occupied(mut entry) => {
51 let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
52 *refcnt = refcnt.checked_add(1).unwrap();
53 InsertResult::AlreadyPresent
54 }
55 Entry::Vacant(entry) => {
56 let (value, output) = f();
57 let _: &mut (NonZeroUsize, V) =
58 entry.insert((NonZeroUsize::new(1).unwrap(), value));
59 InsertResult::Inserted(output)
60 }
61 }
62 }
63
64 pub fn remove(&mut self, key: K) -> RemoveResult<V> {
69 match self.inner.entry(key) {
70 Entry::Vacant(_) => RemoveResult::NotPresent,
71 Entry::Occupied(mut entry) => {
72 let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
73 match NonZeroUsize::new(refcnt.get() - 1) {
74 None => {
75 let (_, value): (NonZeroUsize, V) = entry.remove();
76 RemoveResult::Removed(value)
77 }
78 Some(new_refcnt) => {
79 *refcnt = new_refcnt;
80 RemoveResult::StillPresent
81 }
82 }
83 }
84 }
85 }
86
87 pub fn contains_key(&self, key: &K) -> bool {
89 self.inner.contains_key(key)
90 }
91
92 pub fn get(&self, key: &K) -> Option<&V> {
94 self.inner.get(key).map(|(_, value)| value)
95 }
96
97 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
99 self.inner.get_mut(key).map(|(_, value)| value)
100 }
101
102 pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
105 self.inner.iter_mut().map(|(key, (_, value))| (key, value))
106 }
107
108 pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> + Clone {
111 self.inner.iter().map(|(key, (_, value))| (key, value))
112 }
113
114 pub fn iter_ref_counts<'a>(
117 &'a self,
118 ) -> impl 'a + Iterator<Item = (&'a K, &'a NonZeroUsize)> + Clone {
119 self.inner.iter().map(|(key, (count, _))| (key, count))
120 }
121
122 pub fn is_empty(&self) -> bool {
124 self.inner.is_empty()
125 }
126}
127
128#[derive(Debug)]
130pub struct RefCountedHashSet<T> {
131 inner: RefCountedHashMap<T, ()>,
132}
133
134impl<T> Default for RefCountedHashSet<T> {
135 fn default() -> RefCountedHashSet<T> {
136 RefCountedHashSet { inner: RefCountedHashMap::default() }
137 }
138}
139
140impl<T: Eq + Hash> RefCountedHashSet<T> {
141 pub fn insert(&mut self, value: T) -> InsertResult<()> {
143 self.inner.insert_with(value, || ((), ()))
144 }
145
146 pub fn remove(&mut self, value: T) -> RemoveResult<()> {
151 self.inner.remove(value)
152 }
153
154 pub fn contains(&self, value: &T) -> bool {
156 self.inner.contains_key(value)
157 }
158
159 pub fn len(&self) -> usize {
161 self.inner.inner.len()
162 }
163
164 pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
166 self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
167 }
168}
169
170impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
171 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
172 iter.into_iter().fold(Self::default(), |mut set, t| {
173 let _: InsertResult<()> = set.insert(t);
174 set
175 })
176 }
177}
178
179#[cfg(test)]
180mod test {
181 use super::*;
182
183 #[test]
184 fn test_ref_counted_hash_map() {
185 let mut map = RefCountedHashMap::<&str, ()>::default();
186 let key = "key";
187
188 for refcount in 1..=2 {
193 assert!(!map.contains_key(&key));
194
195 assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
198 assert!(map.contains_key(&key));
199 assert_refcount(&map, key, 1, "after initial insert");
200
201 for i in 1..refcount {
203 assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
206 assert!(map.contains_key(&key));
207 assert_refcount(&map, key, i + 1, "after subsequent insert");
208 }
209
210 for i in 1..refcount {
212 assert_eq!(map.remove(key), RemoveResult::StillPresent);
215 assert!(map.contains_key(&key));
216 assert_refcount(&map, key, refcount - i, "after decrement refcount");
217 }
218
219 assert_refcount(&map, key, 1, "before entry removed");
220 assert_eq!(map.remove(key), RemoveResult::Removed(()));
222 assert!(!map.contains_key(&key));
223
224 assert_eq!(map.remove(key), RemoveResult::NotPresent);
226 }
227 }
228
229 fn assert_refcount(
230 map: &RefCountedHashMap<&str, ()>,
231 key: &str,
232 expected_refcount: usize,
233 context: &str,
234 ) {
235 let (actual_refcount, _value) =
236 map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
237 assert_eq!(actual_refcount.get(), expected_refcount);
238 }
239}