fuchsia_async/
condition.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements a combined mutex and condition.
6//!
7//! # Example:
8//!
9//! ```no_run
10//!     let condition = Condition::new(0);
11//!     condition.when(|state| if state == 1 { Poll::Ready(()) } else { Poll::Pending }).await;
12//!
13//!     // Elsewhere...
14//!     let guard = condition.lock();
15//!     *guard.lock() = 1;
16//!     for waker in guard.drain_wakers() {
17//!         waker.wake();
18//!     }
19//! ```
20
21use std::future::poll_fn;
22use std::marker::PhantomPinned;
23use std::ops::{Deref, DerefMut};
24use std::pin::{pin, Pin};
25use std::ptr::NonNull;
26use std::sync::{Arc, Mutex, MutexGuard};
27use std::task::{Poll, Waker};
28
29/// An async condition which combines a mutex and a condition variable.
30// Condition is implemented as an intrusive doubly linked list.  Typical use should avoid any
31// additional heap allocations after creation, as the nodes of the list are stored as part of the
32// caller's future.
33#[derive(Default)]
34pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
35
36impl<T> Condition<T> {
37    /// Returns a new condition.
38    pub fn new(data: T) -> Self {
39        Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
40    }
41
42    /// Returns the number of wakers waiting on the condition.
43    pub fn waker_count(&self) -> usize {
44        self.0.lock().unwrap().count
45    }
46
47    /// Same as `Mutex::lock`.
48    pub fn lock(&self) -> ConditionGuard<'_, T> {
49        ConditionGuard(self.0.lock().unwrap())
50    }
51
52    /// Returns when `poll` resolves.
53    pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
54        let mut entry = pin!(self.waker_entry());
55        poll_fn(|cx| {
56            let mut guard = self.0.lock().unwrap();
57            // SAFETY: We uphold the pin guarantee.
58            let entry = unsafe { entry.as_mut().get_unchecked_mut() };
59            let result = poll(&mut guard.data);
60            if result.is_pending() {
61                // SAFETY: We set list correctly above.
62                unsafe {
63                    entry.node.add(&mut *guard, cx.waker().clone());
64                }
65            }
66            result
67        })
68        .await
69    }
70
71    /// Returns a new waker entry.
72    pub fn waker_entry(&self) -> WakerEntry<T> {
73        WakerEntry {
74            list: self.0.clone(),
75            node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
76        }
77    }
78}
79
80#[derive(Default)]
81struct Inner<T> {
82    head: Option<NonNull<Node>>,
83    count: usize,
84    data: T,
85}
86
87// SAFETY: Safe because we always access `head` whilst holding the list lock.
88unsafe impl<T: Send> Send for Inner<T> {}
89
90/// Guard returned by `lock`.
91pub struct ConditionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
92
93impl<'a, T> ConditionGuard<'a, T> {
94    /// Adds the waker entry to the condition's list of wakers.
95    pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
96        // SAFETY: We never move the data out.
97        let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
98        // SAFETY: We set list correctly above.
99        unsafe {
100            waker_entry.node.add(&mut *self.0, waker);
101        }
102    }
103
104    /// Returns an iterator that will drain all wakers.  Whilst the drainer exists, a lock is held
105    /// which will prevent new wakers from being added to the list, so depending on your use case,
106    /// you might wish to collect the wakers before calling `wake` on each waker.  NOTE: If the
107    /// drainer is dropped, this will *not* drain elements not visited.
108    pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
109        Drainer(self)
110    }
111
112    /// Returns the number of wakers registered with the condition.
113    pub fn waker_count(&self) -> usize {
114        self.0.count
115    }
116}
117
118impl<T> Deref for ConditionGuard<'_, T> {
119    type Target = T;
120
121    fn deref(&self) -> &Self::Target {
122        &self.0.data
123    }
124}
125
126impl<T> DerefMut for ConditionGuard<'_, T> {
127    fn deref_mut(&mut self) -> &mut Self::Target {
128        &mut self.0.data
129    }
130}
131
132/// A waker entry that can be added to a list.
133pub struct WakerEntry<T> {
134    list: Arc<Mutex<Inner<T>>>,
135    node: Node,
136}
137
138impl<T> Drop for WakerEntry<T> {
139    fn drop(&mut self) {
140        self.node.remove(&mut *self.list.lock().unwrap());
141    }
142}
143
144// The members here must only be accessed whilst holding the mutex on the list.
145struct Node {
146    next: Option<NonNull<Node>>,
147    prev: Option<NonNull<Node>>,
148    waker: Option<Waker>,
149    _pinned: PhantomPinned,
150}
151
152// SAFETY: Safe because we always access all mebers of `Node` whilst holding the list lock.
153unsafe impl Send for Node {}
154
155impl Node {
156    // # Safety
157    //
158    // The waker *must* have `list` set correctly.
159    unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
160        if self.waker.is_none() {
161            self.prev = None;
162            self.next = inner.head;
163            inner.head = Some(self.into());
164            if let Some(mut next) = self.next {
165                // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set
166                // correctly above.
167                unsafe {
168                    next.as_mut().prev = Some(self.into());
169                }
170            }
171            inner.count += 1;
172        }
173        self.waker = Some(waker);
174    }
175
176    fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
177        if self.waker.is_none() {
178            debug_assert!(self.prev.is_none() && self.next.is_none());
179            return None;
180        }
181        if let Some(mut next) = self.next {
182            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
183            unsafe { next.as_mut().prev = self.prev };
184        }
185        if let Some(mut prev) = self.prev {
186            // SAFETY: Safe because we have exclusive access to `Inner` and `head` is set correctly.
187            unsafe { prev.as_mut().next = self.next };
188        } else {
189            inner.head = self.next;
190        }
191        self.prev = None;
192        self.next = None;
193        inner.count -= 1;
194        self.waker.take()
195    }
196}
197
198/// An iterator that will drain waiters.
199pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
200
201impl<T> Iterator for Drainer<'_, '_, T> {
202    type Item = Waker;
203    fn next(&mut self) -> Option<Self::Item> {
204        if let Some(mut head) = self.0 .0.head {
205            // SAFETY: Safe because we have exclusive access to `Inner` and `head is set correctly.
206            unsafe { head.as_mut().remove(&mut self.0 .0) }
207        } else {
208            None
209        }
210    }
211
212    fn size_hint(&self) -> (usize, Option<usize>) {
213        (self.0 .0.count, Some(self.0 .0.count))
214    }
215}
216
217impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
218    fn len(&self) -> usize {
219        self.0 .0.count
220    }
221}
222
223#[cfg(all(target_os = "fuchsia", test))]
224mod tests {
225    use super::Condition;
226    use crate::TestExecutor;
227    use futures::stream::FuturesUnordered;
228    use futures::task::noop_waker;
229    use futures::StreamExt;
230    use std::pin::pin;
231    use std::sync::atomic::{AtomicU64, Ordering};
232    use std::task::Poll;
233
234    #[test]
235    fn test_condition_can_waker_multiple_wakers() {
236        let mut executor = TestExecutor::new();
237        let condition = Condition::new(());
238
239        static COUNT: u64 = 10;
240
241        let counter = AtomicU64::new(0);
242
243        // Use FuturesUnordered so that futures are only polled when explicitly woken.
244        let mut futures = FuturesUnordered::new();
245
246        for _ in 0..COUNT {
247            futures.push(condition.when(|()| {
248                if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
249                    Poll::Ready(())
250                } else {
251                    Poll::Pending
252                }
253            }));
254        }
255
256        assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
257
258        assert_eq!(counter.load(Ordering::Relaxed), COUNT);
259        assert_eq!(condition.waker_count(), COUNT as usize);
260
261        {
262            let mut guard = condition.lock();
263            let drainer = guard.drain_wakers();
264            assert_eq!(drainer.len(), COUNT as usize);
265            for waker in drainer {
266                waker.wake();
267            }
268        }
269
270        assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
271        assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
272    }
273
274    #[test]
275    fn test_dropping_waker_entry_removes_from_list() {
276        let condition = Condition::new(());
277
278        let entry1 = pin!(condition.waker_entry());
279        condition.lock().add_waker(entry1, noop_waker());
280
281        {
282            let entry2 = pin!(condition.waker_entry());
283            condition.lock().add_waker(entry2, noop_waker());
284
285            assert_eq!(condition.waker_count(), 2);
286        }
287
288        assert_eq!(condition.waker_count(), 1);
289        {
290            let mut guard = condition.lock();
291            assert_eq!(guard.drain_wakers().count(), 1);
292        }
293
294        assert_eq!(condition.waker_count(), 0);
295
296        let entry3 = pin!(condition.waker_entry());
297        condition.lock().add_waker(entry3, noop_waker());
298
299        assert_eq!(condition.waker_count(), 1);
300    }
301
302    #[test]
303    fn test_waker_can_be_added_multiple_times() {
304        let condition = Condition::new(());
305
306        let mut entry1 = pin!(condition.waker_entry());
307        condition.lock().add_waker(entry1.as_mut(), noop_waker());
308
309        let mut entry2 = pin!(condition.waker_entry());
310        condition.lock().add_waker(entry2.as_mut(), noop_waker());
311
312        assert_eq!(condition.waker_count(), 2);
313        {
314            let mut guard = condition.lock();
315            assert_eq!(guard.drain_wakers().count(), 2);
316        }
317        assert_eq!(condition.waker_count(), 0);
318
319        condition.lock().add_waker(entry1, noop_waker());
320        condition.lock().add_waker(entry2, noop_waker());
321
322        assert_eq!(condition.waker_count(), 2);
323
324        {
325            let mut guard = condition.lock();
326            assert_eq!(guard.drain_wakers().count(), 2);
327        }
328        assert_eq!(condition.waker_count(), 0);
329    }
330}