starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _, tracing_mutex as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use std::{any, fmt};
11
12#[cfg(not(debug_assertions))]
13pub type Mutex<T> = fuchsia_sync::Mutex<T>;
14#[cfg(not(debug_assertions))]
15pub type MutexGuard<'a, T> = fuchsia_sync::MutexGuard<'a, T>;
16#[allow(unused)]
17#[cfg(not(debug_assertions))]
18pub type MappedMutexGuard<'a, T> = fuchsia_sync::MappedMutexGuard<'a, T>;
19
20#[cfg(not(debug_assertions))]
21pub type RwLock<T> = fuchsia_sync::RwLock<T>;
22#[cfg(not(debug_assertions))]
23pub type RwLockReadGuard<'a, T> = fuchsia_sync::RwLockReadGuard<'a, T>;
24#[cfg(not(debug_assertions))]
25pub type RwLockWriteGuard<'a, T> = fuchsia_sync::RwLockWriteGuard<'a, T>;
26
27#[cfg(debug_assertions)]
28type RawTracingMutex = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncMutex>;
29#[cfg(debug_assertions)]
30pub type Mutex<T> = lock_api::Mutex<RawTracingMutex, T>;
31#[cfg(debug_assertions)]
32pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTracingMutex, T>;
33#[allow(unused)]
34#[cfg(debug_assertions)]
35pub type MappedMutexGuard<'a, T> = lock_api::MappedMutexGuard<'a, RawTracingMutex, T>;
36
37#[cfg(debug_assertions)]
38type RawTracingRwLock = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncRwLock>;
39#[cfg(debug_assertions)]
40pub type RwLock<T> = lock_api::RwLock<RawTracingRwLock, T>;
41#[cfg(debug_assertions)]
42pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawTracingRwLock, T>;
43#[cfg(debug_assertions)]
44pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawTracingRwLock, T>;
45
46/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
47/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
48/// deadlock.
49pub fn ordered_lock<'a, T>(
50    m1: &'a Mutex<T>,
51    m2: &'a Mutex<T>,
52) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
53    let ptr1: *const Mutex<T> = m1;
54    let ptr2: *const Mutex<T> = m2;
55    if ptr1 < ptr2 {
56        let g1 = m1.lock();
57        let g2 = m2.lock();
58        (g1, g2)
59    } else {
60        let g2 = m2.lock();
61        let g1 = m1.lock();
62        (g1, g2)
63    }
64}
65
66/// Acquires multiple mutexes in a consistent order based on their memory addresses.
67/// This helps prevent deadlocks.
68pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
69    // Create a vector of tuples containing the mutex and its original index.
70    let mut indexed_mutexes =
71        mutexes.into_iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
72
73    // Sort the indexed mutexes by their memory addresses.
74    indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
75
76    // Acquire the locks in the sorted order.
77    let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
78
79    // Reorder the guards to match the original order of the mutexes.
80    guards.sort_by_key(|(i, _)| *i);
81
82    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
83}
84
85/// A wrapper for mutex that requires a `Locked` context to acquire.
86/// This context must be of a level that precedes `L` in the lock ordering graph
87/// where `L` is a level associated with this mutex.
88pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
89    mutex: Mutex<T>,
90    _phantom: PhantomData<L>,
91}
92
93impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
94    fn default() -> Self {
95        Self { mutex: Default::default(), _phantom: Default::default() }
96    }
97}
98
99impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
102    }
103}
104
105impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
106    type Data = T;
107    type Guard<'a>
108        = MutexGuard<'a, T>
109    where
110        T: 'a,
111        L: 'a;
112    fn lock(&self) -> Self::Guard<'_> {
113        self.mutex.lock()
114    }
115}
116
117impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
118    pub const fn new(t: T) -> Self {
119        Self { mutex: Mutex::new(t), _phantom: PhantomData }
120    }
121
122    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
123    where
124        P: LockBefore<L>,
125    {
126        locked.lock(self)
127    }
128
129    pub fn lock_and<'a, P>(
130        &'a self,
131        locked: &'a mut Locked<P>,
132    ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
133    where
134        P: LockBefore<L>,
135    {
136        locked.lock_and(self)
137    }
138}
139
140/// Lock two OrderedMutex of the same level in the consistent order. Returns both
141/// guards and a new locked context.
142pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
143    locked: &'a mut Locked<P>,
144    m1: &'a OrderedMutex<T, L>,
145    m2: &'a OrderedMutex<T, L>,
146) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
147where
148    P: LockBefore<L>,
149{
150    locked.lock_both_and(m1, m2)
151}
152
153/// A wrapper for an RwLock that requires a `Locked` context to acquire.
154/// This context must be of a level that precedes `L` in the lock ordering graph
155/// where `L` is a level associated with this RwLock.
156pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
157    rwlock: RwLock<T>,
158    _phantom: PhantomData<L>,
159}
160
161impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
162    fn default() -> Self {
163        Self { rwlock: Default::default(), _phantom: Default::default() }
164    }
165}
166
167impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
168    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
169        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
170    }
171}
172
173impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
174    type Data = T;
175    type ReadGuard<'a>
176        = RwLockReadGuard<'a, T>
177    where
178        T: 'a,
179        L: 'a;
180    type WriteGuard<'a>
181        = RwLockWriteGuard<'a, T>
182    where
183        T: 'a,
184        L: 'a;
185    fn read_lock(&self) -> Self::ReadGuard<'_> {
186        self.rwlock.read()
187    }
188    fn write_lock(&self) -> Self::WriteGuard<'_> {
189        self.rwlock.write()
190    }
191}
192
193impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
194    pub const fn new(t: T) -> Self {
195        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
196    }
197
198    pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
199    where
200        P: LockBefore<L>,
201    {
202        locked.read_lock(self)
203    }
204
205    pub fn write<'a, P>(
206        &'a self,
207        locked: &'a mut Locked<P>,
208    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
209    where
210        P: LockBefore<L>,
211    {
212        locked.write_lock(self)
213    }
214
215    pub fn read_and<'a, P>(
216        &'a self,
217        locked: &'a mut Locked<P>,
218    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
219    where
220        P: LockBefore<L>,
221    {
222        locked.read_lock_and(self)
223    }
224
225    pub fn write_and<'a, P>(
226        &'a self,
227        locked: &'a mut Locked<P>,
228    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
229    where
230        P: LockBefore<L>,
231    {
232        locked.write_lock_and(self)
233    }
234}
235
236#[cfg(test)]
237mod test {
238    use super::*;
239    use crate::Unlocked;
240
241    #[::fuchsia::test]
242    fn test_lock_ordering() {
243        let l1 = Mutex::new(1);
244        let l2 = Mutex::new(2);
245
246        {
247            let (g1, g2) = ordered_lock(&l1, &l2);
248            assert_eq!(*g1, 1);
249            assert_eq!(*g2, 2);
250        }
251        {
252            let (g2, g1) = ordered_lock(&l2, &l1);
253            assert_eq!(*g1, 1);
254            assert_eq!(*g2, 2);
255        }
256    }
257
258    #[::fuchsia::test]
259    fn test_vec_lock_ordering() {
260        let l1 = Mutex::new(1);
261        let l0 = Mutex::new(0);
262        let l2 = Mutex::new(2);
263
264        {
265            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
266            assert_eq!(*guards[0], 0);
267            assert_eq!(*guards[1], 1);
268            assert_eq!(*guards[2], 2);
269        }
270        {
271            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
272            assert_eq!(*guards[0], 2);
273            assert_eq!(*guards[1], 1);
274            assert_eq!(*guards[2], 0);
275        }
276    }
277
278    mod lock_levels {
279        //! Lock ordering tree:
280        //! Unlocked -> A -> B -> C
281        //!          -> D -> E -> F
282        use crate::{LockAfter, UninterruptibleLock, Unlocked};
283        use lock_ordering_macro::lock_ordering;
284        lock_ordering! {
285            Unlocked => A,
286            A => B,
287            B => C,
288            Unlocked => D,
289            D => E,
290            E => F,
291        }
292
293        impl LockAfter<UninterruptibleLock> for A {}
294        impl LockAfter<UninterruptibleLock> for B {}
295        impl LockAfter<UninterruptibleLock> for C {}
296        impl LockAfter<UninterruptibleLock> for D {}
297        impl LockAfter<UninterruptibleLock> for E {}
298        impl LockAfter<UninterruptibleLock> for F {}
299    }
300
301    use lock_levels::{A, B, C, D, E, F};
302
303    #[test]
304    fn test_ordered_mutex() {
305        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
306        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
307        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
308
309        let mut locked = unsafe { Unlocked::new() };
310
311        let (a_data, mut next_locked) = a.lock_and(&mut locked);
312        let c_data = c.lock(&mut next_locked);
313
314        // This won't compile
315        //let _b_data = _b.lock(&mut locked);
316        //let _b_data = _b.lock(&mut next_locked);
317
318        assert_eq!(&*a_data, &15);
319        assert_eq!(&*c_data, &45);
320    }
321    #[test]
322    fn test_ordered_rwlock() {
323        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
324        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
325        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
326
327        let mut locked = unsafe { Unlocked::new() };
328        {
329            let (d_data, mut next_locked) = d.write_and(&mut locked);
330            let f_data = f.read(&mut next_locked);
331
332            // This won't compile
333            //let _e_data = _e.read(&mut locked);
334            //let _e_data = _e.read(&mut next_locked);
335
336            assert_eq!(&*d_data, &15);
337            assert_eq!(&*f_data, &45);
338        }
339        {
340            let (d_data, mut next_locked) = d.read_and(&mut locked);
341            let f_data = f.write(&mut next_locked);
342
343            // This won't compile
344            //let _e_data = _e.write(&mut locked);
345            //let _e_data = _e.write(&mut next_locked);
346
347            assert_eq!(&*d_data, &15);
348            assert_eq!(&*f_data, &45);
349        }
350    }
351
352    #[test]
353    fn test_lock_both() {
354        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
355        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
356        let mut locked = unsafe { Unlocked::new() };
357        {
358            let (a1_data, a2_data, _) = lock_both(&mut locked, &a1, &a2);
359            assert_eq!(&*a1_data, &15);
360            assert_eq!(&*a2_data, &30);
361        }
362        {
363            let (a2_data, a1_data, _) = lock_both(&mut locked, &a2, &a1);
364            assert_eq!(&*a1_data, &15);
365            assert_eq!(&*a2_data, &30);
366        }
367    }
368}