1use crate::token_registry::TokenRegistry;
20
21use fuchsia_async::{self as fasync, JoinHandle, Scope, SpawnableFuture};
22use fuchsia_sync::Mutex;
23use futures::task::{self, Poll};
24use futures::Future;
25use pin_project::pin_project;
26use std::future::{pending, poll_fn};
27use std::pin::{pin, Pin};
28use std::sync::{Arc, OnceLock};
29use std::task::{ready, Context};
30
31#[cfg(target_os = "fuchsia")]
32use fuchsia_async::EHandle;
33
34pub type SpawnError = task::SpawnError;
35
36#[derive(Clone)]
47pub struct ExecutionScope {
48 executor: Arc<Executor>,
49}
50
51struct Executor {
52 inner: Mutex<Inner>,
53 token_registry: TokenRegistry,
54 scope: OnceLock<Scope>,
55}
56
57struct Inner {
58 shutdown_state: ShutdownState,
60
61 active_count: usize,
63
64 fake_active_task: Option<fasync::Task<()>>,
67}
68
69#[derive(Copy, Clone, PartialEq)]
70enum ShutdownState {
71 Active,
72 Shutdown,
73 ForceShutdown,
74}
75
76impl ExecutionScope {
77 pub fn new() -> Self {
80 Self::build().new()
81 }
82
83 pub fn build() -> ExecutionScopeParams {
87 ExecutionScopeParams::default()
88 }
89
90 pub fn active_count(&self) -> usize {
92 self.executor.inner.lock().active_count
93 }
94
95 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
107 self.executor.scope().spawn(FutureWithShutdown { executor: self.executor.clone(), task })
108 }
109
110 pub fn new_task(self, task: impl Future<Output = ()> + Send + 'static) -> Task {
112 Task(
113 self.executor.clone(),
114 SpawnableFuture::new(FutureWithShutdown { executor: self.executor, task }),
115 )
116 }
117
118 pub fn token_registry(&self) -> &TokenRegistry {
119 &self.executor.token_registry
120 }
121
122 pub fn shutdown(&self) {
123 self.executor.shutdown();
124 }
125
126 pub fn force_shutdown(&self) {
128 let mut inner = self.executor.inner.lock();
129 inner.shutdown_state = ShutdownState::ForceShutdown;
130 self.executor.scope().wake_all();
131 }
132
133 pub fn resurrect(&self) {
136 self.executor.inner.lock().shutdown_state = ShutdownState::Active;
137 }
138
139 pub async fn wait(&self) {
141 let mut on_no_tasks = pin!(self.executor.scope().on_no_tasks());
142 poll_fn(|cx| {
143 let mut inner = self.executor.inner.lock();
145 ready!(on_no_tasks.as_mut().poll(cx));
146 if inner.active_count == 0 {
147 Poll::Ready(())
148 } else {
149 let scope = self.executor.scope();
154 inner.fake_active_task = Some(scope.compute(pending::<()>()));
155 on_no_tasks.set(scope.on_no_tasks());
156 assert!(on_no_tasks.as_mut().poll(cx).is_pending());
157 Poll::Pending
158 }
159 })
160 .await;
161 }
162
163 pub fn try_active_guard(&self) -> Option<ActiveGuard> {
166 let mut inner = self.executor.inner.lock();
167 if inner.shutdown_state != ShutdownState::Active {
168 return None;
169 }
170 inner.active_count += 1;
171 Some(ActiveGuard(self.executor.clone()))
172 }
173
174 pub fn active_guard(&self) -> ActiveGuard {
177 self.executor.inner.lock().active_count += 1;
178 ActiveGuard(self.executor.clone())
179 }
180}
181
182impl PartialEq for ExecutionScope {
183 fn eq(&self, other: &Self) -> bool {
184 Arc::as_ptr(&self.executor) == Arc::as_ptr(&other.executor)
185 }
186}
187
188impl Eq for ExecutionScope {}
189
190impl std::fmt::Debug for ExecutionScope {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.write_fmt(format_args!("ExecutionScope {:?}", Arc::as_ptr(&self.executor)))
193 }
194}
195
196#[derive(Default)]
197pub struct ExecutionScopeParams {
198 #[cfg(target_os = "fuchsia")]
199 async_executor: Option<EHandle>,
200}
201
202impl ExecutionScopeParams {
203 #[cfg(target_os = "fuchsia")]
204 pub fn executor(mut self, value: EHandle) -> Self {
205 assert!(self.async_executor.is_none(), "`executor` is already set");
206 self.async_executor = Some(value);
207 self
208 }
209
210 pub fn new(self) -> ExecutionScope {
211 ExecutionScope {
212 executor: Arc::new(Executor {
213 token_registry: TokenRegistry::new(),
214 inner: Mutex::new(Inner {
215 shutdown_state: ShutdownState::Active,
216 active_count: 0,
217 fake_active_task: None,
218 }),
219 #[cfg(target_os = "fuchsia")]
220 scope: self
221 .async_executor
222 .map_or_else(|| OnceLock::new(), |e| e.global_scope().new_child().into()),
223 #[cfg(not(target_os = "fuchsia"))]
224 scope: OnceLock::new(),
225 }),
226 }
227 }
228}
229
230impl Executor {
231 fn scope(&self) -> &Scope {
232 self.scope.get_or_init(|| {
236 #[cfg(target_os = "fuchsia")]
237 return Scope::global().new_child();
238 #[cfg(not(target_os = "fuchsia"))]
239 return Scope::new();
240 })
241 }
242
243 fn shutdown(&self) {
244 let wake_all = {
245 let mut inner = self.inner.lock();
246 inner.shutdown_state = ShutdownState::Shutdown;
247 inner.active_count == 0
248 };
249 if wake_all {
250 if let Some(scope) = self.scope.get() {
251 scope.wake_all();
252 }
253 }
254 }
255}
256
257impl Drop for Executor {
258 fn drop(&mut self) {
259 self.shutdown();
260 }
261}
262
263pub struct ActiveGuard(Arc<Executor>);
265
266impl Drop for ActiveGuard {
267 fn drop(&mut self) {
268 let wake_all = {
269 let mut inner = self.0.inner.lock();
270 inner.active_count -= 1;
271 if inner.active_count == 0 {
272 if let Some(task) = inner.fake_active_task.take() {
273 let _ = task.cancel();
274 }
275 }
276 inner.active_count == 0 && inner.shutdown_state == ShutdownState::Shutdown
277 };
278 if wake_all {
279 self.0.scope().wake_all();
280 }
281 }
282}
283
284pub async fn yield_to_executor() {
286 let mut done = false;
287 poll_fn(|cx| {
288 if done {
289 Poll::Ready(())
290 } else {
291 done = true;
292 cx.waker().wake_by_ref();
293 Poll::Pending
294 }
295 })
296 .await;
297}
298
299#[pin_project]
301struct FutureWithShutdown<Task: Future<Output = ()> + Send + 'static> {
302 executor: Arc<Executor>,
303 #[pin]
304 task: Task,
305}
306
307impl<Task: Future<Output = ()> + Send + 'static> Future for FutureWithShutdown<Task> {
308 type Output = ();
309
310 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
311 let this = self.project();
312 let shutdown_state = this.executor.inner.lock().shutdown_state;
313 match this.task.poll(cx) {
314 Poll::Ready(()) => Poll::Ready(()),
315 Poll::Pending => match shutdown_state {
316 ShutdownState::Active => Poll::Pending,
317 ShutdownState::Shutdown if this.executor.inner.lock().active_count > 0 => {
318 Poll::Pending
319 }
320 _ => Poll::Ready(()),
321 },
322 }
323 }
324}
325
326pub struct Task(Arc<Executor>, SpawnableFuture<'static, ()>);
327
328impl Task {
329 pub fn spawn(self) {
331 self.0.scope().spawn(self.1);
332 }
333}
334
335impl Future for Task {
336 type Output = ();
337
338 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
339 Pin::new(&mut &mut self.1).poll(cx)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::{yield_to_executor, ExecutionScope};
346
347 use fuchsia_async::{Task, TestExecutor, Timer};
348 use futures::channel::oneshot;
349 use futures::stream::FuturesUnordered;
350 use futures::task::Poll;
351 use futures::{Future, StreamExt};
352 use std::pin::pin;
353 use std::sync::atomic::{AtomicBool, Ordering};
354 use std::sync::Arc;
355 use std::time::Duration;
356
357 #[cfg(target_os = "fuchsia")]
358 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
359 where
360 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
361 GetTestRes: Future<Output = ()>,
362 {
363 let mut exec = TestExecutor::new();
364
365 let scope = ExecutionScope::new();
366
367 let test = get_test(scope);
368
369 assert_eq!(
370 exec.run_until_stalled(&mut pin!(test)),
371 Poll::Ready(()),
372 "Test did not complete"
373 );
374 }
375
376 #[cfg(not(target_os = "fuchsia"))]
377 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
378 where
379 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
380 GetTestRes: Future<Output = ()>,
381 {
382 use fuchsia_async::TimeoutExt;
383 let mut exec = TestExecutor::new();
384
385 let scope = ExecutionScope::new();
386
387 let test =
391 get_test(scope).on_stalled(Duration::from_secs(30), || panic!("Test did not complete"));
392
393 exec.run_singlethreaded(&mut pin!(test));
394 }
395
396 #[test]
397 fn simple() {
398 run_test(|scope| {
399 async move {
400 let (sender, receiver) = oneshot::channel();
401 let (counters, task) = mocks::ImmediateTask::new(sender);
402
403 scope.spawn(task);
404
405 receiver.await.unwrap();
407
408 assert_eq!(counters.drop_call(), 1);
409 assert_eq!(counters.poll_call(), 1);
410 }
411 });
412 }
413
414 #[test]
415 fn simple_drop() {
416 run_test(|scope| {
417 async move {
418 let (poll_sender, poll_receiver) = oneshot::channel();
419 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
420 let (drop_sender, drop_receiver) = oneshot::channel();
421 let (counters, task) =
422 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
423
424 scope.spawn(task);
425
426 poll_receiver.await.unwrap();
427
428 processing_done_sender.send(()).unwrap();
429
430 scope.shutdown();
431
432 drop_receiver.await.unwrap();
433
434 let poll_count = counters.poll_call();
437 assert!(poll_count >= 1, "poll was not called");
438
439 assert_eq!(counters.drop_call(), 1);
440 }
441 });
442 }
443
444 #[test]
445 fn test_wait_waits_for_tasks_to_finish() {
446 let mut executor = TestExecutor::new();
447 let scope = ExecutionScope::new();
448 executor.run_singlethreaded(async {
449 let (poll_sender, poll_receiver) = oneshot::channel();
450 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
451 let (drop_sender, _drop_receiver) = oneshot::channel();
452 let (_, task) =
453 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
454
455 scope.spawn(task);
456
457 poll_receiver.await.unwrap();
458
459 let done = fuchsia_sync::Mutex::new(false);
462 futures::join!(
463 async {
464 scope.wait().await;
465 assert_eq!(*done.lock(), true);
466 },
467 async {
468 Timer::new(Duration::from_millis(100)).await;
470 *done.lock() = true;
471 processing_done_sender.send(()).unwrap();
472 }
473 );
474 });
475 }
476
477 #[fuchsia::test]
478 async fn test_active_guard() {
479 let scope = ExecutionScope::new();
480 let (guard_taken_tx, guard_taken_rx) = oneshot::channel();
481 let (shutdown_triggered_tx, shutdown_triggered_rx) = oneshot::channel();
482 let (drop_task_tx, drop_task_rx) = oneshot::channel();
483 let scope_clone = scope.clone();
484 let done = Arc::new(AtomicBool::new(false));
485 let done_clone = done.clone();
486 scope.spawn(async move {
487 {
488 struct OnDrop((ExecutionScope, Option<oneshot::Receiver<()>>));
489 impl Drop for OnDrop {
490 fn drop(&mut self) {
491 let guard = self.0 .0.active_guard();
492 let rx = self.0 .1.take().unwrap();
493 Task::spawn(async move {
494 rx.await.unwrap();
495 std::mem::drop(guard);
496 })
497 .detach();
498 }
499 }
500 let _guard = scope_clone.try_active_guard().unwrap();
501 let _on_drop = OnDrop((scope_clone, Some(drop_task_rx)));
502 guard_taken_tx.send(()).unwrap();
503 shutdown_triggered_rx.await.unwrap();
504 Timer::new(std::time::Duration::from_millis(100)).await;
507 done_clone.store(true, Ordering::SeqCst);
508 }
509 });
510 guard_taken_rx.await.unwrap();
511 scope.shutdown();
512
513 Timer::new(std::time::Duration::from_millis(100)).await;
516 let mut shutdown_wait = std::pin::pin!(scope.wait());
517 assert_eq!(futures::poll!(shutdown_wait.as_mut()), Poll::Pending);
518
519 shutdown_triggered_tx.send(()).unwrap();
520
521 Timer::new(std::time::Duration::from_millis(100)).await;
523 assert_eq!(futures::poll!(shutdown_wait.as_mut()), Poll::Pending);
524
525 drop_task_tx.send(()).unwrap();
526
527 shutdown_wait.await;
528
529 assert!(done.load(Ordering::SeqCst));
530 }
531
532 #[cfg(target_os = "fuchsia")]
533 #[fuchsia::test]
534 async fn test_shutdown_waits_for_channels() {
535 use fuchsia_async as fasync;
536
537 let scope = ExecutionScope::new();
538 let (rx, tx) = zx::Channel::create();
539 let received_msg = Arc::new(AtomicBool::new(false));
540 let (sender, receiver) = futures::channel::oneshot::channel();
541 {
542 let received_msg = received_msg.clone();
543 scope.spawn(async move {
544 let mut msg_buf = zx::MessageBuf::new();
545 msg_buf.ensure_capacity_bytes(64);
546 let _ = sender.send(());
547 let _ = fasync::Channel::from_channel(rx).recv_msg(&mut msg_buf).await;
548 received_msg.store(true, Ordering::Relaxed);
549 });
550 }
551 let _ = receiver.await;
553
554 tx.write(b"hello", &mut []).expect("write failed");
555 scope.shutdown();
556 scope.wait().await;
557 assert!(received_msg.load(Ordering::Relaxed));
558 }
559
560 #[fuchsia::test]
561 async fn test_force_shutdown() {
562 let scope = ExecutionScope::new();
563 let scope_clone = scope.clone();
564 let ref_count = Arc::new(());
565 let ref_count_clone = ref_count.clone();
566
567 scope.spawn(async move {
570 let _ref_count_clone = ref_count_clone;
571
572 let _guard = scope_clone.active_guard();
574
575 let _: () = std::future::pending().await;
576 });
577
578 scope.force_shutdown();
579 scope.wait().await;
580
581 assert_eq!(Arc::strong_count(&ref_count), 1);
583
584 scope.resurrect();
586
587 let ref_count_clone = ref_count.clone();
588 scope.spawn(async move {
589 yield_to_executor().await;
591
592 let _ref_count = ref_count_clone.clone();
594
595 let _: () = std::future::pending().await;
596 });
597
598 while Arc::strong_count(&ref_count) != 3 {
599 yield_to_executor().await;
600 }
601
602 for _ in 0..5 {
604 yield_to_executor().await;
605 assert_eq!(Arc::strong_count(&ref_count), 3);
606 }
607 }
608
609 #[fuchsia::test]
610 async fn test_task_runs_once() {
611 let scope = ExecutionScope::new();
612
613 scope.spawn(async {});
615
616 scope.shutdown();
617
618 let polled = Arc::new(AtomicBool::new(false));
619 let polled_clone = polled.clone();
620
621 let scope_clone = scope.clone();
622
623 let mut futures = FuturesUnordered::new();
625 futures.push(async move { scope_clone.wait().await });
626
627 assert_eq!(futures::poll!(futures.next()), Poll::Pending);
629
630 scope.spawn(async move {
633 assert_eq!(futures::poll!(futures.next()), Poll::Pending);
634 polled_clone.store(true, Ordering::Relaxed);
635 });
636
637 scope.wait().await;
638
639 assert!(polled.load(Ordering::Relaxed));
641 }
642
643 mod mocks {
644 use futures::channel::oneshot;
645 use futures::task::{Context, Poll};
646 use futures::Future;
647 use std::pin::Pin;
648 use std::sync::atomic::{AtomicUsize, Ordering};
649 use std::sync::Arc;
650
651 pub(super) struct TaskCounters {
652 poll_call_count: Arc<AtomicUsize>,
653 drop_call_count: Arc<AtomicUsize>,
654 }
655
656 impl TaskCounters {
657 fn new() -> (Arc<AtomicUsize>, Arc<AtomicUsize>, Self) {
658 let poll_call_count = Arc::new(AtomicUsize::new(0));
659 let drop_call_count = Arc::new(AtomicUsize::new(0));
660
661 (
662 poll_call_count.clone(),
663 drop_call_count.clone(),
664 Self { poll_call_count, drop_call_count },
665 )
666 }
667
668 pub(super) fn poll_call(&self) -> usize {
669 self.poll_call_count.load(Ordering::Relaxed)
670 }
671
672 pub(super) fn drop_call(&self) -> usize {
673 self.drop_call_count.load(Ordering::Relaxed)
674 }
675 }
676
677 pub(super) struct ImmediateTask {
678 poll_call_count: Arc<AtomicUsize>,
679 drop_call_count: Arc<AtomicUsize>,
680 done_sender: Option<oneshot::Sender<()>>,
681 }
682
683 impl ImmediateTask {
684 pub(super) fn new(done_sender: oneshot::Sender<()>) -> (TaskCounters, Self) {
685 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
686 (
687 counters,
688 Self { poll_call_count, drop_call_count, done_sender: Some(done_sender) },
689 )
690 }
691 }
692
693 impl Future for ImmediateTask {
694 type Output = ();
695
696 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
697 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
698
699 if let Some(sender) = self.done_sender.take() {
700 sender.send(()).unwrap();
701 }
702
703 Poll::Ready(())
704 }
705 }
706
707 impl Drop for ImmediateTask {
708 fn drop(&mut self) {
709 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
710 }
711 }
712
713 impl Unpin for ImmediateTask {}
714
715 pub(super) struct ControlledTask {
716 poll_call_count: Arc<AtomicUsize>,
717 drop_call_count: Arc<AtomicUsize>,
718
719 drop_sender: Option<oneshot::Sender<()>>,
720 future: Pin<Box<dyn Future<Output = ()> + Send>>,
721 }
722
723 impl ControlledTask {
724 pub(super) fn new(
725 poll_sender: oneshot::Sender<()>,
726 processing_complete: oneshot::Receiver<()>,
727 drop_sender: oneshot::Sender<()>,
728 ) -> (TaskCounters, Self) {
729 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
730 (
731 counters,
732 Self {
733 poll_call_count,
734 drop_call_count,
735 drop_sender: Some(drop_sender),
736 future: Box::pin(async move {
737 poll_sender.send(()).unwrap();
738 processing_complete.await.unwrap();
739 }),
740 },
741 )
742 }
743 }
744
745 impl Future for ControlledTask {
746 type Output = ();
747
748 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
749 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
750 self.future.as_mut().poll(cx)
751 }
752 }
753
754 impl Drop for ControlledTask {
755 fn drop(&mut self) {
756 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
757 self.drop_sender.take().unwrap().send(()).unwrap();
758 }
759 }
760 }
761}