1use fdf_sys::*;
8
9use core::cell::RefCell;
10use core::ffi;
11use core::marker::PhantomData;
12use core::mem::ManuallyDrop;
13use core::ptr::{NonNull, null_mut};
14use std::sync::atomic::{AtomicPtr, Ordering};
15use std::sync::{Arc, Weak};
16
17use zx::Status;
18
19use crate::shutdown_observer::ShutdownObserver;
20
21pub use fdf_sys::fdf_dispatcher_t;
22pub use libasync::{
23 AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, JoinHandle, OnDispatcher, Task,
24};
25
26pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30#[derive(Default)]
32pub struct DispatcherBuilder {
33 #[doc(hidden)]
34 pub options: u32,
35 #[doc(hidden)]
36 pub name: String,
37 #[doc(hidden)]
38 pub scheduler_role: String,
39 #[doc(hidden)]
40 pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44 pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46 pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48
49 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn unsynchronized(mut self) -> Self {
62 assert!(
63 !self.allows_thread_blocking(),
64 "you may not create an unsynchronized dispatcher that allows synchronous calls"
65 );
66 self.options |= Self::UNSYNCHRONIZED;
67 self
68 }
69
70 pub fn is_unsynchronized(&self) -> bool {
72 (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
73 }
74
75 pub fn allow_thread_blocking(mut self) -> Self {
81 assert!(
82 !self.is_unsynchronized(),
83 "you may not create an unsynchronized dispatcher that allows synchronous calls"
84 );
85 self.options |= Self::ALLOW_THREAD_BLOCKING;
86 self
87 }
88
89 pub fn allows_thread_blocking(&self) -> bool {
91 (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
92 }
93
94 pub fn name(mut self, name: &str) -> Self {
97 self.name = name.to_string();
98 self
99 }
100
101 pub fn scheduler_role(mut self, role: &str) -> Self {
105 self.scheduler_role = role.to_string();
106 self
107 }
108
109 pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
111 self.shutdown_observer = Some(Box::new(shutdown_observer));
112 self
113 }
114
115 pub fn create(self) -> Result<Dispatcher, Status> {
120 let mut out_dispatcher = null_mut();
121 let options = self.options;
122 let name = self.name.as_ptr() as *mut ffi::c_char;
123 let name_len = self.name.len();
124 let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
125 let scheduler_role_len = self.scheduler_role.len();
126 let observer =
127 ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
128 .into_ptr();
129 Status::ok(unsafe {
133 fdf_dispatcher_create(
134 options,
135 name,
136 name_len,
137 scheduler_role,
138 scheduler_role_len,
139 observer,
140 &mut out_dispatcher,
141 )
142 })?;
143 Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
146 }
147
148 pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
152 self.create().map(Dispatcher::release)
153 }
154}
155
156#[derive(Debug)]
158pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
159
160unsafe impl Send for Dispatcher {}
162unsafe impl Sync for Dispatcher {}
163thread_local! {
164 pub(crate) static OVERRIDE_DISPATCHER: RefCell<Option<NonNull<fdf_dispatcher_t>>> = const { RefCell::new(None) };
165}
166
167impl Dispatcher {
168 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
176 Self(handle)
177 }
178
179 fn get_raw_flags(&self) -> u32 {
180 unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
182 }
183
184 pub fn is_unsynchronized(&self) -> bool {
186 (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
187 }
188
189 pub fn allows_thread_blocking(&self) -> bool {
191 (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
192 }
193
194 pub fn is_current_dispatcher(&self) -> bool {
196 self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
199 }
200
201 pub fn release(self) -> DispatcherRef<'static> {
206 DispatcherRef(ManuallyDrop::new(self), PhantomData)
207 }
208
209 pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
212 DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
213 }
214}
215
216impl AsyncDispatcher for Dispatcher {
217 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
218 let async_dispatcher =
219 NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
220 .expect("No async dispatcher on driver dispatcher");
221 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
222 }
223}
224
225impl Drop for Dispatcher {
226 fn drop(&mut self) {
227 unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
230 }
231}
232
233#[derive(Debug)]
243pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
244
245impl AutoReleaseDispatcher {
246 pub fn downgrade(&self) -> WeakDispatcher {
250 WeakDispatcher::from(self)
251 }
252}
253
254impl AsyncDispatcher for AutoReleaseDispatcher {
255 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
256 let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
257 .expect("tried to obtain async dispatcher after drop");
258 unsafe {
261 AsyncDispatcherRef::from_raw(
262 NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
263 )
264 }
265 }
266}
267
268impl From<Dispatcher> for AutoReleaseDispatcher {
269 fn from(dispatcher: Dispatcher) -> Self {
270 let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
271 Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
272 }
273}
274
275impl Drop for AutoReleaseDispatcher {
276 fn drop(&mut self) {
277 self.0.store(null_mut(), Ordering::Relaxed);
280 while Arc::strong_count(&self.0) > 1 {
284 std::thread::sleep(std::time::Duration::from_nanos(100))
287 }
288 }
289}
290
291#[derive(Clone, Debug)]
299pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
300
301impl From<&AutoReleaseDispatcher> for WeakDispatcher {
302 fn from(value: &AutoReleaseDispatcher) -> Self {
303 Self(Arc::downgrade(&value.0))
304 }
305}
306
307impl OnDispatcher for WeakDispatcher {
308 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
309 let Some(dispatcher_ptr) = self.0.upgrade() else {
310 return f(None);
311 };
312 let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
313 return f(None);
314 };
315 f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
319 }
320}
321
322#[derive(Debug)]
326pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
327
328impl<'a> DispatcherRef<'a> {
329 pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
336 Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
338 }
339
340 pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
347 let handle = NonNull::new(unsafe {
348 fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
349 })
350 .unwrap();
351 unsafe { Self::from_raw(handle) }
352 }
353}
354
355impl<'a> AsyncDispatcher for DispatcherRef<'a> {
356 fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
357 self.0.as_async_dispatcher_ref()
358 }
359}
360
361impl<'a> Clone for DispatcherRef<'a> {
362 fn clone(&self) -> Self {
363 Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
364 }
365}
366
367impl<'a> core::ops::Deref for DispatcherRef<'a> {
368 type Target = Dispatcher;
369 fn deref(&self) -> &Self::Target {
370 &self.0
371 }
372}
373
374impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
375 fn deref_mut(&mut self) -> &mut Self::Target {
376 &mut self.0
377 }
378}
379
380impl<'a> OnDispatcher for DispatcherRef<'a> {
381 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
382 f(Some(self.as_async_dispatcher_ref()))
383 }
384}
385
386#[derive(Clone, Copy, Debug, Default, PartialEq)]
389pub struct CurrentDispatcher;
390
391impl OnDispatcher for CurrentDispatcher {
392 fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
393 let dispatcher = OVERRIDE_DISPATCHER
394 .with(|global| *global.borrow())
395 .or_else(|| {
396 NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
398 })
399 .map(|dispatcher| {
400 let async_dispatcher = NonNull::new(unsafe {
406 fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())
407 })
408 .expect("No async dispatcher on driver dispatcher");
409 unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
410 });
411 f(dispatcher)
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 use std::sync::{Arc, Once, Weak, mpsc};
420
421 use futures::channel::mpsc as async_mpsc;
422 use futures::{SinkExt, StreamExt};
423 use zx::sys::ZX_OK;
424
425 use core::ffi::{c_char, c_void};
426 use core::ptr::null_mut;
427
428 static GLOBAL_DRIVER_ENV: Once = Once::new();
429
430 pub fn ensure_driver_env() {
431 GLOBAL_DRIVER_ENV.call_once(|| {
432 unsafe {
435 assert_eq!(fdf_env_start(0), ZX_OK);
436 }
437 });
438 }
439 pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
440 with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, p)
441 }
442
443 pub(crate) fn with_raw_dispatcher_flags<T>(
444 name: &str,
445 flags: u32,
446 p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
447 ) -> T {
448 ensure_driver_env();
449
450 let (shutdown_tx, shutdown_rx) = mpsc::channel();
451 let mut dispatcher = null_mut();
452 let mut observer = ShutdownObserver::new(move |dispatcher| {
453 assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
456 shutdown_tx.send(()).unwrap();
457 })
458 .into_ptr();
459 let driver_ptr = &mut observer as *mut _ as *mut c_void;
460 let res = unsafe {
465 fdf_env_dispatcher_create_with_owner(
466 driver_ptr,
467 flags,
468 name.as_ptr() as *const c_char,
469 name.len(),
470 "".as_ptr() as *const c_char,
471 0_usize,
472 observer,
473 &mut dispatcher,
474 )
475 };
476 assert_eq!(res, ZX_OK);
477 let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
478
479 let res = p(Arc::downgrade(&dispatcher));
480
481 let weak_dispatcher = Arc::downgrade(&dispatcher);
485 drop(dispatcher);
486 shutdown_rx.recv().unwrap();
487 assert_eq!(
488 0,
489 weak_dispatcher.strong_count(),
490 "a dispatcher reference escaped the test body"
491 );
492
493 res
494 }
495
496 #[test]
497 fn start_test_dispatcher() {
498 with_raw_dispatcher("testing", |dispatcher| {
499 println!("hello {dispatcher:?}");
500 })
501 }
502
503 #[test]
504 fn post_task_on_dispatcher() {
505 with_raw_dispatcher("testing task", |dispatcher| {
506 let (tx, rx) = mpsc::channel();
507 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
508 dispatcher
509 .post_task_sync(move |status| {
510 assert_eq!(status, Status::from_raw(ZX_OK));
511 tx.send(status).unwrap();
512 })
513 .unwrap();
514 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
515 });
516 }
517
518 #[test]
519 fn post_task_on_subdispatcher() {
520 let (shutdown_tx, shutdown_rx) = mpsc::channel();
521 with_raw_dispatcher("testing task top level", move |dispatcher| {
522 let (tx, rx) = mpsc::channel();
523 let (inner_tx, inner_rx) = mpsc::channel();
524 let dispatcher = Weak::upgrade(&dispatcher).unwrap();
525 dispatcher
526 .post_task_sync(move |status| {
527 assert_eq!(status, Status::from_raw(ZX_OK));
528 let inner = DispatcherBuilder::new()
529 .name("testing task second level")
530 .scheduler_role("")
531 .allow_thread_blocking()
532 .shutdown_observer(move |_dispatcher| {
533 println!("shutdown observer called");
534 shutdown_tx.send(1).unwrap();
535 })
536 .create()
537 .unwrap();
538 inner
539 .post_task_sync(move |status| {
540 assert_eq!(status, Status::from_raw(ZX_OK));
541 tx.send(status).unwrap();
542 })
543 .unwrap();
544 inner_tx.send(inner).unwrap();
548 })
549 .unwrap();
550 assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
551 inner_rx.recv().unwrap();
552 });
553 assert_eq!(shutdown_rx.recv().unwrap(), 1);
554 }
555
556 async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
557 println!("starting ping!");
558 tx.send(0).await.unwrap();
559 while let Some(next) = rx.next().await {
560 println!("ping! {next}");
561 tx.send(next + 1).await.unwrap();
562 }
563 }
564
565 async fn pong(
566 fin_tx: std::sync::mpsc::Sender<()>,
567 mut tx: async_mpsc::Sender<u8>,
568 mut rx: async_mpsc::Receiver<u8>,
569 ) {
570 println!("starting pong!");
571 while let Some(next) = rx.next().await {
572 println!("pong! {next}");
573 if next > 10 {
574 println!("bye!");
575 break;
576 }
577 tx.send(next + 1).await.unwrap();
578 }
579 fin_tx.send(()).unwrap();
580 }
581
582 #[test]
583 fn async_ping_pong() {
584 with_raw_dispatcher("async ping pong", |dispatcher| {
585 let (fin_tx, fin_rx) = mpsc::channel();
586 let (ping_tx, pong_rx) = async_mpsc::channel(10);
587 let (pong_tx, ping_rx) = async_mpsc::channel(10);
588 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
589 dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
590
591 fin_rx.recv().expect("to receive final value");
592 });
593 }
594
595 async fn slow_pong(
596 fin_tx: std::sync::mpsc::Sender<()>,
597 mut tx: async_mpsc::Sender<u8>,
598 mut rx: async_mpsc::Receiver<u8>,
599 ) {
600 use zx::MonotonicDuration;
601 println!("starting pong!");
602 while let Some(next) = rx.next().await {
603 println!("pong! {next}");
604 fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
605 MonotonicDuration::from_seconds(1),
606 ))
607 .await;
608 if next > 10 {
609 println!("bye!");
610 break;
611 }
612 tx.send(next + 1).await.unwrap();
613 }
614 fin_tx.send(()).unwrap();
615 }
616
617 #[test]
618 fn mixed_executor_async_ping_pong() {
619 with_raw_dispatcher("async ping pong", |dispatcher| {
620 let (fin_tx, fin_rx) = mpsc::channel();
621 let (ping_tx, pong_rx) = async_mpsc::channel(10);
622 let (pong_tx, ping_rx) = async_mpsc::channel(10);
623
624 dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
626
627 let mut executor = fuchsia_async::LocalExecutor::default();
629 executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
630
631 fin_rx.recv().expect("to receive final value");
632 });
633 }
634}