1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
45use core::hash::Hash;
6use core::num::NonZeroUsize;
78use netstack3_hashmap::hash_map::{Entry, HashMap};
910/// The result of inserting an element into a [`RefCountedHashMap`].
11#[derive(Debug, Eq, PartialEq)]
12pub enum InsertResult<O> {
13/// The key was not previously in the map, so it was inserted.
14Inserted(O),
15/// The key was already in the map, so we incremented the entry's reference
16 /// count.
17AlreadyPresent,
18}
1920/// The result of removing an entry from a [`RefCountedHashMap`].
21#[derive(Debug, Eq, PartialEq)]
22pub enum RemoveResult<V> {
23/// The reference count reached 0, so the entry was removed.
24Removed(V),
25/// The reference count did not reach 0, so the entry still exists in the map.
26StillPresent,
27/// The key was not in the map.
28NotPresent,
29}
3031/// A [`HashMap`] which keeps a reference count for each entry.
32#[derive(Debug)]
33pub struct RefCountedHashMap<K, V> {
34 inner: HashMap<K, (NonZeroUsize, V)>,
35}
3637impl<K, V> Default for RefCountedHashMap<K, V> {
38fn default() -> RefCountedHashMap<K, V> {
39 RefCountedHashMap { inner: HashMap::default() }
40 }
41}
4243impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
44/// Increments the reference count of the entry with the given key.
45 ///
46 /// If the key isn't in the map, the given function is called to create its
47 /// associated value.
48pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
49match self.inner.entry(key) {
50 Entry::Occupied(mut entry) => {
51let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
52*refcnt = refcnt.checked_add(1).unwrap();
53 InsertResult::AlreadyPresent
54 }
55 Entry::Vacant(entry) => {
56let (value, output) = f();
57let _: &mut (NonZeroUsize, V) =
58 entry.insert((NonZeroUsize::new(1).unwrap(), value));
59 InsertResult::Inserted(output)
60 }
61 }
62 }
6364/// Decrements the reference count of the entry with the given key.
65 ///
66 /// If the reference count reaches 0, the entry will be removed and its
67 /// value returned.
68pub fn remove(&mut self, key: K) -> RemoveResult<V> {
69match self.inner.entry(key) {
70 Entry::Vacant(_) => RemoveResult::NotPresent,
71 Entry::Occupied(mut entry) => {
72let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
73match NonZeroUsize::new(refcnt.get() - 1) {
74None => {
75let (_, value): (NonZeroUsize, V) = entry.remove();
76 RemoveResult::Removed(value)
77 }
78Some(new_refcnt) => {
79*refcnt = new_refcnt;
80 RemoveResult::StillPresent
81 }
82 }
83 }
84 }
85 }
8687/// Returns `true` if the map contains a value for the specified key.
88pub fn contains_key(&self, key: &K) -> bool {
89self.inner.contains_key(key)
90 }
9192/// Returns a reference to the value corresponding to the key.
93pub fn get(&self, key: &K) -> Option<&V> {
94self.inner.get(key).map(|(_, value)| value)
95 }
9697/// Returns a mutable reference to the value corresponding to the key.
98pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
99self.inner.get_mut(key).map(|(_, value)| value)
100 }
101102/// An iterator visiting all key-value pairs in arbitrary order, with
103 /// mutable references to the values.
104pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
105self.inner.iter_mut().map(|(key, (_, value))| (key, value))
106 }
107108/// An iterator visiting all key-value pairs in arbitrary order, with
109 /// non-mutable references to the values.
110pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> + Clone {
111self.inner.iter().map(|(key, (_, value))| (key, value))
112 }
113114/// An iterator visiting all keys in arbitrary order with the reference
115 /// count for each key.
116pub fn iter_ref_counts<'a>(
117&'a self,
118 ) -> impl 'a + Iterator<Item = (&'a K, &'a NonZeroUsize)> + Clone {
119self.inner.iter().map(|(key, (count, _))| (key, count))
120 }
121122/// Returns whether the map is empty.
123pub fn is_empty(&self) -> bool {
124self.inner.is_empty()
125 }
126}
127128/// A [`RefCountedHashMap`] where the value is `()`.
129#[derive(Debug)]
130pub struct RefCountedHashSet<T> {
131 inner: RefCountedHashMap<T, ()>,
132}
133134impl<T> Default for RefCountedHashSet<T> {
135fn default() -> RefCountedHashSet<T> {
136 RefCountedHashSet { inner: RefCountedHashMap::default() }
137 }
138}
139140impl<T: Eq + Hash> RefCountedHashSet<T> {
141/// Increments the reference count of the given value.
142pub fn insert(&mut self, value: T) -> InsertResult<()> {
143self.inner.insert_with(value, || ((), ()))
144 }
145146/// Decrements the reference count of the given value.
147 ///
148 /// If the reference count reaches 0, the value will be removed from the
149 /// set.
150pub fn remove(&mut self, value: T) -> RemoveResult<()> {
151self.inner.remove(value)
152 }
153154/// Returns `true` if the set contains the given value.
155pub fn contains(&self, value: &T) -> bool {
156self.inner.contains_key(value)
157 }
158159/// Returns the number of values in the set.
160pub fn len(&self) -> usize {
161self.inner.inner.len()
162 }
163164/// Iterates over values and reference counts.
165pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
166self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
167 }
168}
169170impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
171fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
172 iter.into_iter().fold(Self::default(), |mut set, t| {
173let _: InsertResult<()> = set.insert(t);
174 set
175 })
176 }
177}
178179#[cfg(test)]
180mod test {
181use super::*;
182183#[test]
184fn test_ref_counted_hash_map() {
185let mut map = RefCountedHashMap::<&str, ()>::default();
186let key = "key";
187188// Test refcounts 1 and 2. The behavioral difference is that testing
189 // only with a refcount of 1 doesn't exercise the refcount incrementing
190 // functionality - it only exercises the functionality of initializing a
191 // new entry with a refcount of 1.
192for refcount in 1..=2 {
193assert!(!map.contains_key(&key));
194195// Insert an entry for the first time, initializing the refcount to
196 // 1.
197assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
198assert!(map.contains_key(&key));
199 assert_refcount(&map, key, 1, "after initial insert");
200201// Increase the refcount to `refcount`.
202for i in 1..refcount {
203// Since the refcount starts at 1, the entry is always already
204 // in the map.
205assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
206assert!(map.contains_key(&key));
207 assert_refcount(&map, key, i + 1, "after subsequent insert");
208 }
209210// Decrement the refcount to 1.
211for i in 1..refcount {
212// Since we don't decrement the refcount past 1, the entry is
213 // always still present.
214assert_eq!(map.remove(key), RemoveResult::StillPresent);
215assert!(map.contains_key(&key));
216 assert_refcount(&map, key, refcount - i, "after decrement refcount");
217 }
218219 assert_refcount(&map, key, 1, "before entry removed");
220// Remove the entry when the refcount is 1.
221assert_eq!(map.remove(key), RemoveResult::Removed(()));
222assert!(!map.contains_key(&key));
223224// Try to remove an entry that no longer exists.
225assert_eq!(map.remove(key), RemoveResult::NotPresent);
226 }
227 }
228229fn assert_refcount(
230 map: &RefCountedHashMap<&str, ()>,
231 key: &str,
232 expected_refcount: usize,
233 context: &str,
234 ) {
235let (actual_refcount, _value) =
236 map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
237assert_eq!(actual_refcount.get(), expected_refcount);
238 }
239}