fdf_channel/
futures.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::task::Waker;
9use zx::Status;
10
11use crate::channel::{Channel, try_read_raw};
12use crate::message::Message;
13use fdf_core::dispatcher::OnDispatcher;
14use fdf_core::handle::DriverHandle;
15use fdf_sys::*;
16
17use core::mem::MaybeUninit;
18use core::task::{Context, Poll};
19use std::sync::{Arc, Mutex};
20
21pub use fdf_sys::fdf_handle_t;
22
23// state for a read message that is controlled by a lock
24#[derive(Default, Debug)]
25struct ReadMessageStateOpLocked {
26    /// the currently active waker for this read operation. Only set if there
27    /// is currently a pending read operation awaiting a callback.
28    waker: Option<Waker>,
29    /// if the channel was dropped while a pending callback was active, so the
30    /// callback should close the driverhandle when it fires.
31    channel_dropped: bool,
32    /// whether cancelation of this future will happen asynchronously through
33    /// the callback or immediately when [`fdf_channel_cancel_wait`] is called.
34    /// This is used to decide what's responsible for freeing the reference
35    /// to this object when the future is canceled.
36    cancelation_is_async: bool,
37}
38
39/// This struct is shared between the future and the driver runtime, with the first field
40/// being managed by the driver runtime and the second by the future. It will be held by two
41/// [`Arc`]s, one for each of the future and the runtime.
42///
43/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
44/// normal [`Drop`] of the future.
45///
46/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
47/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
48/// called *up to* one time.
49///
50/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
51/// read wait.
52#[repr(C)]
53#[derive(Debug)]
54pub(crate) struct ReadMessageStateOp {
55    /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
56    read_op: fdf_channel_read,
57    state: Mutex<ReadMessageStateOpLocked>,
58}
59
60impl ReadMessageStateOp {
61    unsafe extern "C" fn handler(
62        _dispatcher: *mut fdf_dispatcher,
63        read_op: *mut fdf_channel_read,
64        _status: i32,
65    ) {
66        // Note: we don't really do anything different based on whether the callback
67        // says canceled. If the future was canceled by being dropped, it won't poll
68        // again since it was dropped.
69        // The only unusual case is when the dispatcher is shutting down, and in that
70        // case we will wake the future and it will try to read and get a more useful
71        // error.
72        // Meanwhile, since we use the same state object across multiple
73        // futures due to needing to handle async cancelation, trying to track the
74        // underlying reason for the cancelation becomes more tricky than it's worth.
75
76        // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
77        // for this handler to reconstitute it.
78        let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
79
80        let mut state = op.state.lock().unwrap();
81        if state.channel_dropped {
82            // SAFETY: since the channel dropped we are the only outstanding owner of the
83            // channel object.
84            unsafe { fdf_handle_close(op.read_op.channel) };
85        }
86        let Some(waker) = state.waker.take() else {
87            // the waker was already taken, presumably because the future was dropped.
88            return;
89        };
90        // make sure to drop the lock before calling the waker.
91        drop(state);
92        waker.wake()
93    }
94
95    /// Called by the channel on drop to indicate that the channel has been dropped and
96    /// find out whether it needs to defer dropping the handle until the callback is called.
97    pub fn set_channel_dropped(&self) -> bool {
98        let mut state = self.state.lock().unwrap();
99        if state.waker.is_some() {
100            state.channel_dropped = true;
101            false
102        } else {
103            true
104        }
105    }
106}
107
108/// An object for managing the state of an async channel read message operation that can be used to
109/// implement futures.
110pub struct ReadMessageState {
111    op: Arc<ReadMessageStateOp>,
112    channel: ManuallyDrop<DriverHandle>,
113}
114
115impl ReadMessageState {
116    /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
117    /// data from a channel and then converts it to the appropriate type. It also allows for
118    /// different ways of storing and managing the dispatcher we wait on by deferring the
119    /// dispatcher used to poll time. This state is registered with the given [`Channel`]
120    /// so that dropping the channel will correctly free resources.
121    ///
122    /// # Safety
123    ///
124    /// The caller is responsible for ensuring that the handle inside `channel` outlives this
125    /// object.
126    pub unsafe fn register_read_wait<T: ?Sized>(channel: &mut Channel<T>) -> Self {
127        // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
128        // and that the handle will outlive the created [`ReadMessageState`].
129        let channel_handle = unsafe { channel.handle.get_raw() };
130        let op = channel
131            .wait_state
132            .get_or_insert_with(|| {
133                Arc::new(ReadMessageStateOp {
134                    read_op: fdf_channel_read {
135                        channel: channel_handle.get(),
136                        handler: Some(ReadMessageStateOp::handler),
137                        ..Default::default()
138                    },
139                    state: Mutex::new(ReadMessageStateOpLocked::default()),
140                })
141            })
142            .clone();
143        Self {
144            op,
145            // SAFETY: We know this is a valid driver handle by construction and we are
146            // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
147            // The caller is responsible for ensuring that the handle outlives this object.
148            channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel_handle) }),
149        }
150    }
151
152    /// Polls this channel read operation against the given dispatcher.
153    #[expect(clippy::type_complexity)]
154    pub fn poll_with_dispatcher<D: OnDispatcher>(
155        &mut self,
156        cx: &mut Context<'_>,
157        dispatcher: D,
158    ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
159        let mut state = self.op.state.lock().unwrap();
160
161        match try_read_raw(&self.channel) {
162            Ok(res) => Poll::Ready(Ok(res)),
163            Err(Status::SHOULD_WAIT) => {
164                // if we haven't yet set a waker, that means we haven't started the wait operation
165                // yet.
166                if state.waker.is_none() {
167                    // increment the reference count of the read op to account for the copy that will be given to
168                    // `fdf_channel_wait_async`.
169                    let op = Arc::into_raw(self.op.clone());
170                    let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
171                        // if we're not running on the same dispatcher as we're waiting from, we
172                        // want to force async cancellation
173                        let options = if !dispatcher.is_current_dispatcher() {
174                            FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
175                        } else {
176                            0
177                        };
178                        // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
179                        // has `repr(C)` layout, so is safe to be cast to the latter.
180                        let res = Status::ok(unsafe {
181                            fdf_channel_wait_async(
182                                dispatcher.inner().as_ptr(),
183                                op.cast_mut().cast(),
184                                options,
185                            )
186                        });
187                        if res.is_ok() {
188                            // only replace the waker if we succeeded, so we'll try again next time
189                            // otherwise.
190                            state.waker.replace(cx.waker().clone());
191                        } else {
192                            // reconstitute the arc we made for the callback so it can be dropped
193                            // since the async wait didn't succeed.
194                            drop(unsafe { Arc::from_raw(op) });
195                        }
196                        // if the dispatcher we're waiting on is unsynchronized, the callback
197                        // will drop the Arc and we need to indicate to our own Drop impl
198                        // that it should not.
199                        res.map(|_| {
200                            options == FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
201                                || dispatcher.is_unsynchronized()
202                        })
203                    });
204
205                    // the default state should be that `drop` will free the arc.
206                    state.cancelation_is_async = false;
207                    match res {
208                        Err(Status::BAD_STATE) => {
209                            return Poll::Pending; // a pending await is being cancelled
210                        }
211                        Ok(cancelation_is_async) => {
212                            state.cancelation_is_async = cancelation_is_async;
213                        }
214                        Err(e) => return Poll::Ready(Err(e)),
215                    }
216                }
217                Poll::Pending
218            }
219            Err(e) => Poll::Ready(Err(e)),
220        }
221    }
222}
223
224impl Drop for ReadMessageState {
225    fn drop(&mut self) {
226        let mut state = self.op.state.lock().unwrap();
227        if state.waker.is_none() {
228            // if there's no waker either the callback has already fired or we never waited on this
229            // future in the first place, so just leave it be.
230            return;
231        }
232
233        // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
234        // be valid.
235        let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
236        match res {
237            Ok(_) => {}
238            Err(Status::NOT_FOUND) => {
239                // the callback is already being called or the wait was already cancelled, so just
240                // return and leave it.
241                return;
242            }
243            Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
244        }
245        // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
246        // successful, the callback will not be called and we will have to free the `Arc` that the
247        // callback would have consumed.
248        if !state.cancelation_is_async {
249            // steal the waker so it doesn't get called, if there is one.
250            state.waker.take();
251            unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
252        }
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use std::pin::pin;
259    use std::sync::Weak;
260
261    use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
262    use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
263
264    use crate::arena::Arena;
265    use crate::channel::{Channel, read_raw};
266
267    use super::*;
268
269    /// assert that the strong count of an arc is correct
270    #[track_caller]
271    fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
272        assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
273    }
274
275    /// create, poll, and then immediately drop a read future for a channel and verify
276    /// that the internal op arc has the right refcount at all steps. Returns a copy
277    /// of the op arc at the end so it can be verified that the count goes down
278    /// to zero correctly.
279    async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher>(
280        channel: &mut Channel<T>,
281        dispatcher: D,
282    ) -> Weak<ReadMessageStateOp> {
283        let fut = unsafe { read_raw(channel, dispatcher) };
284        let op_arc = Arc::downgrade(&fut.raw_fut.op);
285        assert_strong_count(&op_arc, 2);
286        let mut fut = pin!(fut);
287        let Poll::Pending = futures::poll!(fut.as_mut()) else {
288            panic!("expected pending state after polling channel read once");
289        };
290        assert_strong_count(&op_arc, 3);
291        op_arc
292    }
293
294    #[test]
295    fn early_cancel_future() {
296        spawn_in_driver("early cancellation", async {
297            let (mut a, b) = Channel::create();
298
299            // create, poll, and then immediately drop a read future for channel `a`
300            // so that it properly sets up the wait.
301            read_and_drop(&mut a, CurrentDispatcher).await;
302            b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
303            assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
304        })
305    }
306
307    #[test]
308    fn very_early_cancel_state_drops_correctly() {
309        spawn_in_driver("early cancellation drop correctness", async {
310            let (mut a, _b) = Channel::<[u8]>::create();
311
312            // drop before even polling it should drop the arc correctly
313            let fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
314            let op_arc = Arc::downgrade(&fut.raw_fut.op);
315            assert_strong_count(&op_arc, 2);
316            drop(fut);
317            assert_strong_count(&op_arc, 1);
318        })
319    }
320
321    #[test]
322    fn synchronized_early_cancel_state_drops_correctly() {
323        spawn_in_driver("early cancellation drop correctness", async {
324            let (mut a, _b) = Channel::<[u8]>::create();
325
326            assert_strong_count(&read_and_drop(&mut a, CurrentDispatcher).await, 1);
327        });
328    }
329
330    #[test]
331    fn unsynchronized_early_cancel_state_drops_correctly() {
332        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
333        // be closed before the read wait has been cancelled.
334        let (mut a, _b) = Channel::<[u8]>::create();
335        let unsync_op =
336            spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
337                // We send the arc out to be checked after the dispatcher has shut down so
338                // that we can be sure that the callback has had a chance to be called.
339                // We send the channel back out so that it lives long enough for the
340                // cancellation to be called on it.
341                read_and_drop(&mut a, CurrentDispatcher).await
342            });
343
344        // check that there are no more owners of the inner op for the unsynchronized dispatcher.
345        assert_strong_count(&unsync_op, 0);
346    }
347
348    #[test]
349    fn unsynchronized_early_cancel_state_drops_repeatedly_correctly() {
350        // the channel needs to outlive the dispatcher for this test because the channel shouldn't
351        // be closed before the read wait has been cancelled.
352        let (mut a, _b) = Channel::<[u8]>::create();
353        spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
354            for _ in 0..10000 {
355                let mut fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
356                let Poll::Pending = futures::poll!(&mut fut) else {
357                    panic!("expected pending state after polling channel read once");
358                };
359                drop(fut);
360            }
361        });
362    }
363}