fdf_channel/futures.rs
1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Internal helpers for implementing futures against channel objects
6
7use std::mem::ManuallyDrop;
8use std::sync::atomic::AtomicBool;
9use std::task::Waker;
10use zx::Status;
11
12use crate::channel::try_read_raw;
13use crate::message::Message;
14use fdf_core::dispatcher::OnDispatcher;
15use fdf_core::handle::DriverHandle;
16use fdf_sys::*;
17
18use core::mem::MaybeUninit;
19use core::task::{Context, Poll};
20use std::sync::{Arc, Mutex};
21
22pub use fdf_sys::fdf_handle_t;
23
24/// This struct is shared between the future and the driver runtime, with the first field
25/// being managed by the driver runtime and the second by the future. It will be held by two
26/// [`Arc`]s, one for each of the future and the runtime.
27///
28/// The future's [`Arc`] will be dropped when the future is either fulfilled or cancelled through
29/// normal [`Drop`] of the future.
30///
31/// The runtime's [`Arc`]'s dropping varies depending on whether the dispatcher it was registered on
32/// was synchronized or not, and whether it was cancelled or not. The callback will only ever be
33/// called *up to* one time.
34///
35/// If the dispatcher is synchronized, then the callback will *only* be called on fulfillment of the
36/// read wait.
37#[repr(C)]
38struct ReadMessageStateOp {
39 /// This must be at the start of the struct so that `ReadMessageStateOp` can be cast to and from `fdf_channel_read`.
40 read_op: fdf_channel_read,
41 waker: Mutex<Option<Waker>>,
42 cancelled: AtomicBool,
43}
44
45impl ReadMessageStateOp {
46 unsafe extern "C" fn handler(
47 _dispatcher: *mut fdf_dispatcher,
48 read_op: *mut fdf_channel_read,
49 status: i32,
50 ) {
51 // SAFETY: When setting up the read op, we incremented the refcount of the `Arc` to allow
52 // for this handler to reconstitute it.
53 let op: Arc<Self> = unsafe { Arc::from_raw(read_op.cast()) };
54 if Status::from_raw(status) == Status::CANCELED {
55 op.cancelled.store(true, std::sync::atomic::Ordering::Release);
56 }
57
58 let Some(waker) = op.waker.lock().unwrap().take() else {
59 // the waker was already taken, presumably because the future was dropped.
60 return;
61 };
62 waker.wake()
63 }
64}
65
66/// An object for managing the state of an async channel read message operation that can be used to
67/// implement futures.
68pub struct ReadMessageState {
69 op: Arc<ReadMessageStateOp>,
70 channel: ManuallyDrop<DriverHandle>,
71 callback_drops_arc: bool,
72}
73
74impl ReadMessageState {
75 /// Creates a new raw read message state that can be used to implement a [`Future`] that reads
76 /// data from a channel and then converts it to the appropriate type. It also allows for
77 /// different ways of storing and managing the dispatcher we wait on by deferring the
78 /// dispatcher used to poll time.
79 ///
80 /// # Safety
81 ///
82 /// The caller is responsible for ensuring that `channel` outlives this object.
83 pub unsafe fn new(channel: &DriverHandle) -> Self {
84 // SAFETY: The caller is responsible for ensuring that the handle is a correct channel handle
85 // and that the handle will outlive the created [`ReadMessageState`].
86 let channel = unsafe { channel.get_raw() };
87 Self {
88 op: Arc::new(ReadMessageStateOp {
89 read_op: fdf_channel_read {
90 channel: channel.get(),
91 handler: Some(ReadMessageStateOp::handler),
92 ..Default::default()
93 },
94 waker: Mutex::new(None),
95 cancelled: AtomicBool::new(false),
96 }),
97 // SAFETY: We know this is a valid driver handle by construction and we are
98 // storing this handle in a [`ManuallyDrop`] to prevent it from being double-dropped.
99 // The caller is responsible for ensuring that the handle outlives this object.
100 channel: ManuallyDrop::new(unsafe { DriverHandle::new_unchecked(channel) }),
101 // We haven't waited on it yet so we are responsible for dropping the arc for now,
102 // regardless of what kind of dispatcher it's intended to be used with.
103 callback_drops_arc: false,
104 }
105 }
106
107 /// Polls this channel read operation against the given dispatcher.
108 pub fn poll_with_dispatcher<D: OnDispatcher>(
109 self: &mut Self,
110 cx: &mut Context<'_>,
111 dispatcher: D,
112 ) -> Poll<Result<Option<Message<[MaybeUninit<u8>]>>, Status>> {
113 let mut waker_lock = self.op.waker.lock().unwrap();
114
115 if self.op.cancelled.load(std::sync::atomic::Ordering::Relaxed) {
116 // if the dispatcher we were waiting on is shutting down then when we try to go
117 // to wait again we'll get ZX_ERR_UNAVAILABLE anyways, so just short circuit that and
118 // return it right away.
119 return Poll::Ready(Err(Status::UNAVAILABLE));
120 }
121
122 match try_read_raw(&self.channel) {
123 Ok(res) => Poll::Ready(Ok(res)),
124 Err(Status::SHOULD_WAIT) => {
125 // if we haven't yet set a waker, that means we haven't started the wait operation
126 // yet.
127 if waker_lock.replace(cx.waker().clone()).is_none() {
128 // increment the reference count of the read op to account for the copy that will be given to
129 // `fdf_channel_wait_async`.
130 let op = Arc::into_raw(self.op.clone());
131 let res = dispatcher.on_maybe_dispatcher(|dispatcher| {
132 // SAFETY: the `ReadMessageStateOp` starts with an `fdf_channel_read` struct and
133 // has `repr(C)` layout, so is safe to be cast to the latter.
134 let res = Status::ok(unsafe {
135 fdf_channel_wait_async(
136 dispatcher.inner().as_ptr(),
137 op.cast_mut().cast(),
138 0,
139 )
140 });
141 // if the dispatcher we're waiting on is unsynchronized, the callback
142 // will drop the Arc and we need to indicate to our own Drop impl
143 // that it should not.
144 let callback_drops_arc = res.is_ok() && dispatcher.is_unsynchronized();
145 Ok(callback_drops_arc)
146 });
147
148 match res {
149 Ok(callback_drops_arc) => {
150 self.callback_drops_arc = callback_drops_arc;
151 }
152 Err(e) => return Poll::Ready(Err(e)),
153 }
154 }
155 Poll::Pending
156 }
157 Err(e) => Poll::Ready(Err(e)),
158 }
159 }
160}
161
162impl Drop for ReadMessageState {
163 fn drop(&mut self) {
164 let mut waker_lock = self.op.waker.lock().unwrap();
165 if waker_lock.is_none() {
166 // if there's no waker either the callback has already fired or we never waited on this
167 // future in the first place, so just leave it be.
168 return;
169 }
170
171 // SAFETY: since we hold a lifetimed-reference to the channel object here, the channel must
172 // be valid.
173 let res = Status::ok(unsafe { fdf_channel_cancel_wait(self.channel.get_raw().get()) });
174 match res {
175 Ok(_) => {}
176 Err(Status::NOT_FOUND) => {
177 // the callback is already being called or the wait was already cancelled, so just
178 // return and leave it.
179 return;
180 }
181 Err(e) => panic!("Unexpected error {e:?} cancelling driver channel read wait"),
182 }
183 // steal the waker so it doesn't get called, if there is one.
184 waker_lock.take();
185 // SAFETY: if the channel was waited on by a synchronized dispatcher, and the cancel was
186 // successful, the callback will not be called and we will have to free the `Arc` that the
187 // callback would have consumed.
188 if !self.callback_drops_arc {
189 unsafe { Arc::decrement_strong_count(Arc::as_ptr(&self.op)) };
190 }
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use std::pin::pin;
197 use std::sync::Weak;
198
199 use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
200 use fdf_env::test::{spawn_in_driver, spawn_in_driver_etc};
201
202 use crate::arena::Arena;
203 use crate::channel::{read_raw, Channel};
204
205 use super::*;
206
207 /// assert that the strong count of an arc is correct
208 fn assert_strong_count<T>(arc: &Weak<T>, count: usize) {
209 assert_eq!(Weak::strong_count(arc), count, "unexpected strong count on arc");
210 }
211
212 /// create, poll, and then immediately drop a read future for a channel and verify
213 /// that the internal op arc has the right refcount at all steps. Returns a copy
214 /// of the op arc at the end so it can be verified that the count goes down
215 /// to zero correctly.
216 async fn read_and_drop<T: ?Sized + 'static, D: OnDispatcher>(
217 channel: &Channel<T>,
218 dispatcher: D,
219 ) -> Weak<ReadMessageStateOp> {
220 let fut = read_raw(&channel.0, dispatcher);
221 let op_arc = Arc::downgrade(&fut.raw_fut.op);
222 assert_strong_count(&op_arc, 1);
223 let mut fut = pin!(fut);
224 let Poll::Pending = futures::poll!(fut.as_mut()) else {
225 panic!("expected pending state after polling channel read once");
226 };
227 assert_strong_count(&op_arc, 2);
228 op_arc
229 }
230
231 #[test]
232 fn early_cancel_future() {
233 spawn_in_driver("early cancellation", async {
234 let (a, b) = Channel::create();
235
236 // create, poll, and then immediately drop a read future for channel `a`
237 // so that it properly sets up the wait.
238 read_and_drop(&a, CurrentDispatcher).await;
239 b.write_with_data(Arena::new(), |arena| arena.insert(1)).unwrap();
240 assert_eq!(a.read(CurrentDispatcher).await.unwrap().unwrap().data(), Some(&1));
241 })
242 }
243
244 #[test]
245 fn very_early_cancel_state_drops_correctly() {
246 spawn_in_driver("early cancellation drop correctness", async {
247 let (a, _b) = Channel::<[u8]>::create();
248
249 // drop before even polling it should drop the arc correctly
250 let fut = read_raw(&a.0, CurrentDispatcher);
251 let op_arc = Arc::downgrade(&fut.raw_fut.op);
252 assert_strong_count(&op_arc, 1);
253 drop(fut);
254 assert_strong_count(&op_arc, 0);
255 })
256 }
257
258 #[test]
259 fn synchronized_early_cancel_state_drops_correctly() {
260 spawn_in_driver("early cancellation drop correctness", async {
261 let (a, _b) = Channel::<[u8]>::create();
262
263 assert_strong_count(&read_and_drop(&a, CurrentDispatcher).await, 0);
264 });
265 }
266
267 #[test]
268 fn unsynchronized_early_cancel_state_drops_correctly() {
269 // the channel needs to outlive the dispatcher for this test because the channel shouldn't
270 // be closed before the read wait has been cancelled.
271 let (a, _b) = Channel::<[u8]>::create();
272 let (unsync_op, _a) =
273 spawn_in_driver_etc("early cancellation drop correctness", false, true, async move {
274 // We send the arc out to be checked after the dispatcher has shut down so
275 // that we can be sure that the callback has had a chance to be called.
276 // We send the channel back out so that it lives long enough for the
277 // cancellation to be called on it.
278 let res = read_and_drop(&a, CurrentDispatcher).await;
279 (res, a)
280 });
281
282 // check that there are no more owners of the inner op for the unsynchronized dispatcher.
283 assert_strong_count(&unsync_op, 0);
284 }
285}