1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::{Arc, Weak};
15
16use zx::Status;
17
18use crate::shutdown_observer::ShutdownObserver;
19
20pub use fdf_sys::fdf_dispatcher_t;
21pub use libasync::{
22 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, DispatcherTimerExt, JoinHandle,
23 OnDispatcher, Task, WeakDispatcher,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48 pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
50
51 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn unsynchronized(mut self) -> Self {
64 assert!(
65 !self.allows_thread_blocking(),
66 "you may not create an unsynchronized dispatcher that allows synchronous calls"
67 );
68 self.options |= Self::UNSYNCHRONIZED;
69 self
70 }
71
72 pub fn is_unsynchronized(&self) -> bool {
74 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
75 }
76
77 pub fn allow_thread_blocking(mut self) -> Self {
83 assert!(
84 !self.is_unsynchronized(),
85 "you may not create an unsynchronized dispatcher that allows synchronous calls"
86 );
87 self.options |= Self::ALLOW_THREAD_BLOCKING;
88 self
89 }
90
91 pub fn allows_thread_blocking(&self) -> bool {
93 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
94 }
95
96 pub fn no_thread_migration(mut self) -> Self {
103 self.options |= Self::NO_THREAD_MIGRATION;
104 self
105 }
106
107 pub fn allows_thread_migration(&self) -> bool {
109 (self.options & Self::NO_THREAD_MIGRATION) == 0
110 }
111
112 pub fn name(mut self, name: &str) -> Self {
115 self.name = name.to_string();
116 self
117 }
118
119 pub fn scheduler_role(mut self, role: &str) -> Self {
123 self.scheduler_role = role.to_string();
124 self
125 }
126
127 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
129 self.shutdown_observer = Some(Box::new(shutdown_observer));
130 self
131 }
132
133 pub fn create(self) -> Result<Dispatcher, Status> {
138 let mut out_dispatcher = null_mut();
139 let options = self.options;
140 let name = self.name.as_ptr() as *mut ffi::c_char;
141 let name_len = self.name.len();
142 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
143 let scheduler_role_len = self.scheduler_role.len();
144 let observer =
145 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
146 .into_ptr();
147 Status::ok(unsafe {
151 fdf_dispatcher_create(
152 options,
153 name,
154 name_len,
155 scheduler_role,
156 scheduler_role_len,
157 observer,
158 &mut out_dispatcher,
159 )
160 })?;
161 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
164 }
165
166 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
170 self.create().map(Dispatcher::release)
171 }
172}
173
174#[derive(Debug)]
176pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
177
178unsafe impl Send for Dispatcher {}
180unsafe impl Sync for Dispatcher {}
181thread_local! {
182 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
183}
184
185impl Dispatcher {
186 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
194 Self(handle)
195 }
196
197 fn get_raw_flags(&self) -> u32 {
198 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
200 }
201
202 pub fn is_unsynchronized(&self) -> bool {
204 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
205 }
206
207 pub fn allows_thread_blocking(&self) -> bool {
209 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
210 }
211
212 pub fn allows_thread_migration(&self) -> bool {
215 (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
216 }
217
218 pub fn is_current_dispatcher(&self) -> bool {
220 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
223 }
224
225 pub fn release(self) -> DispatcherRef<'static> {
230 DispatcherRef(ManuallyDrop::new(self), PhantomData)
231 }
232
233 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
236 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
237 }
238}
239
240impl AsyncDispatcher for Dispatcher {
241 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
242 let async_dispatcher =
243 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
244 .expect("No async dispatcher on driver dispatcher");
245 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
246 }
247}
248
249impl Drop for Dispatcher {
250 fn drop(&mut self) {
251 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
254 }
255}
256
257#[derive(Debug)]
267pub struct AutoReleaseDispatcher {
268 dispatcher: ManuallyDrop<Dispatcher>,
269 weak_ref: std::sync::OnceLock<WeakDispatcher>,
270}
271
272impl AutoReleaseDispatcher {
273 pub fn as_weak(&self) -> WeakDispatcher {
277 self.weak_ref.get_or_init(|| WeakDispatcher::new(self.as_async_dispatcher_ref())).clone()
278 }
279
280 pub fn always_on_dispatcher(&self) -> AutoReleaseDispatcher {
282 let dispatcher_ref = unsafe { DispatcherRef::from_raw(self.dispatcher.0) };
285 let dispatcher = unsafe { Dispatcher::from_raw(dispatcher_ref.always_on_dispatcher().0.0) };
290 Self { dispatcher: ManuallyDrop::new(dispatcher), weak_ref: Default::default() }
291 }
292}
293
294impl AsyncDispatcher for AutoReleaseDispatcher {
295 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
296 self.dispatcher.as_async_dispatcher_ref()
297 }
298}
299
300impl From<Dispatcher> for AutoReleaseDispatcher {
301 fn from(dispatcher: Dispatcher) -> Self {
302 Self { dispatcher: ManuallyDrop::new(dispatcher), weak_ref: Default::default() }
303 }
304}
305
306#[derive(Debug)]
310pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
311
312impl<'a> DispatcherRef<'a> {
313 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
320 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
322 }
323
324 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
331 let handle = NonNull::new(unsafe {
332 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
333 })
334 .unwrap();
335 unsafe { Self::from_raw(handle) }
336 }
337
338 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
344 unsafe { self.0.0.as_mut() }
345 }
346
347 pub fn always_on_dispatcher(&self) -> DispatcherRef<'a> {
350 let ptr = unsafe { fdf_dispatcher_get_always_on_dispatcher(self.0.0.as_ptr()) };
352 DispatcherRef(
353 ManuallyDrop::new(Dispatcher(NonNull::new(ptr).expect("Always-on dispatcher is NULL"))),
354 PhantomData,
355 )
356 }
357}
358
359struct AddSendFuture<T>(T);
367
368impl<T: Future> Future for AddSendFuture<T> {
369 type Output = T::Output;
370
371 fn poll(
372 self: std::pin::Pin<&mut Self>,
373 cx: &mut std::task::Context<'_>,
374 ) -> std::task::Poll<Self::Output> {
375 let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
377 fut.poll(cx)
378 }
379}
380
381unsafe impl<T> Send for AddSendFuture<T> {}
385
386pub trait OnDriverDispatcher: OnDispatcher {
389 fn spawn_local(
402 &self,
403 future: impl Future<Output = ()> + 'static,
404 ) -> Result<JoinHandle<()>, Status>
405 where
406 Self: 'static,
407 {
408 self.on_maybe_dispatcher(|dispatcher| {
409 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
410 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
411 OnDispatcher::spawn(self, AddSendFuture(future))
412 } else {
413 Err(Status::BAD_STATE)
414 }
415 })
416 }
417
418 fn compute_local<T: Send + 'static>(
433 &self,
434 future: impl Future<Output = T> + 'static,
435 ) -> Result<Task<T>, Status>
436 where
437 Self: 'static,
438 {
439 self.on_maybe_dispatcher(|dispatcher| {
440 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
441 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
442 Ok(OnDispatcher::compute(self, AddSendFuture(future)))
443 } else {
444 Err(Status::BAD_STATE)
445 }
446 })
447 }
448}
449
450impl OnDriverDispatcher for Arc<Dispatcher> {}
451impl OnDriverDispatcher for Weak<Dispatcher> {}
452
453impl<'a> AsyncDispatcher for DispatcherRef<'a> {
454 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
455 self.0.as_async_dispatcher_ref()
456 }
457}
458
459impl<'a> Clone for DispatcherRef<'a> {
460 fn clone(&self) -> Self {
461 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
462 }
463}
464
465impl<'a> core::ops::Deref for DispatcherRef<'a> {
466 type Target = Dispatcher;
467 fn deref(&self) -> &Self::Target {
468 &self.0
469 }
470}
471
472impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
473 fn deref_mut(&mut self) -> &mut Self::Target {
474 &mut self.0
475 }
476}
477
478impl<'a> OnDispatcher for DispatcherRef<'a> {
479 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
480 f(Some(self.as_async_dispatcher_ref()))
481 }
482}
483
484impl<'a> OnDriverDispatcher for DispatcherRef<'a> {}
485
486#[derive(Clone, Copy, Debug, Default, PartialEq)]
489pub struct CurrentDispatcher;
490
491impl OnDispatcher for CurrentDispatcher {
492 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
493 let dispatcher = OVERRIDE_DISPATCHER
494 .with(|global| *global.borrow())
495 .or_else(|| {
496 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
498 })
499 .map(|dispatcher| {
500 let async_dispatcher = NonNull::new(unsafe {
506 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
507 })
508 .expect("No async dispatcher on driver dispatcher");
509 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
510 });
511 f(dispatcher)
512 }
513}
514
515impl OnDriverDispatcher for CurrentDispatcher {}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 use std::sync::{Arc, Once, Weak, mpsc};
522
523 use futures::channel::mpsc as async_mpsc;
524 use futures::{SinkExt, StreamExt};
525 use zx::sys::ZX_OK;
526
527 use core::ffi::{c_char, c_void};
528 use core::ptr::null_mut;
529
530 static GLOBAL_DRIVER_ENV: Once = Once::new();
531 const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
532
533 pub fn ensure_driver_env() {
534 GLOBAL_DRIVER_ENV.call_once(|| {
535 unsafe {
538 assert_eq!(fdf_env_start(0), ZX_OK);
539 assert_eq!(
540 fdf_env_set_scheduler_role_opts(
541 NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
542 NO_SYNC_CALLS_ROLE.len(),
543 FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
544 ),
545 ZX_OK
546 );
547 }
548 });
549 }
550 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
551 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
552 }
553
554 pub(crate) fn with_raw_dispatcher_flags<T>(
555 name: &str,
556 flags: u32,
557 scheduler_role: &str,
558 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
559 ) -> T {
560 ensure_driver_env();
561
562 let (shutdown_tx, shutdown_rx) = mpsc::channel();
563 let mut dispatcher = null_mut();
564 let mut observer = ShutdownObserver::new(move |dispatcher| {
565 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
568 shutdown_tx.send(()).unwrap();
569 })
570 .into_ptr();
571 let driver_ptr = &mut observer as *mut _ as *mut c_void;
572 let res = unsafe {
577 fdf_env_dispatcher_create_with_owner(
578 driver_ptr,
579 flags,
580 name.as_ptr() as *const c_char,
581 name.len(),
582 scheduler_role.as_ptr() as *const c_char,
583 scheduler_role.len(),
584 observer,
585 &mut dispatcher,
586 )
587 };
588 assert_eq!(res, ZX_OK);
589 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
590
591 let res = p(Arc::downgrade(&dispatcher));
592
593 let weak_dispatcher = Arc::downgrade(&dispatcher);
597 drop(dispatcher);
598 shutdown_rx.recv().unwrap();
599 assert_eq!(
600 0,
601 weak_dispatcher.strong_count(),
602 "a dispatcher reference escaped the test body"
603 );
604
605 res
606 }
607
608 #[test]
609 fn start_test_dispatcher() {
610 with_raw_dispatcher("testing", |dispatcher| {
611 println!("hello {dispatcher:?}");
612 })
613 }
614
615 #[test]
616 fn post_task_on_dispatcher() {
617 with_raw_dispatcher("testing task", |dispatcher| {
618 let (tx, rx) = mpsc::channel();
619 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
620 dispatcher
621 .post_task_sync(move |status| {
622 assert_eq!(status, Status::from_raw(ZX_OK));
623 tx.send(status).unwrap();
624 })
625 .unwrap();
626 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
627 });
628 }
629
630 #[test]
631 fn post_task_on_subdispatcher() {
632 let (shutdown_tx, shutdown_rx) = mpsc::channel();
633 with_raw_dispatcher("testing task top level", move |dispatcher| {
634 let (tx, rx) = mpsc::channel();
635 let (inner_tx, inner_rx) = mpsc::channel();
636 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
637 dispatcher
638 .post_task_sync(move |status| {
639 assert_eq!(status, Status::from_raw(ZX_OK));
640 let inner = DispatcherBuilder::new()
641 .name("testing task second level")
642 .scheduler_role("")
643 .allow_thread_blocking()
644 .shutdown_observer(move |_dispatcher| {
645 println!("shutdown observer called");
646 shutdown_tx.send(1).unwrap();
647 })
648 .create()
649 .unwrap();
650 inner
651 .post_task_sync(move |status| {
652 assert_eq!(status, Status::from_raw(ZX_OK));
653 tx.send(status).unwrap();
654 })
655 .unwrap();
656 inner_tx.send(inner).unwrap();
660 })
661 .unwrap();
662 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
663 inner_rx.recv().unwrap();
664 });
665 assert_eq!(shutdown_rx.recv().unwrap(), 1);
666 }
667
668 #[test]
669 fn spawn_local_fails_on_normal_dispatcher() {
670 let (shutdown_tx, shutdown_rx) = mpsc::channel();
671 with_raw_dispatcher("spawn local failures", move |dispatcher| {
672 let inside_dispatcher = dispatcher.clone();
673 dispatcher
674 .spawn(async move {
675 assert_eq!(
676 inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
677 Status::BAD_STATE
678 );
679 assert_eq!(
680 inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
681 Status::BAD_STATE
682 );
683 shutdown_tx.send(()).unwrap();
684 })
685 .unwrap();
686 shutdown_rx.recv().unwrap();
687 });
688 }
689
690 #[test]
691 #[ignore = "Pending resolution of b/488397193"]
692 fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
693 let (tx, rx) = mpsc::channel();
694 with_raw_dispatcher_flags(
695 "spawn local success",
696 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
697 NO_SYNC_CALLS_ROLE,
698 move |dispatcher| {
699 let inside_dispatcher = dispatcher.clone();
700 dispatcher
701 .spawn(async move {
702 let tx_clone = tx.clone();
703 inside_dispatcher
704 .spawn_local(async move {
705 tx_clone.send(()).unwrap();
706 })
707 .unwrap();
708 inside_dispatcher
709 .compute_local(async move {
710 tx.send(()).unwrap();
711 })
712 .unwrap()
713 .await
714 .unwrap();
715 })
716 .unwrap();
717 rx.recv().unwrap();
719 rx.recv().unwrap();
720 },
721 );
722 }
723
724 #[test]
725 #[ignore = "Pending resolution of b/488397193"]
726 fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
727 with_raw_dispatcher_flags(
728 "spawn local success",
729 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
730 NO_SYNC_CALLS_ROLE,
731 move |dispatcher| {
732 assert_eq!(
735 dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
736 Status::BAD_STATE
737 );
738 assert_eq!(
739 dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
740 Status::BAD_STATE
741 );
742 },
743 );
744 }
745
746 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
747 println!("starting ping!");
748 tx.send(0).await.unwrap();
749 while let Some(next) = rx.next().await {
750 println!("ping! {next}");
751 tx.send(next + 1).await.unwrap();
752 }
753 }
754
755 async fn pong(
756 fin_tx: std::sync::mpsc::Sender<()>,
757 mut tx: async_mpsc::Sender<u8>,
758 mut rx: async_mpsc::Receiver<u8>,
759 ) {
760 println!("starting pong!");
761 while let Some(next) = rx.next().await {
762 println!("pong! {next}");
763 if next > 10 {
764 println!("bye!");
765 break;
766 }
767 tx.send(next + 1).await.unwrap();
768 }
769 fin_tx.send(()).unwrap();
770 }
771
772 #[test]
773 fn async_ping_pong() {
774 with_raw_dispatcher("async ping pong", |dispatcher| {
775 let (fin_tx, fin_rx) = mpsc::channel();
776 let (ping_tx, pong_rx) = async_mpsc::channel(10);
777 let (pong_tx, ping_rx) = async_mpsc::channel(10);
778 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
779 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
780
781 fin_rx.recv().expect("to receive final value");
782 });
783 }
784
785 async fn slow_pong(
786 fin_tx: std::sync::mpsc::Sender<()>,
787 mut tx: async_mpsc::Sender<u8>,
788 mut rx: async_mpsc::Receiver<u8>,
789 ) {
790 use zx::MonotonicDuration;
791 println!("starting pong!");
792 while let Some(next) = rx.next().await {
793 println!("pong! {next}");
794 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
795 MonotonicDuration::from_seconds(1),
796 ))
797 .await;
798 if next > 10 {
799 println!("bye!");
800 break;
801 }
802 tx.send(next + 1).await.unwrap();
803 }
804 fin_tx.send(()).unwrap();
805 }
806
807 #[test]
808 fn mixed_executor_async_ping_pong() {
809 with_raw_dispatcher("async ping pong", |dispatcher| {
810 let (fin_tx, fin_rx) = mpsc::channel();
811 let (ping_tx, pong_rx) = async_mpsc::channel(10);
812 let (pong_tx, ping_rx) = async_mpsc::channel(10);
813
814 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
816
817 let mut executor = fuchsia_async::LocalExecutor::default();
819 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
820
821 fin_rx.recv().expect("to receive final value");
822 });
823 }
824}