1use fdf_sys::*;
8
9use core::cell::{RefCell, UnsafeCell};
10use core::ffi;
11use core::future::Future;
12use core::marker::PhantomData;
13use core::mem::ManuallyDrop;
14use core::ptr::{addr_of_mut, null_mut, NonNull};
15use core::task::Context;
16use std::sync::{Arc, Mutex, Weak};
17
18use zx::Status;
19
20use futures::future::{BoxFuture, FutureExt};
21use futures::task::{waker_ref, ArcWake};
22
23pub use fdf_sys::fdf_dispatcher_t;
24
25pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
27impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28
29#[derive(Default)]
31pub struct DispatcherBuilder {
32 #[doc(hidden)]
33 pub options: u32,
34 #[doc(hidden)]
35 pub name: String,
36 #[doc(hidden)]
37 pub scheduler_role: String,
38 #[doc(hidden)]
39 pub shutdown_observer: Option<ShutdownObserver>,
40}
41
42impl DispatcherBuilder {
43 pub(crate) const UNSYNCHRONIZED: u32 = 0b01;
45 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = 0b10;
47
48 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn unsynchronized(mut self) -> Self {
61 assert!(
62 !self.allows_thread_blocking(),
63 "you may not create an unsynchronized dispatcher that allows synchronous calls"
64 );
65 self.options = self.options | Self::UNSYNCHRONIZED;
66 self
67 }
68
69 pub fn is_unsynchronized(&self) -> bool {
71 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
72 }
73
74 pub fn allow_thread_blocking(mut self) -> Self {
80 assert!(
81 !self.is_unsynchronized(),
82 "you may not create an unsynchronized dispatcher that allows synchronous calls"
83 );
84 self.options = self.options | Self::ALLOW_THREAD_BLOCKING;
85 self
86 }
87
88 pub fn allows_thread_blocking(&self) -> bool {
90 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
91 }
92
93 pub fn name(mut self, name: &str) -> Self {
96 self.name = name.to_string();
97 self
98 }
99
100 pub fn scheduler_role(mut self, role: &str) -> Self {
104 self.scheduler_role = role.to_string();
105 self
106 }
107
108 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
110 self.shutdown_observer = Some(ShutdownObserver::new(shutdown_observer));
111 self
112 }
113
114 pub fn create(self) -> Result<Dispatcher, Status> {
119 let mut out_dispatcher = null_mut();
120 let options = self.options;
121 let name = self.name.as_ptr() as *mut ffi::c_char;
122 let name_len = self.name.len();
123 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
124 let scheduler_role_len = self.scheduler_role.len();
125 let observer =
126 self.shutdown_observer.unwrap_or_else(|| ShutdownObserver::new(|_| {})).into_ptr();
127 Status::ok(unsafe {
131 fdf_dispatcher_create(
132 options,
133 name,
134 name_len,
135 scheduler_role,
136 scheduler_role_len,
137 observer,
138 &mut out_dispatcher,
139 )
140 })?;
141 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
144 }
145
146 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
150 self.create().map(Dispatcher::release)
151 }
152}
153
154#[derive(Debug)]
156pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
157
158unsafe impl Send for Dispatcher {}
160unsafe impl Sync for Dispatcher {}
161thread_local! {
162 static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
163}
164
165impl Dispatcher {
166 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
174 Self(handle)
175 }
176
177 #[doc(hidden)]
178 pub fn inner<'a>(&'a self) -> &'a NonNull<fdf_dispatcher_t> {
179 &self.0
180 }
181
182 fn get_raw_flags(&self) -> u32 {
183 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
185 }
186
187 pub fn is_unsynchronized(&self) -> bool {
189 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
190 }
191
192 pub fn allows_thread_blocking(&self) -> bool {
194 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
195 }
196
197 pub fn post_task_sync(&self, p: impl TaskCallback) -> Result<(), Status> {
199 let async_dispatcher = unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) };
201 let task_arc = Arc::new(UnsafeCell::new(TaskFunc {
202 task: async_task { handler: Some(TaskFunc::call), ..Default::default() },
203 func: Box::new(p),
204 }));
205
206 let task_cell = Arc::into_raw(task_arc);
207 let res = unsafe {
214 let task_ptr = addr_of_mut!((*UnsafeCell::raw_get(task_cell)).task);
215 async_post_task(async_dispatcher, task_ptr)
216 };
217 if res != ZX_OK {
218 unsafe { Arc::decrement_strong_count(task_cell) }
221 Err(Status::from_raw(res))
222 } else {
223 Ok(())
224 }
225 }
226
227 pub fn release(self) -> DispatcherRef<'static> {
232 DispatcherRef(ManuallyDrop::new(self), PhantomData)
233 }
234
235 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
238 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
239 }
240
241 #[doc(hidden)]
244 pub fn override_current<R>(dispatcher: DispatcherRef<'_>, f: impl FnOnce() -> R) -> R {
245 OVERRIDE_DISPATCHER.with(|global| {
246 let previous = global.replace(Some(dispatcher.0 .0));
247 let res = f();
248 global.replace(previous);
249 res
250 })
251 }
252}
253
254impl Drop for Dispatcher {
255 fn drop(&mut self) {
256 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
259 }
260}
261
262#[derive(Debug)]
266pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
267
268impl<'a> DispatcherRef<'a> {
269 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
276 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
278 }
279}
280
281impl<'a> Clone for DispatcherRef<'a> {
282 fn clone(&self) -> Self {
283 Self(ManuallyDrop::new(Dispatcher(self.0 .0)), PhantomData)
284 }
285}
286
287impl<'a> core::ops::Deref for DispatcherRef<'a> {
288 type Target = Dispatcher;
289 fn deref(&self) -> &Self::Target {
290 &self.0
291 }
292}
293
294impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300pub trait OnDispatcher: Clone + Send + Sync + Unpin {
302 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R;
305
306 fn on_maybe_dispatcher<R, E: From<Status>>(
309 &self,
310 f: impl FnOnce(DispatcherRef<'_>) -> Result<R, E>,
311 ) -> Result<R, E> {
312 self.on_dispatcher(|dispatcher| {
313 let dispatcher = dispatcher.ok_or(Status::BAD_STATE)?;
314 f(dispatcher)
315 })
316 }
317
318 fn spawn_task(&self, future: impl Future<Output = ()> + Send + 'static) -> Result<(), Status>
322 where
323 Self: 'static,
324 {
325 let task =
326 Arc::new(Task { future: Mutex::new(Some(future.boxed())), dispatcher: self.clone() });
327 task.queue()
328 }
329}
330
331impl<'a, D: OnDispatcher> OnDispatcher for &'a D {
332 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
333 D::on_dispatcher(*self, f)
334 }
335}
336
337impl<'a> OnDispatcher for &'a Dispatcher {
338 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
339 f(Some(self.as_dispatcher_ref()))
340 }
341}
342
343impl<'a> OnDispatcher for DispatcherRef<'a> {
344 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
345 f(Some(self.as_dispatcher_ref()))
346 }
347}
348
349impl OnDispatcher for Arc<Dispatcher> {
350 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
351 f(Some(self.as_dispatcher_ref()))
352 }
353}
354
355impl OnDispatcher for Weak<Dispatcher> {
356 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
357 let dispatcher = Weak::upgrade(self);
358 match dispatcher {
359 Some(dispatcher) => f(Some(dispatcher.as_dispatcher_ref())),
360 None => f(None),
361 }
362 }
363}
364
365#[derive(Clone, Copy)]
368pub struct CurrentDispatcher;
369
370impl OnDispatcher for CurrentDispatcher {
371 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<DispatcherRef<'_>>) -> R) -> R {
372 let dispatcher = OVERRIDE_DISPATCHER
373 .with(|global| global.borrow().clone())
374 .or_else(|| {
375 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
377 })
378 .map(|dispatcher| {
379 DispatcherRef(
385 ManuallyDrop::new(unsafe { Dispatcher::from_raw(dispatcher) }),
386 Default::default(),
387 )
388 });
389 f(dispatcher)
390 }
391}
392
393pub trait TaskCallback: FnOnce(Status) + 'static + Send {}
395impl<T> TaskCallback for T where T: FnOnce(Status) + 'static + Send {}
396
397struct Task<D> {
398 future: Mutex<Option<BoxFuture<'static, ()>>>,
399 dispatcher: D,
400}
401
402impl<D: OnDispatcher + 'static> ArcWake for Task<D> {
403 fn wake_by_ref(arc_self: &Arc<Self>) {
404 match arc_self.queue() {
405 Err(e) if e == Status::from_raw(ZX_ERR_BAD_STATE) => {
406 let future_slot = arc_self.future.lock().unwrap().take();
409 core::mem::drop(future_slot);
410 }
411 res => res.expect("Unexpected error waking dispatcher task"),
412 }
413 }
414}
415
416impl<D: OnDispatcher + 'static> Task<D> {
417 fn queue(self: &Arc<Self>) -> Result<(), Status> {
421 let arc_self = self.clone();
422 self.dispatcher.on_maybe_dispatcher(move |dispatcher| {
423 dispatcher
424 .post_task_sync(move |status| {
425 let mut future_slot = arc_self.future.lock().unwrap();
426 if status != Status::from_raw(ZX_OK) {
428 core::mem::drop(future_slot.take());
429 return;
430 }
431
432 let Some(mut future) = future_slot.take() else {
433 return;
434 };
435 let waker = waker_ref(&arc_self);
436 let context = &mut Context::from_waker(&waker);
437 if future.as_mut().poll(context).is_pending() {
438 *future_slot = Some(future);
439 }
440 })
441 .map(|_| ())
442 })
443 }
444}
445
446#[repr(C)]
447struct TaskFunc {
448 task: async_task,
449 func: Box<dyn TaskCallback>,
450}
451
452impl TaskFunc {
453 extern "C" fn call(_dispatcher: *mut async_dispatcher, task: *mut async_task, status: i32) {
454 let task = unsafe { Arc::from_raw(task as *const UnsafeCell<Self>) };
457 if let Some(task) = Arc::try_unwrap(task).ok() {
460 (task.into_inner().func)(Status::from_raw(status));
461 }
462 }
463}
464
465#[repr(C)]
474#[doc(hidden)]
475pub struct ShutdownObserver {
476 observer: fdf_dispatcher_shutdown_observer,
477 shutdown_fn: Box<dyn ShutdownObserverFn>,
478}
479
480impl ShutdownObserver {
481 pub fn new<F: ShutdownObserverFn>(f: F) -> Self {
484 let shutdown_fn = Box::new(f);
485 Self {
486 observer: fdf_dispatcher_shutdown_observer { handler: Some(Self::handler) },
487 shutdown_fn,
488 }
489 }
490
491 pub fn into_ptr(self) -> *mut fdf_dispatcher_shutdown_observer {
495 Box::leak(Box::new(self)) as *mut _ as *mut _
498 }
499
500 unsafe extern "C" fn handler(
510 dispatcher: *mut fdf_dispatcher_t,
511 observer: *mut fdf_dispatcher_shutdown_observer_t,
512 ) {
513 let observer = unsafe { Box::from_raw(observer as *mut ShutdownObserver) };
516 let dispatcher_ref = DispatcherRef(
518 ManuallyDrop::new(Dispatcher(unsafe { NonNull::new_unchecked(dispatcher) })),
519 PhantomData,
520 );
521 (observer.shutdown_fn)(dispatcher_ref);
522 unsafe { fdf_dispatcher_destroy(dispatcher) };
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 use std::sync::{mpsc, Once};
533
534 use futures::channel::mpsc as async_mpsc;
535 use futures::{SinkExt, StreamExt};
536
537 use core::ffi::{c_char, c_void};
538 use core::ptr::null_mut;
539
540 static GLOBAL_DRIVER_ENV: Once = Once::new();
541
542 pub fn ensure_driver_env() {
543 GLOBAL_DRIVER_ENV.call_once(|| {
544 unsafe {
547 assert_eq!(fdf_env_start(0), ZX_OK);
548 }
549 });
550 }
551 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
552 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
553 }
554
555 pub(crate) fn with_raw_dispatcher_flags<T>(
556 name: &str,
557 flags: u32,
558 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
559 ) -> T {
560 ensure_driver_env();
561
562 let (shutdown_tx, shutdown_rx) = mpsc::channel();
563 let mut dispatcher = null_mut();
564 let mut observer = ShutdownObserver::new(move |dispatcher| {
565 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0 .0.as_ptr()) });
568 shutdown_tx.send(()).unwrap();
569 })
570 .into_ptr();
571 let driver_ptr = &mut observer as *mut _ as *mut c_void;
572 let res = unsafe {
577 fdf_env_dispatcher_create_with_owner(
578 driver_ptr,
579 flags,
580 name.as_ptr() as *const c_char,
581 name.len(),
582 "".as_ptr() as *const c_char,
583 0 as usize,
584 observer,
585 &mut dispatcher,
586 )
587 };
588 assert_eq!(res, ZX_OK);
589 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
590
591 let res = p(Arc::downgrade(&dispatcher));
592
593 let weak_dispatcher = Arc::downgrade(&dispatcher);
597 drop(dispatcher);
598 shutdown_rx.recv().unwrap();
599 assert_eq!(
600 0,
601 weak_dispatcher.strong_count(),
602 "a dispatcher reference escaped the test body"
603 );
604
605 res
606 }
607
608 #[test]
609 fn start_test_dispatcher() {
610 with_raw_dispatcher("testing", |dispatcher| {
611 println!("hello {dispatcher:?}");
612 })
613 }
614
615 #[test]
616 fn post_task_on_dispatcher() {
617 with_raw_dispatcher("testing task", |dispatcher| {
618 let (tx, rx) = mpsc::channel();
619 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
620 dispatcher
621 .post_task_sync(move |status| {
622 assert_eq!(status, Status::from_raw(ZX_OK));
623 tx.send(status).unwrap();
624 })
625 .unwrap();
626 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
627 });
628 }
629
630 #[test]
631 fn post_task_on_subdispatcher() {
632 let (shutdown_tx, shutdown_rx) = mpsc::channel();
633 with_raw_dispatcher("testing task top level", move |dispatcher| {
634 let (tx, rx) = mpsc::channel();
635 let (inner_tx, inner_rx) = mpsc::channel();
636 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
637 dispatcher
638 .post_task_sync(move |status| {
639 assert_eq!(status, Status::from_raw(ZX_OK));
640 let inner = DispatcherBuilder::new()
641 .name("testing task second level")
642 .scheduler_role("")
643 .allow_thread_blocking()
644 .shutdown_observer(move |_dispatcher| {
645 println!("shutdown observer called");
646 shutdown_tx.send(1).unwrap();
647 })
648 .create()
649 .unwrap();
650 inner
651 .post_task_sync(move |status| {
652 assert_eq!(status, Status::from_raw(ZX_OK));
653 tx.send(status).unwrap();
654 })
655 .unwrap();
656 inner_tx.send(inner).unwrap();
660 })
661 .unwrap();
662 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
663 inner_rx.recv().unwrap();
664 });
665 assert_eq!(shutdown_rx.recv().unwrap(), 1);
666 }
667
668 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
669 println!("starting ping!");
670 tx.send(0).await.unwrap();
671 while let Some(next) = rx.next().await {
672 println!("ping! {next}");
673 tx.send(next + 1).await.unwrap();
674 }
675 }
676
677 async fn pong(
678 fin_tx: std::sync::mpsc::Sender<()>,
679 mut tx: async_mpsc::Sender<u8>,
680 mut rx: async_mpsc::Receiver<u8>,
681 ) {
682 println!("starting pong!");
683 while let Some(next) = rx.next().await {
684 println!("pong! {next}");
685 if next > 10 {
686 println!("bye!");
687 break;
688 }
689 tx.send(next + 1).await.unwrap();
690 }
691 fin_tx.send(()).unwrap();
692 }
693
694 #[test]
695 fn async_ping_pong() {
696 with_raw_dispatcher("async ping pong", |dispatcher| {
697 let (fin_tx, fin_rx) = mpsc::channel();
698 let (ping_tx, pong_rx) = async_mpsc::channel(10);
699 let (pong_tx, ping_rx) = async_mpsc::channel(10);
700 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
701 dispatcher.spawn_task(pong(fin_tx, pong_tx, pong_rx)).unwrap();
702
703 fin_rx.recv().expect("to receive final value");
704 });
705 }
706
707 async fn slow_pong(
708 fin_tx: std::sync::mpsc::Sender<()>,
709 mut tx: async_mpsc::Sender<u8>,
710 mut rx: async_mpsc::Receiver<u8>,
711 ) {
712 use zx::MonotonicDuration;
713 println!("starting pong!");
714 while let Some(next) = rx.next().await {
715 println!("pong! {next}");
716 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
717 MonotonicDuration::from_seconds(1),
718 ))
719 .await;
720 if next > 10 {
721 println!("bye!");
722 break;
723 }
724 tx.send(next + 1).await.unwrap();
725 }
726 fin_tx.send(()).unwrap();
727 }
728
729 #[test]
730 fn mixed_executor_async_ping_pong() {
731 with_raw_dispatcher("async ping pong", |dispatcher| {
732 let (fin_tx, fin_rx) = mpsc::channel();
733 let (ping_tx, pong_rx) = async_mpsc::channel(10);
734 let (pong_tx, ping_rx) = async_mpsc::channel(10);
735
736 dispatcher.spawn_task(ping(ping_tx, ping_rx)).unwrap();
738
739 let mut executor = fuchsia_async::LocalExecutor::new();
741 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
742
743 fin_rx.recv().expect("to receive final value");
744 });
745 }
746}