fdf_channel/futures.rs
1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::task::Waker;
9use zx::Status;
10
11use crate::channel::{Channel, try_read_raw};
12use crate::message::Message;
13use fdf_core::dispatcher::OnDispatcher;
14use fdf_core::handle::DriverHandle;
15use fdf_sys::*;
16
17use core::mem::MaybeUninit;
18use core::task::{Context, Poll};
19use std::sync::{Arc, Mutex};
20
21pub use fdf_sys::fdf_handle_t;
22
23// state for a read message that is controlled by a lock
24#[derive(Default, Debug)]
25struct ReadMessageStateOpLocked {
26 /// the currently active waker for this read operation. Only set if there
27 /// is currently a pending read operation awaiting a callback.
28 waker: Option<Waker>,
29 /// if the channel was dropped while a pending callback was active, so the
30 /// callback should close the driverhandle when it fires.
31 channel_dropped: bool,
32 /// whether cancelation of this future will happen asynchronously through
33 /// the callback or immediately when [`fdf_channel_cancel_wait`] is called.
34 /// This is used to decide what's responsible for freeing the reference
35 /// to this object when the future is canceled.
36 cancelation_is_async: bool,
37}
38
39/// This struct is shared between the future and the driver runtime, with the first field
40/// being managed by the driver runtime and the second by the future. It will be held by two
41/// [`Arc`]s, one for each of the future and the runtime.
42///
43/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
44/// normal [`Drop`] of the future.
45///
46/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
47/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
48/// called *up to* one time.
49///
50/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
51/// read wait.
52#[repr(C)]
53#[derive(Debug)]
54pub(crate) struct ReadMessageStateOp {
55 /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
56 read_op: fdf_channel_read,
57 state: Mutex<ReadMessageStateOpLocked>,
58}
59
60impl ReadMessageStateOp {
61 unsafe extern "C" fn handler(
62 _dispatcher: *mut fdf_dispatcher,
63 read_op: *mut fdf_channel_read,
64 _status: i32,
65 ) {
66 // Note: we don't really do anything different based on whether the callback
67 // says canceled. If the future was canceled by being dropped, it won't poll
68 // again since it was dropped.
69 // The only unusual case is when the dispatcher is shutting down, and in that
70 // case we will wake the future and it will try to read and get a more useful
71 // error.
72 // Meanwhile, since we use the same state object across multiple
73 // futures due to needing to handle async cancelation, trying to track the
74 // underlying reason for the cancelation becomes more tricky than it's worth.
75
76 // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
77 // for this handler to reconstitute it.
78 let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
79
80 let mut state = op.state.lock().unwrap();
81 if state.channel_dropped {
82 // SAFETY: since the channel dropped we are the only outstanding owner of the
83 // channel object.
84 unsafe { fdf_handle_close(op.read_op.channel) };
85 }
86 let Some(waker) = state.waker.take() else {
87 // the waker was already taken, presumably because the future was dropped.
88 return;
89 };
90 // make sure to drop the lock before calling the waker.
91 drop(state);
92 waker.wake()
93 }
94
95 /// Called by the channel on drop to indicate that the channel has been dropped and
96 /// find out whether it needs to defer dropping the handle until the callback is called.
97 pub fn set_channel_dropped(&self) -> bool {
98 let mut state = self.state.lock().unwrap();
99 if state.waker.is_some() {
100 state.channel_dropped = true;
101 false
102 } else {
103 true
104 }
105 }
106}
107
108/// An object for managing the state of an async channel read message operation that can be used to
109/// implement futures.
110pub struct ReadMessageState {
111 op: Arc<ReadMessageStateOp>,
112 channel: ManuallyDrop<DriverHandle>,
113}
114
115impl ReadMessageState {
116 /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
117 /// data from a channel and then converts it to the appropriate type. It also allows for
118 /// different ways of storing and managing the dispatcher we wait on by deferring the
119 /// dispatcher used to poll time. This state is registered with the given [`Channel`]
120 /// so that dropping the channel will correctly free resources.
121 ///
122 /// # Safety
123 ///
124 /// The caller is responsible for ensuring that the handle inside `channel` outlives this
125 /// object.
126 pub unsafe fn register_read_wait<T: ?Sized>(channel: &mut Channel<T>) -> Self {
127 // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
128 // and that the handle will outlive the created [`ReadMessageState`].
129 let channel_handle = unsafe { channel.handle.get_raw() };
130 let op = channel
131 .wait_state
132 .get_or_insert_with(|| {
133 Arc::new(ReadMessageStateOp {
134 read_op: fdf_channel_read {
135 channel: channel_handle.get(),
136 handler: Some(ReadMessageStateOp::handler),
137 ..Default::default()
138 },
139 state: Mutex::new(ReadMessageStateOpLocked::default()),
140 })
141 })
142 .clone();
143 Self {
144 op,
145 // SAFETY: We know this is a valid driver handle by construction and we are
146 // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
147 // The caller is responsible for ensuring that the handle outlives this object.
148 channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel_handle) }),
149 }
150 }
151
152 /// Polls this channel read operation against the given dispatcher.
153 #[expect(clippy::type_complexity)]
154 pub fn poll_with_dispatcher<D: OnDispatcher>(
155 &mut self,
156 cx: &mut Context<'_>,
157 dispatcher: D,
158 ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
159 let mut state = self.op.state.lock().unwrap();
160
161 match try_read_raw(&self.channel) {
162 Ok(res) => Poll::Ready(Ok(res)),
163 Err(Status::SHOULD_WAIT) => {
164 // if we haven't yet set a waker, that means we haven't started the wait operation
165 // yet.
166 if state.waker.is_none() {
167 // increment the reference count of the read op to account for the copy that will be given to
168 // `fdf_channel_wait_async`.
169 let op = Arc::into_raw(self.op.clone());
170 let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
171 // if we're not running on the same dispatcher as we're waiting from, we
172 // want to force async cancellation
173 let options = if !dispatcher.is_current_dispatcher() {
174 FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
175 } else {
176 0
177 };
178 // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
179 // has `repr(C)` layout, so is safe to be cast to the latter.
180 let res = Status::ok(unsafe {
181 fdf_channel_wait_async(
182 dispatcher.inner().as_ptr(),
183 op.cast_mut().cast(),
184 options,
185 )
186 });
187 if res.is_ok() {
188 // only replace the waker if we succeeded, so we'll try again next time
189 // otherwise.
190 state.waker.replace(cx.waker().clone());
191 } else {
192 // reconstitute the arc we made for the callback so it can be dropped
193 // since the async wait didn't succeed.
194 drop(unsafe { Arc::from_raw(op) });
195 }
196 // if the dispatcher we're waiting on is unsynchronized, the callback
197 // will drop the Arc and we need to indicate to our own Drop impl
198 // that it should not.
199 res.map(|_| {
200 options == FDF_CHANNEL_WAIT_OPTION_FORCE_ASYNC_CANCEL
201 || dispatcher.is_unsynchronized()
202 })
203 });
204
205 // the default state should be that `drop` will free the arc.
206 state.cancelation_is_async = false;
207 match res {
208 Err(Status::BAD_STATE) => {
209 return Poll::Pending; // a pending await is being cancelled
210 }
211 Ok(cancelation_is_async) => {
212 state.cancelation_is_async = cancelation_is_async;
213 }
214 Err(e) => return Poll::Ready(Err(e)),
215 }
216 }
217 Poll::Pending
218 }
219 Err(e) => Poll::Ready(Err(e)),
220 }
221 }
222}
223
224impl Drop for ReadMessageState {
225 fn drop(&mut self) {
226 let mut state = self.op.state.lock().unwrap();
227 if state.waker.is_none() {
228 // if there's no waker either the callback has already fired or we never waited on this
229 // future in the first place, so just leave it be.
230 return;
231 }
232
233 // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
234 // be valid.
235 let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
236 match res {
237 Ok(_) => {}
238 Err(Status::NOT_FOUND) => {
239 // the callback is already being called or the wait was already cancelled, so just
240 // return and leave it.
241 return;
242 }
243 Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
244 }
245 // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
246 // successful, the callback will not be called and we will have to free the `Arc` that the
247 // callback would have consumed.
248 if !state.cancelation_is_async {
249 // steal the waker so it doesn't get called, if there is one.
250 state.waker.take();
251 unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
252 }
253 }
254}
255
256#[cfg(test)]
257mod test {
258 use std::pin::pin;
259 use std::sync::Weak;
260
261 use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
262 use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
263
264 use crate::arena::Arena;
265 use crate::channel::{Channel, read_raw};
266
267 use super::*;
268
269 /// assert that the strong count of an arc is correct
270 #[track_caller]
271 fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
272 assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
273 }
274
275 /// create, poll, and then immediately drop a read future for a channel and verify
276 /// that the internal op arc has the right refcount at all steps. Returns a copy
277 /// of the op arc at the end so it can be verified that the count goes down
278 /// to zero correctly.
279 async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher>(
280 channel: &mut Channel<T>,
281 dispatcher: D,
282 ) -> Weak<ReadMessageStateOp> {
283 let fut = unsafe { read_raw(channel, dispatcher) };
284 let op_arc = Arc::downgrade(&fut.raw_fut.op);
285 assert_strong_count(&op_arc, 2);
286 let mut fut = pin!(fut);
287 let Poll::Pending = futures::poll!(fut.as_mut()) else {
288 panic!("expected pending state after polling channel read once");
289 };
290 assert_strong_count(&op_arc, 3);
291 op_arc
292 }
293
294 #[test]
295 fn early_cancel_future() {
296 spawn_in_driver("early cancellation", async {
297 let (mut a, b) = Channel::create();
298
299 // create, poll, and then immediately drop a read future for channel `a`
300 // so that it properly sets up the wait.
301 read_and_drop(&mut a, CurrentDispatcher).await;
302 b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
303 assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
304 })
305 }
306
307 #[test]
308 fn very_early_cancel_state_drops_correctly() {
309 spawn_in_driver("early cancellation drop correctness", async {
310 let (mut a, _b) = Channel::<[u8]>::create();
311
312 // drop before even polling it should drop the arc correctly
313 let fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
314 let op_arc = Arc::downgrade(&fut.raw_fut.op);
315 assert_strong_count(&op_arc, 2);
316 drop(fut);
317 assert_strong_count(&op_arc, 1);
318 })
319 }
320
321 #[test]
322 fn synchronized_early_cancel_state_drops_correctly() {
323 spawn_in_driver("early cancellation drop correctness", async {
324 let (mut a, _b) = Channel::<[u8]>::create();
325
326 assert_strong_count(&read_and_drop(&mut a, CurrentDispatcher).await, 1);
327 });
328 }
329
330 #[test]
331 fn unsynchronized_early_cancel_state_drops_correctly() {
332 // the channel needs to outlive the dispatcher for this test because the channel shouldn't
333 // be closed before the read wait has been cancelled.
334 let (mut a, _b) = Channel::<[u8]>::create();
335 let unsync_op =
336 spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
337 // We send the arc out to be checked after the dispatcher has shut down so
338 // that we can be sure that the callback has had a chance to be called.
339 // We send the channel back out so that it lives long enough for the
340 // cancellation to be called on it.
341 read_and_drop(&mut a, CurrentDispatcher).await
342 });
343
344 // check that there are no more owners of the inner op for the unsynchronized dispatcher.
345 assert_strong_count(&unsync_op, 0);
346 }
347
348 #[test]
349 fn unsynchronized_early_cancel_state_drops_repeatedly_correctly() {
350 // the channel needs to outlive the dispatcher for this test because the channel shouldn't
351 // be closed before the read wait has been cancelled.
352 let (mut a, _b) = Channel::<[u8]>::create();
353 spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
354 for _ in 0..10000 {
355 let mut fut = unsafe { read_raw(&mut a, CurrentDispatcher) };
356 let Poll::Pending = futures::poll!(&mut fut) else {
357 panic!("expected pending state after polling channel read once");
358 };
359 drop(fut);
360 }
361 });
362 }
363}