1use fdf_sys::*;
8
9use core::ffi;
10use core::marker::PhantomData;
11use core::mem::ManuallyDrop;
12use core::ptr::{NonNull, null_mut};
13use std::sync::atomic::{AtomicPtr, Ordering};
14use std::sync::{Arc, Weak};
15
16use zx::Status;
17
18use crate::shutdown_observer::ShutdownObserver;
19
20pub use fdf_sys::fdf_dispatcher_t;
21pub use libasync::{
22 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, CurrentDispatcher, JoinHandle,
23 OnDispatcher, Task,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48 pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
50
51 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn unsynchronized(mut self) -> Self {
64 assert!(
65 !self.allows_thread_blocking(),
66 "you may not create an unsynchronized dispatcher that allows synchronous calls"
67 );
68 self.options |= Self::UNSYNCHRONIZED;
69 self
70 }
71
72 pub fn is_unsynchronized(&self) -> bool {
74 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
75 }
76
77 pub fn allow_thread_blocking(mut self) -> Self {
83 assert!(
84 !self.is_unsynchronized(),
85 "you may not create an unsynchronized dispatcher that allows synchronous calls"
86 );
87 self.options |= Self::ALLOW_THREAD_BLOCKING;
88 self
89 }
90
91 pub fn allows_thread_blocking(&self) -> bool {
93 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
94 }
95
96 pub fn no_thread_migration(mut self) -> Self {
103 self.options |= Self::NO_THREAD_MIGRATION;
104 self
105 }
106
107 pub fn allows_thread_migration(&self) -> bool {
109 (self.options & Self::NO_THREAD_MIGRATION) == 0
110 }
111
112 pub fn name(mut self, name: &str) -> Self {
115 self.name = name.to_string();
116 self
117 }
118
119 pub fn scheduler_role(mut self, role: &str) -> Self {
123 self.scheduler_role = role.to_string();
124 self
125 }
126
127 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
129 self.shutdown_observer = Some(Box::new(shutdown_observer));
130 self
131 }
132
133 pub fn create(self) -> Result<Dispatcher, Status> {
138 let mut out_dispatcher = null_mut();
139 let options = self.options;
140 let name = self.name.as_ptr() as *mut ffi::c_char;
141 let name_len = self.name.len();
142 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
143 let scheduler_role_len = self.scheduler_role.len();
144 let observer =
145 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
146 .into_ptr();
147 Status::ok(unsafe {
151 fdf_dispatcher_create(
152 options,
153 name,
154 name_len,
155 scheduler_role,
156 scheduler_role_len,
157 observer,
158 &mut out_dispatcher,
159 )
160 })?;
161 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
164 }
165
166 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
170 self.create().map(Dispatcher::release)
171 }
172}
173
174#[derive(Debug)]
176pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
177
178unsafe impl Send for Dispatcher {}
180unsafe impl Sync for Dispatcher {}
181
182impl Dispatcher {
183 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
191 Self(handle)
192 }
193
194 fn get_raw_flags(&self) -> u32 {
195 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
197 }
198
199 pub fn is_unsynchronized(&self) -> bool {
201 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
202 }
203
204 pub fn allows_thread_blocking(&self) -> bool {
206 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
207 }
208
209 pub fn allows_thread_migration(&self) -> bool {
212 (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
213 }
214
215 pub fn is_current_dispatcher(&self) -> bool {
217 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
220 }
221
222 pub fn release(self) -> DispatcherRef<'static> {
227 DispatcherRef(ManuallyDrop::new(self), PhantomData)
228 }
229
230 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
233 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
234 }
235}
236
237impl AsyncDispatcher for Dispatcher {
238 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
239 let async_dispatcher =
240 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
241 .expect("No async dispatcher on driver dispatcher");
242 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
243 }
244}
245
246impl Drop for Dispatcher {
247 fn drop(&mut self) {
248 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
251 }
252}
253
254#[derive(Debug)]
264pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
265
266impl AutoReleaseDispatcher {
267 pub fn downgrade(&self) -> WeakDispatcher {
271 WeakDispatcher::from(self)
272 }
273}
274
275impl AsyncDispatcher for AutoReleaseDispatcher {
276 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
277 let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
278 .expect("tried to obtain async dispatcher after drop");
279 unsafe {
282 AsyncDispatcherRef::from_raw(
283 NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
284 )
285 }
286 }
287}
288
289impl From<Dispatcher> for AutoReleaseDispatcher {
290 fn from(dispatcher: Dispatcher) -> Self {
291 let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
292 Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
293 }
294}
295
296impl Drop for AutoReleaseDispatcher {
297 fn drop(&mut self) {
298 self.0.store(null_mut(), Ordering::Relaxed);
301 while Arc::strong_count(&self.0) > 1 {
305 std::thread::sleep(std::time::Duration::from_nanos(100))
308 }
309 }
310}
311
312#[derive(Clone, Debug)]
320pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
321
322impl From<&AutoReleaseDispatcher> for WeakDispatcher {
323 fn from(value: &AutoReleaseDispatcher) -> Self {
324 Self(Arc::downgrade(&value.0))
325 }
326}
327
328impl OnDispatcher for WeakDispatcher {
329 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
330 let Some(dispatcher_ptr) = self.0.upgrade() else {
331 return f(None);
332 };
333 let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
334 return f(None);
335 };
336 f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
340 }
341}
342
343impl OnDriverDispatcher for WeakDispatcher {}
344
345#[derive(Debug)]
349pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
350
351impl<'a> DispatcherRef<'a> {
352 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
359 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
361 }
362
363 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
370 let handle = NonNull::new(unsafe {
371 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
372 })
373 .unwrap();
374 unsafe { Self::from_raw(handle) }
375 }
376
377 pub fn as_async_dispatcher(&self) -> AsyncDispatcherRef<'a> {
379 let handle = unsafe { fdf_dispatcher_get_async_dispatcher(self.0.0.as_ptr()) };
382 unsafe { AsyncDispatcherRef::from_raw(NonNull::new(handle).unwrap()) }
383 }
384
385 pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
391 unsafe { self.0.0.as_mut() }
392 }
393}
394
395struct AddSendFuture<T>(T);
403
404impl<T: Future> Future for AddSendFuture<T> {
405 type Output = T::Output;
406
407 fn poll(
408 self: std::pin::Pin<&mut Self>,
409 cx: &mut std::task::Context<'_>,
410 ) -> std::task::Poll<Self::Output> {
411 let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
413 fut.poll(cx)
414 }
415}
416
417unsafe impl<T> Send for AddSendFuture<T> {}
421
422pub trait OnDriverDispatcher: OnDispatcher {
425 fn spawn_local(
438 &self,
439 future: impl Future<Output = ()> + 'static,
440 ) -> Result<JoinHandle<()>, Status>
441 where
442 Self: 'static,
443 {
444 self.on_maybe_dispatcher(|dispatcher| {
445 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
446 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
447 OnDispatcher::spawn(self, AddSendFuture(future))
448 } else {
449 Err(Status::BAD_STATE)
450 }
451 })
452 }
453
454 fn compute_local<T: Send + 'static>(
469 &self,
470 future: impl Future<Output = T> + 'static,
471 ) -> Result<Task<T>, Status>
472 where
473 Self: 'static,
474 {
475 self.on_maybe_dispatcher(|dispatcher| {
476 let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
477 if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
478 Ok(OnDispatcher::compute(self, AddSendFuture(future)))
479 } else {
480 Err(Status::BAD_STATE)
481 }
482 })
483 }
484}
485
486impl OnDriverDispatcher for Arc<Dispatcher> {}
487impl OnDriverDispatcher for Weak<Dispatcher> {}
488
489impl<'a> AsyncDispatcher for DispatcherRef<'a> {
490 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
491 self.0.as_async_dispatcher_ref()
492 }
493}
494
495impl<'a> Clone for DispatcherRef<'a> {
496 fn clone(&self) -> Self {
497 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
498 }
499}
500
501impl<'a> core::ops::Deref for DispatcherRef<'a> {
502 type Target = Dispatcher;
503 fn deref(&self) -> &Self::Target {
504 &self.0
505 }
506}
507
508impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
509 fn deref_mut(&mut self) -> &mut Self::Target {
510 &mut self.0
511 }
512}
513
514impl<'a> OnDispatcher for DispatcherRef<'a> {
515 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
516 f(Some(self.as_async_dispatcher_ref()))
517 }
518}
519
520impl<'a> OnDriverDispatcher for DispatcherRef<'a> {}
521
522impl OnDriverDispatcher for CurrentDispatcher {}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 use std::sync::{Arc, Once, Weak, mpsc};
529
530 use futures::channel::mpsc as async_mpsc;
531 use futures::{SinkExt, StreamExt};
532 use zx::sys::ZX_OK;
533
534 use core::ffi::{c_char, c_void};
535 use core::ptr::null_mut;
536
537 static GLOBAL_DRIVER_ENV: Once = Once::new();
538 const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
539
540 pub fn ensure_driver_env() {
541 GLOBAL_DRIVER_ENV.call_once(|| {
542 unsafe {
545 assert_eq!(fdf_env_start(0), ZX_OK);
546 assert_eq!(
547 fdf_env_set_scheduler_role_opts(
548 NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
549 NO_SYNC_CALLS_ROLE.len(),
550 FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
551 ),
552 ZX_OK
553 );
554 }
555 });
556 }
557 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
558 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
559 }
560
561 pub(crate) fn with_raw_dispatcher_flags<T>(
562 name: &str,
563 flags: u32,
564 scheduler_role: &str,
565 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
566 ) -> T {
567 ensure_driver_env();
568
569 let (shutdown_tx, shutdown_rx) = mpsc::channel();
570 let mut dispatcher = null_mut();
571 let mut observer = ShutdownObserver::new(move |dispatcher| {
572 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
575 shutdown_tx.send(()).unwrap();
576 })
577 .into_ptr();
578 let driver_ptr = &mut observer as *mut _ as *mut c_void;
579 let res = unsafe {
584 fdf_env_dispatcher_create_with_owner(
585 driver_ptr,
586 flags,
587 name.as_ptr() as *const c_char,
588 name.len(),
589 scheduler_role.as_ptr() as *const c_char,
590 scheduler_role.len(),
591 observer,
592 &mut dispatcher,
593 )
594 };
595 assert_eq!(res, ZX_OK);
596 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
597
598 let res = p(Arc::downgrade(&dispatcher));
599
600 let weak_dispatcher = Arc::downgrade(&dispatcher);
604 drop(dispatcher);
605 shutdown_rx.recv().unwrap();
606 assert_eq!(
607 0,
608 weak_dispatcher.strong_count(),
609 "a dispatcher reference escaped the test body"
610 );
611
612 res
613 }
614
615 #[test]
616 fn start_test_dispatcher() {
617 with_raw_dispatcher("testing", |dispatcher| {
618 println!("hello {dispatcher:?}");
619 })
620 }
621
622 #[test]
623 fn post_task_on_dispatcher() {
624 with_raw_dispatcher("testing task", |dispatcher| {
625 let (tx, rx) = mpsc::channel();
626 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
627 dispatcher
628 .post_task_sync(move |status| {
629 assert_eq!(status, Status::from_raw(ZX_OK));
630 tx.send(status).unwrap();
631 })
632 .unwrap();
633 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
634 });
635 }
636
637 #[test]
638 fn post_task_on_subdispatcher() {
639 let (shutdown_tx, shutdown_rx) = mpsc::channel();
640 with_raw_dispatcher("testing task top level", move |dispatcher| {
641 let (tx, rx) = mpsc::channel();
642 let (inner_tx, inner_rx) = mpsc::channel();
643 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
644 dispatcher
645 .post_task_sync(move |status| {
646 assert_eq!(status, Status::from_raw(ZX_OK));
647 let inner = DispatcherBuilder::new()
648 .name("testing task second level")
649 .scheduler_role("")
650 .allow_thread_blocking()
651 .shutdown_observer(move |_dispatcher| {
652 println!("shutdown observer called");
653 shutdown_tx.send(1).unwrap();
654 })
655 .create()
656 .unwrap();
657 inner
658 .post_task_sync(move |status| {
659 assert_eq!(status, Status::from_raw(ZX_OK));
660 tx.send(status).unwrap();
661 })
662 .unwrap();
663 inner_tx.send(inner).unwrap();
667 })
668 .unwrap();
669 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
670 inner_rx.recv().unwrap();
671 });
672 assert_eq!(shutdown_rx.recv().unwrap(), 1);
673 }
674
675 #[test]
676 fn spawn_local_fails_on_normal_dispatcher() {
677 let (shutdown_tx, shutdown_rx) = mpsc::channel();
678 with_raw_dispatcher("spawn local failures", move |dispatcher| {
679 let inside_dispatcher = dispatcher.clone();
680 dispatcher
681 .spawn(async move {
682 assert_eq!(
683 inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
684 Status::BAD_STATE
685 );
686 assert_eq!(
687 inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
688 Status::BAD_STATE
689 );
690 shutdown_tx.send(()).unwrap();
691 })
692 .unwrap();
693 shutdown_rx.recv().unwrap();
694 });
695 }
696
697 #[test]
698 #[ignore = "Pending resolution of b/488397193"]
699 fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
700 let (tx, rx) = mpsc::channel();
701 with_raw_dispatcher_flags(
702 "spawn local success",
703 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
704 NO_SYNC_CALLS_ROLE,
705 move |dispatcher| {
706 let inside_dispatcher = dispatcher.clone();
707 dispatcher
708 .spawn(async move {
709 let tx_clone = tx.clone();
710 inside_dispatcher
711 .spawn_local(async move {
712 tx_clone.send(()).unwrap();
713 })
714 .unwrap();
715 inside_dispatcher
716 .compute_local(async move {
717 tx.send(()).unwrap();
718 })
719 .unwrap()
720 .await
721 .unwrap();
722 })
723 .unwrap();
724 rx.recv().unwrap();
726 rx.recv().unwrap();
727 },
728 );
729 }
730
731 #[test]
732 #[ignore = "Pending resolution of b/488397193"]
733 fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
734 with_raw_dispatcher_flags(
735 "spawn local success",
736 FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
737 NO_SYNC_CALLS_ROLE,
738 move |dispatcher| {
739 assert_eq!(
742 dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
743 Status::BAD_STATE
744 );
745 assert_eq!(
746 dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
747 Status::BAD_STATE
748 );
749 },
750 );
751 }
752
753 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
754 println!("starting ping!");
755 tx.send(0).await.unwrap();
756 while let Some(next) = rx.next().await {
757 println!("ping! {next}");
758 tx.send(next + 1).await.unwrap();
759 }
760 }
761
762 async fn pong(
763 fin_tx: std::sync::mpsc::Sender<()>,
764 mut tx: async_mpsc::Sender<u8>,
765 mut rx: async_mpsc::Receiver<u8>,
766 ) {
767 println!("starting pong!");
768 while let Some(next) = rx.next().await {
769 println!("pong! {next}");
770 if next > 10 {
771 println!("bye!");
772 break;
773 }
774 tx.send(next + 1).await.unwrap();
775 }
776 fin_tx.send(()).unwrap();
777 }
778
779 #[test]
780 fn async_ping_pong() {
781 with_raw_dispatcher("async ping pong", |dispatcher| {
782 let (fin_tx, fin_rx) = mpsc::channel();
783 let (ping_tx, pong_rx) = async_mpsc::channel(10);
784 let (pong_tx, ping_rx) = async_mpsc::channel(10);
785 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
786 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
787
788 fin_rx.recv().expect("to receive final value");
789 });
790 }
791
792 async fn slow_pong(
793 fin_tx: std::sync::mpsc::Sender<()>,
794 mut tx: async_mpsc::Sender<u8>,
795 mut rx: async_mpsc::Receiver<u8>,
796 ) {
797 use zx::MonotonicDuration;
798 println!("starting pong!");
799 while let Some(next) = rx.next().await {
800 println!("pong! {next}");
801 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
802 MonotonicDuration::from_seconds(1),
803 ))
804 .await;
805 if next > 10 {
806 println!("bye!");
807 break;
808 }
809 tx.send(next + 1).await.unwrap();
810 }
811 fin_tx.send(()).unwrap();
812 }
813
814 #[test]
815 fn mixed_executor_async_ping_pong() {
816 with_raw_dispatcher("async ping pong", |dispatcher| {
817 let (fin_tx, fin_rx) = mpsc::channel();
818 let (ping_tx, pong_rx) = async_mpsc::channel(10);
819 let (pong_tx, ping_rx) = async_mpsc::channel(10);
820
821 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
823
824 let mut executor = fuchsia_async::LocalExecutor::default();
826 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
827
828 fin_rx.recv().expect("to receive final value");
829 });
830 }
831}