tracing_mutex/stdsync/
tracing.rs

1use std::fmt;
2use std::ops::Deref;
3use std::ops::DerefMut;
4use std::sync;
5use std::sync::LockResult;
6use std::sync::OnceState;
7use std::sync::PoisonError;
8use std::sync::TryLockError;
9use std::sync::TryLockResult;
10use std::sync::WaitTimeoutResult;
11use std::time::Duration;
12
13use crate::BorrowedMutex;
14use crate::LazyMutexId;
15use crate::util::PrivateTraced;
16
17#[cfg(has_std__sync__LazyLock)]
18pub use lazy_lock::LazyLock;
19
20#[cfg(has_std__sync__LazyLock)]
21mod lazy_lock;
22
23/// Wrapper for [`std::sync::Mutex`].
24///
25/// Refer to the [crate-level][`crate`] documentation for the differences between this struct and
26/// the one it wraps.
27#[derive(Debug, Default)]
28pub struct Mutex<T> {
29    inner: sync::Mutex<T>,
30    id: LazyMutexId,
31}
32
33/// Wrapper for [`std::sync::MutexGuard`].
34///
35/// Refer to the [crate-level][`crate`] documentation for the differences between this struct and
36/// the one it wraps.
37#[derive(Debug)]
38pub struct MutexGuard<'a, T> {
39    inner: sync::MutexGuard<'a, T>,
40    _mutex: BorrowedMutex<'a>,
41}
42
43fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
44where
45    F: FnOnce(I) -> T,
46{
47    match result {
48        Ok(inner) => Ok(mapper(inner)),
49        Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))),
50    }
51}
52
53fn map_trylockresult<T, I, F>(result: TryLockResult<I>, mapper: F) -> TryLockResult<T>
54where
55    F: FnOnce(I) -> T,
56{
57    match result {
58        Ok(inner) => Ok(mapper(inner)),
59        Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
60        Err(TryLockError::Poisoned(poisoned)) => {
61            Err(PoisonError::new(mapper(poisoned.into_inner())).into())
62        }
63    }
64}
65
66impl<T> Mutex<T> {
67    /// Create a new tracing mutex with the provided value.
68    pub const fn new(t: T) -> Self {
69        Self {
70            inner: sync::Mutex::new(t),
71            id: LazyMutexId::new(),
72        }
73    }
74
75    /// Wrapper for [`std::sync::Mutex::lock`].
76    ///
77    /// # Panics
78    ///
79    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
80    /// dependency cycle, this method will panic.
81    #[track_caller]
82    pub fn lock(&self) -> LockResult<MutexGuard<T>> {
83        let mutex = self.id.get_borrowed();
84        let result = self.inner.lock();
85
86        let mapper = |guard| MutexGuard {
87            _mutex: mutex,
88            inner: guard,
89        };
90
91        map_lockresult(result, mapper)
92    }
93
94    /// Wrapper for [`std::sync::Mutex::try_lock`].
95    ///
96    /// # Panics
97    ///
98    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
99    /// dependency cycle, this method will panic.
100    #[track_caller]
101    pub fn try_lock(&self) -> TryLockResult<MutexGuard<T>> {
102        let mutex = self.id.get_borrowed();
103        let result = self.inner.try_lock();
104
105        let mapper = |guard| MutexGuard {
106            _mutex: mutex,
107            inner: guard,
108        };
109
110        map_trylockresult(result, mapper)
111    }
112
113    /// Wrapper for [`std::sync::Mutex::is_poisoned`].
114    pub fn is_poisoned(&self) -> bool {
115        self.inner.is_poisoned()
116    }
117
118    /// Return a mutable reference to the underlying data.
119    ///
120    /// This method does not block as the locking is handled compile-time by the type system.
121    pub fn get_mut(&mut self) -> LockResult<&mut T> {
122        self.inner.get_mut()
123    }
124
125    /// Unwrap the mutex and return its inner value.
126    pub fn into_inner(self) -> LockResult<T> {
127        self.inner.into_inner()
128    }
129}
130
131impl<T> PrivateTraced for Mutex<T> {
132    fn get_id(&self) -> &crate::MutexId {
133        &self.id
134    }
135}
136
137impl<T> From<T> for Mutex<T> {
138    fn from(t: T) -> Self {
139        Self::new(t)
140    }
141}
142
143impl<T> Deref for MutexGuard<'_, T> {
144    type Target = T;
145
146    fn deref(&self) -> &Self::Target {
147        &self.inner
148    }
149}
150
151impl<T> DerefMut for MutexGuard<'_, T> {
152    fn deref_mut(&mut self) -> &mut Self::Target {
153        &mut self.inner
154    }
155}
156
157impl<T: fmt::Display> fmt::Display for MutexGuard<'_, T> {
158    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159        self.inner.fmt(f)
160    }
161}
162
163/// Wrapper around [`std::sync::Condvar`].
164///
165/// Allows `TracingMutexGuard` to be used with a `Condvar`. Unlike other structs in this module,
166/// this wrapper does not add any additional dependency tracking or other overhead on top of the
167/// primitive it wraps. All dependency tracking happens through the mutexes itself.
168///
169/// # Panics
170///
171/// This struct does not add any panics over the base implementation of `Condvar`, but panics due to
172/// dependency tracking may poison associated mutexes.
173///
174/// # Examples
175///
176/// ```
177/// use std::sync::Arc;
178/// use std::thread;
179///
180/// use tracing_mutex::stdsync::tracing::{Condvar, Mutex};
181///
182/// let pair = Arc::new((Mutex::new(false), Condvar::new()));
183/// let pair2 = Arc::clone(&pair);
184///
185/// // Spawn a thread that will unlock the condvar
186/// thread::spawn(move || {
187///     let (lock, condvar) = &*pair2;
188///     *lock.lock().unwrap() = true;
189///     condvar.notify_one();
190/// });
191///
192/// // Wait until the thread unlocks the condvar
193/// let (lock, condvar) = &*pair;
194/// let guard = lock.lock().unwrap();
195/// let guard = condvar.wait_while(guard, |started| !*started).unwrap();
196///
197/// // Guard should read true now
198/// assert!(*guard);
199/// ```
200#[derive(Debug, Default)]
201pub struct Condvar(sync::Condvar);
202
203impl Condvar {
204    /// Creates a new condition variable which is ready to be waited on and notified.
205    pub const fn new() -> Self {
206        Self(sync::Condvar::new())
207    }
208
209    /// Wrapper for [`std::sync::Condvar::wait`].
210    pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
211        let MutexGuard { _mutex, inner } = guard;
212
213        map_lockresult(self.0.wait(inner), |inner| MutexGuard { _mutex, inner })
214    }
215
216    /// Wrapper for [`std::sync::Condvar::wait_while`].
217    pub fn wait_while<'a, T, F>(
218        &self,
219        guard: MutexGuard<'a, T>,
220        condition: F,
221    ) -> LockResult<MutexGuard<'a, T>>
222    where
223        F: FnMut(&mut T) -> bool,
224    {
225        let MutexGuard { _mutex, inner } = guard;
226
227        map_lockresult(self.0.wait_while(inner, condition), |inner| MutexGuard {
228            _mutex,
229            inner,
230        })
231    }
232
233    /// Wrapper for [`std::sync::Condvar::wait_timeout`].
234    pub fn wait_timeout<'a, T>(
235        &self,
236        guard: MutexGuard<'a, T>,
237        dur: Duration,
238    ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> {
239        let MutexGuard { _mutex, inner } = guard;
240
241        map_lockresult(self.0.wait_timeout(inner, dur), |(inner, result)| {
242            (MutexGuard { _mutex, inner }, result)
243        })
244    }
245
246    /// Wrapper for [`std::sync::Condvar::wait_timeout_while`].
247    pub fn wait_timeout_while<'a, T, F>(
248        &self,
249        guard: MutexGuard<'a, T>,
250        dur: Duration,
251        condition: F,
252    ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)>
253    where
254        F: FnMut(&mut T) -> bool,
255    {
256        let MutexGuard { _mutex, inner } = guard;
257
258        map_lockresult(
259            self.0.wait_timeout_while(inner, dur, condition),
260            |(inner, result)| (MutexGuard { _mutex, inner }, result),
261        )
262    }
263
264    /// Wrapper for [`std::sync::Condvar::notify_one`].
265    pub fn notify_one(&self) {
266        self.0.notify_one();
267    }
268
269    /// Wrapper for [`std::sync::Condvar::notify_all`].
270    pub fn notify_all(&self) {
271        self.0.notify_all();
272    }
273}
274
275/// Wrapper for [`std::sync::RwLock`].
276#[derive(Debug, Default)]
277pub struct RwLock<T> {
278    inner: sync::RwLock<T>,
279    id: LazyMutexId,
280}
281
282/// Hybrid wrapper for both [`std::sync::RwLockReadGuard`] and [`std::sync::RwLockWriteGuard`].
283///
284/// Please refer to [`RwLockReadGuard`] and [`RwLockWriteGuard`] for usable types.
285#[derive(Debug)]
286pub struct TracingRwLockGuard<'a, L> {
287    inner: L,
288    _mutex: BorrowedMutex<'a>,
289}
290
291/// Wrapper around [`std::sync::RwLockReadGuard`].
292pub type RwLockReadGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockReadGuard<'a, T>>;
293/// Wrapper around [`std::sync::RwLockWriteGuard`].
294pub type RwLockWriteGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockWriteGuard<'a, T>>;
295
296impl<T> RwLock<T> {
297    pub const fn new(t: T) -> Self {
298        Self {
299            inner: sync::RwLock::new(t),
300            id: LazyMutexId::new(),
301        }
302    }
303
304    /// Wrapper for [`std::sync::RwLock::read`].
305    ///
306    /// # Panics
307    ///
308    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
309    /// dependency cycle, this method will panic.
310    #[track_caller]
311    pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
312        let mutex = self.id.get_borrowed();
313        let result = self.inner.read();
314
315        map_lockresult(result, |inner| TracingRwLockGuard {
316            inner,
317            _mutex: mutex,
318        })
319    }
320
321    /// Wrapper for [`std::sync::RwLock::write`].
322    ///
323    /// # Panics
324    ///
325    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
326    /// dependency cycle, this method will panic.
327    #[track_caller]
328    pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
329        let mutex = self.id.get_borrowed();
330        let result = self.inner.write();
331
332        map_lockresult(result, |inner| TracingRwLockGuard {
333            inner,
334            _mutex: mutex,
335        })
336    }
337
338    /// Wrapper for [`std::sync::RwLock::try_read`].
339    ///
340    /// # Panics
341    ///
342    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
343    /// dependency cycle, this method will panic.
344    #[track_caller]
345    pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<T>> {
346        let mutex = self.id.get_borrowed();
347        let result = self.inner.try_read();
348
349        map_trylockresult(result, |inner| TracingRwLockGuard {
350            inner,
351            _mutex: mutex,
352        })
353    }
354
355    /// Wrapper for [`std::sync::RwLock::try_write`].
356    ///
357    /// # Panics
358    ///
359    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
360    /// dependency cycle, this method will panic.
361    #[track_caller]
362    pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<T>> {
363        let mutex = self.id.get_borrowed();
364        let result = self.inner.try_write();
365
366        map_trylockresult(result, |inner| TracingRwLockGuard {
367            inner,
368            _mutex: mutex,
369        })
370    }
371
372    /// Return a mutable reference to the underlying data.
373    ///
374    /// This method does not block as the locking is handled compile-time by the type system.
375    pub fn get_mut(&mut self) -> LockResult<&mut T> {
376        self.inner.get_mut()
377    }
378
379    /// Unwrap the mutex and return its inner value.
380    pub fn into_inner(self) -> LockResult<T> {
381        self.inner.into_inner()
382    }
383}
384
385impl<T> PrivateTraced for RwLock<T> {
386    fn get_id(&self) -> &crate::MutexId {
387        &self.id
388    }
389}
390
391impl<T> From<T> for RwLock<T> {
392    fn from(t: T) -> Self {
393        Self::new(t)
394    }
395}
396
397impl<L, T> Deref for TracingRwLockGuard<'_, L>
398where
399    L: Deref<Target = T>,
400{
401    type Target = T;
402
403    fn deref(&self) -> &Self::Target {
404        self.inner.deref()
405    }
406}
407
408impl<T, L> DerefMut for TracingRwLockGuard<'_, L>
409where
410    L: Deref<Target = T> + DerefMut,
411{
412    fn deref_mut(&mut self) -> &mut Self::Target {
413        self.inner.deref_mut()
414    }
415}
416
417/// Wrapper around [`std::sync::Once`].
418///
419/// Refer to the [crate-level][`crate`] documentaiton for the differences between this struct
420/// and the one it wraps.
421#[derive(Debug)]
422pub struct Once {
423    inner: sync::Once,
424    mutex_id: LazyMutexId,
425}
426
427// New without default is intentional, `std::sync::Once` doesn't implement it either
428#[allow(clippy::new_without_default)]
429impl Once {
430    /// Create a new `Once` value.
431    pub const fn new() -> Self {
432        Self {
433            inner: sync::Once::new(),
434            mutex_id: LazyMutexId::new(),
435        }
436    }
437
438    /// Wrapper for [`std::sync::Once::call_once`].
439    ///
440    /// # Panics
441    ///
442    /// In addition to the panics that `Once` can cause, this method will panic if calling it
443    /// introduces a cycle in the lock dependency graph.
444    pub fn call_once<F>(&self, f: F)
445    where
446        F: FnOnce(),
447    {
448        self.mutex_id.with_held(|| self.inner.call_once(f))
449    }
450
451    /// Performs the same operation as [`call_once`][Once::call_once] except it ignores
452    /// poisoning.
453    ///
454    /// # Panics
455    ///
456    /// This method participates in lock dependency tracking. If acquiring this lock introduces a
457    /// dependency cycle, this method will panic.
458    pub fn call_once_force<F>(&self, f: F)
459    where
460        F: FnOnce(&OnceState),
461    {
462        self.mutex_id.with_held(|| self.inner.call_once_force(f))
463    }
464
465    /// Returns true if some `call_once` has completed successfully.
466    pub fn is_completed(&self) -> bool {
467        self.inner.is_completed()
468    }
469}
470
471impl PrivateTraced for Once {
472    fn get_id(&self) -> &crate::MutexId {
473        &self.mutex_id
474    }
475}
476
477/// Wrapper for [`std::sync::OnceLock`]
478///
479/// The exact locking behaviour of [`std::sync::OnceLock`] is currently undefined, but may
480/// deadlock in the event of reentrant initialization attempts. This wrapper participates in
481/// cycle detection as normal and will therefore panic in the event of reentrancy.
482///
483/// Most of this primitive's methods do not involve locking and as such are simply passed
484/// through to the inner implementation.
485///
486/// # Examples
487///
488/// ```
489/// use tracing_mutex::stdsync::tracing::OnceLock;
490///
491/// static LOCK: OnceLock<i32> = OnceLock::new();
492/// assert!(LOCK.get().is_none());
493///
494/// std::thread::spawn(|| {
495///    let value: &i32 = LOCK.get_or_init(|| 42);
496///    assert_eq!(value, &42);
497/// }).join().unwrap();
498///
499/// let value: Option<&i32> = LOCK.get();
500/// assert_eq!(value, Some(&42));
501/// ```
502#[derive(Debug)]
503pub struct OnceLock<T> {
504    id: LazyMutexId,
505    inner: sync::OnceLock<T>,
506}
507
508// N.B. this impl inlines everything that directly calls the inner implementation as there
509// should be 0 overhead to doing so.
510impl<T> OnceLock<T> {
511    /// Creates a new empty cell
512    pub const fn new() -> Self {
513        Self {
514            id: LazyMutexId::new(),
515            inner: sync::OnceLock::new(),
516        }
517    }
518
519    /// Gets a reference to the underlying value.
520    ///
521    /// This method does not attempt to lock and therefore does not participate in cycle
522    /// detection.
523    #[inline]
524    pub fn get(&self) -> Option<&T> {
525        self.inner.get()
526    }
527
528    /// Gets a mutable reference to the underlying value.
529    ///
530    /// This method does not attempt to lock and therefore does not participate in cycle
531    /// detection.
532    #[inline]
533    pub fn get_mut(&mut self) -> Option<&mut T> {
534        self.inner.get_mut()
535    }
536
537    /// Sets the contents of this cell to the underlying value
538    ///
539    /// As this method may block until initialization is complete, it participates in cycle
540    /// detection.
541    pub fn set(&self, value: T) -> Result<(), T> {
542        self.id.with_held(|| self.inner.set(value))
543    }
544
545    /// Gets the contents of the cell, initializing it with `f` if the cell was empty.
546    ///
547    /// This method participates in cycle detection. Reentrancy is considered a cycle.
548    pub fn get_or_init<F>(&self, f: F) -> &T
549    where
550        F: FnOnce() -> T,
551    {
552        self.id.with_held(|| self.inner.get_or_init(f))
553    }
554
555    /// Takes the value out of this `OnceLock`, moving it back to an uninitialized state.
556    ///
557    /// This method does not attempt to lock and therefore does not participate in cycle
558    /// detection.
559    #[inline]
560    pub fn take(&mut self) -> Option<T> {
561        self.inner.take()
562    }
563
564    /// Consumes the `OnceLock`, returning the wrapped value. Returns None if the cell was
565    /// empty.
566    ///
567    /// This method does not attempt to lock and therefore does not participate in cycle
568    /// detection.
569    #[inline]
570    pub fn into_inner(mut self) -> Option<T> {
571        self.take()
572    }
573}
574
575impl<T> PrivateTraced for OnceLock<T> {
576    fn get_id(&self) -> &crate::MutexId {
577        &self.id
578    }
579}
580
581impl<T> Default for OnceLock<T> {
582    #[inline]
583    fn default() -> Self {
584        Self::new()
585    }
586}
587
588impl<T: PartialEq> PartialEq for OnceLock<T> {
589    #[inline]
590    fn eq(&self, other: &Self) -> bool {
591        self.inner == other.inner
592    }
593}
594
595impl<T: Eq> Eq for OnceLock<T> {}
596
597impl<T: Clone> Clone for OnceLock<T> {
598    fn clone(&self) -> Self {
599        Self {
600            id: LazyMutexId::new(),
601            inner: self.inner.clone(),
602        }
603    }
604}
605
606impl<T> From<T> for OnceLock<T> {
607    #[inline]
608    fn from(value: T) -> Self {
609        Self {
610            id: LazyMutexId::new(),
611            inner: sync::OnceLock::from(value),
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use std::sync::Arc;
619    use std::thread;
620
621    use super::*;
622
623    #[test]
624    fn test_mutex_usage() {
625        let mutex = Arc::new(Mutex::new(0));
626
627        assert_eq!(*mutex.lock().unwrap(), 0);
628        *mutex.lock().unwrap() = 1;
629        assert_eq!(*mutex.lock().unwrap(), 1);
630
631        let mutex_clone = mutex.clone();
632
633        let _guard = mutex.lock().unwrap();
634
635        // Now try to cause a blocking exception in another thread
636        let handle = thread::spawn(move || {
637            let result = mutex_clone.try_lock().unwrap_err();
638
639            assert!(matches!(result, TryLockError::WouldBlock));
640        });
641
642        handle.join().unwrap();
643    }
644
645    #[test]
646    fn test_rwlock_usage() {
647        let rwlock = Arc::new(RwLock::new(0));
648
649        assert_eq!(*rwlock.read().unwrap(), 0);
650        assert_eq!(*rwlock.write().unwrap(), 0);
651        *rwlock.write().unwrap() = 1;
652        assert_eq!(*rwlock.read().unwrap(), 1);
653        assert_eq!(*rwlock.write().unwrap(), 1);
654
655        let rwlock_clone = rwlock.clone();
656
657        let _read_lock = rwlock.read().unwrap();
658
659        // Now try to cause a blocking exception in another thread
660        let handle = thread::spawn(move || {
661            let write_result = rwlock_clone.try_write().unwrap_err();
662
663            assert!(matches!(write_result, TryLockError::WouldBlock));
664
665            // Should be able to get a read lock just fine.
666            let _read_lock = rwlock_clone.read().unwrap();
667        });
668
669        handle.join().unwrap();
670    }
671
672    #[test]
673    fn test_once_usage() {
674        let once = Arc::new(Once::new());
675        let once_clone = once.clone();
676
677        assert!(!once.is_completed());
678
679        let handle = thread::spawn(move || {
680            assert!(!once_clone.is_completed());
681
682            once_clone.call_once(|| {});
683
684            assert!(once_clone.is_completed());
685        });
686
687        handle.join().unwrap();
688
689        assert!(once.is_completed());
690    }
691
692    #[test]
693    #[should_panic(expected = "Found cycle in mutex dependency graph")]
694    fn test_detect_cycle() {
695        let a = Mutex::new(());
696        let b = Mutex::new(());
697
698        let hold_a = a.lock().unwrap();
699        let _ = b.lock();
700
701        drop(hold_a);
702
703        let _hold_b = b.lock().unwrap();
704        let _ = a.lock();
705    }
706}