netstack3_base/data_structures/
ref_counted_hash_map.rs

1// Copyright 2021 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::hash::Hash;
6use core::num::NonZeroUsize;
7
8use netstack3_hashmap::hash_map::{Entry, HashMap};
9
10/// The result of inserting an element into a [`RefCountedHashMap`].
11#[derive(Debug, Eq, PartialEq)]
12pub enum InsertResult<O> {
13    /// The key was not previously in the map, so it was inserted.
14    Inserted(O),
15    /// The key was already in the map, so we incremented the entry's reference
16    /// count.
17    AlreadyPresent,
18}
19
20/// The result of removing an entry from a [`RefCountedHashMap`].
21#[derive(Debug, Eq, PartialEq)]
22pub enum RemoveResult<V> {
23    /// The reference count reached 0, so the entry was removed.
24    Removed(V),
25    /// The reference count did not reach 0, so the entry still exists in the map.
26    StillPresent,
27    /// The key was not in the map.
28    NotPresent,
29}
30
31/// A [`HashMap`] which keeps a reference count for each entry.
32#[derive(Debug)]
33pub struct RefCountedHashMap<K, V> {
34    inner: HashMap<K, (NonZeroUsize, V)>,
35}
36
37impl<K, V> Default for RefCountedHashMap<K, V> {
38    fn default() -> RefCountedHashMap<K, V> {
39        RefCountedHashMap { inner: HashMap::default() }
40    }
41}
42
43impl<K: Eq + Hash, V> RefCountedHashMap<K, V> {
44    /// Increments the reference count of the entry with the given key.
45    ///
46    /// If the key isn't in the map, the given function is called to create its
47    /// associated value.
48    pub fn insert_with<O, F: FnOnce() -> (V, O)>(&mut self, key: K, f: F) -> InsertResult<O> {
49        match self.inner.entry(key) {
50            Entry::Occupied(mut entry) => {
51                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
52                *refcnt = refcnt.checked_add(1).unwrap();
53                InsertResult::AlreadyPresent
54            }
55            Entry::Vacant(entry) => {
56                let (value, output) = f();
57                let _: &mut (NonZeroUsize, V) =
58                    entry.insert((NonZeroUsize::new(1).unwrap(), value));
59                InsertResult::Inserted(output)
60            }
61        }
62    }
63
64    /// Decrements the reference count of the entry with the given key.
65    ///
66    /// If the reference count reaches 0, the entry will be removed and its
67    /// value returned.
68    pub fn remove(&mut self, key: K) -> RemoveResult<V> {
69        match self.inner.entry(key) {
70            Entry::Vacant(_) => RemoveResult::NotPresent,
71            Entry::Occupied(mut entry) => {
72                let (refcnt, _): &mut (NonZeroUsize, V) = entry.get_mut();
73                match NonZeroUsize::new(refcnt.get() - 1) {
74                    None => {
75                        let (_, value): (NonZeroUsize, V) = entry.remove();
76                        RemoveResult::Removed(value)
77                    }
78                    Some(new_refcnt) => {
79                        *refcnt = new_refcnt;
80                        RemoveResult::StillPresent
81                    }
82                }
83            }
84        }
85    }
86
87    /// Returns `true` if the map contains a value for the specified key.
88    pub fn contains_key(&self, key: &K) -> bool {
89        self.inner.contains_key(key)
90    }
91
92    /// Returns a reference to the value corresponding to the key.
93    pub fn get(&self, key: &K) -> Option<&V> {
94        self.inner.get(key).map(|(_, value)| value)
95    }
96
97    /// Returns a mutable reference to the value corresponding to the key.
98    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
99        self.inner.get_mut(key).map(|(_, value)| value)
100    }
101
102    /// An iterator visiting all key-value pairs in arbitrary order, with
103    /// mutable references to the values.
104    pub fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a K, &'a mut V)> {
105        self.inner.iter_mut().map(|(key, (_, value))| (key, value))
106    }
107
108    /// An iterator visiting all key-value pairs in arbitrary order, with
109    /// non-mutable references to the values.
110    pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = (&'a K, &'a V)> + Clone {
111        self.inner.iter().map(|(key, (_, value))| (key, value))
112    }
113
114    /// An iterator visiting all keys in arbitrary order with the reference
115    /// count for each key.
116    pub fn iter_ref_counts<'a>(
117        &'a self,
118    ) -> impl 'a + Iterator<Item = (&'a K, &'a NonZeroUsize)> + Clone {
119        self.inner.iter().map(|(key, (count, _))| (key, count))
120    }
121
122    /// Returns whether the map is empty.
123    pub fn is_empty(&self) -> bool {
124        self.inner.is_empty()
125    }
126}
127
128/// A [`RefCountedHashMap`] where the value is `()`.
129#[derive(Debug)]
130pub struct RefCountedHashSet<T> {
131    inner: RefCountedHashMap<T, ()>,
132}
133
134impl<T> Default for RefCountedHashSet<T> {
135    fn default() -> RefCountedHashSet<T> {
136        RefCountedHashSet { inner: RefCountedHashMap::default() }
137    }
138}
139
140impl<T: Eq + Hash> RefCountedHashSet<T> {
141    /// Increments the reference count of the given value.
142    pub fn insert(&mut self, value: T) -> InsertResult<()> {
143        self.inner.insert_with(value, || ((), ()))
144    }
145
146    /// Decrements the reference count of the given value.
147    ///
148    /// If the reference count reaches 0, the value will be removed from the
149    /// set.
150    pub fn remove(&mut self, value: T) -> RemoveResult<()> {
151        self.inner.remove(value)
152    }
153
154    /// Returns `true` if the set contains the given value.
155    pub fn contains(&self, value: &T) -> bool {
156        self.inner.contains_key(value)
157    }
158
159    /// Returns the number of values in the set.
160    pub fn len(&self) -> usize {
161        self.inner.inner.len()
162    }
163
164    /// Iterates over values and reference counts.
165    pub fn iter_counts(&self) -> impl Iterator<Item = (&'_ T, NonZeroUsize)> + '_ {
166        self.inner.inner.iter().map(|(key, (count, ()))| (key, *count))
167    }
168}
169
170impl<T: Eq + Hash> core::iter::FromIterator<T> for RefCountedHashSet<T> {
171    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
172        iter.into_iter().fold(Self::default(), |mut set, t| {
173            let _: InsertResult<()> = set.insert(t);
174            set
175        })
176    }
177}
178
179#[cfg(test)]
180mod test {
181    use super::*;
182
183    #[test]
184    fn test_ref_counted_hash_map() {
185        let mut map = RefCountedHashMap::<&str, ()>::default();
186        let key = "key";
187
188        // Test refcounts 1 and 2. The behavioral difference is that testing
189        // only with a refcount of 1 doesn't exercise the refcount incrementing
190        // functionality - it only exercises the functionality of initializing a
191        // new entry with a refcount of 1.
192        for refcount in 1..=2 {
193            assert!(!map.contains_key(&key));
194
195            // Insert an entry for the first time, initializing the refcount to
196            // 1.
197            assert_eq!(map.insert_with(key, || ((), ())), InsertResult::Inserted(()));
198            assert!(map.contains_key(&key));
199            assert_refcount(&map, key, 1, "after initial insert");
200
201            // Increase the refcount to `refcount`.
202            for i in 1..refcount {
203                // Since the refcount starts at 1, the entry is always already
204                // in the map.
205                assert_eq!(map.insert_with(key, || ((), ())), InsertResult::AlreadyPresent);
206                assert!(map.contains_key(&key));
207                assert_refcount(&map, key, i + 1, "after subsequent insert");
208            }
209
210            // Decrement the refcount to 1.
211            for i in 1..refcount {
212                // Since we don't decrement the refcount past 1, the entry is
213                // always still present.
214                assert_eq!(map.remove(key), RemoveResult::StillPresent);
215                assert!(map.contains_key(&key));
216                assert_refcount(&map, key, refcount - i, "after decrement refcount");
217            }
218
219            assert_refcount(&map, key, 1, "before entry removed");
220            // Remove the entry when the refcount is 1.
221            assert_eq!(map.remove(key), RemoveResult::Removed(()));
222            assert!(!map.contains_key(&key));
223
224            // Try to remove an entry that no longer exists.
225            assert_eq!(map.remove(key), RemoveResult::NotPresent);
226        }
227    }
228
229    fn assert_refcount(
230        map: &RefCountedHashMap<&str, ()>,
231        key: &str,
232        expected_refcount: usize,
233        context: &str,
234    ) {
235        let (actual_refcount, _value) =
236            map.inner.get(key).unwrap_or_else(|| panic!("refcount should be non-zero {}", context));
237        assert_eq!(actual_refcount.get(), expected_refcount);
238    }
239}