fdf_channel/
futures.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::sync::atomic::AtomicBool;
9use std::task::Waker;
10use zx::Status;
11
12use crate::channel::try_read_raw;
13use crate::message::Message;
14use fdf_core::dispatcher::OnDispatcher;
15use fdf_core::handle::DriverHandle;
16use fdf_sys::*;
17
18use core::mem::MaybeUninit;
19use core::task::{Context, Poll};
20use std::sync::{Arc, Mutex};
21
22pub use fdf_sys::fdf_handle_t;
23
24/// This struct is shared between the future and the driver runtime, with the first field
25/// being managed by the driver runtime and the second by the future. It will be held by two
26/// [`Arc`]s, one for each of the future and the runtime.
27///
28/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
29/// normal [`Drop`] of the future.
30///
31/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
32/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
33/// called *up to* one time.
34///
35/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
36/// read wait.
37#[repr(C)]
38struct ReadMessageStateOp {
39    /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
40    read_op: fdf_channel_read,
41    waker: Mutex<Option<Waker>>,
42    cancelled: AtomicBool,
43}
44
45impl ReadMessageStateOp {
46    unsafe extern "C" fn handler(
47        _dispatcher: *mut fdf_dispatcher,
48        read_op: *mut fdf_channel_read,
49        status: i32,
50    ) {
51        // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
52        // for this handler to reconstitute it.
53        let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
54        if Status::from_raw(status) == Status::CANCELED {
55            op.cancelled.store(true, std::sync::atomic::Ordering::Release);
56        }
57
58        let Some(waker) = op.waker.lock().unwrap().take() else {
59            // the waker was already taken, presumably because the future was dropped.
60            return;
61        };
62        waker.wake()
63    }
64}
65
66/// An object for managing the state of an async channel read message operation that can be used to
67/// implement futures.
68pub struct ReadMessageState {
69    op: Arc<ReadMessageStateOp>,
70    channel: ManuallyDrop<DriverHandle>,
71    callback_drops_arc: bool,
72}
73
74impl ReadMessageState {
75    /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
76    /// data from a channel and then converts it to the appropriate type. It also allows for
77    /// different ways of storing and managing the dispatcher we wait on by deferring the
78    /// dispatcher used to poll time.
79    ///
80    /// # Safety
81    ///
82    /// The caller is responsible for ensuring that `channel` outlives this object.
83    pub unsafe fn new(channel: &DriverHandle) -> Self {
84        // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
85        // and that the handle will outlive the created [`ReadMessageState`].
86        let channel = unsafe { channel.get_raw() };
87        Self {
88            op: Arc::new(ReadMessageStateOp {
89                read_op: fdf_channel_read {
90                    channel: channel.get(),
91                    handler: Some(ReadMessageStateOp::handler),
92                    ..Default::default()
93                },
94                waker: Mutex::new(None),
95                cancelled: AtomicBool::new(false),
96            }),
97            // SAFETY: We know this is a valid driver handle by construction and we are
98            // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
99            // The caller is responsible for ensuring that the handle outlives this object.
100            channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel) }),
101            // We haven't waited on it yet so we are responsible for dropping the arc for now,
102            // regardless of what kind of dispatcher it's intended to be used with.
103            callback_drops_arc: false,
104        }
105    }
106
107    /// Polls this channel read operation against the given dispatcher.
108    pub fn poll_with_dispatcher<D: OnDispatcher>(
109        self: &mut Self,
110        cx: &mut Context<'_>,
111        dispatcher: D,
112    ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
113        let mut waker_lock = self.op.waker.lock().unwrap();
114
115        if self.op.cancelled.load(std::sync::atomic::Ordering::Relaxed) {
116            // if the dispatcher we were waiting on is shutting down then when we try to go
117            // to wait again we'll get ZX_ERR_UNAVAILABLE anyways, so just short circuit that and
118            // return it right away.
119            return Poll::Ready(Err(Status::UNAVAILABLE));
120        }
121
122        match try_read_raw(&self.channel) {
123            Ok(res) => Poll::Ready(Ok(res)),
124            Err(Status::SHOULD_WAIT) => {
125                // if we haven't yet set a waker, that means we haven't started the wait operation
126                // yet.
127                if waker_lock.replace(cx.waker().clone()).is_none() {
128                    // increment the reference count of the read op to account for the copy that will be given to
129                    // `fdf_channel_wait_async`.
130                    let op = Arc::into_raw(self.op.clone());
131                    let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
132                        // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
133                        // has `repr(C)` layout, so is safe to be cast to the latter.
134                        let res = Status::ok(unsafe {
135                            fdf_channel_wait_async(
136                                dispatcher.inner().as_ptr(),
137                                op.cast_mut().cast(),
138                                0,
139                            )
140                        });
141                        // if the dispatcher we're waiting on is unsynchronized, the callback
142                        // will drop the Arc and we need to indicate to our own Drop impl
143                        // that it should not.
144                        let callback_drops_arc = res.is_ok() && dispatcher.is_unsynchronized();
145                        Ok(callback_drops_arc)
146                    });
147
148                    match res {
149                        Ok(callback_drops_arc) => {
150                            self.callback_drops_arc = callback_drops_arc;
151                        }
152                        Err(e) => return Poll::Ready(Err(e)),
153                    }
154                }
155                Poll::Pending
156            }
157            Err(e) => Poll::Ready(Err(e)),
158        }
159    }
160}
161
162impl Drop for ReadMessageState {
163    fn drop(&mut self) {
164        let mut waker_lock = self.op.waker.lock().unwrap();
165        if waker_lock.is_none() {
166            // if there's no waker either the callback has already fired or we never waited on this
167            // future in the first place, so just leave it be.
168            return;
169        }
170
171        // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
172        // be valid.
173        let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
174        match res {
175            Ok(_) => {}
176            Err(Status::NOT_FOUND) => {
177                // the callback is already being called or the wait was already cancelled, so just
178                // return and leave it.
179                return;
180            }
181            Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
182        }
183        // steal the waker so it doesn't get called, if there is one.
184        waker_lock.take();
185        // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
186        // successful, the callback will not be called and we will have to free the `Arc` that the
187        // callback would have consumed.
188        if !self.callback_drops_arc {
189            unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
190        }
191    }
192}
193
194#[cfg(test)]
195mod test {
196    use std::pin::pin;
197    use std::sync::Weak;
198
199    use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
200    use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
201
202    use crate::arena::Arena;
203    use crate::channel::{read_raw, Channel};
204
205    use super::*;
206
207    /// assert that the strong count of an arc is correct
208    fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
209        assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
210    }
211
212    /// create, poll, and then immediately drop a read future for a channel and verify
213    /// that the internal op arc has the right refcount at all steps. Returns a copy
214    /// of the op arc at the end so it can be verified that the count goes down
215    /// to zero correctly.
216    async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher>(
217        channel: &Channel<T>,
218        dispatcher: D,
219    ) -> Weak<ReadMessageStateOp> {
220        let fut = read_raw(&channel.0, dispatcher);
221        let op_arc = Arc::downgrade(&fut.raw_fut.op);
222        assert_strong_count(&op_arc, 1);
223        let mut fut = pin!(fut);
224        let Poll::Pending = futures::poll!(fut.as_mut()) else {
225            panic!("expected pending state after polling channel read once");
226        };
227        assert_strong_count(&op_arc, 2);
228        op_arc
229    }
230
231    #[test]
232    fn early_cancel_future() {
233        spawn_in_driver("early cancellation", async {
234            let (a, b) = Channel::create();
235
236            // create, poll, and then immediately drop a read future for channel `a`
237            // so that it properly sets up the wait.
238            read_and_drop(&a, CurrentDispatcher).await;
239            b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
240            assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
241        })
242    }
243
244    #[test]
245    fn very_early_cancel_state_drops_correctly() {
246        spawn_in_driver("early cancellation drop correctness", async {
247            let (a, _b) = Channel::<[u8]>::create();
248
249            // drop before even polling it should drop the arc correctly
250            let fut = read_raw(&a.0, CurrentDispatcher);
251            let op_arc = Arc::downgrade(&fut.raw_fut.op);
252            assert_strong_count(&op_arc, 1);
253            drop(fut);
254            assert_strong_count(&op_arc, 0);
255        })
256    }
257
258    #[test]
259    fn synchronized_early_cancel_state_drops_correctly() {
260        spawn_in_driver("early cancellation drop correctness", async {
261            let (a, _b) = Channel::<[u8]>::create();
262
263            assert_strong_count(&read_and_drop(&a, CurrentDispatcher).await, 0);
264        });
265    }
266
267    #[test]
268    fn unsynchronized_early_cancel_state_drops_correctly() {
269        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
270        // be closed before the read wait has been cancelled.
271        let (a, _b) = Channel::<[u8]>::create();
272        let (unsync_op, _a) =
273            spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
274                // We send the arc out to be checked after the dispatcher has shut down so
275                // that we can be sure that the callback has had a chance to be called.
276                // We send the channel back out so that it lives long enough for the
277                // cancellation to be called on it.
278                let res = read_and_drop(&a, CurrentDispatcher).await;
279                (res, a)
280            });
281
282        // check that there are no more owners of the inner op for the unsynchronized dispatcher.
283        assert_strong_count(&unsync_op, 0);
284    }
285}