1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14
15use zx::Status;
16
17use crate::shutdown_observer::ShutdownObserver;
18
19pub use fdf_sys::fdf_dispatcher_t;
20pub use libasync::{
21 AfterDeadline, AsAsyncDispatcherRef, AsyncDispatcher, AsyncDispatcherRef, DispatcherTimerExt,
22 GetAsyncDispatcher, JoinHandle, OnDispatcher, Task,
23};
24
25pub trait ShutdownObserverFn: FnOnce(DriverDispatcherRef<'_>) + Send + 'static {}
27impl<T> ShutdownObserverFn for T where T: FnOnce(DriverDispatcherRef<'_>) + Send + 'static {}
28
29#[derive(Default)]
31pub struct DispatcherBuilder {
32 #[doc(hidden)]
33 pub options: u32,
34 #[doc(hidden)]
35 pub name: String,
36 #[doc(hidden)]
37 pub scheduler_role: String,
38 #[doc(hidden)]
39 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
40}
41
42impl DispatcherBuilder {
43 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
45 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
47 pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
49
50 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn unsynchronized(mut self) -> Self {
63 assert!(
64 !self.allows_thread_blocking(),
65 "you may not create an unsynchronized dispatcher that allows synchronous calls"
66 );
67 self.options |= Self::UNSYNCHRONIZED;
68 self
69 }
70
71 pub fn is_unsynchronized(&self) -> bool {
73 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
74 }
75
76 pub fn allow_thread_blocking(mut self) -> Self {
82 assert!(
83 !self.is_unsynchronized(),
84 "you may not create an unsynchronized dispatcher that allows synchronous calls"
85 );
86 self.options |= Self::ALLOW_THREAD_BLOCKING;
87 self
88 }
89
90 pub fn allows_thread_blocking(&self) -> bool {
92 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
93 }
94
95 pub fn no_thread_migration(mut self) -> Self {
102 self.options |= Self::NO_THREAD_MIGRATION;
103 self
104 }
105
106 pub fn allows_thread_migration(&self) -> bool {
108 (self.options & Self::NO_THREAD_MIGRATION) == 0
109 }
110
111 pub fn name(mut self, name: &str) -> Self {
114 self.name = name.to_string();
115 self
116 }
117
118 pub fn scheduler_role(mut self, role: &str) -> Self {
122 self.scheduler_role = role.to_string();
123 self
124 }
125
126 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
128 self.shutdown_observer = Some(Box::new(shutdown_observer));
129 self
130 }
131
132 pub fn create(self) -> Result<Dispatcher, Status> {
137 let mut out_dispatcher = null_mut();
138 let options = self.options;
139 let name = self.name.as_ptr() as *mut ffi::c_char;
140 let name_len = self.name.len();
141 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
142 let scheduler_role_len = self.scheduler_role.len();
143 let observer =
144 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
145 .into_ptr();
146 Status::ok(unsafe {
150 fdf_dispatcher_create(
151 options,
152 name,
153 name_len,
154 scheduler_role,
155 scheduler_role_len,
156 observer,
157 &mut out_dispatcher,
158 )
159 })?;
160 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
163 }
164
165 pub fn create_released(self) -> Result<AutoReleaseDispatcher, Status> {
169 self.create().map(Dispatcher::release)
170 }
171}
172
173#[derive(Debug)]
175pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
176
177unsafe impl Send for Dispatcher {}
179unsafe impl Sync for Dispatcher {}
180thread_local! {
181 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
182}
183
184impl Dispatcher {
185 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
193 Self(handle)
194 }
195
196 fn get_raw_flags(&self) -> u32 {
197 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
199 }
200
201 pub fn is_unsynchronized(&self) -> bool {
203 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
204 }
205
206 pub fn allows_thread_blocking(&self) -> bool {
208 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
209 }
210
211 pub fn allows_thread_migration(&self) -> bool {
214 (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
215 }
216
217 pub fn is_current_dispatcher(&self) -> bool {
219 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
222 }
223
224 pub fn release(self) -> AutoReleaseDispatcher {
229 AutoReleaseDispatcher { dispatcher: ManuallyDrop::new(self) }
230 }
231
232 pub fn as_dispatcher_ref(&self) -> DriverDispatcherRef<'_> {
235 DriverDispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
236 }
237}
238
239impl AsAsyncDispatcherRef for Dispatcher {
240 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
241 let async_dispatcher =
242 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
243 .expect("No async dispatcher on driver dispatcher");
244 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
245 }
246}
247
248impl Drop for Dispatcher {
249 fn drop(&mut self) {
250 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
253 }
254}
255
256#[derive(Debug)]
266pub struct AutoReleaseDispatcher {
267 dispatcher: ManuallyDrop<Dispatcher>,
268}
269
270impl AutoReleaseDispatcher {
271 pub unsafe fn from_raw(dispatcher: NonNull<fdf_dispatcher_t>) -> Self {
279 let dispatcher = ManuallyDrop::new(Dispatcher(dispatcher));
280 Self { dispatcher }
281 }
282
283 pub fn as_async_dispatcher(&self) -> AsyncDispatcher {
287 AsyncDispatcher::new(self)
288 }
289
290 pub fn as_dispatcher_ref(&self) -> DriverDispatcherRef<'_> {
293 DriverDispatcherRef(ManuallyDrop::new(Dispatcher(self.dispatcher.0)), PhantomData)
294 }
295
296 pub fn always_on_dispatcher(&self) -> AutoReleaseDispatcher {
298 let dispatcher_ref = unsafe { DriverDispatcherRef::from_raw(self.dispatcher.0) };
301 let dispatcher = unsafe { Dispatcher::from_raw(dispatcher_ref.always_on_dispatcher().0.0) };
306 Self { dispatcher: ManuallyDrop::new(dispatcher) }
307 }
308}
309
310impl AsAsyncDispatcherRef for AutoReleaseDispatcher {
311 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
312 self.dispatcher.as_async_dispatcher_ref()
313 }
314}
315
316impl From<Dispatcher> for AutoReleaseDispatcher {
317 fn from(dispatcher: Dispatcher) -> Self {
318 Self { dispatcher: ManuallyDrop::new(dispatcher) }
319 }
320}
321
322#[derive(Debug)]
326pub struct DriverDispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
327
328impl<'a> DriverDispatcherRef<'a> {
329 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
336 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
338 }
339
340 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
347 let handle = NonNull::new(unsafe {
348 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
349 })
350 .unwrap();
351 unsafe { Self::from_raw(handle) }
352 }
353
354 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
360 unsafe { self.0.0.as_mut() }
361 }
362
363 pub fn always_on_dispatcher(&self) -> DriverDispatcherRef<'a> {
366 let ptr = unsafe { fdf_dispatcher_get_always_on_dispatcher(self.0.0.as_ptr()) };
368 DriverDispatcherRef(
369 ManuallyDrop::new(Dispatcher(NonNull::new(ptr).expect("Always-on dispatcher is NULL"))),
370 PhantomData,
371 )
372 }
373}
374
375struct AddSendFuture<T>(T);
383
384impl<T: Future> Future for AddSendFuture<T> {
385 type Output = T::Output;
386
387 fn poll(
388 self: std::pin::Pin<&mut Self>,
389 cx: &mut std::task::Context<'_>,
390 ) -> std::task::Poll<Self::Output> {
391 let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
393 fut.poll(cx)
394 }
395}
396
397unsafe impl<T> Send for AddSendFuture<T> {}
401
402pub trait OnDriverDispatcher: OnDispatcher {
405 fn spawn_local(&self, future: impl Future<Output = ()> + 'static) -> JoinHandle<()>
418 where
419 Self: 'static,
420 {
421 self.compute_local(future).detach_on_drop()
422 }
423
424 fn compute_local<T: Send + 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
439 where
440 Self: 'static,
441 {
442 let Some(dispatcher) = self.try_get_async_dispatcher() else {
443 return Task::new_failed(Status::BAD_STATE);
444 };
445 let dispatcher =
446 DriverDispatcherRef::from_async_dispatcher(dispatcher.as_async_dispatcher_ref());
447 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
448 OnDispatcher::compute(self, AddSendFuture(future))
449 } else {
450 Task::new_failed(Status::BAD_STATE)
451 }
452 }
453}
454
455impl<'a> AsAsyncDispatcherRef for DriverDispatcherRef<'a> {
456 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
457 self.0.as_async_dispatcher_ref()
458 }
459}
460
461impl<'a> Clone for DriverDispatcherRef<'a> {
462 fn clone(&self) -> Self {
463 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
464 }
465}
466
467impl<'a> core::ops::Deref for DriverDispatcherRef<'a> {
468 type Target = Dispatcher;
469 fn deref(&self) -> &Self::Target {
470 &self.0
471 }
472}
473
474impl<'a> core::ops::DerefMut for DriverDispatcherRef<'a> {
475 fn deref_mut(&mut self) -> &mut Self::Target {
476 &mut self.0
477 }
478}
479
480impl<T> OnDriverDispatcher for T where T: AsAsyncDispatcherRef + Clone {}
483
484#[derive(Clone, Copy, Debug, Default, PartialEq)]
486pub struct CurrentDispatcher;
487
488impl GetAsyncDispatcher for CurrentDispatcher {
489 fn try_get_async_dispatcher(&self) -> Option<AsyncDispatcher> {
490 OVERRIDE_DISPATCHER
491 .with(|global| *global.borrow())
492 .or_else(|| {
493 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
495 })
496 .map(|dispatcher| {
497 let async_dispatcher = NonNull::new(unsafe {
503 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
504 })
505 .expect("No async dispatcher on driver dispatcher");
506 AsyncDispatcher::new(&unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) })
507 })
508 }
509}
510
511impl OnDriverDispatcher for CurrentDispatcher {}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 use std::sync::{Once, mpsc};
518
519 use futures::channel::mpsc as async_mpsc;
520 use futures::{SinkExt, StreamExt};
521 use zx::sys::ZX_OK;
522
523 use core::ffi::{c_char, c_void};
524 use core::ptr::null_mut;
525
526 static GLOBAL_DRIVER_ENV: Once = Once::new();
527 const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
528
529 pub fn ensure_driver_env() {
530 GLOBAL_DRIVER_ENV.call_once(|| {
531 unsafe {
534 assert_eq!(fdf_env_start(0), ZX_OK);
535 assert_eq!(
536 fdf_env_set_scheduler_role_opts(
537 NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
538 NO_SYNC_CALLS_ROLE.len(),
539 FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
540 ),
541 ZX_OK
542 );
543 }
544 });
545 }
546 pub fn with_raw_dispatcher<T>(name: &str, p: impl FnOnce(AsyncDispatcher) -> T) -> T {
547 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
548 }
549
550 pub(crate) fn with_raw_dispatcher_flags<T>(
551 name: &str,
552 flags: u32,
553 scheduler_role: &str,
554 p: impl FnOnce(AsyncDispatcher) -> T,
555 ) -> T {
556 ensure_driver_env();
557
558 let (shutdown_tx, shutdown_rx) = mpsc::channel();
559 let mut dispatcher = null_mut();
560 let mut observer = ShutdownObserver::new(move |dispatcher| {
561 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
564 shutdown_tx.send(()).unwrap();
565 })
566 .into_ptr();
567 let driver_ptr = &mut observer as *mut _ as *mut c_void;
568 let res = unsafe {
573 fdf_env_dispatcher_create_with_owner(
574 driver_ptr,
575 flags,
576 name.as_ptr() as *const c_char,
577 name.len(),
578 scheduler_role.as_ptr() as *const c_char,
579 scheduler_role.len(),
580 observer,
581 &mut dispatcher,
582 )
583 };
584 assert_eq!(res, ZX_OK);
585 let dispatcher = Dispatcher(NonNull::new(dispatcher).unwrap());
586
587 let res = p(AsyncDispatcher::new(&dispatcher));
588
589 drop(dispatcher);
590 shutdown_rx.recv().unwrap();
591
592 res
593 }
594
595 #[test]
596 fn start_test_dispatcher() {
597 with_raw_dispatcher("testing", |dispatcher| {
598 println!("hello {dispatcher:?}");
599 })
600 }
601
602 #[test]
603 fn post_task_on_dispatcher() {
604 with_raw_dispatcher("testing task", |dispatcher| {
605 let (tx, rx) = mpsc::channel();
606 dispatcher
607 .post_task_sync(move |status| {
608 assert_eq!(status, Status::from_raw(ZX_OK));
609 tx.send(status).unwrap();
610 })
611 .unwrap();
612 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
613 });
614 }
615
616 #[test]
617 fn post_task_on_subdispatcher() {
618 let (shutdown_tx, shutdown_rx) = mpsc::channel();
619 with_raw_dispatcher("testing task top level", move |dispatcher| {
620 let (tx, rx) = mpsc::channel();
621 let (inner_tx, inner_rx) = mpsc::channel();
622 dispatcher
623 .post_task_sync(move |status| {
624 assert_eq!(status, Status::from_raw(ZX_OK));
625 let inner = DispatcherBuilder::new()
626 .name("testing task second level")
627 .scheduler_role("")
628 .allow_thread_blocking()
629 .shutdown_observer(move |_dispatcher| {
630 println!("shutdown observer called");
631 shutdown_tx.send(1).unwrap();
632 })
633 .create()
634 .unwrap();
635 inner
636 .post_task_sync(move |status| {
637 assert_eq!(status, Status::from_raw(ZX_OK));
638 tx.send(status).unwrap();
639 })
640 .unwrap();
641 inner_tx.send(inner).unwrap();
645 })
646 .unwrap();
647 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
648 inner_rx.recv().unwrap();
649 });
650 assert_eq!(shutdown_rx.recv().unwrap(), 1);
651 }
652
653 #[test]
654 fn spawn_local_fails_on_normal_dispatcher() {
655 let (shutdown_tx, shutdown_rx) = mpsc::channel();
656 with_raw_dispatcher("spawn local failures", move |dispatcher| {
657 let inside_dispatcher = dispatcher.clone();
658 dispatcher.spawn(async move {
659 assert_eq!(
660 inside_dispatcher.spawn_local(futures::future::ready(())).await.unwrap_err(),
661 Status::BAD_STATE
662 );
663 assert_eq!(
664 inside_dispatcher.compute_local(futures::future::ready(())).await.unwrap_err(),
665 Status::BAD_STATE
666 );
667 shutdown_tx.send(()).unwrap();
668 });
669 shutdown_rx.recv().unwrap();
670 });
671 }
672
673 #[test]
674 #[ignore = "Pending resolution of b/488397193"]
675 fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
676 let (tx, rx) = mpsc::channel();
677 with_raw_dispatcher_flags(
678 "spawn local success",
679 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
680 NO_SYNC_CALLS_ROLE,
681 move |dispatcher| {
682 let inside_dispatcher = dispatcher.clone();
683 dispatcher.spawn(async move {
684 let tx_clone = tx.clone();
685 inside_dispatcher.spawn_local(async move {
686 tx_clone.send(()).unwrap();
687 });
688 inside_dispatcher
689 .compute_local(async move {
690 tx.send(()).unwrap();
691 })
692 .await
693 .unwrap();
694 });
695 rx.recv().unwrap();
697 rx.recv().unwrap();
698 },
699 );
700 }
701
702 #[test]
703 #[ignore = "Pending resolution of b/488397193"]
704 fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
705 with_raw_dispatcher_flags(
706 "spawn local success",
707 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
708 NO_SYNC_CALLS_ROLE,
709 move |dispatcher| {
710 let mut executor = fuchsia_async::LocalExecutor::default();
711 executor.run_singlethreaded(async {
712 assert_eq!(
715 dispatcher.spawn_local(futures::future::ready(())).await.unwrap_err(),
716 Status::BAD_STATE
717 );
718 assert_eq!(
719 dispatcher.compute_local(futures::future::ready(())).await.unwrap_err(),
720 Status::BAD_STATE
721 );
722 });
723 },
724 );
725 }
726
727 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
728 println!("starting ping!");
729 tx.send(0).await.unwrap();
730 while let Some(next) = rx.next().await {
731 println!("ping! {next}");
732 tx.send(next + 1).await.unwrap();
733 }
734 }
735
736 async fn pong(
737 fin_tx: std::sync::mpsc::Sender<()>,
738 mut tx: async_mpsc::Sender<u8>,
739 mut rx: async_mpsc::Receiver<u8>,
740 ) {
741 println!("starting pong!");
742 while let Some(next) = rx.next().await {
743 println!("pong! {next}");
744 if next > 10 {
745 println!("bye!");
746 break;
747 }
748 tx.send(next + 1).await.unwrap();
749 }
750 fin_tx.send(()).unwrap();
751 }
752
753 #[test]
754 fn async_ping_pong() {
755 with_raw_dispatcher("async ping pong", |dispatcher| {
756 let (fin_tx, fin_rx) = mpsc::channel();
757 let (ping_tx, pong_rx) = async_mpsc::channel(10);
758 let (pong_tx, ping_rx) = async_mpsc::channel(10);
759 dispatcher.spawn(ping(ping_tx, ping_rx));
760 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx));
761
762 fin_rx.recv().expect("to receive final value");
763 });
764 }
765
766 async fn slow_pong(
767 fin_tx: std::sync::mpsc::Sender<()>,
768 mut tx: async_mpsc::Sender<u8>,
769 mut rx: async_mpsc::Receiver<u8>,
770 ) {
771 use zx::MonotonicDuration;
772 println!("starting pong!");
773 while let Some(next) = rx.next().await {
774 println!("pong! {next}");
775 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
776 MonotonicDuration::from_seconds(1),
777 ))
778 .await;
779 if next > 10 {
780 println!("bye!");
781 break;
782 }
783 tx.send(next + 1).await.unwrap();
784 }
785 fin_tx.send(()).unwrap();
786 }
787
788 #[test]
789 fn mixed_executor_async_ping_pong() {
790 with_raw_dispatcher("async ping pong", |dispatcher| {
791 let (fin_tx, fin_rx) = mpsc::channel();
792 let (ping_tx, pong_rx) = async_mpsc::channel(10);
793 let (pong_tx, ping_rx) = async_mpsc::channel(10);
794
795 dispatcher.spawn(ping(ping_tx, ping_rx));
797
798 let mut executor = fuchsia_async::LocalExecutor::default();
800 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
801
802 fin_rx.recv().expect("to receive final value");
803 });
804 }
805}