Skip to main content

fdf_core/
dispatcher.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the driver runtime dispatcher stable ABI
6
7use fdf_sys::*;
8
9use core::ffi;
10use core::marker::PhantomData;
11use core::mem::ManuallyDrop;
12use core::ptr::{NonNull, null_mut};
13use std::sync::atomic::{AtomicPtr, Ordering};
14use std::sync::{Arc, Weak};
15
16use zx::Status;
17
18use crate::shutdown_observer::ShutdownObserver;
19
20pub use fdf_sys::fdf_dispatcher_t;
21pub use libasync::{
22    AfterDeadline, AsyncDispatcher, AsyncDispatcherRef, CurrentDispatcher, JoinHandle,
23    OnDispatcher, Task,
24};
25
26/// A marker trait for a function type that can be used as a shutdown observer for [`Dispatcher`].
27pub trait ShutdownObserverFn: FnOnce(DispatcherRef<'_>) + Send + 'static {}
28impl<T> ShutdownObserverFn for T where T: FnOnce(DispatcherRef<'_>) + Send + 'static {}
29
30/// A builder for [`Dispatcher`]s
31#[derive(Default)]
32pub struct DispatcherBuilder {
33    #[doc(hidden)]
34    pub options: u32,
35    #[doc(hidden)]
36    pub name: String,
37    #[doc(hidden)]
38    pub scheduler_role: String,
39    #[doc(hidden)]
40    pub shutdown_observer: Option<Box<dyn ShutdownObserverFn>>,
41}
42
43impl DispatcherBuilder {
44    /// See `FDF_DISPATCHER_OPTION_UNSYNCHRONIZED` in the C API
45    pub(crate) const UNSYNCHRONIZED: u32 = fdf_sys::FDF_DISPATCHER_OPTION_UNSYNCHRONIZED;
46    /// See `FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS` in the C API
47    pub(crate) const ALLOW_THREAD_BLOCKING: u32 = fdf_sys::FDF_DISPATCHER_OPTION_ALLOW_SYNC_CALLS;
48    /// See `FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION` in the C API
49    pub(crate) const NO_THREAD_MIGRATION: u32 = fdf_sys::FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION;
50
51    /// Creates a new [`DispatcherBuilder`] that can be used to configure a new dispatcher.
52    /// For more information on the threading-related flags for the dispatcher, see
53    /// https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Sets whether parallel callbacks in the callbacks set in the dispatcher are allowed. May
59    /// not be set with [`Self::allow_thread_blocking`].
60    ///
61    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
62    /// for more information on the threading model of driver dispatchers.
63    pub fn unsynchronized(mut self) -> Self {
64        assert!(
65            !self.allows_thread_blocking(),
66            "you may not create an unsynchronized dispatcher that allows synchronous calls"
67        );
68        self.options |= Self::UNSYNCHRONIZED;
69        self
70    }
71
72    /// Whether or not this is an unsynchronized dispatcher
73    pub fn is_unsynchronized(&self) -> bool {
74        (self.options & Self::UNSYNCHRONIZED) == Self::UNSYNCHRONIZED
75    }
76
77    /// This dispatcher may not share zircon threads with other drivers. May not be set with
78    /// [`Self::unsynchronized`].
79    ///
80    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
81    /// for more information on the threading model of driver dispatchers.
82    pub fn allow_thread_blocking(mut self) -> Self {
83        assert!(
84            !self.is_unsynchronized(),
85            "you may not create an unsynchronized dispatcher that allows synchronous calls"
86        );
87        self.options |= Self::ALLOW_THREAD_BLOCKING;
88        self
89    }
90
91    /// Whether or not this dispatcher allows synchronous calls
92    pub fn allows_thread_blocking(&self) -> bool {
93        (self.options & Self::ALLOW_THREAD_BLOCKING) == Self::ALLOW_THREAD_BLOCKING
94    }
95
96    /// This dispatcher may not run on more than one thread. This can only be set if the
97    /// dispatcher is being run on a scheduler role that does not allow sync calls on
98    /// any of its dispatchers.
99    ///
100    /// See https://fuchsia.dev/fuchsia-src/concepts/drivers/driver-dispatcher-and-threads
101    /// for more information on the threading model of driver dispatchers.
102    pub fn no_thread_migration(mut self) -> Self {
103        self.options |= Self::NO_THREAD_MIGRATION;
104        self
105    }
106
107    /// Whether or not this dispatcher is allowed to run on multiple threads
108    pub fn allows_thread_migration(&self) -> bool {
109        (self.options & Self::NO_THREAD_MIGRATION) == 0
110    }
111
112    /// A descriptive name for this dispatcher that is used in debug output and process
113    /// lists.
114    pub fn name(mut self, name: &str) -> Self {
115        self.name = name.to_string();
116        self
117    }
118
119    /// A hint string for the runtime that may or may not impact the priority the work scheduled
120    /// by this dispatcher is handled at. It may or may not impact the ability for other drivers
121    /// to share zircon threads with the dispatcher.
122    pub fn scheduler_role(mut self, role: &str) -> Self {
123        self.scheduler_role = role.to_string();
124        self
125    }
126
127    /// A callback to be called before after the dispatcher has completed asynchronous shutdown.
128    pub fn shutdown_observer<F: ShutdownObserverFn>(mut self, shutdown_observer: F) -> Self {
129        self.shutdown_observer = Some(Box::new(shutdown_observer));
130        self
131    }
132
133    /// Create the dispatcher as configured by this object. This must be called from a
134    /// thread managed by the driver runtime. The dispatcher returned is owned by the caller,
135    /// and will initiate asynchronous shutdown when the object is dropped unless
136    /// [`Dispatcher::release`] is called on it to convert it into an unowned [`DispatcherRef`].
137    pub fn create(self) -> Result<Dispatcher, Status> {
138        let mut out_dispatcher = null_mut();
139        let options = self.options;
140        let name = self.name.as_ptr() as *mut ffi::c_char;
141        let name_len = self.name.len();
142        let scheduler_role = self.scheduler_role.as_ptr() as *mut ffi::c_char;
143        let scheduler_role_len = self.scheduler_role.len();
144        let observer =
145            ShutdownObserver::new(self.shutdown_observer.unwrap_or_else(|| Box::new(|_| {})))
146                .into_ptr();
147        // SAFETY: all arguments point to memory that will be available for the duration
148        // of the call, except `observer`, which will be available until it is unallocated
149        // by the dispatcher exit handler.
150        Status::ok(unsafe {
151            fdf_dispatcher_create(
152                options,
153                name,
154                name_len,
155                scheduler_role,
156                scheduler_role_len,
157                observer,
158                &mut out_dispatcher,
159            )
160        })?;
161        // SAFETY: `out_dispatcher` is valid by construction if `fdf_dispatcher_create` returns
162        // ZX_OK.
163        Ok(Dispatcher(unsafe { NonNull::new_unchecked(out_dispatcher) }))
164    }
165
166    /// As with [`Self::create`], this creates a new dispatcher as configured by this object, but
167    /// instead of returning an owned reference it immediately releases the reference to be
168    /// managed by the driver runtime.
169    pub fn create_released(self) -> Result<DispatcherRef<'static>, Status> {
170        self.create().map(Dispatcher::release)
171    }
172}
173
174/// An owned handle for a dispatcher managed by the driver runtime.
175#[derive(Debug)]
176pub struct Dispatcher(pub(crate) NonNull<fdf_dispatcher_t>);
177
178// SAFETY: The api of fdf_dispatcher_t is thread safe.
179unsafe impl Send for Dispatcher {}
180unsafe impl Sync for Dispatcher {}
181
182impl Dispatcher {
183    /// Creates a dispatcher ref from a raw handle.
184    ///
185    /// # Safety
186    ///
187    /// Caller is responsible for ensuring that the given handle is valid and
188    /// not owned by any other wrapper that will free it at an arbitrary
189    /// time.
190    pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
191        Self(handle)
192    }
193
194    fn get_raw_flags(&self) -> u32 {
195        // SAFETY: the inner fdf_dispatcher_t is valid by construction
196        unsafe { fdf_dispatcher_get_options(self.0.as_ptr()) }
197    }
198
199    /// Whether this dispatcher's tasks and futures can run on multiple threads at the same time.
200    pub fn is_unsynchronized(&self) -> bool {
201        (self.get_raw_flags() & DispatcherBuilder::UNSYNCHRONIZED) != 0
202    }
203
204    /// Whether this dispatcher is allowed to call blocking functions or not
205    pub fn allows_thread_blocking(&self) -> bool {
206        (self.get_raw_flags() & DispatcherBuilder::ALLOW_THREAD_BLOCKING) != 0
207    }
208
209    /// Whether this dispatcher is allowed to migrate threads, in which case it can't
210    /// be used for non-[`Send`] tasks.
211    pub fn allows_thread_migration(&self) -> bool {
212        (self.get_raw_flags() & DispatcherBuilder::NO_THREAD_MIGRATION) == 0
213    }
214
215    /// Whether this is the dispatcher the current thread is running on
216    pub fn is_current_dispatcher(&self) -> bool {
217        // SAFETY: we don't do anything with the dispatcher pointer, and NULL is returned if this
218        // isn't a dispatcher-managed thread.
219        self.0.as_ptr() == unsafe { fdf_dispatcher_get_current_dispatcher() }
220    }
221
222    /// Releases ownership over this dispatcher and returns a [`DispatcherRef`]
223    /// that can be used to access it. The lifetime of this reference is static because it will
224    /// exist so long as this current driver is loaded, but the driver runtime will shut it down
225    /// when the driver is unloaded.
226    pub fn release(self) -> DispatcherRef<'static> {
227        DispatcherRef(ManuallyDrop::new(self), PhantomData)
228    }
229
230    /// Returns a [`DispatcherRef`] that references this dispatcher with a lifetime constrained by
231    /// `self`.
232    pub fn as_dispatcher_ref(&self) -> DispatcherRef<'_> {
233        DispatcherRef(ManuallyDrop::new(Dispatcher(self.0)), PhantomData)
234    }
235}
236
237impl AsyncDispatcher for Dispatcher {
238    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
239        let async_dispatcher =
240            NonNull::new(unsafe { fdf_dispatcher_get_async_dispatcher(self.0.as_ptr()) })
241                .expect("No async dispatcher on driver dispatcher");
242        unsafe { AsyncDispatcherRef::from_raw(async_dispatcher) }
243    }
244}
245
246impl Drop for Dispatcher {
247    fn drop(&mut self) {
248        // SAFETY: we only ever provide an owned `Dispatcher` to one owner, so when
249        // that one is dropped we can invoke the shutdown of the dispatcher
250        unsafe { fdf_dispatcher_shutdown_async(self.0.as_mut()) }
251    }
252}
253
254/// An owned reference to a driver runtime dispatcher that auto-releases when dropped. This gives
255/// you the best of both worlds of having an `Arc<Dispatcher>` and a `DispatcherRef<'static>`
256/// created by [`Dispatcher::release`]:
257///
258/// - You can vend [`Weak`]-like pointers to it that will not cause memory access errors if used
259///   after the dispatcher has shut down, like an [`Arc`].
260/// - You can tie its terminal lifetime to that of the driver itself.
261///
262/// This is particularly useful in tests.
263#[derive(Debug)]
264pub struct AutoReleaseDispatcher(Arc<AtomicPtr<fdf_dispatcher>>);
265
266impl AutoReleaseDispatcher {
267    /// Returns a weakened reference to this dispatcher. This weak reference will only be valid so
268    /// long as the [`AutoReleaseDispatcher`] object that spawned it is alive, after which it will
269    /// no longer be usable to spawn tasks on.
270    pub fn downgrade(&self) -> WeakDispatcher {
271        WeakDispatcher::from(self)
272    }
273}
274
275impl AsyncDispatcher for AutoReleaseDispatcher {
276    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
277        let dispatcher = NonNull::new(self.0.load(Ordering::Relaxed))
278            .expect("tried to obtain async dispatcher after drop");
279        // SAFETY: the validity of this dispatcher is ensured by use of NonNull above and this
280        // object's exclusive ownership over the dispatcher while it's alive.
281        unsafe {
282            AsyncDispatcherRef::from_raw(
283                NonNull::new(fdf_dispatcher_get_async_dispatcher(dispatcher.as_ptr())).unwrap(),
284            )
285        }
286    }
287}
288
289impl From<Dispatcher> for AutoReleaseDispatcher {
290    fn from(dispatcher: Dispatcher) -> Self {
291        let dispatcher_ptr = dispatcher.release().0.0.as_ptr();
292        Self(Arc::new(AtomicPtr::new(dispatcher_ptr)))
293    }
294}
295
296impl Drop for AutoReleaseDispatcher {
297    fn drop(&mut self) {
298        // Store nullptr into the atomic so that any future attempts to obtain a strong reference
299        // through a WeakDispatcher will not successfully upgrade.
300        self.0.store(null_mut(), Ordering::Relaxed);
301        // We want to allow for any outstanding `on_dispatcher` calls to finish before returning
302        // from drop, so we're going to loop until the strong reference count goes down to zero,
303        // after which any future attempts to call `on_dispatcher` on a `WeakDispatcher` will fail.
304        while Arc::strong_count(&self.0) > 1 {
305            // This sleep is kind of gross, but it should happen extremely rarely and
306            // `on_dispatcher` calls should not be performing any blocking work.
307            std::thread::sleep(std::time::Duration::from_nanos(100))
308        }
309    }
310}
311
312/// An unowned but reference counted reference to a dispatcher. This would usually come from
313/// an [`AutoReleaseDispatcher`] reference to a dispatcher.
314///
315/// The advantage to using this instead of using [`Weak`] directly is that it controls the lifetime
316/// of any given strong reference to the dispatcher, since the only way to access that strong
317/// reference is through [`OnDispatcher::on_dispatcher`]. This makes it much easier to be sure
318/// that you aren't leaving any dangling strong references to the dispatcher object around.
319#[derive(Clone, Debug)]
320pub struct WeakDispatcher(Weak<AtomicPtr<fdf_dispatcher>>);
321
322impl From<&AutoReleaseDispatcher> for WeakDispatcher {
323    fn from(value: &AutoReleaseDispatcher) -> Self {
324        Self(Arc::downgrade(&value.0))
325    }
326}
327
328impl OnDispatcher for WeakDispatcher {
329    fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
330        let Some(dispatcher_ptr) = self.0.upgrade() else {
331            return f(None);
332        };
333        let Some(dispatcher) = NonNull::new(dispatcher_ptr.load(Ordering::Relaxed)) else {
334            return f(None);
335        };
336        // SAFETY: As long as we hold the strong reference in dispatcher_ptr, the
337        // AutoReleaseDispatcher will not allow its drop to finish and the dispatcher should still
338        // be valid.
339        f(Some(unsafe { DispatcherRef::from_raw(dispatcher) }.as_async_dispatcher_ref()))
340    }
341}
342
343impl OnDriverDispatcher for WeakDispatcher {}
344
345/// An unowned reference to a driver runtime dispatcher such as is produced by calling
346/// [`Dispatcher::release`]. When this object goes out of scope it won't shut down the dispatcher,
347/// leaving that up to the driver runtime or another owner.
348#[derive(Debug)]
349pub struct DispatcherRef<'a>(ManuallyDrop<Dispatcher>, PhantomData<&'a Dispatcher>);
350
351impl<'a> DispatcherRef<'a> {
352    /// Creates a dispatcher ref from a raw handle.
353    ///
354    /// # Safety
355    ///
356    /// Caller is responsible for ensuring that the given handle is valid for
357    /// the lifetime `'a`.
358    pub unsafe fn from_raw(handle: NonNull<fdf_dispatcher_t>) -> Self {
359        // SAFETY: Caller promises the handle is valid.
360        Self(ManuallyDrop::new(unsafe { Dispatcher::from_raw(handle) }), PhantomData)
361    }
362
363    /// Creates a dispatcher ref from an [`AsyncDispatcherRef`].
364    ///
365    /// # Panics
366    ///
367    /// Note that this will cause an assert if the [`AsyncDispatcherRef`] was not created from a
368    /// driver dispatcher in the first place.
369    pub fn from_async_dispatcher(dispatcher: AsyncDispatcherRef<'a>) -> Self {
370        let handle = NonNull::new(unsafe {
371            fdf_dispatcher_downcast_async_dispatcher(dispatcher.inner().as_ptr())
372        })
373        .unwrap();
374        unsafe { Self::from_raw(handle) }
375    }
376
377    /// Creates an [`AsyncDispatcherRef`] to this dispatcher with the same lifetime.
378    pub fn as_async_dispatcher(&self) -> AsyncDispatcherRef<'a> {
379        // SAFETY: The dispatcher referenced is valid by construction, and it should always
380        // be possible to get an async dispatcher from a valid driver dispatcher.
381        let handle = unsafe { fdf_dispatcher_get_async_dispatcher(self.0.0.as_ptr()) };
382        unsafe { AsyncDispatcherRef::from_raw(NonNull::new(handle).unwrap()) }
383    }
384
385    /// Gets the raw handle from this dispatcher ref.
386    ///
387    /// # Safety
388    ///
389    /// Caller is responsible for ensuring that the dispatcher handle is used safely.
390    pub unsafe fn as_raw(&mut self) -> *mut fdf_dispatcher_t {
391        unsafe { self.0.0.as_mut() }
392    }
393}
394
395/// Used to wrap a non-send future as send when we've dynamically checked that the dispatcher
396/// we're going to spawn it on is non-[`Send`]-safe.
397///
398/// This should only ever be used after validating that the dispatcher is the currently running
399/// one and that the dispatcher does not migrate threads.
400///
401/// This is an internal implementation detail and should never be made public.
402struct AddSendFuture<T>(T);
403
404impl<T: Future> Future for AddSendFuture<T> {
405    type Output = T::Output;
406
407    fn poll(
408        self: std::pin::Pin<&mut Self>,
409        cx: &mut std::task::Context<'_>,
410    ) -> std::task::Poll<Self::Output> {
411        // SAFETY: self.0 is pinned if self is.
412        let fut = unsafe { self.map_unchecked_mut(|fut| &mut fut.0) };
413        fut.poll(cx)
414    }
415}
416
417// SAFETY: We are forcing this future to be [`Send`] even though the inner future is not because
418// we validate at runtime before spawning the task that the dispatcher is correctly configured to
419// do the right thing with it.
420unsafe impl<T> Send for AddSendFuture<T> {}
421
422/// Makes available additional functionality available on driver dispatchers on top of what's
423/// available on [`OnDispatcher`].
424pub trait OnDriverDispatcher: OnDispatcher {
425    /// Spawn an asynchronous local task on this dispatcher. If this returns [`Ok`] then the task
426    /// has successfully been scheduled and will run or be cancelled and dropped when the dispatcher
427    /// shuts down. The returned future's result will be [`Ok`] if the future completed
428    /// successfully, or an [`Err`] if the task did not complete for some reason (like the
429    /// dispatcher shut down).
430    ///
431    /// Unlike [`OnDispatcher::spawn`], this will accept a future that does not implement [`Send`]. If
432    /// called from a thread other than the one the dispatcher is running on or the dispatcher
433    /// is not guaranteed to always poll from the same thread, this will return
434    /// [`Status::BAD_STATE`].
435    ///
436    /// Returns a [`JoinHandle`] that will detach the future when dropped.
437    fn spawn_local(
438        &self,
439        future: impl Future<Output = ()> + 'static,
440    ) -> Result<JoinHandle<()>, Status>
441    where
442        Self: 'static,
443    {
444        self.on_maybe_dispatcher(|dispatcher| {
445            let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
446            if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
447                OnDispatcher::spawn(self, AddSendFuture(future))
448            } else {
449                Err(Status::BAD_STATE)
450            }
451        })
452    }
453
454    /// Spawn a local asynchronous task that outputs type 'T' on this dispatcher. The returned future's
455    /// result will be [`Ok`] if the task was started and completed successfully, or an [`Err`] if
456    /// the task couldn't be started or failed to complete (for example because the dispatcher was
457    /// shutting down).
458    ///
459    /// Returns a [`Task`] that will cancel the future when dropped.
460    ///
461    /// Unlike [`OnDispatcher::compute`], this will accept a future that does not implement [`Send`]. If
462    /// called from a thread other than the one the dispatcher is running on or the dispatcher
463    /// is not guaranteed to always poll from the same thread, this will return
464    /// [`Status::BAD_STATE`].
465    ///
466    /// TODO(470088116): This may be the cause of some flakes, so care should be used with it
467    /// in critical paths for now.
468    fn compute_local<T: Send + 'static>(
469        &self,
470        future: impl Future<Output = T> + 'static,
471    ) -> Result<Task<T>, Status>
472    where
473        Self: 'static,
474    {
475        self.on_maybe_dispatcher(|dispatcher| {
476            let dispatcher = DispatcherRef::from_async_dispatcher(dispatcher);
477            if dispatcher.0.is_current_dispatcher() && !dispatcher.0.allows_thread_migration() {
478                Ok(OnDispatcher::compute(self, AddSendFuture(future)))
479            } else {
480                Err(Status::BAD_STATE)
481            }
482        })
483    }
484}
485
486impl OnDriverDispatcher for Arc<Dispatcher> {}
487impl OnDriverDispatcher for Weak<Dispatcher> {}
488
489impl<'a> AsyncDispatcher for DispatcherRef<'a> {
490    fn as_async_dispatcher_ref(&self) -> AsyncDispatcherRef<'_> {
491        self.0.as_async_dispatcher_ref()
492    }
493}
494
495impl<'a> Clone for DispatcherRef<'a> {
496    fn clone(&self) -> Self {
497        Self(ManuallyDrop::new(Dispatcher(self.0.0)), PhantomData)
498    }
499}
500
501impl<'a> core::ops::Deref for DispatcherRef<'a> {
502    type Target = Dispatcher;
503    fn deref(&self) -> &Self::Target {
504        &self.0
505    }
506}
507
508impl<'a> core::ops::DerefMut for DispatcherRef<'a> {
509    fn deref_mut(&mut self) -> &mut Self::Target {
510        &mut self.0
511    }
512}
513
514impl<'a> OnDispatcher for DispatcherRef<'a> {
515    fn on_dispatcher<R>(&self, f: impl FnOnce(Option<AsyncDispatcherRef<'_>>) -> R) -> R {
516        f(Some(self.as_async_dispatcher_ref()))
517    }
518}
519
520impl<'a> OnDriverDispatcher for DispatcherRef<'a> {}
521
522impl OnDriverDispatcher for CurrentDispatcher {}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    use std::sync::{Arc, Once, Weak, mpsc};
529
530    use futures::channel::mpsc as async_mpsc;
531    use futures::{SinkExt, StreamExt};
532    use zx::sys::ZX_OK;
533
534    use core::ffi::{c_char, c_void};
535    use core::ptr::null_mut;
536
537    static GLOBAL_DRIVER_ENV: Once = Once::new();
538    const NO_SYNC_CALLS_ROLE: &str = "no sync calls role";
539
540    pub fn ensure_driver_env() {
541        GLOBAL_DRIVER_ENV.call_once(|| {
542            // SAFETY: calling fdf_env_start, which does not have any soundness
543            // concerns for rust code, and this is only used in tests.
544            unsafe {
545                assert_eq!(fdf_env_start(0), ZX_OK);
546                assert_eq!(
547                    fdf_env_set_scheduler_role_opts(
548                        NO_SYNC_CALLS_ROLE.as_ptr() as *const c_char,
549                        NO_SYNC_CALLS_ROLE.len(),
550                        FDF_SCHEDULER_ROLE_OPTION_NO_SYNC_CALLS
551                    ),
552                    ZX_OK
553                );
554            }
555        });
556    }
557    pub fn with_raw_dispatcher<T>(name: &str, p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T) -> T {
558        with_raw_dispatcher_flags(name, DispatcherBuilder::ALLOW_THREAD_BLOCKING, "", p)
559    }
560
561    pub(crate) fn with_raw_dispatcher_flags<T>(
562        name: &str,
563        flags: u32,
564        scheduler_role: &str,
565        p: impl for<'a> FnOnce(Weak<Dispatcher>) -> T,
566    ) -> T {
567        ensure_driver_env();
568
569        let (shutdown_tx, shutdown_rx) = mpsc::channel();
570        let mut dispatcher = null_mut();
571        let mut observer = ShutdownObserver::new(move |dispatcher| {
572            // SAFETY: we verify that the dispatcher has no tasks left queued in it,
573            // just because this is testing code.
574            assert!(!unsafe { fdf_env_dispatcher_has_queued_tasks(dispatcher.0.0.as_ptr()) });
575            shutdown_tx.send(()).unwrap();
576        })
577        .into_ptr();
578        let driver_ptr = &mut observer as *mut _ as *mut c_void;
579        // SAFETY: The pointers we pass to this function are all stable for the
580        // duration of this function, and are not available to copy or clone to
581        // client code (only through a ref to the non-`Clone`` `Dispatcher`
582        // wrapper).
583        let res = unsafe {
584            fdf_env_dispatcher_create_with_owner(
585                driver_ptr,
586                flags,
587                name.as_ptr() as *const c_char,
588                name.len(),
589                scheduler_role.as_ptr() as *const c_char,
590                scheduler_role.len(),
591                observer,
592                &mut dispatcher,
593            )
594        };
595        assert_eq!(res, ZX_OK);
596        let dispatcher = Arc::new(Dispatcher(NonNull::new(dispatcher).unwrap()));
597
598        let res = p(Arc::downgrade(&dispatcher));
599
600        // this initiates the dispatcher shutdown on a driver runtime
601        // thread. When all tasks on the dispatcher have completed, the wait
602        // on the shutdown_rx below will end and we can tear it down.
603        let weak_dispatcher = Arc::downgrade(&dispatcher);
604        drop(dispatcher);
605        shutdown_rx.recv().unwrap();
606        assert_eq!(
607            0,
608            weak_dispatcher.strong_count(),
609            "a dispatcher reference escaped the test body"
610        );
611
612        res
613    }
614
615    #[test]
616    fn start_test_dispatcher() {
617        with_raw_dispatcher("testing", |dispatcher| {
618            println!("hello {dispatcher:?}");
619        })
620    }
621
622    #[test]
623    fn post_task_on_dispatcher() {
624        with_raw_dispatcher("testing task", |dispatcher| {
625            let (tx, rx) = mpsc::channel();
626            let dispatcher = Weak::upgrade(&dispatcher).unwrap();
627            dispatcher
628                .post_task_sync(move |status| {
629                    assert_eq!(status, Status::from_raw(ZX_OK));
630                    tx.send(status).unwrap();
631                })
632                .unwrap();
633            assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
634        });
635    }
636
637    #[test]
638    fn post_task_on_subdispatcher() {
639        let (shutdown_tx, shutdown_rx) = mpsc::channel();
640        with_raw_dispatcher("testing task top level", move |dispatcher| {
641            let (tx, rx) = mpsc::channel();
642            let (inner_tx, inner_rx) = mpsc::channel();
643            let dispatcher = Weak::upgrade(&dispatcher).unwrap();
644            dispatcher
645                .post_task_sync(move |status| {
646                    assert_eq!(status, Status::from_raw(ZX_OK));
647                    let inner = DispatcherBuilder::new()
648                        .name("testing task second level")
649                        .scheduler_role("")
650                        .allow_thread_blocking()
651                        .shutdown_observer(move |_dispatcher| {
652                            println!("shutdown observer called");
653                            shutdown_tx.send(1).unwrap();
654                        })
655                        .create()
656                        .unwrap();
657                    inner
658                        .post_task_sync(move |status| {
659                            assert_eq!(status, Status::from_raw(ZX_OK));
660                            tx.send(status).unwrap();
661                        })
662                        .unwrap();
663                    // we want to make sure the inner dispatcher lives long
664                    // enough to run the task, so we sent it out to the outer
665                    // closure.
666                    inner_tx.send(inner).unwrap();
667                })
668                .unwrap();
669            assert_eq!(rx.recv().unwrap(), Status::from_raw(ZX_OK));
670            inner_rx.recv().unwrap();
671        });
672        assert_eq!(shutdown_rx.recv().unwrap(), 1);
673    }
674
675    #[test]
676    fn spawn_local_fails_on_normal_dispatcher() {
677        let (shutdown_tx, shutdown_rx) = mpsc::channel();
678        with_raw_dispatcher("spawn local failures", move |dispatcher| {
679            let inside_dispatcher = dispatcher.clone();
680            dispatcher
681                .spawn(async move {
682                    assert_eq!(
683                        inside_dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
684                        Status::BAD_STATE
685                    );
686                    assert_eq!(
687                        inside_dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
688                        Status::BAD_STATE
689                    );
690                    shutdown_tx.send(()).unwrap();
691                })
692                .unwrap();
693            shutdown_rx.recv().unwrap();
694        });
695    }
696
697    #[test]
698    #[ignore = "Pending resolution of b/488397193"]
699    fn spawn_local_succeeds_on_no_thread_migration_dispatcher() {
700        let (tx, rx) = mpsc::channel();
701        with_raw_dispatcher_flags(
702            "spawn local success",
703            FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
704            NO_SYNC_CALLS_ROLE,
705            move |dispatcher| {
706                let inside_dispatcher = dispatcher.clone();
707                dispatcher
708                    .spawn(async move {
709                        let tx_clone = tx.clone();
710                        inside_dispatcher
711                            .spawn_local(async move {
712                                tx_clone.send(()).unwrap();
713                            })
714                            .unwrap();
715                        inside_dispatcher
716                            .compute_local(async move {
717                                tx.send(()).unwrap();
718                            })
719                            .unwrap()
720                            .await
721                            .unwrap();
722                    })
723                    .unwrap();
724                // one empty object received each for spawn and compute _local.
725                rx.recv().unwrap();
726                rx.recv().unwrap();
727            },
728        );
729    }
730
731    #[test]
732    #[ignore = "Pending resolution of b/488397193"]
733    fn spawn_local_fails_on_no_thread_migration_dispatcher_from_different_thread() {
734        with_raw_dispatcher_flags(
735            "spawn local success",
736            FDF_DISPATCHER_OPTION_NO_THREAD_MIGRATION,
737            NO_SYNC_CALLS_ROLE,
738            move |dispatcher| {
739                // we are not currently running in any dispatcher here, so this is a context
740                // where the 'current dispatcher' is definitely not the one in question.
741                assert_eq!(
742                    dispatcher.spawn_local(futures::future::ready(())).unwrap_err(),
743                    Status::BAD_STATE
744                );
745                assert_eq!(
746                    dispatcher.compute_local(futures::future::ready(())).unwrap_err(),
747                    Status::BAD_STATE
748                );
749            },
750        );
751    }
752
753    async fn ping(mut tx: async_mpsc::Sender<u8>, mut rx: async_mpsc::Receiver<u8>) {
754        println!("starting ping!");
755        tx.send(0).await.unwrap();
756        while let Some(next) = rx.next().await {
757            println!("ping! {next}");
758            tx.send(next + 1).await.unwrap();
759        }
760    }
761
762    async fn pong(
763        fin_tx: std::sync::mpsc::Sender<()>,
764        mut tx: async_mpsc::Sender<u8>,
765        mut rx: async_mpsc::Receiver<u8>,
766    ) {
767        println!("starting pong!");
768        while let Some(next) = rx.next().await {
769            println!("pong! {next}");
770            if next > 10 {
771                println!("bye!");
772                break;
773            }
774            tx.send(next + 1).await.unwrap();
775        }
776        fin_tx.send(()).unwrap();
777    }
778
779    #[test]
780    fn async_ping_pong() {
781        with_raw_dispatcher("async ping pong", |dispatcher| {
782            let (fin_tx, fin_rx) = mpsc::channel();
783            let (ping_tx, pong_rx) = async_mpsc::channel(10);
784            let (pong_tx, ping_rx) = async_mpsc::channel(10);
785            dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
786            dispatcher.spawn(pong(fin_tx, pong_tx, pong_rx)).unwrap();
787
788            fin_rx.recv().expect("to receive final value");
789        });
790    }
791
792    async fn slow_pong(
793        fin_tx: std::sync::mpsc::Sender<()>,
794        mut tx: async_mpsc::Sender<u8>,
795        mut rx: async_mpsc::Receiver<u8>,
796    ) {
797        use zx::MonotonicDuration;
798        println!("starting pong!");
799        while let Some(next) = rx.next().await {
800            println!("pong! {next}");
801            fuchsia_async::Timer::new(fuchsia_async::MonotonicInstant::after(
802                MonotonicDuration::from_seconds(1),
803            ))
804            .await;
805            if next > 10 {
806                println!("bye!");
807                break;
808            }
809            tx.send(next + 1).await.unwrap();
810        }
811        fin_tx.send(()).unwrap();
812    }
813
814    #[test]
815    fn mixed_executor_async_ping_pong() {
816        with_raw_dispatcher("async ping pong", |dispatcher| {
817            let (fin_tx, fin_rx) = mpsc::channel();
818            let (ping_tx, pong_rx) = async_mpsc::channel(10);
819            let (pong_tx, ping_rx) = async_mpsc::channel(10);
820
821            // spawn ping on the driver dispatcher
822            dispatcher.spawn(ping(ping_tx, ping_rx)).unwrap();
823
824            // and run pong on the fuchsia_async executor
825            let mut executor = fuchsia_async::LocalExecutor::default();
826            executor.run_singlethreaded(slow_pong(fin_tx, pong_tx, pong_rx));
827
828            fin_rx.recv().expect("to receive final value");
829        });
830    }
831}