fuchsia_async/
condition.rs1use std::future::poll_fn;
22use std::marker::PhantomPinned;
23use std::ops::{Deref, DerefMut};
24use std::pin::{pin, Pin};
25use std::ptr::NonNull;
26use std::sync::{Arc, Mutex, MutexGuard};
27use std::task::{Poll, Waker};
28
29#[derive(Default)]
34pub struct Condition<T>(Arc<Mutex<Inner<T>>>);
35
36impl<T> Condition<T> {
37 pub fn new(data: T) -> Self {
39 Self(Arc::new(Mutex::new(Inner { head: None, count: 0, data })))
40 }
41
42 pub fn waker_count(&self) -> usize {
44 self.0.lock().unwrap().count
45 }
46
47 pub fn lock(&self) -> ConditionGuard<'_, T> {
49 ConditionGuard(self.0.lock().unwrap())
50 }
51
52 pub async fn when<R>(&self, poll: impl Fn(&mut T) -> Poll<R>) -> R {
54 let mut entry = pin!(self.waker_entry());
55 poll_fn(|cx| {
56 let mut guard = self.0.lock().unwrap();
57 let entry = unsafe { entry.as_mut().get_unchecked_mut() };
59 let result = poll(&mut guard.data);
60 if result.is_pending() {
61 unsafe {
63 entry.node.add(&mut *guard, cx.waker().clone());
64 }
65 }
66 result
67 })
68 .await
69 }
70
71 pub fn waker_entry(&self) -> WakerEntry<T> {
73 WakerEntry {
74 list: self.0.clone(),
75 node: Node { next: None, prev: None, waker: None, _pinned: PhantomPinned },
76 }
77 }
78}
79
80#[derive(Default)]
81struct Inner<T> {
82 head: Option<NonNull<Node>>,
83 count: usize,
84 data: T,
85}
86
87unsafe impl<T: Send> Send for Inner<T> {}
89
90pub struct ConditionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
92
93impl<'a, T> ConditionGuard<'a, T> {
94 pub fn add_waker(&mut self, waker_entry: Pin<&mut WakerEntry<T>>, waker: Waker) {
96 let waker_entry = unsafe { waker_entry.get_unchecked_mut() };
98 unsafe {
100 waker_entry.node.add(&mut *self.0, waker);
101 }
102 }
103
104 pub fn drain_wakers<'b>(&'b mut self) -> Drainer<'b, 'a, T> {
109 Drainer(self)
110 }
111
112 pub fn waker_count(&self) -> usize {
114 self.0.count
115 }
116}
117
118impl<T> Deref for ConditionGuard<'_, T> {
119 type Target = T;
120
121 fn deref(&self) -> &Self::Target {
122 &self.0.data
123 }
124}
125
126impl<T> DerefMut for ConditionGuard<'_, T> {
127 fn deref_mut(&mut self) -> &mut Self::Target {
128 &mut self.0.data
129 }
130}
131
132pub struct WakerEntry<T> {
134 list: Arc<Mutex<Inner<T>>>,
135 node: Node,
136}
137
138impl<T> Drop for WakerEntry<T> {
139 fn drop(&mut self) {
140 self.node.remove(&mut *self.list.lock().unwrap());
141 }
142}
143
144struct Node {
146 next: Option<NonNull<Node>>,
147 prev: Option<NonNull<Node>>,
148 waker: Option<Waker>,
149 _pinned: PhantomPinned,
150}
151
152unsafe impl Send for Node {}
154
155impl Node {
156 unsafe fn add<T>(&mut self, inner: &mut Inner<T>, waker: Waker) {
160 if self.waker.is_none() {
161 self.prev = None;
162 self.next = inner.head;
163 inner.head = Some(self.into());
164 if let Some(mut next) = self.next {
165 unsafe {
168 next.as_mut().prev = Some(self.into());
169 }
170 }
171 inner.count += 1;
172 }
173 self.waker = Some(waker);
174 }
175
176 fn remove<T>(&mut self, inner: &mut Inner<T>) -> Option<Waker> {
177 if self.waker.is_none() {
178 debug_assert!(self.prev.is_none() && self.next.is_none());
179 return None;
180 }
181 if let Some(mut next) = self.next {
182 unsafe { next.as_mut().prev = self.prev };
184 }
185 if let Some(mut prev) = self.prev {
186 unsafe { prev.as_mut().next = self.next };
188 } else {
189 inner.head = self.next;
190 }
191 self.prev = None;
192 self.next = None;
193 inner.count -= 1;
194 self.waker.take()
195 }
196}
197
198pub struct Drainer<'a, 'b, T>(&'a mut ConditionGuard<'b, T>);
200
201impl<T> Iterator for Drainer<'_, '_, T> {
202 type Item = Waker;
203 fn next(&mut self) -> Option<Self::Item> {
204 if let Some(mut head) = self.0 .0.head {
205 unsafe { head.as_mut().remove(&mut self.0 .0) }
207 } else {
208 None
209 }
210 }
211
212 fn size_hint(&self) -> (usize, Option<usize>) {
213 (self.0 .0.count, Some(self.0 .0.count))
214 }
215}
216
217impl<T> ExactSizeIterator for Drainer<'_, '_, T> {
218 fn len(&self) -> usize {
219 self.0 .0.count
220 }
221}
222
223#[cfg(all(target_os = "fuchsia", test))]
224mod tests {
225 use super::Condition;
226 use crate::TestExecutor;
227 use futures::stream::FuturesUnordered;
228 use futures::task::noop_waker;
229 use futures::StreamExt;
230 use std::pin::pin;
231 use std::sync::atomic::{AtomicU64, Ordering};
232 use std::task::Poll;
233
234 #[test]
235 fn test_condition_can_waker_multiple_wakers() {
236 let mut executor = TestExecutor::new();
237 let condition = Condition::new(());
238
239 static COUNT: u64 = 10;
240
241 let counter = AtomicU64::new(0);
242
243 let mut futures = FuturesUnordered::new();
245
246 for _ in 0..COUNT {
247 futures.push(condition.when(|()| {
248 if counter.fetch_add(1, Ordering::Relaxed) >= COUNT {
249 Poll::Ready(())
250 } else {
251 Poll::Pending
252 }
253 }));
254 }
255
256 assert!(executor.run_until_stalled(&mut futures.next()).is_pending());
257
258 assert_eq!(counter.load(Ordering::Relaxed), COUNT);
259 assert_eq!(condition.waker_count(), COUNT as usize);
260
261 {
262 let mut guard = condition.lock();
263 let drainer = guard.drain_wakers();
264 assert_eq!(drainer.len(), COUNT as usize);
265 for waker in drainer {
266 waker.wake();
267 }
268 }
269
270 assert!(executor.run_until_stalled(&mut futures.collect::<Vec<_>>()).is_ready());
271 assert_eq!(counter.load(Ordering::Relaxed), COUNT * 2);
272 }
273
274 #[test]
275 fn test_dropping_waker_entry_removes_from_list() {
276 let condition = Condition::new(());
277
278 let entry1 = pin!(condition.waker_entry());
279 condition.lock().add_waker(entry1, noop_waker());
280
281 {
282 let entry2 = pin!(condition.waker_entry());
283 condition.lock().add_waker(entry2, noop_waker());
284
285 assert_eq!(condition.waker_count(), 2);
286 }
287
288 assert_eq!(condition.waker_count(), 1);
289 {
290 let mut guard = condition.lock();
291 assert_eq!(guard.drain_wakers().count(), 1);
292 }
293
294 assert_eq!(condition.waker_count(), 0);
295
296 let entry3 = pin!(condition.waker_entry());
297 condition.lock().add_waker(entry3, noop_waker());
298
299 assert_eq!(condition.waker_count(), 1);
300 }
301
302 #[test]
303 fn test_waker_can_be_added_multiple_times() {
304 let condition = Condition::new(());
305
306 let mut entry1 = pin!(condition.waker_entry());
307 condition.lock().add_waker(entry1.as_mut(), noop_waker());
308
309 let mut entry2 = pin!(condition.waker_entry());
310 condition.lock().add_waker(entry2.as_mut(), noop_waker());
311
312 assert_eq!(condition.waker_count(), 2);
313 {
314 let mut guard = condition.lock();
315 assert_eq!(guard.drain_wakers().count(), 2);
316 }
317 assert_eq!(condition.waker_count(), 0);
318
319 condition.lock().add_waker(entry1, noop_waker());
320 condition.lock().add_waker(entry2, noop_waker());
321
322 assert_eq!(condition.waker_count(), 2);
323
324 {
325 let mut guard = condition.lock();
326 assert_eq!(guard.drain_wakers().count(), 2);
327 }
328 assert_eq!(condition.waker_count(), 0);
329 }
330}