fdf_channel/
channel.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Safe bindings for the driver runtime channel stable ABI
6
7use core::future::Future;
8use zx::Status;
9
10use crate::arena::{Arena, ArenaBox};
11use crate::futures::ReadMessageState;
12use crate::message::Message;
13use fdf_core::dispatcher::OnDispatcher;
14use fdf_core::handle::{DriverHandle, MixedHandle};
15use fdf_sys::*;
16
17use core::marker::PhantomData;
18use core::mem::{size_of_val, MaybeUninit};
19use core::num::NonZero;
20use core::pin::Pin;
21use core::ptr::{null_mut, NonNull};
22use core::task::{Context, Poll};
23
24pub use fdf_sys::fdf_handle_t;
25
26/// Implements a message channel through the Fuchsia Driver Runtime
27#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
28pub struct Channel<T: ?Sized + 'static>(pub(crate) DriverHandle, PhantomData<Message<T>>);
29
30impl<T: ?Sized + 'static> Channel<T> {
31    /// Creates a new channel pair that can be used to send messages of type `T`
32    /// between threads managed by the driver runtime.
33    pub fn create() -> (Self, Self) {
34        let mut channel1 = 0;
35        let mut channel2 = 0;
36        // This call cannot fail as the only reason it would fail is due to invalid
37        // option flags, and 0 is a valid option.
38        Status::ok(unsafe { fdf_channel_create(0, &mut channel1, &mut channel2) })
39            .expect("failed to create channel pair");
40        // SAFETY: if fdf_channel_create returned ZX_OK, it will have placed
41        // valid channel handles that must be non-zero.
42        unsafe {
43            (
44                Self::from_handle_unchecked(NonZero::new_unchecked(channel1)),
45                Self::from_handle_unchecked(NonZero::new_unchecked(channel2)),
46            )
47        }
48    }
49
50    /// Returns a reference to the inner handle of the channel.
51    pub fn driver_handle(&self) -> &DriverHandle {
52        &self.0
53    }
54
55    /// Takes the inner handle to the channel. The caller is responsible for ensuring
56    /// that the handle is freed.
57    pub fn into_driver_handle(self) -> DriverHandle {
58        self.0
59    }
60
61    /// Initializes a [`Channel`] object from the given non-zero handle.
62    ///
63    /// # Safety
64    ///
65    /// The caller must ensure that the handle is not invalid and that it is
66    /// part of a driver runtime channel pair of type `T`.
67    unsafe fn from_handle_unchecked(handle: NonZero<fdf_handle_t>) -> Self {
68        // SAFETY: caller is responsible for ensuring that it is a valid channel
69        Self(unsafe { DriverHandle::new_unchecked(handle) }, PhantomData)
70    }
71
72    /// Initializes a [`Channel`] object from the given [`DriverHandle`],
73    /// assuming that it is a channel of type `T`.
74    ///
75    /// # Safety
76    ///
77    /// The caller must ensure that the handle is a [`Channel`]-based handle that is
78    /// using type `T` as its wire format.
79    pub unsafe fn from_driver_handle(handle: DriverHandle) -> Self {
80        Self(handle, PhantomData)
81    }
82
83    /// Writes the [`Message`] given to the channel. This will complete asynchronously and can't
84    /// be cancelled.
85    ///
86    /// The channel will take ownership of the data and handles passed in,
87    pub fn write(&self, message: Message<T>) -> Result<(), Status> {
88        // get the sizes while the we still have refs to the data and handles
89        let data_len = message.data().map_or(0, |data| size_of_val(&*data) as u32);
90        let handles_count = message.handles().map_or(0, |handles| handles.len() as u32);
91
92        let (arena, data, handles) = message.into_raw();
93
94        // transform the `Option<NonNull<T>>` into just `*mut T`
95        let data_ptr = data.map_or(null_mut(), |data| data.cast().as_ptr());
96        let handles_ptr = handles.map_or(null_mut(), |handles| handles.cast().as_ptr());
97
98        // SAFETY:
99        // - Normally, we could be reading uninit bytes here. However, as long as fdf_channel_write
100        //   doesn't allow cross-LTO then it won't care whether the bytes are initialized.
101        // - The `Message` will generally only construct correctly if the data and handles pointers
102        //   inside it are from the arena it holds, but just in case `fdf_channel_write` will check
103        //   that we are using the correct arena so we do not need to re-verify that they are from
104        //   the same arena.
105        Status::ok(unsafe {
106            fdf_channel_write(
107                self.0.get_raw().get(),
108                0,
109                arena.as_ptr(),
110                data_ptr,
111                data_len,
112                handles_ptr,
113                handles_count,
114            )
115        })?;
116
117        // SAFETY: this is the valid-by-contruction arena we were passed in through the [`Message`]
118        // object, and now that we have completed `fdf_channel_write` it is safe to drop our copy
119        // of it.
120        unsafe { fdf_arena_drop_ref(arena.as_ptr()) };
121        Ok(())
122    }
123
124    /// Shorthand for calling [`Self::write`] with the result of [`Message::new_with`]
125    pub fn write_with<F>(&self, arena: Arena, f: F) -> Result<(), Status>
126    where
127        F: for<'a> FnOnce(
128            &'a Arena,
129        )
130            -> (Option<ArenaBox<'a, T>>, Option<ArenaBox<'a, [Option<MixedHandle>]>>),
131    {
132        self.write(Message::new_with(arena, f))
133    }
134
135    /// Shorthand for calling [`Self::write`] with the result of [`Message::new_with`]
136    pub fn write_with_data<F>(&self, arena: Arena, f: F) -> Result<(), Status>
137    where
138        F: for<'a> FnOnce(&'a Arena) -> ArenaBox<'a, T>,
139    {
140        self.write(Message::new_with_data(arena, f))
141    }
142}
143
144/// Attempts to read from the channel, returning a [`Message`] object that can be used to
145/// access or take the data received if there was any. This is the basic building block
146/// on which the other `try_read_*` methods are built.
147pub(crate) fn try_read_raw(
148    channel: &DriverHandle,
149) -> Result<Option<Message<[MaybeUninit<u8>]>>, Status> {
150    let mut out_arena = null_mut();
151    let mut out_data = null_mut();
152    let mut out_num_bytes = 0;
153    let mut out_handles = null_mut();
154    let mut out_num_handles = 0;
155    Status::ok(unsafe {
156        fdf_channel_read(
157            channel.get_raw().get(),
158            0,
159            &mut out_arena,
160            &mut out_data,
161            &mut out_num_bytes,
162            &mut out_handles,
163            &mut out_num_handles,
164        )
165    })?;
166    // if no arena was returned, that means no data was returned.
167    if out_arena == null_mut() {
168        return Ok(None);
169    }
170    // SAFETY: we just checked that the `out_arena` is non-null
171    let arena = Arena(unsafe { NonNull::new_unchecked(out_arena) });
172    let data_ptr = if !out_data.is_null() {
173        let ptr = core::ptr::slice_from_raw_parts_mut(out_data.cast(), out_num_bytes as usize);
174        // SAFETY: we just checked that the pointer was non-null, the slice version of it should
175        // be too.
176        Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
177    } else {
178        None
179    };
180    let handles_ptr = if !out_handles.is_null() {
181        let ptr = core::ptr::slice_from_raw_parts_mut(out_handles.cast(), out_num_handles as usize);
182        // SAFETY: we just checked that the pointer was non-null, the slice version of it should
183        // be too.
184        Some(unsafe { ArenaBox::new(NonNull::new_unchecked(ptr)) })
185    } else {
186        None
187    };
188    Ok(Some(unsafe { Message::new_unchecked(arena, data_ptr, handles_ptr) }))
189}
190
191/// Reads a message from the channel asynchronously
192///
193/// # Panic
194///
195/// Panics if this is not run from a driver framework dispatcher.
196pub(crate) fn read_raw<'a, D>(channel: &'a DriverHandle, dispatcher: D) -> ReadMessageRawFut<D> {
197    // SAFETY: Since the future's lifetime is bound to the original driver handle and it
198    // holds the message state, the message state object can't outlive the handle.
199    ReadMessageRawFut { raw_fut: unsafe { ReadMessageState::new(channel) }, dispatcher }
200}
201
202impl<T> Channel<T> {
203    /// Attempts to read an object of type `T` and a handle set from the channel
204    pub fn try_read<'a>(&self) -> Result<Option<Message<T>>, Status> {
205        // read a message from the channel
206        let Some(message) = try_read_raw(&self.0)? else {
207            return Ok(None);
208        };
209        // SAFETY: It is an invariant of Channel<T> that messages sent or received are always of
210        // type T.
211        Ok(Some(unsafe { message.cast_unchecked() }))
212    }
213
214    /// Reads an object of type `T` and a handle set from the channel asynchronously
215    pub async fn read<D: OnDispatcher>(&self, dispatcher: D) -> Result<Option<Message<T>>, Status> {
216        let Some(message) = read_raw(&self.0, dispatcher).await? else {
217            return Ok(None);
218        };
219        // SAFETY: It is an invariant of Channel<T> that messages sent or received are always of
220        // type T.
221        Ok(Some(unsafe { message.cast_unchecked() }))
222    }
223}
224
225impl Channel<[u8]> {
226    /// Attempts to read an object of type `T` and a handle set from the channel
227    pub fn try_read_bytes<'a>(&self) -> Result<Option<Message<[u8]>>, Status> {
228        // read a message from the channel
229        let Some(message) = try_read_raw(&self.0)? else {
230            return Ok(None);
231        };
232        // SAFETY: It is an invariant of Channel<[u8]> that messages sent or received are always of
233        // type [u8].
234        Ok(Some(unsafe { message.assume_init() }))
235    }
236
237    /// Reads a slice of type `T` and a handle set from the channel asynchronously
238    pub async fn read_bytes<D: OnDispatcher>(
239        &self,
240        dispatcher: D,
241    ) -> Result<Option<Message<[u8]>>, Status> {
242        // read a message from the channel
243        let Some(message) = read_raw(&self.0, dispatcher).await? else {
244            return Ok(None);
245        };
246        // SAFETY: It is an invariant of Channel<[u8]> that messages sent or received are always of
247        // type [u8].
248        Ok(Some(unsafe { message.assume_init() }))
249    }
250}
251
252impl<T> From<Channel<T>> for MixedHandle {
253    fn from(value: Channel<T>) -> Self {
254        MixedHandle::from(value.0)
255    }
256}
257
258pub(crate) struct ReadMessageRawFut<D> {
259    pub(crate) raw_fut: ReadMessageState,
260    dispatcher: D,
261}
262
263impl<D: OnDispatcher> Future for ReadMessageRawFut<D> {
264    type Output = Result<Option<Message<[MaybeUninit<u8>]>>, Status>;
265
266    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        let dispatcher = self.dispatcher.clone();
268        self.as_mut().raw_fut.poll_with_dispatcher(cx, dispatcher)
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use std::pin::pin;
275    use std::sync::mpsc;
276
277    use fdf_core::dispatcher::{CurrentDispatcher, Dispatcher, DispatcherBuilder, OnDispatcher};
278    use fdf_core::handle::MixedHandleType;
279    use fdf_env::test::spawn_in_driver;
280    use futures::poll;
281
282    use super::*;
283    use crate::test_utils::*;
284
285    #[test]
286    fn send_and_receive_bytes_synchronously() {
287        let (first, second) = Channel::create();
288        let arena = Arena::new();
289        assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
290        first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
291        assert_eq!(&*second.try_read_bytes().unwrap().unwrap().data().unwrap(), &[1, 2, 3, 4]);
292        assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
293        second.write_with_data(arena.clone(), |arena| arena.insert_slice(&[5, 6, 7, 8])).unwrap();
294        assert_eq!(&*first.try_read_bytes().unwrap().unwrap().data().unwrap(), &[5, 6, 7, 8]);
295        assert_eq!(first.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
296        assert_eq!(second.try_read_bytes().unwrap_err(), Status::from_raw(ZX_ERR_SHOULD_WAIT));
297        drop(second);
298        assert_eq!(
299            first.write_with_data(arena.clone(), |arena| arena.insert_slice(&[9, 10, 11, 12])),
300            Err(Status::from_raw(ZX_ERR_PEER_CLOSED))
301        );
302    }
303
304    #[test]
305    fn send_and_receive_bytes_asynchronously() {
306        spawn_in_driver("channel async", async {
307            let arena = Arena::new();
308            let (first, second) = Channel::create();
309
310            assert!(poll!(pin!(first.read_bytes(CurrentDispatcher))).is_pending());
311            second.write_with_data(arena, |arena| arena.insert_slice(&[1, 2, 3, 4])).unwrap();
312            assert_eq!(
313                first.read_bytes(CurrentDispatcher).await.unwrap().unwrap().data().unwrap(),
314                &[1, 2, 3, 4]
315            );
316        });
317    }
318
319    #[test]
320    fn send_and_receive_objects_synchronously() {
321        let arena = Arena::new();
322        let (first, second) = Channel::create();
323        let (tx, rx) = mpsc::channel();
324        first
325            .write_with_data(arena.clone(), |arena| arena.insert(DropSender::new(1, tx.clone())))
326            .unwrap();
327        rx.try_recv().expect_err("should not drop the object when sent");
328        let message = second.try_read().unwrap().unwrap();
329        assert_eq!(message.data().unwrap().0, 1);
330        rx.try_recv().expect_err("should not drop the object when received");
331        drop(message);
332        rx.try_recv().expect("dropped when received");
333    }
334
335    #[test]
336    fn send_and_receive_handles_synchronously() {
337        println!("Create channels and write one end of one of the channel pairs to the other");
338        let (first, second) = Channel::<()>::create();
339        let (inner_first, inner_second) = Channel::<String>::create();
340        let message = Message::new_with(Arena::new(), |arena| {
341            (None, Some(arena.insert_boxed_slice(Box::new([Some(inner_first.into())]))))
342        });
343        first.write(message).unwrap();
344
345        println!("Receive the channel back on the other end of the first channel pair.");
346        let mut arena = None;
347        let message =
348            second.try_read().unwrap().expect("Expected a message with contents to be received");
349        let (_, received_handles) = message.into_arena_boxes(&mut arena);
350        let mut first_handle_received =
351            ArenaBox::take_boxed_slice(received_handles.expect("expected handles in the message"));
352        let first_handle_received = first_handle_received
353            .first_mut()
354            .expect("expected one handle in the handle set")
355            .take()
356            .expect("expected the first handle to be non-null");
357        let first_handle_received = first_handle_received.resolve();
358        let MixedHandleType::Driver(driver_handle) = first_handle_received else {
359            panic!("Got a non-driver handle when we sent a driver handle");
360        };
361        let inner_first_received = unsafe { Channel::from_driver_handle(driver_handle) };
362
363        println!("Send and receive a string across the now-transmitted channel pair.");
364        inner_first_received
365            .write_with_data(Arena::new(), |arena| arena.insert("boom".to_string()))
366            .unwrap();
367        assert_eq!(inner_second.try_read().unwrap().unwrap().data().unwrap(), &"boom".to_string());
368    }
369
370    async fn ping(chan: Channel<u8>) {
371        println!("starting ping!");
372        chan.write_with_data(Arena::new(), |arena| arena.insert(0)).unwrap();
373        while let Ok(Some(msg)) = chan.read(CurrentDispatcher).await {
374            let next = *msg.data().unwrap();
375            println!("ping! {next}");
376            chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
377        }
378    }
379
380    async fn pong(chan: Channel<u8>) {
381        println!("starting pong!");
382        while let Some(msg) = chan.read(CurrentDispatcher).await.unwrap() {
383            let next = *msg.data().unwrap();
384            println!("pong! {next}");
385            if next > 10 {
386                println!("bye!");
387                break;
388            }
389            chan.write_with_data(msg.take_arena(), |arena| arena.insert(next + 1)).unwrap();
390        }
391    }
392
393    #[test]
394    fn async_ping_pong() {
395        spawn_in_driver("async ping pong", async {
396            let (ping_chan, pong_chan) = Channel::create();
397            CurrentDispatcher.spawn_task(ping(ping_chan)).unwrap();
398            pong(pong_chan).await;
399        });
400    }
401
402    #[test]
403    fn async_ping_pong_on_fuchsia_async() {
404        spawn_in_driver("async ping pong", async {
405            let (ping_chan, pong_chan) = Channel::create();
406
407            let fdf_dispatcher = DispatcherBuilder::new()
408                .name("fdf-async")
409                .create()
410                .expect("failure creating non-blocking dispatcher for fdf operations on rust-async dispatcher")
411                .release();
412
413            let rust_async_dispatcher = DispatcherBuilder::new()
414                .name("fuchsia-async")
415                .allow_thread_blocking()
416                .create()
417                .expect("failure creating blocking dispatcher for rust async")
418                .release();
419
420            rust_async_dispatcher
421                .post_task_sync(move |_| {
422                    Dispatcher::override_current(fdf_dispatcher, || {
423                        let mut executor = fuchsia_async::LocalExecutor::new();
424                        executor.run_singlethreaded(ping(ping_chan));
425                    });
426                })
427                .unwrap();
428
429            pong(pong_chan).await
430        });
431    }
432}