1use std::fmt;
2use std::ops::Deref;
3use std::ops::DerefMut;
4use std::sync;
5use std::sync::LockResult;
6use std::sync::OnceState;
7use std::sync::PoisonError;
8use std::sync::TryLockError;
9use std::sync::TryLockResult;
10use std::sync::WaitTimeoutResult;
11use std::time::Duration;
12
13use crate::BorrowedMutex;
14use crate::LazyMutexId;
15use crate::util::PrivateTraced;
16
17#[cfg(has_std__sync__LazyLock)]
18pub use lazy_lock::LazyLock;
19
20#[cfg(has_std__sync__LazyLock)]
21mod lazy_lock;
22
23#[derive(Debug, Default)]
28pub struct Mutex<T> {
29 inner: sync::Mutex<T>,
30 id: LazyMutexId,
31}
32
33#[derive(Debug)]
38pub struct MutexGuard<'a, T> {
39 inner: sync::MutexGuard<'a, T>,
40 _mutex: BorrowedMutex<'a>,
41}
42
43fn map_lockresult<T, I, F>(result: LockResult<I>, mapper: F) -> LockResult<T>
44where
45 F: FnOnce(I) -> T,
46{
47 match result {
48 Ok(inner) => Ok(mapper(inner)),
49 Err(poisoned) => Err(PoisonError::new(mapper(poisoned.into_inner()))),
50 }
51}
52
53fn map_trylockresult<T, I, F>(result: TryLockResult<I>, mapper: F) -> TryLockResult<T>
54where
55 F: FnOnce(I) -> T,
56{
57 match result {
58 Ok(inner) => Ok(mapper(inner)),
59 Err(TryLockError::WouldBlock) => Err(TryLockError::WouldBlock),
60 Err(TryLockError::Poisoned(poisoned)) => {
61 Err(PoisonError::new(mapper(poisoned.into_inner())).into())
62 }
63 }
64}
65
66impl<T> Mutex<T> {
67 pub const fn new(t: T) -> Self {
69 Self {
70 inner: sync::Mutex::new(t),
71 id: LazyMutexId::new(),
72 }
73 }
74
75 #[track_caller]
82 pub fn lock(&self) -> LockResult<MutexGuard<T>> {
83 let mutex = self.id.get_borrowed();
84 let result = self.inner.lock();
85
86 let mapper = |guard| MutexGuard {
87 _mutex: mutex,
88 inner: guard,
89 };
90
91 map_lockresult(result, mapper)
92 }
93
94 #[track_caller]
101 pub fn try_lock(&self) -> TryLockResult<MutexGuard<T>> {
102 let mutex = self.id.get_borrowed();
103 let result = self.inner.try_lock();
104
105 let mapper = |guard| MutexGuard {
106 _mutex: mutex,
107 inner: guard,
108 };
109
110 map_trylockresult(result, mapper)
111 }
112
113 pub fn is_poisoned(&self) -> bool {
115 self.inner.is_poisoned()
116 }
117
118 pub fn get_mut(&mut self) -> LockResult<&mut T> {
122 self.inner.get_mut()
123 }
124
125 pub fn into_inner(self) -> LockResult<T> {
127 self.inner.into_inner()
128 }
129}
130
131impl<T> PrivateTraced for Mutex<T> {
132 fn get_id(&self) -> &crate::MutexId {
133 &self.id
134 }
135}
136
137impl<T> From<T> for Mutex<T> {
138 fn from(t: T) -> Self {
139 Self::new(t)
140 }
141}
142
143impl<T> Deref for MutexGuard<'_, T> {
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 &self.inner
148 }
149}
150
151impl<T> DerefMut for MutexGuard<'_, T> {
152 fn deref_mut(&mut self) -> &mut Self::Target {
153 &mut self.inner
154 }
155}
156
157impl<T: fmt::Display> fmt::Display for MutexGuard<'_, T> {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 self.inner.fmt(f)
160 }
161}
162
163#[derive(Debug, Default)]
201pub struct Condvar(sync::Condvar);
202
203impl Condvar {
204 pub const fn new() -> Self {
206 Self(sync::Condvar::new())
207 }
208
209 pub fn wait<'a, T>(&self, guard: MutexGuard<'a, T>) -> LockResult<MutexGuard<'a, T>> {
211 let MutexGuard { _mutex, inner } = guard;
212
213 map_lockresult(self.0.wait(inner), |inner| MutexGuard { _mutex, inner })
214 }
215
216 pub fn wait_while<'a, T, F>(
218 &self,
219 guard: MutexGuard<'a, T>,
220 condition: F,
221 ) -> LockResult<MutexGuard<'a, T>>
222 where
223 F: FnMut(&mut T) -> bool,
224 {
225 let MutexGuard { _mutex, inner } = guard;
226
227 map_lockresult(self.0.wait_while(inner, condition), |inner| MutexGuard {
228 _mutex,
229 inner,
230 })
231 }
232
233 pub fn wait_timeout<'a, T>(
235 &self,
236 guard: MutexGuard<'a, T>,
237 dur: Duration,
238 ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> {
239 let MutexGuard { _mutex, inner } = guard;
240
241 map_lockresult(self.0.wait_timeout(inner, dur), |(inner, result)| {
242 (MutexGuard { _mutex, inner }, result)
243 })
244 }
245
246 pub fn wait_timeout_while<'a, T, F>(
248 &self,
249 guard: MutexGuard<'a, T>,
250 dur: Duration,
251 condition: F,
252 ) -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)>
253 where
254 F: FnMut(&mut T) -> bool,
255 {
256 let MutexGuard { _mutex, inner } = guard;
257
258 map_lockresult(
259 self.0.wait_timeout_while(inner, dur, condition),
260 |(inner, result)| (MutexGuard { _mutex, inner }, result),
261 )
262 }
263
264 pub fn notify_one(&self) {
266 self.0.notify_one();
267 }
268
269 pub fn notify_all(&self) {
271 self.0.notify_all();
272 }
273}
274
275#[derive(Debug, Default)]
277pub struct RwLock<T> {
278 inner: sync::RwLock<T>,
279 id: LazyMutexId,
280}
281
282#[derive(Debug)]
286pub struct TracingRwLockGuard<'a, L> {
287 inner: L,
288 _mutex: BorrowedMutex<'a>,
289}
290
291pub type RwLockReadGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockReadGuard<'a, T>>;
293pub type RwLockWriteGuard<'a, T> = TracingRwLockGuard<'a, sync::RwLockWriteGuard<'a, T>>;
295
296impl<T> RwLock<T> {
297 pub const fn new(t: T) -> Self {
298 Self {
299 inner: sync::RwLock::new(t),
300 id: LazyMutexId::new(),
301 }
302 }
303
304 #[track_caller]
311 pub fn read(&self) -> LockResult<RwLockReadGuard<T>> {
312 let mutex = self.id.get_borrowed();
313 let result = self.inner.read();
314
315 map_lockresult(result, |inner| TracingRwLockGuard {
316 inner,
317 _mutex: mutex,
318 })
319 }
320
321 #[track_caller]
328 pub fn write(&self) -> LockResult<RwLockWriteGuard<T>> {
329 let mutex = self.id.get_borrowed();
330 let result = self.inner.write();
331
332 map_lockresult(result, |inner| TracingRwLockGuard {
333 inner,
334 _mutex: mutex,
335 })
336 }
337
338 #[track_caller]
345 pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<T>> {
346 let mutex = self.id.get_borrowed();
347 let result = self.inner.try_read();
348
349 map_trylockresult(result, |inner| TracingRwLockGuard {
350 inner,
351 _mutex: mutex,
352 })
353 }
354
355 #[track_caller]
362 pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<T>> {
363 let mutex = self.id.get_borrowed();
364 let result = self.inner.try_write();
365
366 map_trylockresult(result, |inner| TracingRwLockGuard {
367 inner,
368 _mutex: mutex,
369 })
370 }
371
372 pub fn get_mut(&mut self) -> LockResult<&mut T> {
376 self.inner.get_mut()
377 }
378
379 pub fn into_inner(self) -> LockResult<T> {
381 self.inner.into_inner()
382 }
383}
384
385impl<T> PrivateTraced for RwLock<T> {
386 fn get_id(&self) -> &crate::MutexId {
387 &self.id
388 }
389}
390
391impl<T> From<T> for RwLock<T> {
392 fn from(t: T) -> Self {
393 Self::new(t)
394 }
395}
396
397impl<L, T> Deref for TracingRwLockGuard<'_, L>
398where
399 L: Deref<Target = T>,
400{
401 type Target = T;
402
403 fn deref(&self) -> &Self::Target {
404 self.inner.deref()
405 }
406}
407
408impl<T, L> DerefMut for TracingRwLockGuard<'_, L>
409where
410 L: Deref<Target = T> + DerefMut,
411{
412 fn deref_mut(&mut self) -> &mut Self::Target {
413 self.inner.deref_mut()
414 }
415}
416
417#[derive(Debug)]
422pub struct Once {
423 inner: sync::Once,
424 mutex_id: LazyMutexId,
425}
426
427#[allow(clippy::new_without_default)]
429impl Once {
430 pub const fn new() -> Self {
432 Self {
433 inner: sync::Once::new(),
434 mutex_id: LazyMutexId::new(),
435 }
436 }
437
438 pub fn call_once<F>(&self, f: F)
445 where
446 F: FnOnce(),
447 {
448 self.mutex_id.with_held(|| self.inner.call_once(f))
449 }
450
451 pub fn call_once_force<F>(&self, f: F)
459 where
460 F: FnOnce(&OnceState),
461 {
462 self.mutex_id.with_held(|| self.inner.call_once_force(f))
463 }
464
465 pub fn is_completed(&self) -> bool {
467 self.inner.is_completed()
468 }
469}
470
471impl PrivateTraced for Once {
472 fn get_id(&self) -> &crate::MutexId {
473 &self.mutex_id
474 }
475}
476
477#[derive(Debug)]
503pub struct OnceLock<T> {
504 id: LazyMutexId,
505 inner: sync::OnceLock<T>,
506}
507
508impl<T> OnceLock<T> {
511 pub const fn new() -> Self {
513 Self {
514 id: LazyMutexId::new(),
515 inner: sync::OnceLock::new(),
516 }
517 }
518
519 #[inline]
524 pub fn get(&self) -> Option<&T> {
525 self.inner.get()
526 }
527
528 #[inline]
533 pub fn get_mut(&mut self) -> Option<&mut T> {
534 self.inner.get_mut()
535 }
536
537 pub fn set(&self, value: T) -> Result<(), T> {
542 self.id.with_held(|| self.inner.set(value))
543 }
544
545 pub fn get_or_init<F>(&self, f: F) -> &T
549 where
550 F: FnOnce() -> T,
551 {
552 self.id.with_held(|| self.inner.get_or_init(f))
553 }
554
555 #[inline]
560 pub fn take(&mut self) -> Option<T> {
561 self.inner.take()
562 }
563
564 #[inline]
570 pub fn into_inner(mut self) -> Option<T> {
571 self.take()
572 }
573}
574
575impl<T> PrivateTraced for OnceLock<T> {
576 fn get_id(&self) -> &crate::MutexId {
577 &self.id
578 }
579}
580
581impl<T> Default for OnceLock<T> {
582 #[inline]
583 fn default() -> Self {
584 Self::new()
585 }
586}
587
588impl<T: PartialEq> PartialEq for OnceLock<T> {
589 #[inline]
590 fn eq(&self, other: &Self) -> bool {
591 self.inner == other.inner
592 }
593}
594
595impl<T: Eq> Eq for OnceLock<T> {}
596
597impl<T: Clone> Clone for OnceLock<T> {
598 fn clone(&self) -> Self {
599 Self {
600 id: LazyMutexId::new(),
601 inner: self.inner.clone(),
602 }
603 }
604}
605
606impl<T> From<T> for OnceLock<T> {
607 #[inline]
608 fn from(value: T) -> Self {
609 Self {
610 id: LazyMutexId::new(),
611 inner: sync::OnceLock::from(value),
612 }
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use std::sync::Arc;
619 use std::thread;
620
621 use super::*;
622
623 #[test]
624 fn test_mutex_usage() {
625 let mutex = Arc::new(Mutex::new(0));
626
627 assert_eq!(*mutex.lock().unwrap(), 0);
628 *mutex.lock().unwrap() = 1;
629 assert_eq!(*mutex.lock().unwrap(), 1);
630
631 let mutex_clone = mutex.clone();
632
633 let _guard = mutex.lock().unwrap();
634
635 let handle = thread::spawn(move || {
637 let result = mutex_clone.try_lock().unwrap_err();
638
639 assert!(matches!(result, TryLockError::WouldBlock));
640 });
641
642 handle.join().unwrap();
643 }
644
645 #[test]
646 fn test_rwlock_usage() {
647 let rwlock = Arc::new(RwLock::new(0));
648
649 assert_eq!(*rwlock.read().unwrap(), 0);
650 assert_eq!(*rwlock.write().unwrap(), 0);
651 *rwlock.write().unwrap() = 1;
652 assert_eq!(*rwlock.read().unwrap(), 1);
653 assert_eq!(*rwlock.write().unwrap(), 1);
654
655 let rwlock_clone = rwlock.clone();
656
657 let _read_lock = rwlock.read().unwrap();
658
659 let handle = thread::spawn(move || {
661 let write_result = rwlock_clone.try_write().unwrap_err();
662
663 assert!(matches!(write_result, TryLockError::WouldBlock));
664
665 let _read_lock = rwlock_clone.read().unwrap();
667 });
668
669 handle.join().unwrap();
670 }
671
672 #[test]
673 fn test_once_usage() {
674 let once = Arc::new(Once::new());
675 let once_clone = once.clone();
676
677 assert!(!once.is_completed());
678
679 let handle = thread::spawn(move || {
680 assert!(!once_clone.is_completed());
681
682 once_clone.call_once(|| {});
683
684 assert!(once_clone.is_completed());
685 });
686
687 handle.join().unwrap();
688
689 assert!(once.is_completed());
690 }
691
692 #[test]
693 #[should_panic(expected = "Found cycle in mutex dependency graph")]
694 fn test_detect_cycle() {
695 let a = Mutex::new(());
696 let b = Mutex::new(());
697
698 let hold_a = a.lock().unwrap();
699 let _ = b.lock();
700
701 drop(hold_a);
702
703 let _hold_b = b.lock().unwrap();
704 let _ = a.lock();
705 }
706}