1use crate::token_registry::TokenRegistry;
20
21use fuchsia_async::{JoinHandle, Scope, ScopeHandle, SpawnableFuture};
22use fuchsia_sync::{MappedMutexGuard, Mutex, MutexGuard};
23use futures::task::{self, Poll};
24use futures::Future;
25use std::future::poll_fn;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::task::Context;
29
30#[cfg(target_os = "fuchsia")]
31use fuchsia_async::EHandle;
32
33pub use fuchsia_async::scope::ScopeActiveGuard as ActiveGuard;
34
35pub type SpawnError = task::SpawnError;
36
37#[derive(Clone)]
48pub struct ExecutionScope {
49 executor: Arc<Executor>,
50}
51
52struct Executor {
53 token_registry: TokenRegistry,
54 scope: Mutex<Option<Scope>>,
55}
56
57impl ExecutionScope {
58 pub fn new() -> Self {
61 Self::build().new()
62 }
63
64 pub fn build() -> ExecutionScopeParams {
68 ExecutionScopeParams::default()
69 }
70
71 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
83 self.executor.scope().spawn(task)
84 }
85
86 pub fn new_task(self, task: impl Future<Output = ()> + Send + 'static) -> Task {
88 Task(self.executor, SpawnableFuture::new(task))
89 }
90
91 pub fn token_registry(&self) -> &TokenRegistry {
92 &self.executor.token_registry
93 }
94
95 pub fn shutdown(&self) {
96 self.executor.shutdown();
97 }
98
99 pub fn force_shutdown(&self) {
101 let _ = self.executor.scope().clone().abort();
102 }
103
104 pub fn resurrect(&self) {
107 *self.executor.scope.lock() = None;
110 }
111
112 pub async fn wait(&self) {
114 let scope = self.executor.scope().clone();
115 scope.on_no_tasks_and_guards().await;
116 }
117
118 pub fn try_active_guard(&self) -> Option<ActiveGuard> {
121 self.executor.scope().active_guard()
122 }
123}
124
125impl PartialEq for ExecutionScope {
126 fn eq(&self, other: &Self) -> bool {
127 Arc::as_ptr(&self.executor) == Arc::as_ptr(&other.executor)
128 }
129}
130
131impl Eq for ExecutionScope {}
132
133impl std::fmt::Debug for ExecutionScope {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.write_fmt(format_args!("ExecutionScope {:?}", Arc::as_ptr(&self.executor)))
136 }
137}
138
139#[derive(Default)]
140pub struct ExecutionScopeParams {
141 #[cfg(target_os = "fuchsia")]
142 async_executor: Option<EHandle>,
143}
144
145impl ExecutionScopeParams {
146 #[cfg(target_os = "fuchsia")]
147 pub fn executor(mut self, value: EHandle) -> Self {
148 assert!(self.async_executor.is_none(), "`executor` is already set");
149 self.async_executor = Some(value);
150 self
151 }
152
153 pub fn new(self) -> ExecutionScope {
154 ExecutionScope {
155 executor: Arc::new(Executor {
156 token_registry: TokenRegistry::new(),
157 #[cfg(target_os = "fuchsia")]
158 scope: self.async_executor.map_or_else(
159 || Mutex::new(None),
160 |e| Mutex::new(Some(e.global_scope().new_child())),
161 ),
162 #[cfg(not(target_os = "fuchsia"))]
163 scope: Mutex::new(None),
164 }),
165 }
166 }
167}
168
169impl Executor {
170 fn scope(&self) -> MappedMutexGuard<'_, Scope> {
171 MutexGuard::map(self.scope.lock(), |s| {
175 s.get_or_insert_with(|| {
176 #[cfg(target_os = "fuchsia")]
177 return Scope::global().new_child();
178 #[cfg(not(target_os = "fuchsia"))]
179 return Scope::new();
180 })
181 })
182 }
183
184 fn shutdown(&self) {
185 if let Some(scope) = &*self.scope.lock() {
186 scope.wake_all_with_active_guard();
187 let _ = ScopeHandle::clone(&*scope).cancel();
188 }
189 }
190}
191
192impl Drop for Executor {
193 fn drop(&mut self) {
194 self.shutdown();
195 if let Some(scope) = self.scope.get_mut().take() {
198 scope.detach();
199 }
200 }
201}
202
203pub async fn yield_to_executor() {
205 let mut done = false;
206 poll_fn(|cx| {
207 if done {
208 Poll::Ready(())
209 } else {
210 done = true;
211 cx.waker().wake_by_ref();
212 Poll::Pending
213 }
214 })
215 .await;
216}
217
218pub struct Task(Arc<Executor>, SpawnableFuture<'static, ()>);
219
220impl Task {
221 pub fn spawn(self) {
223 self.0.scope().spawn(self.1);
224 }
225}
226
227impl Future for Task {
228 type Output = ();
229
230 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
231 Pin::new(&mut &mut self.1).poll(cx)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::{yield_to_executor, ExecutionScope};
238
239 use fuchsia_async::{TestExecutor, Timer};
240 use futures::channel::oneshot;
241 use futures::Future;
242 use std::pin::pin;
243 #[cfg(target_os = "fuchsia")]
244 use std::sync::atomic::{AtomicBool, Ordering};
245 use std::sync::Arc;
246 #[cfg(target_os = "fuchsia")]
247 use std::task::Poll;
248 use std::time::Duration;
249
250 #[cfg(target_os = "fuchsia")]
251 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
252 where
253 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
254 GetTestRes: Future<Output = ()>,
255 {
256 let mut exec = TestExecutor::new();
257
258 let scope = ExecutionScope::new();
259
260 let test = get_test(scope);
261
262 assert_eq!(
263 exec.run_until_stalled(&mut pin!(test)),
264 Poll::Ready(()),
265 "Test did not complete"
266 );
267 }
268
269 #[cfg(not(target_os = "fuchsia"))]
270 fn run_test<GetTest, GetTestRes>(get_test: GetTest)
271 where
272 GetTest: FnOnce(ExecutionScope) -> GetTestRes,
273 GetTestRes: Future<Output = ()>,
274 {
275 use fuchsia_async::TimeoutExt;
276 let mut exec = TestExecutor::new();
277
278 let scope = ExecutionScope::new();
279
280 let test =
284 get_test(scope).on_stalled(Duration::from_secs(30), || panic!("Test did not complete"));
285
286 exec.run_singlethreaded(&mut pin!(test));
287 }
288
289 #[test]
290 fn simple() {
291 run_test(|scope| {
292 async move {
293 let (sender, receiver) = oneshot::channel();
294 let (counters, task) = mocks::ImmediateTask::new(sender);
295
296 scope.spawn(task);
297
298 receiver.await.unwrap();
300
301 assert_eq!(counters.drop_call(), 1);
302 assert_eq!(counters.poll_call(), 1);
303 }
304 });
305 }
306
307 #[test]
308 fn simple_drop() {
309 run_test(|scope| {
310 async move {
311 let (poll_sender, poll_receiver) = oneshot::channel();
312 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
313 let (drop_sender, drop_receiver) = oneshot::channel();
314 let (counters, task) =
315 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
316
317 scope.spawn(task);
318
319 poll_receiver.await.unwrap();
320
321 processing_done_sender.send(()).unwrap();
322
323 scope.shutdown();
324
325 drop_receiver.await.unwrap();
326
327 let poll_count = counters.poll_call();
330 assert!(poll_count >= 1, "poll was not called");
331
332 assert_eq!(counters.drop_call(), 1);
333 }
334 });
335 }
336
337 #[test]
338 fn test_wait_waits_for_tasks_to_finish() {
339 let mut executor = TestExecutor::new();
340 let scope = ExecutionScope::new();
341 executor.run_singlethreaded(async {
342 let (poll_sender, poll_receiver) = oneshot::channel();
343 let (processing_done_sender, processing_done_receiver) = oneshot::channel();
344 let (drop_sender, _drop_receiver) = oneshot::channel();
345 let (_, task) =
346 mocks::ControlledTask::new(poll_sender, processing_done_receiver, drop_sender);
347
348 scope.spawn(task);
349
350 poll_receiver.await.unwrap();
351
352 let done = fuchsia_sync::Mutex::new(false);
355 futures::join!(
356 async {
357 scope.wait().await;
358 assert_eq!(*done.lock(), true);
359 },
360 async {
361 Timer::new(Duration::from_millis(100)).await;
363 *done.lock() = true;
364 processing_done_sender.send(()).unwrap();
365 }
366 );
367 });
368 }
369
370 #[cfg(target_os = "fuchsia")]
371 #[fuchsia::test]
372 async fn test_shutdown_waits_for_channels() {
373 use fuchsia_async as fasync;
374
375 let scope = ExecutionScope::new();
376 let (rx, tx) = zx::Channel::create();
377 let received_msg = Arc::new(AtomicBool::new(false));
378 let (sender, receiver) = futures::channel::oneshot::channel();
379 {
380 let received_msg = received_msg.clone();
381 scope.spawn(async move {
382 let mut msg_buf = zx::MessageBuf::new();
383 msg_buf.ensure_capacity_bytes(64);
384 let _ = sender.send(());
385 let _ = fasync::Channel::from_channel(rx).recv_msg(&mut msg_buf).await;
386 received_msg.store(true, Ordering::Relaxed);
387 });
388 }
389 let _ = receiver.await;
391
392 tx.write(b"hello", &mut []).expect("write failed");
393 scope.shutdown();
394 scope.wait().await;
395 assert!(received_msg.load(Ordering::Relaxed));
396 }
397
398 #[fuchsia::test]
399 async fn test_force_shutdown() {
400 let scope = ExecutionScope::new();
401 let scope_clone = scope.clone();
402 let ref_count = Arc::new(());
403 let ref_count_clone = ref_count.clone();
404
405 scope.spawn(async move {
408 let _ref_count_clone = ref_count_clone;
409
410 let _guard = scope_clone.try_active_guard().unwrap();
412
413 let _: () = std::future::pending().await;
414 });
415
416 scope.force_shutdown();
417 scope.wait().await;
418
419 assert_eq!(Arc::strong_count(&ref_count), 1);
421
422 scope.resurrect();
424
425 let ref_count_clone = ref_count.clone();
426 scope.spawn(async move {
427 yield_to_executor().await;
429
430 let _ref_count = ref_count_clone.clone();
432
433 let _: () = std::future::pending().await;
434 });
435
436 while Arc::strong_count(&ref_count) != 3 {
437 yield_to_executor().await;
438 }
439
440 for _ in 0..5 {
442 yield_to_executor().await;
443 assert_eq!(Arc::strong_count(&ref_count), 3);
444 }
445 }
446
447 mod mocks {
448 use futures::channel::oneshot;
449 use futures::task::{Context, Poll};
450 use futures::Future;
451 use std::pin::Pin;
452 use std::sync::atomic::{AtomicUsize, Ordering};
453 use std::sync::Arc;
454
455 pub(super) struct TaskCounters {
456 poll_call_count: Arc<AtomicUsize>,
457 drop_call_count: Arc<AtomicUsize>,
458 }
459
460 impl TaskCounters {
461 fn new() -> (Arc<AtomicUsize>, Arc<AtomicUsize>, Self) {
462 let poll_call_count = Arc::new(AtomicUsize::new(0));
463 let drop_call_count = Arc::new(AtomicUsize::new(0));
464
465 (
466 poll_call_count.clone(),
467 drop_call_count.clone(),
468 Self { poll_call_count, drop_call_count },
469 )
470 }
471
472 pub(super) fn poll_call(&self) -> usize {
473 self.poll_call_count.load(Ordering::Relaxed)
474 }
475
476 pub(super) fn drop_call(&self) -> usize {
477 self.drop_call_count.load(Ordering::Relaxed)
478 }
479 }
480
481 pub(super) struct ImmediateTask {
482 poll_call_count: Arc<AtomicUsize>,
483 drop_call_count: Arc<AtomicUsize>,
484 done_sender: Option<oneshot::Sender<()>>,
485 }
486
487 impl ImmediateTask {
488 pub(super) fn new(done_sender: oneshot::Sender<()>) -> (TaskCounters, Self) {
489 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
490 (
491 counters,
492 Self { poll_call_count, drop_call_count, done_sender: Some(done_sender) },
493 )
494 }
495 }
496
497 impl Future for ImmediateTask {
498 type Output = ();
499
500 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
501 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
502
503 if let Some(sender) = self.done_sender.take() {
504 sender.send(()).unwrap();
505 }
506
507 Poll::Ready(())
508 }
509 }
510
511 impl Drop for ImmediateTask {
512 fn drop(&mut self) {
513 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
514 }
515 }
516
517 impl Unpin for ImmediateTask {}
518
519 pub(super) struct ControlledTask {
520 poll_call_count: Arc<AtomicUsize>,
521 drop_call_count: Arc<AtomicUsize>,
522
523 drop_sender: Option<oneshot::Sender<()>>,
524 future: Pin<Box<dyn Future<Output = ()> + Send>>,
525 }
526
527 impl ControlledTask {
528 pub(super) fn new(
529 poll_sender: oneshot::Sender<()>,
530 processing_complete: oneshot::Receiver<()>,
531 drop_sender: oneshot::Sender<()>,
532 ) -> (TaskCounters, Self) {
533 let (poll_call_count, drop_call_count, counters) = TaskCounters::new();
534 (
535 counters,
536 Self {
537 poll_call_count,
538 drop_call_count,
539 drop_sender: Some(drop_sender),
540 future: Box::pin(async move {
541 poll_sender.send(()).unwrap();
542 processing_complete.await.unwrap();
543 }),
544 },
545 )
546 }
547 }
548
549 impl Future for ControlledTask {
550 type Output = ();
551
552 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
553 self.poll_call_count.fetch_add(1, Ordering::Relaxed);
554 self.future.as_mut().poll(cx)
555 }
556 }
557
558 impl Drop for ControlledTask {
559 fn drop(&mut self) {
560 self.drop_call_count.fetch_add(1, Ordering::Relaxed);
561 self.drop_sender.take().unwrap().send(()).unwrap();
562 }
563 }
564 }
565}