starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _, tracing_mutex as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use std::{any, fmt};
11
12#[cfg(not(debug_assertions))]
13pub type Mutex<T> = fuchsia_sync::Mutex<T>;
14#[cfg(not(debug_assertions))]
15pub type MutexGuard<'a, T> = fuchsia_sync::MutexGuard<'a, T>;
16#[allow(unused)]
17#[cfg(not(debug_assertions))]
18pub type MappedMutexGuard<'a, T> = fuchsia_sync::MappedMutexGuard<'a, T>;
19
20#[cfg(not(debug_assertions))]
21pub type RwLock<T> = fuchsia_sync::RwLock<T>;
22#[cfg(not(debug_assertions))]
23pub type RwLockReadGuard<'a, T> = fuchsia_sync::RwLockReadGuard<'a, T>;
24#[cfg(not(debug_assertions))]
25pub type RwLockWriteGuard<'a, T> = fuchsia_sync::RwLockWriteGuard<'a, T>;
26
27#[cfg(debug_assertions)]
28type RawTracingMutex = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncMutex>;
29#[cfg(debug_assertions)]
30pub type Mutex<T> = lock_api::Mutex<RawTracingMutex, T>;
31#[cfg(debug_assertions)]
32pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTracingMutex, T>;
33#[allow(unused)]
34#[cfg(debug_assertions)]
35pub type MappedMutexGuard<'a, T> = lock_api::MappedMutexGuard<'a, RawTracingMutex, T>;
36
37#[cfg(debug_assertions)]
38type RawTracingRwLock = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncRwLock>;
39#[cfg(debug_assertions)]
40pub type RwLock<T> = lock_api::RwLock<RawTracingRwLock, T>;
41#[cfg(debug_assertions)]
42pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawTracingRwLock, T>;
43#[cfg(debug_assertions)]
44pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawTracingRwLock, T>;
45
46/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
47/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
48/// deadlock.
49pub fn ordered_lock<'a, T>(
50    m1: &'a Mutex<T>,
51    m2: &'a Mutex<T>,
52) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
53    let ptr1: *const Mutex<T> = m1;
54    let ptr2: *const Mutex<T> = m2;
55    if ptr1 < ptr2 {
56        let g1 = m1.lock();
57        let g2 = m2.lock();
58        (g1, g2)
59    } else {
60        let g2 = m2.lock();
61        let g1 = m1.lock();
62        (g1, g2)
63    }
64}
65
66/// Acquires multiple mutexes in a consistent order based on their memory addresses.
67/// This helps prevent deadlocks.
68pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
69    // Create a vector of tuples containing the mutex and its original index.
70    let mut indexed_mutexes =
71        mutexes.into_iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
72
73    // Sort the indexed mutexes by their memory addresses.
74    indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
75
76    // Acquire the locks in the sorted order.
77    let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
78
79    // Reorder the guards to match the original order of the mutexes.
80    guards.sort_by_key(|(i, _)| *i);
81
82    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
83}
84
85/// A wrapper for mutex that requires a `Locked` context to acquire.
86/// This context must be of a level that precedes `L` in the lock ordering graph
87/// where `L` is a level associated with this mutex.
88pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
89    mutex: Mutex<T>,
90    _phantom: PhantomData<L>,
91}
92
93impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
94    fn default() -> Self {
95        Self { mutex: Default::default(), _phantom: Default::default() }
96    }
97}
98
99impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
102    }
103}
104
105impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
106    type Data = T;
107    type Guard<'a>
108        = MutexGuard<'a, T>
109    where
110        T: 'a,
111        L: 'a;
112    fn lock(&self) -> Self::Guard<'_> {
113        self.mutex.lock()
114    }
115}
116
117impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
118    pub const fn new(t: T) -> Self {
119        Self { mutex: Mutex::new(t), _phantom: PhantomData }
120    }
121
122    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<'_, P>) -> <Self as LockFor<L>>::Guard<'a>
123    where
124        P: LockBefore<L>,
125    {
126        locked.lock(self)
127    }
128
129    pub fn lock_and<'a, P>(
130        &'a self,
131        locked: &'a mut Locked<'_, P>,
132    ) -> (<Self as LockFor<L>>::Guard<'a>, Locked<'a, L>)
133    where
134        P: LockBefore<L>,
135    {
136        locked.lock_and(self)
137    }
138}
139
140/// Lock two OrderedMutex of the same level in the consistent order. Returns both
141/// guards and a new locked context.
142pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
143    locked: &'a mut Locked<'_, P>,
144    m1: &'a OrderedMutex<T, L>,
145    m2: &'a OrderedMutex<T, L>,
146) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, Locked<'a, L>)
147where
148    P: LockBefore<L>,
149{
150    locked.lock_both_and(m1, m2)
151}
152
153/// A wrapper for an RwLock that requires a `Locked` context to acquire.
154/// This context must be of a level that precedes `L` in the lock ordering graph
155/// where `L` is a level associated with this RwLock.
156pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
157    rwlock: RwLock<T>,
158    _phantom: PhantomData<L>,
159}
160
161impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
162    fn default() -> Self {
163        Self { rwlock: Default::default(), _phantom: Default::default() }
164    }
165}
166
167impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
170    }
171}
172
173impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
174    type Data = T;
175    type ReadGuard<'a>
176        = RwLockReadGuard<'a, T>
177    where
178        T: 'a,
179        L: 'a;
180    type WriteGuard<'a>
181        = RwLockWriteGuard<'a, T>
182    where
183        T: 'a,
184        L: 'a;
185    fn read_lock(&self) -> Self::ReadGuard<'_> {
186        self.rwlock.read()
187    }
188    fn write_lock(&self) -> Self::WriteGuard<'_> {
189        self.rwlock.write()
190    }
191}
192
193impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
194    pub const fn new(t: T) -> Self {
195        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
196    }
197
198    pub fn read<'a, P>(
199        &'a self,
200        locked: &'a mut Locked<'_, P>,
201    ) -> <Self as RwLockFor<L>>::ReadGuard<'a>
202    where
203        P: LockBefore<L>,
204    {
205        locked.read_lock(self)
206    }
207
208    pub fn write<'a, P>(
209        &'a self,
210        locked: &'a mut Locked<'_, P>,
211    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
212    where
213        P: LockBefore<L>,
214    {
215        locked.write_lock(self)
216    }
217
218    pub fn read_and<'a, P>(
219        &'a self,
220        locked: &'a mut Locked<'_, P>,
221    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, Locked<'a, L>)
222    where
223        P: LockBefore<L>,
224    {
225        locked.read_lock_and(self)
226    }
227
228    pub fn write_and<'a, P>(
229        &'a self,
230        locked: &'a mut Locked<'_, P>,
231    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, Locked<'a, L>)
232    where
233        P: LockBefore<L>,
234    {
235        locked.write_lock_and(self)
236    }
237}
238
239#[cfg(test)]
240mod test {
241    use super::*;
242    use crate::Unlocked;
243
244    #[::fuchsia::test]
245    fn test_lock_ordering() {
246        let l1 = Mutex::new(1);
247        let l2 = Mutex::new(2);
248
249        {
250            let (g1, g2) = ordered_lock(&l1, &l2);
251            assert_eq!(*g1, 1);
252            assert_eq!(*g2, 2);
253        }
254        {
255            let (g2, g1) = ordered_lock(&l2, &l1);
256            assert_eq!(*g1, 1);
257            assert_eq!(*g2, 2);
258        }
259    }
260
261    #[::fuchsia::test]
262    fn test_vec_lock_ordering() {
263        let l1 = Mutex::new(1);
264        let l0 = Mutex::new(0);
265        let l2 = Mutex::new(2);
266
267        {
268            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
269            assert_eq!(*guards[0], 0);
270            assert_eq!(*guards[1], 1);
271            assert_eq!(*guards[2], 2);
272        }
273        {
274            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
275            assert_eq!(*guards[0], 2);
276            assert_eq!(*guards[1], 1);
277            assert_eq!(*guards[2], 0);
278        }
279    }
280
281    mod lock_levels {
282        //! Lock ordering tree:
283        //! Unlocked -> A -> B -> C
284        //!          -> D -> E -> F
285        use crate::{LockAfter, UninterruptibleLock, Unlocked};
286        use lock_ordering_macro::lock_ordering;
287        lock_ordering! {
288            Unlocked => A,
289            A => B,
290            B => C,
291            Unlocked => D,
292            D => E,
293            E => F,
294        }
295
296        impl LockAfter<UninterruptibleLock> for A {}
297        impl LockAfter<UninterruptibleLock> for B {}
298        impl LockAfter<UninterruptibleLock> for C {}
299        impl LockAfter<UninterruptibleLock> for D {}
300        impl LockAfter<UninterruptibleLock> for E {}
301        impl LockAfter<UninterruptibleLock> for F {}
302    }
303
304    use lock_levels::{A, B, C, D, E, F};
305
306    #[test]
307    fn test_ordered_mutex() {
308        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
309        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
310        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
311
312        let mut locked = unsafe { Unlocked::new() };
313
314        let (a_data, mut next_locked) = a.lock_and(&mut locked);
315        let c_data = c.lock(&mut next_locked);
316
317        // This won't compile
318        //let _b_data = _b.lock(&mut locked);
319        //let _b_data = _b.lock(&mut next_locked);
320
321        assert_eq!(&*a_data, &15);
322        assert_eq!(&*c_data, &45);
323    }
324    #[test]
325    fn test_ordered_rwlock() {
326        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
327        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
328        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
329
330        let mut locked = unsafe { Unlocked::new() };
331        {
332            let (d_data, mut next_locked) = d.write_and(&mut locked);
333            let f_data = f.read(&mut next_locked);
334
335            // This won't compile
336            //let _e_data = _e.read(&mut locked);
337            //let _e_data = _e.read(&mut next_locked);
338
339            assert_eq!(&*d_data, &15);
340            assert_eq!(&*f_data, &45);
341        }
342        {
343            let (d_data, mut next_locked) = d.read_and(&mut locked);
344            let f_data = f.write(&mut next_locked);
345
346            // This won't compile
347            //let _e_data = _e.write(&mut locked);
348            //let _e_data = _e.write(&mut next_locked);
349
350            assert_eq!(&*d_data, &15);
351            assert_eq!(&*f_data, &45);
352        }
353    }
354
355    #[test]
356    fn test_lock_both() {
357        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
358        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
359        let mut locked = unsafe { Unlocked::new() };
360        {
361            let (a1_data, a2_data, _) = lock_both(&mut locked, &a1, &a2);
362            assert_eq!(&*a1_data, &15);
363            assert_eq!(&*a2_data, &30);
364        }
365        {
366            let (a2_data, a1_data, _) = lock_both(&mut locked, &a2, &a1);
367            assert_eq!(&*a1_data, &15);
368            assert_eq!(&*a2_data, &30);
369        }
370    }
371}