netstack3_base/data_structures/
rcu.rs1use core::marker::PhantomData;
6use core::mem::ManuallyDrop;
7use core::ops::{Deref, DerefMut};
8
9use alloc::sync::Arc;
10use arc_swap::ArcSwap;
11use netstack3_sync::{LockGuard, Mutex};
12
13#[derive(Default)]
16pub struct SynchronizedWriterRcu<T> {
17 lock: Mutex<()>,
18 data: ArcSwap<T>,
19}
20
21impl<T> SynchronizedWriterRcu<T> {
22 pub fn new(value: T) -> Self {
24 Self { lock: Mutex::new(()), data: ArcSwap::new(Arc::new(value)) }
25 }
26
27 pub fn read(&self) -> ReadGuard<'_, T> {
29 ReadGuard { guard: self.data.load(), _marker: PhantomData }
30 }
31
32 pub fn write(&self) -> WriteGuard<'_, T>
36 where
37 T: Clone,
38 {
39 self.write_with(T::clone)
40 }
41
42 pub fn write_with<F: FnOnce(&T) -> T>(&self, f: F) -> WriteGuard<'_, T> {
50 let Self { lock, data } = self;
51 let lock_guard = lock.lock();
53 let copy = f(&*data.load());
54 WriteGuard(ManuallyDrop::new(WriteGuardInner { copy, lock_guard, data: &self.data }))
55 }
56
57 pub fn replace(&self, value: T) {
67 let Self { lock, data } = self;
68 let guard = lock.lock();
69 data.store(Arc::new(value));
70 core::mem::drop(guard);
72 }
73}
74
75pub struct ReadGuard<'a, T> {
79 guard: arc_swap::Guard<Arc<T>>,
80 _marker: PhantomData<&'a ()>,
81}
82
83impl<'a, T> Deref for ReadGuard<'a, T> {
84 type Target = T;
85
86 fn deref(&self) -> &Self::Target {
87 &*self.guard
88 }
89}
90
91pub struct WriteGuard<'a, T>(ManuallyDrop<WriteGuardInner<'a, T>>);
97
98struct WriteGuardInner<'a, T> {
99 copy: T,
100 data: &'a ArcSwap<T>,
101 lock_guard: LockGuard<'a, ()>,
102}
103
104impl<'a, T> Deref for WriteGuard<'a, T> {
105 type Target = T;
106
107 fn deref(&self) -> &Self::Target {
108 let Self(inner) = self;
109 &inner.copy
110 }
111}
112
113impl<'a, T> DerefMut for WriteGuard<'a, T> {
114 fn deref_mut(&mut self) -> &mut Self::Target {
115 let Self(inner) = self;
116 &mut inner.copy
117 }
118}
119
120impl<'a, T> WriteGuard<'a, T> {
121 pub fn discard(mut self) {
124 let Self(inner) = &mut self;
126 unsafe {
129 ManuallyDrop::drop(inner);
130 }
131 core::mem::forget(self);
132 }
133}
134
135impl<'a, T> Drop for WriteGuard<'a, T> {
136 fn drop(&mut self) {
137 let Self(inner) = self;
138 let WriteGuardInner { copy, data, lock_guard } = unsafe { ManuallyDrop::take(inner) };
140 data.store(Arc::new(copy));
141 core::mem::drop(lock_guard);
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn race_writers() {
152 const ROUNDS: usize = 100;
153 let rcu = Arc::new(SynchronizedWriterRcu::new(0usize));
154 let rcu_clone = rcu.clone();
155 let writer = move || {
156 let mut last = None;
157 for _ in 0..ROUNDS {
158 let mut data = rcu_clone.write();
159 assert!(last.is_none_or(|l| *data > l));
160 last = Some(*data);
161 *data += 1;
162 }
163 };
164 let w1 = teststd::thread::spawn(writer.clone());
165 let w2 = teststd::thread::spawn(writer);
166 let rcu_clone = rcu.clone();
167 let reader = teststd::thread::spawn(move || {
168 let mut last = None;
169 for _ in 0..(ROUNDS * 2) {
170 let data = rcu_clone.read();
171 assert!(last.is_none_or(|l| *data >= l));
172 last = Some(*data);
173 }
174 });
175
176 w1.join().expect("join w1");
177 w2.join().expect("join w2");
178 reader.join().expect("join reader");
179 assert_eq!(*rcu.read(), ROUNDS * 2);
180 }
181
182 #[test]
183 fn race_replace() {
184 const ROUNDS: usize = 100;
185 const DELTA: usize = 1000;
186 let rcu = Arc::new(SynchronizedWriterRcu::new(0usize));
187 let rcu_clone = rcu.clone();
188 let w1 = teststd::thread::spawn(move || {
189 let mut last = None;
190 for _ in 0..ROUNDS {
191 let mut data = rcu_clone.write();
192 assert!(last.is_none_or(|l| *data > l));
193 last = Some(*data);
194 *data += 1;
195 }
196 });
197 let rcu_clone = rcu.clone();
198 let w2 = teststd::thread::spawn(move || {
199 for i in 1..=ROUNDS {
200 let step = i * DELTA;
201 rcu_clone.replace(step);
202 assert!(*rcu_clone.read() >= step);
205 }
206 });
207 w1.join().expect("join w1");
208 w2.join().expect("join w2");
209 let value = *rcu.read();
210 let min = ROUNDS * DELTA;
211 let max = min + ROUNDS;
212 assert_eq!(value.min(min), min);
213 assert_eq!(value.max(max), max);
214 }
215
216 #[test]
217 fn read_guard_post_write() {
218 let rcu = SynchronizedWriterRcu::new(0usize);
219 let read1 = rcu.read();
220 assert_eq!(*read1, 0);
221 let mut write = rcu.write();
222 *write = 1;
223 core::mem::drop(write);
225 let read2 = rcu.read();
226 assert_eq!(*read1, 0);
227 assert_eq!(*read2, 1);
228 }
229
230 #[test]
231 fn write_guard_discard() {
232 let rcu = SynchronizedWriterRcu::new(0usize);
233 let read1 = rcu.read();
234 assert_eq!(*read1, 0);
235 let mut write = rcu.write();
236 *write = 1;
237 write.discard();
238 let read2 = rcu.read();
239 assert_eq!(*read1, 0);
240 assert_eq!(*read2, 0);
241 }
242}