netstack3_base/data_structures/
rcu.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::mem::ManuallyDrop;
7use core::ops::{Deref, DerefMut};
8
9use alloc::sync::Arc;
10use arc_swap::ArcSwap;
11use netstack3_sync::{LockGuard, Mutex};
12
13/// An RCU (Read-Copy-Update) data structure that uses a `Mutex` to synchronize
14/// writers.
15#[derive(Default)]
16pub struct SynchronizedWriterRcu<T> {
17    lock: Mutex<()>,
18    data: ArcSwap<T>,
19}
20
21impl<T> SynchronizedWriterRcu<T> {
22    /// Creates a new `SingleWriterRcu` with `value`.
23    pub fn new(value: T) -> Self {
24        Self { lock: Mutex::new(()), data: ArcSwap::new(Arc::new(value)) }
25    }
26
27    /// Acquires a read guard on the RCU.
28    pub fn read(&self) -> ReadGuard<'_, T> {
29        ReadGuard { guard: self.data.load(), _marker: PhantomData }
30    }
31
32    /// Acquires a write guard on the RCU by cloning `T`.
33    ///
34    /// See [`WriteGuard::write_with``].
35    pub fn write(&self) -> WriteGuard<'_, T>
36    where
37        T: Clone,
38    {
39        self.write_with(T::clone)
40    }
41
42    /// Acquires a write guard on the RCU.
43    ///
44    /// [`WriteGuard`] provides mutable access to a new copy of the data kept by
45    /// `SingleWriterRcu` via `f`.
46    ///
47    /// Dropping the returned [`WriteGuard`] commits any changes back to the
48    /// RCU. Changes may be discarded with [`WriteGuard::discard`].
49    pub fn write_with<F: FnOnce(&T) -> T>(&self, f: F) -> WriteGuard<'_, T> {
50        let Self { lock, data } = self;
51        // Lock before reading the data.
52        let lock_guard = lock.lock();
53        let copy = f(&*data.load());
54        WriteGuard(ManuallyDrop::new(WriteGuardInner { copy, lock_guard, data: &self.data }))
55    }
56
57    /// Replaces the value in the RCU with `value` without reading the current
58    /// value.
59    ///
60    /// *WARNING*: do *NOT* use this method with a value built from a clone of
61    /// the data from [`SingleWriterRcu::read`], this is only meant to be used
62    /// when the new value is produced independently of the previous value. The
63    /// value may be changed by another thread between `read` and `replace` -
64    /// these changes would be lost. Use [`SingleWriterRcu::write`] to ensure
65    /// writer synchronization is applied.
66    pub fn replace(&self, value: T) {
67        let Self { lock, data } = self;
68        let guard = lock.lock();
69        data.store(Arc::new(value));
70        // Only drop the guard after we've stored the new value in the ArcSwap.
71        core::mem::drop(guard);
72    }
73}
74
75/// A read guard on [`SingleWriterRcu`].
76///
77/// Implements [`Deref`] to get to the contained data.
78pub struct ReadGuard<'a, T> {
79    guard: arc_swap::Guard<Arc<T>>,
80    _marker: PhantomData<&'a ()>,
81}
82
83impl<'a, T> Deref for ReadGuard<'a, T> {
84    type Target = T;
85
86    fn deref(&self) -> &Self::Target {
87        &*self.guard
88    }
89}
90
91/// A write guard on [`SingleWriterRcu`].
92///
93/// Implements [`Deref`]  and [`DerefMut`] to get to the contained data.
94///
95/// Changes to the contained data are applied back to the RCU on drop.
96pub struct WriteGuard<'a, T>(ManuallyDrop<WriteGuardInner<'a, T>>);
97
98struct WriteGuardInner<'a, T> {
99    copy: T,
100    data: &'a ArcSwap<T>,
101    lock_guard: LockGuard<'a, ()>,
102}
103
104impl<'a, T> Deref for WriteGuard<'a, T> {
105    type Target = T;
106
107    fn deref(&self) -> &Self::Target {
108        let Self(inner) = self;
109        &inner.copy
110    }
111}
112
113impl<'a, T> DerefMut for WriteGuard<'a, T> {
114    fn deref_mut(&mut self) -> &mut Self::Target {
115        let Self(inner) = self;
116        &mut inner.copy
117    }
118}
119
120impl<'a, T> WriteGuard<'a, T> {
121    /// Discards this attempt at writing to the RCU. No changes will be
122    /// committed.
123    pub fn discard(mut self) {
124        // Drop the inner slot without dropping self.
125        let Self(inner) = &mut self;
126        // SAFETY: inner is not used again. We prevent the inner drop
127        // implementation from running by forgetting self.
128        unsafe {
129            ManuallyDrop::drop(inner);
130        }
131        core::mem::forget(self);
132    }
133}
134
135impl<'a, T> Drop for WriteGuard<'a, T> {
136    fn drop(&mut self) {
137        let Self(inner) = self;
138        // SAFETY: inner is not used again, self is being dropped.
139        let WriteGuardInner { copy, data, lock_guard } = unsafe { ManuallyDrop::take(inner) };
140        data.store(Arc::new(copy));
141        // Only drop the lock once we're done.
142        core::mem::drop(lock_guard);
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[test]
151    fn race_writers() {
152        const ROUNDS: usize = 100;
153        let rcu = Arc::new(SynchronizedWriterRcu::new(0usize));
154        let rcu_clone = rcu.clone();
155        let writer = move || {
156            let mut last = None;
157            for _ in 0..ROUNDS {
158                let mut data = rcu_clone.write();
159                assert!(last.is_none_or(|l| *data > l));
160                last = Some(*data);
161                *data += 1;
162            }
163        };
164        let w1 = teststd::thread::spawn(writer.clone());
165        let w2 = teststd::thread::spawn(writer);
166        let rcu_clone = rcu.clone();
167        let reader = teststd::thread::spawn(move || {
168            let mut last = None;
169            for _ in 0..(ROUNDS * 2) {
170                let data = rcu_clone.read();
171                assert!(last.is_none_or(|l| *data >= l));
172                last = Some(*data);
173            }
174        });
175
176        w1.join().expect("join w1");
177        w2.join().expect("join w2");
178        reader.join().expect("join reader");
179        assert_eq!(*rcu.read(), ROUNDS * 2);
180    }
181
182    #[test]
183    fn race_replace() {
184        const ROUNDS: usize = 100;
185        const DELTA: usize = 1000;
186        let rcu = Arc::new(SynchronizedWriterRcu::new(0usize));
187        let rcu_clone = rcu.clone();
188        let w1 = teststd::thread::spawn(move || {
189            let mut last = None;
190            for _ in 0..ROUNDS {
191                let mut data = rcu_clone.write();
192                assert!(last.is_none_or(|l| *data > l));
193                last = Some(*data);
194                *data += 1;
195            }
196        });
197        let rcu_clone = rcu.clone();
198        let w2 = teststd::thread::spawn(move || {
199            for i in 1..=ROUNDS {
200                let step = i * DELTA;
201                rcu_clone.replace(step);
202                // If replace didn't properly hold a lock this would fail
203                // because the writer thread would have an out of date copy.
204                assert!(*rcu_clone.read() >= step);
205            }
206        });
207        w1.join().expect("join w1");
208        w2.join().expect("join w2");
209        let value = *rcu.read();
210        let min = ROUNDS * DELTA;
211        let max = min + ROUNDS;
212        assert_eq!(value.min(min), min);
213        assert_eq!(value.max(max), max);
214    }
215
216    #[test]
217    fn read_guard_post_write() {
218        let rcu = SynchronizedWriterRcu::new(0usize);
219        let read1 = rcu.read();
220        assert_eq!(*read1, 0);
221        let mut write = rcu.write();
222        *write = 1;
223        // Drop to commit.
224        core::mem::drop(write);
225        let read2 = rcu.read();
226        assert_eq!(*read1, 0);
227        assert_eq!(*read2, 1);
228    }
229
230    #[test]
231    fn write_guard_discard() {
232        let rcu = SynchronizedWriterRcu::new(0usize);
233        let read1 = rcu.read();
234        assert_eq!(*read1, 0);
235        let mut write = rcu.write();
236        *write = 1;
237        write.discard();
238        let read2 = rcu.read();
239        assert_eq!(*read1, 0);
240        assert_eq!(*read2, 0);
241    }
242}