starnix_sync/
lock_sequence.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Tools for describing and enforcing lock acquisition order.
6//!
7//! To use these tools:
8//! 1. A lock level must be defined for each type of lock. This can be a simple enum.
9//! 2. Then a relation `LockedAfter` between these levels must be described,
10//! forming a graph. This graph must be acyclic, since a cycle would indicate
11//! a potential deadlock.
12//! 3. Each time a lock is acquired, it must be done using an object of a `Locked<P>`
13//! type, where `P` is any lock level that comes before the level `L` that is
14//! associated with this lock. Doing so yields a new object of type `Locked<L>`
15//! that can be used to acquire subsequent locks.
16//! 3. Each place where a lock is used must be marked with the maximum lock level
17//! that can be already acquired before attempting to acquire this lock. To do this,
18//! it takes a special marker object `Locked<P>` where `P` is a lock level that
19//! must come before the level associated in this lock in the graph. This object
20//! is then used to acquire the lock, and a new object `Locked<L>` is returned, with
21//! a new lock level `L` that comes after `P` in the lock ordering graph.
22//!
23//! ## Example
24//! See also tests for this crate.
25//!
26//! ```
27//! use std::sync::Mutex;
28//! use starnix_sync::{lock_ordering, lock::LockFor, relation::LockAfter, Unlocked};
29//!
30//! #[derive(Default)]
31//! struct HoldsLocks {
32//!    a: Mutex<u8>,
33//!    b: Mutex<u32>,
34//! }
35//!
36//! lock_ordering! {
37//!    // LockA is the top of the lock hierarchy.
38//!    Unlocked => LevelA,
39//!    // LockA can be acquired before LockB.
40//!    LevelA => LevelB,
41//! }
42//!
43//! impl LockFor<LockA> for HoldsLocks {
44//!    type Data = u8;
45//!    type Guard<'l> = std::sync::MutexGuard<'l, u8>
46//!        where Self: 'l;
47//!    fn lock(&self) -> Self::Guard<'_> {
48//!        self.a.lock().unwrap()
49//!    }
50//! }
51//!
52//! impl LockFor<LockB> for HoldsLocks {
53//!    type Data = u32;
54//!    type Guard<'l> = std::sync::MutexGuard<'l, u32>
55//!        where Self: 'l;
56//!    fn lock(&self) -> Self::Guard<'_> {
57//!        self.b.lock().unwrap()
58//!    }
59//! }
60//!
61//! // Accessing locked state looks like this:
62//!
63//! let state = HoldsLocks::default();
64//! // Create a new lock session with the "root" lock level (empty tuple).
65//! let mut locked = Unlocked::new();
66//! // Access locked state.
67//! let (a, mut locked_a) = locked.lock_and::<LockA, _>(&state);
68//! let b = locked_a.lock::<LockB, _>(&state);
69//! ```
70//!
71//! The [lock_ordering] macro provides definitions for lock levels and
72//! implementations of [LockAfter] for all the locks that are connected
73//! in the graph (one can be locked after another). It also prevents
74//! accidental lock ordering inversion introduced while defining the graph
75//! by detecting cycles in it.
76//!
77//! This won't compile:
78//! ```compile_fail
79//! lock_ordering!{
80//!     Unlocked => A,
81//!     A => B,
82//!     B => A,
83//! }
84//! ```
85//!
86//! The methods on [Locked] prevent out-of-order locking according to the
87//! specified lock relationships.
88//!
89//! This won't compile because `LockB` does not implement `LockBefore<LockA>`:
90//! ```compile_fail
91//! # use std::sync::Mutex;
92//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
93//! #
94//! # #[derive(Default)]
95//! # struct HoldsLocks {
96//! #    a: Mutex<u8>,
97//! #    b: Mutex<u32>,
98//! # }
99//! #
100//! # lock_ordering! {
101//! #    // LockA is the top of the lock hierarchy.
102//! #    Unlocked => LockA,
103//! #    // LockA can be acquired before LockB.
104//! #    LockA => LockB,
105//! # }
106//! #
107//! # impl LockFor<LockA> for HoldsLocks {
108//! #    type Data = u8;
109//! #    type Guard<'l> = std::sync::MutexGuard<'l, u8>
110//! #        where Self: 'l;
111//! #    fn lock(&self) -> Self::Guard<'_> {
112//! #        self.a.lock().unwrap()
113//! #    }
114//! # }
115//! #
116//! # impl LockFor<LockB> for HoldsLocks {
117//! #     type Data = u32;
118//! #     type Guard<'l> = std::sync::MutexGuard<'l, u32>
119//! #         where Self: 'l;
120//! #     fn lock(&self) -> Self::Guard<'_> {
121//! #         self.b.lock().unwrap()
122//! #     }
123//! # }
124//! #
125//!
126//! let state = HoldsLocks::default();
127//! let mut locked = Unlocked::new();
128//!
129//! // Locking B without A is fine, but locking A after B is not.
130//! let (b, mut locked_b) = locked.lock_and::<LockB, _>(&state);
131//! // compile error: LockB does not implement LockBefore<LockA>
132//! let a = locked_b.lock::<LockA, _>(&state);
133//! ```
134//!
135//! Even if the lock guard goes out of scope, the new `Locked` instance returned
136//! by [Locked::lock_and] will prevent the original one from being used to
137//! access state. This doesn't work:
138//!
139//! ```compile_fail
140//! # use std::sync::Mutex;
141//! # use starnix_sync::{lock_ordering, lock::LockFor, Locked, Unlocked};
142//! #
143//! # #[derive(Default)]
144//! # struct HoldsLocks {
145//! #     a: Mutex<u8>,
146//! #     b: Mutex<u32>,
147//! # }
148//! #
149//! # lock_ordering! {
150//! #    // LockA is the top of the lock hierarchy.
151//! #    Unlocked => LockA,
152//! #    // LockA can be acquired before LockB.
153//! #    LockA => LockB,
154//! # }
155//! #
156//! # impl LockFor<LockA> for HoldsLocks {
157//! #     type Data = u8;
158//! #     type Guard<'l> = std::sync::MutexGuard<'l, u8>
159//! #         where Self: 'l;
160//! #     fn lock(&self) -> Self::Guard<'_> {
161//! #         self.a.lock().unwrap()
162//! #     }
163//! # }
164//! #
165//! # impl LockFor<LockB> for HoldsLocks {
166//! #     type Data = u32;
167//! #     type Guard<'l> = std::sync::MutexGuard<'l, u32>
168//! #         where Self: 'l;
169//! #     fn lock(&self) -> Self::Guard<'_> {
170//! #         self.b.lock().unwrap()
171//! #     }
172//! # }
173//!
174//! let state = HoldsLocks::default();
175//! let mut locked = Unlocked::new();
176//!
177//! let (b, mut locked_b) = locked.lock_and::<LockB, _>();
178//! drop(b);
179//! let b = locked_b.lock::<LockB, _>(&state);
180//! // Won't work; `locked` is mutably borrowed by `locked_b`.
181//! let a = locked.lock::<LockA, _>(&state);
182//! ```
183
184use core::marker::PhantomData;
185use static_assertions::const_assert_eq;
186
187pub use crate::{LockBefore, LockEqualOrBefore, LockFor, RwLockFor};
188
189/// Enforcement mechanism for lock ordering.
190///
191/// `Locked` is a context that holds the lock level marker. Any state that
192/// requires a lock to access should acquire this lock by calling `lock_and`
193/// on a `Locked` object that is of an appropriate lock level. Acquiring
194/// a lock in this way produces the guard and a new `Locked` instance
195/// (with a different lock level) that mutably borrows from the original
196/// instance. This means the original instance can't be used to acquire
197/// new locks until the new instance leaves scope.
198pub struct Locked<L>(PhantomData<L>);
199
200/// "Highest" lock level
201///
202/// The lock level for the thing returned by `Locked::new`. Users of this crate
203/// should implement `LockAfter<Unlocked>` for the root of any lock ordering
204/// trees.
205pub enum Unlocked {}
206
207const_assert_eq!(std::mem::size_of::<Locked<Unlocked>>(), 0);
208
209impl Unlocked {
210    /// Entry point for locked access.
211    ///
212    /// `Unlocked` is the "root" lock level and can be acquired before any lock.
213    ///
214    /// # Safety
215    /// `Unlocked` should only be used before any lock in the program has been acquired.
216    #[inline(always)]
217    pub unsafe fn new() -> Locked<Unlocked> {
218        Locked::<Unlocked>(Default::default())
219    }
220}
221impl LockEqualOrBefore<Unlocked> for Unlocked {}
222
223impl<L> Locked<L> {
224    /// Acquire the given lock.
225    ///
226    /// This requires that `M` can be locked after `L`.
227    #[inline(always)]
228    pub fn lock<'a, M, S>(&'a mut self, source: &'a S) -> S::Guard<'a>
229    where
230        M: 'a,
231        S: LockFor<M>,
232        L: LockBefore<M>,
233    {
234        let (data, _) = self.lock_and::<M, S>(source);
235        data
236    }
237
238    /// Acquire the given lock and a new locked context.
239    ///
240    /// This requires that `M` can be locked after `L`.
241    #[inline(always)]
242    pub fn lock_and<'a, M, S>(&'a mut self, source: &'a S) -> (S::Guard<'a>, &'a mut Locked<M>)
243    where
244        M: 'a,
245        S: LockFor<M>,
246        L: LockBefore<M>,
247    {
248        let data = S::lock(source);
249        (data, Locked::fabricate())
250    }
251
252    /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards
253    /// as well as the new locked context.
254    ///
255    /// This requires that `M` can be locked after `L`.
256    #[inline(always)]
257    pub fn lock_both_and<'a, M, S>(
258        &'a mut self,
259        source1: &'a S,
260        source2: &'a S,
261    ) -> (S::Guard<'a>, S::Guard<'a>, &mut Locked<M>)
262    where
263        M: 'a,
264        S: LockFor<M>,
265        L: LockBefore<M>,
266    {
267        let ptr1: *const S = source1;
268        let ptr2: *const S = source2;
269        if ptr1 < ptr2 {
270            let data1 = S::lock(source1);
271            let data2 = S::lock(source2);
272            (data1, data2, Locked::fabricate())
273        } else {
274            let data2 = S::lock(source2);
275            let data1 = S::lock(source1);
276            (data1, data2, Locked::fabricate())
277        }
278    }
279    /// Acquire two locks that are on the same level, in a consistent order (sorted by memory address) and return both guards.
280    ///
281    /// This requires that `M` can be locked after `L`.
282    #[inline(always)]
283    pub fn lock_both<'a, M, S>(
284        &'a mut self,
285        source1: &'a S,
286        source2: &'a S,
287    ) -> (S::Guard<'a>, S::Guard<'a>)
288    where
289        M: 'a,
290        S: LockFor<M>,
291        L: LockBefore<M>,
292    {
293        let (data1, data2, _) = self.lock_both_and(source1, source2);
294        (data1, data2)
295    }
296
297    /// Attempt to acquire the given read lock and a new locked context.
298    ///
299    /// For accessing state via reader/writer locks. This requires that `M` can
300    /// be locked after `L`.
301    #[inline(always)]
302    pub fn read_lock_and<'a, M, S>(
303        &'a mut self,
304        source: &'a S,
305    ) -> (S::ReadGuard<'a>, &mut Locked<M>)
306    where
307        M: 'a,
308        S: RwLockFor<M>,
309        L: LockBefore<M>,
310    {
311        let data = S::read_lock(source);
312        (data, Locked::fabricate())
313    }
314
315    /// Attempt to acquire the given read lock.
316    ///
317    /// For accessing state via reader/writer locks. This requires that `M` can
318    /// be locked after `L`.
319    #[inline(always)]
320    pub fn read_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::ReadGuard<'a>
321    where
322        M: 'a,
323        S: RwLockFor<M>,
324        L: LockBefore<M>,
325    {
326        let (data, _) = self.read_lock_and::<M, S>(source);
327        data
328    }
329
330    /// Attempt to acquire the given write lock and a new locked context.
331    ///
332    /// For accessing state via reader/writer locks. This requires that `M` can
333    /// be locked after `L`.
334    #[inline(always)]
335    pub fn write_lock_and<'a, M, S>(
336        &'a mut self,
337        source: &'a S,
338    ) -> (S::WriteGuard<'a>, &mut Locked<M>)
339    where
340        M: 'a,
341        S: RwLockFor<M>,
342        L: LockBefore<M>,
343    {
344        let data = S::write_lock(source);
345        (data, Locked::fabricate())
346    }
347
348    /// Attempt to acquire the given write lock.
349    ///
350    /// For accessing state via reader/writer locks. This requires that `M` can
351    /// be locked after `L`.
352    #[inline(always)]
353    pub fn write_lock<'a, M, S>(&'a mut self, source: &'a S) -> S::WriteGuard<'a>
354    where
355        M: 'a,
356        S: RwLockFor<M>,
357        L: LockBefore<M>,
358    {
359        let (data, _) = self.write_lock_and::<M, S>(source);
360        data
361    }
362
363    /// Restrict locking as if a lock was acquired.
364    ///
365    /// Like `lock_and` but doesn't actually acquire the lock `M`. This is
366    /// safe because any locks that could be acquired with the lock `M` held can
367    /// also be acquired without `M` being held.
368    #[inline(always)]
369    pub fn cast_locked<M>(&mut self) -> &mut Locked<M>
370    where
371        L: LockEqualOrBefore<M>,
372    {
373        Locked::fabricate()
374    }
375
376    const CHECK_ZST: () = assert!(std::mem::size_of::<Self>() == 0, "Locked<T> must be a ZST");
377    fn fabricate<'a>() -> &'a mut Self {
378        let _ = Self::CHECK_ZST;
379        // SAFETY: As confirmed by the preceding assert, `Self`
380        // is a ZST. `NonNull::as_mut` requires that the pointer is convertible
381        // to a reference [1], which in turn requires the following [2]:
382        // - The pointer is properly aligned (guaranteed by `NonNull::dangling`)
383        // - Non-null (guaranteed by invariant on `NonNull`)
384        // - Dereferenceable (guaranteed for all zero-sized pointers [3])
385        // - Points to a valid referent (trivially true for any zero-sized referent)
386        // - Satisfies Rust's aliasing rules (trivially true for any zero-sized referent)
387        //
388        // [1] https://doc.rust-lang.org/1.87.0/std/ptr/struct.NonNull.html#method.as_mut
389        // [2] https://doc.rust-lang.org/1.87.0/std/ptr/index.html#pointer-to-reference-conversion
390        // [3] https://doc.rust-lang.org/1.87.0/std/ptr/index.html#safety
391        unsafe { std::ptr::NonNull::dangling().as_mut() }
392    }
393}
394
395#[cfg(test)]
396mod test {
397    use std::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
398
399    #[test]
400    fn example() {
401        use crate::{lock_ordering, Unlocked};
402
403        #[derive(Default)]
404        pub struct HoldsLocks {
405            a: Mutex<u8>,
406            b: Mutex<u32>,
407        }
408
409        lock_ordering! {
410            // LockA is the top of the lock hierarchy.
411            Unlocked => LockA,
412            // LockA can be acquired before LockB.
413            LockA => LockB,
414        }
415
416        impl LockFor<LockA> for HoldsLocks {
417            type Data = u8;
418            type Guard<'l>
419                = std::sync::MutexGuard<'l, u8>
420            where
421                Self: 'l;
422            fn lock(&self) -> Self::Guard<'_> {
423                self.a.lock().unwrap()
424            }
425        }
426
427        impl LockFor<LockB> for HoldsLocks {
428            type Data = u32;
429            type Guard<'l>
430                = std::sync::MutexGuard<'l, u32>
431            where
432                Self: 'l;
433            fn lock(&self) -> Self::Guard<'_> {
434                self.b.lock().unwrap()
435            }
436        }
437
438        // Accessing locked state looks like this:
439
440        let state = HoldsLocks::default();
441        // Create a new lock session with the "root" lock level (empty tuple).
442        let mut locked = unsafe { Unlocked::new() };
443        // Access locked state.
444        let (_a, locked_a) = locked.lock_and::<LockA, _>(&state);
445        let _b = locked_a.lock::<LockB, _>(&state);
446    }
447
448    mod lock_levels {
449        use crate::Unlocked;
450        use lock_ordering_macro::lock_ordering;
451        // Lock ordering tree:
452        // A -> B -> {C, D, E -> F, G -> H}
453        lock_ordering! {
454            Unlocked => A,
455            A => B,
456            B => C,
457            B => D,
458            B => E,
459            E => F,
460            B => G,
461            G => H,
462        }
463    }
464
465    use crate::{LockFor, RwLockFor, Unlocked};
466    use lock_levels::{A, B, C, D, E, F, G, H};
467
468    /// Data type with multiple locked fields.
469    #[derive(Default)]
470    pub struct Data {
471        a: Mutex<u8>,
472        b: Mutex<u16>,
473        c: Mutex<u64>,
474        d: RwLock<u128>,
475        e: Mutex<Mutex<u8>>,
476        g: Mutex<Vec<Mutex<u8>>>,
477        u: usize,
478    }
479
480    impl LockFor<A> for Data {
481        type Data = u8;
482        type Guard<'l> = MutexGuard<'l, u8>;
483        fn lock(&self) -> Self::Guard<'_> {
484            self.a.lock().unwrap()
485        }
486    }
487
488    impl LockFor<B> for Data {
489        type Data = u16;
490        type Guard<'l> = MutexGuard<'l, u16>;
491        fn lock(&self) -> Self::Guard<'_> {
492            self.b.lock().unwrap()
493        }
494    }
495
496    impl LockFor<C> for Data {
497        type Data = u64;
498        type Guard<'l> = MutexGuard<'l, u64>;
499        fn lock(&self) -> Self::Guard<'_> {
500            self.c.lock().unwrap()
501        }
502    }
503
504    impl RwLockFor<D> for Data {
505        type Data = u128;
506        type ReadGuard<'l> = RwLockReadGuard<'l, u128>;
507        type WriteGuard<'l> = RwLockWriteGuard<'l, u128>;
508        fn read_lock(&self) -> Self::ReadGuard<'_> {
509            self.d.read().unwrap()
510        }
511        fn write_lock(&self) -> Self::WriteGuard<'_> {
512            self.d.write().unwrap()
513        }
514    }
515
516    impl LockFor<E> for Data {
517        type Data = Mutex<u8>;
518        type Guard<'l> = MutexGuard<'l, Mutex<u8>>;
519        fn lock(&self) -> Self::Guard<'_> {
520            self.e.lock().unwrap()
521        }
522    }
523
524    impl LockFor<F> for Mutex<u8> {
525        type Data = u8;
526        type Guard<'l> = MutexGuard<'l, u8>;
527        fn lock(&self) -> Self::Guard<'_> {
528            self.lock().unwrap()
529        }
530    }
531
532    impl LockFor<G> for Data {
533        type Data = Vec<Mutex<u8>>;
534        type Guard<'l> = MutexGuard<'l, Vec<Mutex<u8>>>;
535        fn lock(&self) -> Self::Guard<'_> {
536            self.g.lock().unwrap()
537        }
538    }
539
540    impl LockFor<H> for Mutex<u8> {
541        type Data = u8;
542        type Guard<'l> = MutexGuard<'l, u8>;
543        fn lock(&self) -> Self::Guard<'_> {
544            self.lock().unwrap()
545        }
546    }
547
548    #[derive(Debug)]
549    #[allow(dead_code)]
550    struct NotPresent;
551
552    #[test]
553    fn lock_a_then_c() {
554        let data = Data::default();
555
556        let mut w = unsafe { Unlocked::new() };
557        let (_a, wa) = w.lock_and::<A, _>(&data);
558        let (_c, _wc) = wa.lock_and::<C, _>(&data);
559        // This won't compile!
560        // let _b = _wc.lock::<B, _>(&data);
561    }
562
563    #[test]
564    fn cast_a_then_c() {
565        let data = Data::default();
566
567        let mut w = unsafe { Unlocked::new() };
568        let wa = w.cast_locked::<A>();
569        let (_c, _wc) = wa.lock_and::<C, _>(&data);
570        // This should not compile:
571        // let _b = w.lock::<B, _>(&data);
572    }
573
574    #[test]
575    fn unlocked_access_does_not_prevent_locking() {
576        let data = Data { a: Mutex::new(15), u: 34, ..Data::default() };
577
578        let mut locked = unsafe { Unlocked::new() };
579        let u = &data.u;
580
581        // Prove that `u` does not prevent locked state from being accessed.
582        let a = locked.lock::<A, _>(&data);
583
584        assert_eq!(u, &34);
585        assert_eq!(&*a, &15);
586    }
587
588    #[test]
589    fn nested_locks() {
590        let data = Data { e: Mutex::new(Mutex::new(1)), ..Data::default() };
591
592        let mut locked = unsafe { Unlocked::new() };
593        let (e, next_locked) = locked.lock_and::<E, _>(&data);
594        let v = next_locked.lock::<F, _>(&*e);
595        assert_eq!(*v, 1);
596    }
597
598    #[test]
599    fn rw_lock() {
600        let data = Data { d: RwLock::new(1), ..Data::default() };
601
602        let mut locked = unsafe { Unlocked::new() };
603        {
604            let mut d = locked.write_lock::<D, _>(&data);
605            *d = 10;
606        }
607        let d = locked.read_lock::<D, _>(&data);
608        assert_eq!(*d, 10);
609    }
610
611    #[test]
612    fn collections() {
613        let data = Data { g: Mutex::new(vec![Mutex::new(0), Mutex::new(1)]), ..Data::default() };
614
615        let mut locked = unsafe { Unlocked::new() };
616        let (g, next_locked) = locked.lock_and::<G, _>(&data);
617        let v = next_locked.lock::<H, _>(&g[1]);
618        assert_eq!(*v, 1);
619    }
620
621    #[test]
622    fn lock_same_level() {
623        let data1 = Data { a: Mutex::new(5), b: Mutex::new(15), ..Data::default() };
624        let data2 = Data { a: Mutex::new(10), b: Mutex::new(20), ..Data::default() };
625        let mut locked = unsafe { Unlocked::new() };
626        {
627            let (a1, a2, new_locked) = locked.lock_both_and::<A, _>(&data1, &data2);
628            assert_eq!(*a1, 5);
629            assert_eq!(*a2, 10);
630            let (b1, b2) = new_locked.lock_both::<B, _>(&data1, &data2);
631            assert_eq!(*b1, 15);
632            assert_eq!(*b2, 20);
633        }
634        {
635            let (a2, a1) = locked.lock_both::<A, _>(&data2, &data1);
636            assert_eq!(*a1, 5);
637            assert_eq!(*a2, 10);
638        }
639    }
640}