starnix_sync/
locks.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Use these crates so that we don't need to make the dependencies conditional.
6use {fuchsia_sync as _, lock_api as _, tracing_mutex as _};
7
8use crate::{LockAfter, LockBefore, LockFor, Locked, RwLockFor, UninterruptibleLock};
9use core::marker::PhantomData;
10use lock_api::RawMutex;
11use std::{any, fmt};
12
13#[cfg(not(detect_lock_cycles))]
14pub mod internal {
15    pub type Mutex<T> = fuchsia_sync::Mutex<T>;
16    pub type MutexGuard<'a, T> = fuchsia_sync::MutexGuard<'a, T>;
17    pub type MappedMutexGuard<'a, T> = fuchsia_sync::MappedMutexGuard<'a, T>;
18    pub type RwLock<T> = fuchsia_sync::RwLock<T>;
19    pub type RwLockReadGuard<'a, T> = fuchsia_sync::RwLockReadGuard<'a, T>;
20    pub type RwLockWriteGuard<'a, T> = fuchsia_sync::RwLockWriteGuard<'a, T>;
21}
22
23#[cfg(detect_lock_cycles)]
24pub mod internal {
25    type RawTracingMutex = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncMutex>;
26    pub type Mutex<T> = lock_api::Mutex<RawTracingMutex, T>;
27    pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawTracingMutex, T>;
28    pub type MappedMutexGuard<'a, T> = lock_api::MappedMutexGuard<'a, RawTracingMutex, T>;
29    type RawTracingRwLock = tracing_mutex::lockapi::TracingWrapper<fuchsia_sync::RawSyncRwLock>;
30    pub type RwLock<T> = lock_api::RwLock<RawTracingRwLock, T>;
31    pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawTracingRwLock, T>;
32    pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawTracingRwLock, T>;
33}
34
35pub use internal::{
36    MappedMutexGuard, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard,
37};
38
39/// A trait for lock guards that can be temporarily unlocked asynchronously.
40/// This is useful for performing async operations while holding a lock, without
41/// causing deadlocks or holding the lock for an extended period.
42#[async_trait::async_trait(?Send)]
43pub trait AsyncUnlockable {
44    /// Temporarily unlocks the guard `s`, executes the async function `f`, and then
45    /// re-locks the guard.
46    /// The lock is guaranteed to be re-acquired before this function returns.
47    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
48    where
49        F: AsyncFnOnce() -> U;
50}
51
52#[async_trait::async_trait(?Send)]
53impl<'a, T> crate::AsyncUnlockable for MutexGuard<'a, T> {
54    async fn unlocked_async<F, U>(s: &mut Self, f: F) -> U
55    where
56        F: AsyncFnOnce() -> U,
57    {
58        // SAFETY: The guard always have a lock mutex.
59        unsafe {
60            Self::mutex(s).raw().unlock();
61        }
62        scopeguard::defer!(
63            // SAFETY: The mutex has been unlocked previously.
64            unsafe { Self::mutex(s).raw().lock() }
65        );
66        f().await
67    }
68}
69
70/// Lock `m1` and `m2` in a consistent order (using the memory address of m1 and m2 and returns the
71/// associated guard. This ensure that `ordered_lock(m1, m2)` and `ordered_lock(m2, m1)` will not
72/// deadlock.
73pub fn ordered_lock<'a, T>(
74    m1: &'a Mutex<T>,
75    m2: &'a Mutex<T>,
76) -> (MutexGuard<'a, T>, MutexGuard<'a, T>) {
77    let ptr1: *const Mutex<T> = m1;
78    let ptr2: *const Mutex<T> = m2;
79    if ptr1 < ptr2 {
80        let g1 = m1.lock();
81        let g2 = m2.lock();
82        (g1, g2)
83    } else {
84        let g2 = m2.lock();
85        let g1 = m1.lock();
86        (g1, g2)
87    }
88}
89
90/// Acquires multiple mutexes in a consistent order based on their memory addresses.
91/// This helps prevent deadlocks.
92pub fn ordered_lock_vec<'a, T>(mutexes: &[&'a Mutex<T>]) -> Vec<MutexGuard<'a, T>> {
93    // Create a vector of tuples containing the mutex and its original index.
94    let mut indexed_mutexes =
95        mutexes.into_iter().enumerate().map(|(i, m)| (i, *m)).collect::<Vec<_>>();
96
97    // Sort the indexed mutexes by their memory addresses.
98    indexed_mutexes.sort_by_key(|(_, m)| *m as *const Mutex<T>);
99
100    // Acquire the locks in the sorted order.
101    let mut guards = indexed_mutexes.into_iter().map(|(i, m)| (i, m.lock())).collect::<Vec<_>>();
102
103    // Reorder the guards to match the original order of the mutexes.
104    guards.sort_by_key(|(i, _)| *i);
105
106    guards.into_iter().map(|(_, g)| g).collect::<Vec<_>>()
107}
108
109/// A wrapper for mutex that requires a `Locked` context to acquire.
110/// This context must be of a level that precedes `L` in the lock ordering graph
111/// where `L` is a level associated with this mutex.
112pub struct OrderedMutex<T, L: LockAfter<UninterruptibleLock>> {
113    mutex: Mutex<T>,
114    _phantom: PhantomData<L>,
115}
116
117impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedMutex<T, L> {
118    fn default() -> Self {
119        Self { mutex: Default::default(), _phantom: Default::default() }
120    }
121}
122
123impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedMutex<T, L> {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        write!(f, "OrderedMutex({:?}, {})", self.mutex, any::type_name::<L>())
126    }
127}
128
129impl<T, L: LockAfter<UninterruptibleLock>> LockFor<L> for OrderedMutex<T, L> {
130    type Data = T;
131    type Guard<'a>
132        = MutexGuard<'a, T>
133    where
134        T: 'a,
135        L: 'a;
136    fn lock(&self) -> Self::Guard<'_> {
137        self.mutex.lock()
138    }
139}
140
141impl<T, L: LockAfter<UninterruptibleLock>> OrderedMutex<T, L> {
142    pub const fn new(t: T) -> Self {
143        Self { mutex: Mutex::new(t), _phantom: PhantomData }
144    }
145
146    pub fn lock<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as LockFor<L>>::Guard<'a>
147    where
148        P: LockBefore<L>,
149    {
150        locked.lock(self)
151    }
152
153    pub fn lock_and<'a, P>(
154        &'a self,
155        locked: &'a mut Locked<P>,
156    ) -> (<Self as LockFor<L>>::Guard<'a>, &'a mut Locked<L>)
157    where
158        P: LockBefore<L>,
159    {
160        locked.lock_and(self)
161    }
162}
163
164/// Lock two OrderedMutex of the same level in the consistent order. Returns both
165/// guards and a new locked context.
166pub fn lock_both<'a, T, L: LockAfter<UninterruptibleLock>, P>(
167    locked: &'a mut Locked<P>,
168    m1: &'a OrderedMutex<T, L>,
169    m2: &'a OrderedMutex<T, L>,
170) -> (MutexGuard<'a, T>, MutexGuard<'a, T>, &'a mut Locked<L>)
171where
172    P: LockBefore<L>,
173{
174    locked.lock_both_and(m1, m2)
175}
176
177/// A wrapper for an RwLock that requires a `Locked` context to acquire.
178/// This context must be of a level that precedes `L` in the lock ordering graph
179/// where `L` is a level associated with this RwLock.
180pub struct OrderedRwLock<T, L: LockAfter<UninterruptibleLock>> {
181    rwlock: RwLock<T>,
182    _phantom: PhantomData<L>,
183}
184
185impl<T: Default, L: LockAfter<UninterruptibleLock>> Default for OrderedRwLock<T, L> {
186    fn default() -> Self {
187        Self { rwlock: Default::default(), _phantom: Default::default() }
188    }
189}
190
191impl<T: fmt::Debug, L: LockAfter<UninterruptibleLock>> fmt::Debug for OrderedRwLock<T, L> {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        write!(f, "OrderedRwLock({:?}, {})", self.rwlock, any::type_name::<L>())
194    }
195}
196
197impl<T, L: LockAfter<UninterruptibleLock>> RwLockFor<L> for OrderedRwLock<T, L> {
198    type Data = T;
199    type ReadGuard<'a>
200        = RwLockReadGuard<'a, T>
201    where
202        T: 'a,
203        L: 'a;
204    type WriteGuard<'a>
205        = RwLockWriteGuard<'a, T>
206    where
207        T: 'a,
208        L: 'a;
209    fn read_lock(&self) -> Self::ReadGuard<'_> {
210        self.rwlock.read()
211    }
212    fn write_lock(&self) -> Self::WriteGuard<'_> {
213        self.rwlock.write()
214    }
215}
216
217impl<T, L: LockAfter<UninterruptibleLock>> OrderedRwLock<T, L> {
218    pub const fn new(t: T) -> Self {
219        Self { rwlock: RwLock::new(t), _phantom: PhantomData }
220    }
221
222    pub fn read<'a, P>(&'a self, locked: &'a mut Locked<P>) -> <Self as RwLockFor<L>>::ReadGuard<'a>
223    where
224        P: LockBefore<L>,
225    {
226        locked.read_lock(self)
227    }
228
229    pub fn write<'a, P>(
230        &'a self,
231        locked: &'a mut Locked<P>,
232    ) -> <Self as RwLockFor<L>>::WriteGuard<'a>
233    where
234        P: LockBefore<L>,
235    {
236        locked.write_lock(self)
237    }
238
239    pub fn read_and<'a, P>(
240        &'a self,
241        locked: &'a mut Locked<P>,
242    ) -> (<Self as RwLockFor<L>>::ReadGuard<'a>, &'a mut Locked<L>)
243    where
244        P: LockBefore<L>,
245    {
246        locked.read_lock_and(self)
247    }
248
249    pub fn write_and<'a, P>(
250        &'a self,
251        locked: &'a mut Locked<P>,
252    ) -> (<Self as RwLockFor<L>>::WriteGuard<'a>, &'a mut Locked<L>)
253    where
254        P: LockBefore<L>,
255    {
256        locked.write_lock_and(self)
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263    use crate::Unlocked;
264
265    #[::fuchsia::test]
266    fn test_lock_ordering() {
267        let l1 = Mutex::new(1);
268        let l2 = Mutex::new(2);
269
270        {
271            let (g1, g2) = ordered_lock(&l1, &l2);
272            assert_eq!(*g1, 1);
273            assert_eq!(*g2, 2);
274        }
275        {
276            let (g2, g1) = ordered_lock(&l2, &l1);
277            assert_eq!(*g1, 1);
278            assert_eq!(*g2, 2);
279        }
280    }
281
282    #[::fuchsia::test]
283    fn test_vec_lock_ordering() {
284        let l1 = Mutex::new(1);
285        let l0 = Mutex::new(0);
286        let l2 = Mutex::new(2);
287
288        {
289            let guards = ordered_lock_vec(&[&l0, &l1, &l2]);
290            assert_eq!(*guards[0], 0);
291            assert_eq!(*guards[1], 1);
292            assert_eq!(*guards[2], 2);
293        }
294        {
295            let guards = ordered_lock_vec(&[&l2, &l1, &l0]);
296            assert_eq!(*guards[0], 2);
297            assert_eq!(*guards[1], 1);
298            assert_eq!(*guards[2], 0);
299        }
300    }
301
302    mod lock_levels {
303        //! Lock ordering tree:
304        //! Unlocked -> A -> B -> C
305        //!          -> D -> E -> F
306        use crate::{LockAfter, UninterruptibleLock, Unlocked};
307        use lock_ordering_macro::lock_ordering;
308        lock_ordering! {
309            Unlocked => A,
310            A => B,
311            B => C,
312            Unlocked => D,
313            D => E,
314            E => F,
315        }
316
317        impl LockAfter<UninterruptibleLock> for A {}
318        impl LockAfter<UninterruptibleLock> for B {}
319        impl LockAfter<UninterruptibleLock> for C {}
320        impl LockAfter<UninterruptibleLock> for D {}
321        impl LockAfter<UninterruptibleLock> for E {}
322        impl LockAfter<UninterruptibleLock> for F {}
323    }
324
325    use lock_levels::{A, B, C, D, E, F};
326
327    #[test]
328    fn test_ordered_mutex() {
329        let a: OrderedMutex<u8, A> = OrderedMutex::new(15);
330        let _b: OrderedMutex<u16, B> = OrderedMutex::new(30);
331        let c: OrderedMutex<u32, C> = OrderedMutex::new(45);
332
333        #[allow(
334            clippy::undocumented_unsafe_blocks,
335            reason = "Force documented unsafe blocks in Starnix"
336        )]
337        let locked = unsafe { Unlocked::new() };
338
339        let (a_data, mut next_locked) = a.lock_and(locked);
340        let c_data = c.lock(&mut next_locked);
341
342        // This won't compile
343        //let _b_data = _b.lock(locked);
344        //let _b_data = _b.lock(&mut next_locked);
345
346        assert_eq!(&*a_data, &15);
347        assert_eq!(&*c_data, &45);
348    }
349    #[test]
350    fn test_ordered_rwlock() {
351        let d: OrderedRwLock<u8, D> = OrderedRwLock::new(15);
352        let _e: OrderedRwLock<u16, E> = OrderedRwLock::new(30);
353        let f: OrderedRwLock<u32, F> = OrderedRwLock::new(45);
354
355        #[allow(
356            clippy::undocumented_unsafe_blocks,
357            reason = "Force documented unsafe blocks in Starnix"
358        )]
359        let locked = unsafe { Unlocked::new() };
360        {
361            let (d_data, mut next_locked) = d.write_and(locked);
362            let f_data = f.read(&mut next_locked);
363
364            // This won't compile
365            //let _e_data = _e.read(locked);
366            //let _e_data = _e.read(&mut next_locked);
367
368            assert_eq!(&*d_data, &15);
369            assert_eq!(&*f_data, &45);
370        }
371        {
372            let (d_data, mut next_locked) = d.read_and(locked);
373            let f_data = f.write(&mut next_locked);
374
375            // This won't compile
376            //let _e_data = _e.write(locked);
377            //let _e_data = _e.write(&mut next_locked);
378
379            assert_eq!(&*d_data, &15);
380            assert_eq!(&*f_data, &45);
381        }
382    }
383
384    #[test]
385    fn test_lock_both() {
386        let a1: OrderedMutex<u8, A> = OrderedMutex::new(15);
387        let a2: OrderedMutex<u8, A> = OrderedMutex::new(30);
388        #[allow(
389            clippy::undocumented_unsafe_blocks,
390            reason = "Force documented unsafe blocks in Starnix"
391        )]
392        let locked = unsafe { Unlocked::new() };
393        {
394            let (a1_data, a2_data, _) = lock_both(locked, &a1, &a2);
395            assert_eq!(&*a1_data, &15);
396            assert_eq!(&*a2_data, &30);
397        }
398        {
399            let (a2_data, a1_data, _) = lock_both(locked, &a2, &a1);
400            assert_eq!(&*a1_data, &15);
401            assert_eq!(&*a2_data, &30);
402        }
403    }
404}