fdf_fidl/
lib.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::task::AtomicWaker;
6use std::num::NonZero;
7use std::ptr::NonNull;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::Poll;
11
12use fidl_next::Chunk;
13use zx::Status;
14
15use fdf_channel::arena::{Arena, ArenaBox};
16use fdf_channel::channel::Channel;
17use fdf_channel::futures::ReadMessageState;
18use fdf_channel::message::Message;
19use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
20use fdf_core::handle::{MixedHandle, MixedHandleType};
21
22/// A fidl-compatible driver channel that also holds a reference to the
23/// dispatcher. Defaults to using [`CurrentDispatcher`].
24pub struct DriverChannel<D = CurrentDispatcher> {
25    dispatcher: D,
26    channel: Channel<[Chunk]>,
27}
28
29impl<D> DriverChannel<D> {
30    /// Create a new driver fidl channel that will perform its operations on the given
31    /// dispatcher handle.
32    pub fn new_with_dispatcher(dispatcher: D, channel: Channel<[Chunk]>) -> Self {
33        Self { dispatcher, channel }
34    }
35}
36
37impl DriverChannel<CurrentDispatcher> {
38    /// Create a new driver fidl channel that will perform its operations on the
39    /// [`CurrentDispatcher`].
40    pub fn new(channel: Channel<[Chunk]>) -> Self {
41        Self::new_with_dispatcher(CurrentDispatcher, channel)
42    }
43}
44
45/// A channel buffer.
46pub struct SendBuffer {
47    handles: Vec<Option<MixedHandle>>,
48    data: Vec<Chunk>,
49}
50
51impl SendBuffer {
52    fn new() -> Self {
53        Self { handles: Vec::new(), data: Vec::new() }
54    }
55}
56
57impl fidl_next::Encoder for SendBuffer {
58    #[inline]
59    fn bytes_written(&self) -> usize {
60        fidl_next::Encoder::bytes_written(&self.data)
61    }
62
63    #[inline]
64    fn write(&mut self, bytes: &[u8]) {
65        fidl_next::Encoder::write(&mut self.data, bytes)
66    }
67
68    #[inline]
69    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
70        fidl_next::Encoder::rewrite(&mut self.data, pos, bytes)
71    }
72
73    fn write_zeroes(&mut self, len: usize) {
74        fidl_next::Encoder::write_zeroes(&mut self.data, len);
75    }
76}
77
78impl fidl_next::encoder::InternalHandleEncoder for SendBuffer {
79    #[inline]
80    fn __internal_handle_count(&self) -> usize {
81        self.handles.len()
82    }
83}
84
85impl fidl_next::fuchsia::HandleEncoder for SendBuffer {
86    fn push_handle(&mut self, handle: zx::Handle) -> Result<(), fidl_next::EncodeError> {
87        if let Some(handle) = MixedHandle::from_zircon_handle(handle) {
88            if handle.is_driver() {
89                return Err(fidl_next::EncodeError::ExpectedZirconHandle);
90            }
91            self.handles.push(Some(handle));
92        } else {
93            self.handles.push(None);
94        }
95        Ok(())
96    }
97
98    fn push_raw_driver_handle(&mut self, handle: u32) -> Result<(), fidl_next::EncodeError> {
99        if let Some(handle) = NonZero::new(handle) {
100            // SAFETY: the fidl framework is responsible for providing us with a valid, otherwise
101            // unowned handle.
102            let handle = unsafe { MixedHandle::from_raw(handle) };
103            if !handle.is_driver() {
104                return Err(fidl_next::EncodeError::ExpectedDriverHandle);
105            }
106            self.handles.push(Some(handle));
107        } else {
108            self.handles.push(None);
109        }
110        Ok(())
111    }
112
113    fn handles_pushed(&self) -> usize {
114        self.handles.len()
115    }
116}
117
118pub struct RecvBuffer {
119    buffer: Message<[Chunk]>,
120    data_offset: usize,
121    handle_offset: usize,
122}
123
124impl RecvBuffer {
125    fn next_handle(&self) -> Result<&MixedHandle, fidl_next::DecodeError> {
126        let Some(handles) = self.buffer.handles() else {
127            return Err(fidl_next::DecodeError::InsufficientHandles);
128        };
129        if handles.len() < self.handle_offset + 1 {
130            return Err(fidl_next::DecodeError::InsufficientHandles);
131        }
132        handles[self.handle_offset].as_ref().ok_or(fidl_next::DecodeError::RequiredHandleAbsent)
133    }
134}
135
136// SAFETY: The decoder implementation stores the data buffer in a [`Message`] tied to an [`Arena`],
137// and the memory in an [`Arena`] is guaranteed not to move while the arena is valid.
138// Also, since we own the [`Message`] and nothing else can, it is ok to treat its contents
139// as mutable through an `&mut self` reference to the struct.
140unsafe impl fidl_next::Decoder for RecvBuffer {
141    // SAFETY: if the caller requests a number of [`Chunk`]s that we can't supply, we return
142    // `InsufficientData`.
143    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, fidl_next::DecodeError> {
144        let Some(data) = self.buffer.data_mut() else {
145            return Err(fidl_next::DecodeError::InsufficientData);
146        };
147        if data.len() < self.data_offset + count {
148            return Err(fidl_next::DecodeError::InsufficientData);
149        }
150        let pos = self.data_offset;
151        self.data_offset += count;
152        Ok(unsafe { NonNull::new_unchecked((&mut data[pos..(pos + count)]).as_mut_ptr()) })
153    }
154
155    fn commit(&mut self) {
156        if let Some(handles) = self.buffer.handles_mut() {
157            for i in 0..self.handle_offset {
158                core::mem::forget(handles[i].take());
159            }
160        }
161    }
162
163    fn finish(&self) -> Result<(), fidl_next::DecodeError> {
164        let data_len = self.buffer.data().unwrap_or(&[]).len();
165        if self.data_offset != data_len {
166            return Err(fidl_next::DecodeError::ExtraBytes {
167                num_extra: data_len - self.data_offset,
168            });
169        }
170        let handle_len = self.buffer.handles().unwrap_or(&[]).len();
171        if self.handle_offset != handle_len {
172            return Err(fidl_next::DecodeError::ExtraHandles {
173                num_extra: handle_len - self.handle_offset,
174            });
175        }
176        Ok(())
177    }
178}
179
180impl fidl_next::decoder::InternalHandleDecoder for RecvBuffer {
181    fn __internal_take_handles(&mut self, count: usize) -> Result<(), fidl_next::DecodeError> {
182        let Some(handles) = self.buffer.handles_mut() else {
183            return Err(fidl_next::DecodeError::InsufficientHandles);
184        };
185        if handles.len() < self.handle_offset + count {
186            return Err(fidl_next::DecodeError::InsufficientHandles);
187        }
188        let pos = self.handle_offset;
189        self.handle_offset = pos + count;
190        Ok(())
191    }
192
193    fn __internal_handles_remaining(&self) -> usize {
194        self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
195    }
196}
197
198impl fidl_next::fuchsia::HandleDecoder for RecvBuffer {
199    fn take_raw_handle(&mut self) -> Result<zx::sys::zx_handle_t, fidl_next::DecodeError> {
200        let result = {
201            let handle = self.next_handle()?.resolve_ref();
202            let MixedHandleType::Zircon(handle) = handle else {
203                return Err(fidl_next::DecodeError::ExpectedZirconHandle);
204            };
205            handle.raw_handle()
206        };
207        let pos = self.handle_offset;
208        self.handle_offset = pos + 1;
209        Ok(result)
210    }
211
212    fn take_raw_driver_handle(&mut self) -> Result<u32, fidl_next::DecodeError> {
213        let result = {
214            let handle = self.next_handle()?.resolve_ref();
215            let MixedHandleType::Driver(handle) = handle else {
216                return Err(fidl_next::DecodeError::ExpectedDriverHandle);
217            };
218            unsafe { handle.get_raw().get() }
219        };
220        let pos = self.handle_offset;
221        self.handle_offset = pos + 1;
222        Ok(result)
223    }
224
225    fn handles_remaining(&mut self) -> usize {
226        self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
227    }
228}
229
230/// The inner state of a receive future used by [`fidl_next::protocol::Transport`].
231pub struct DriverRecvState(ReadMessageState);
232
233struct Shared<D> {
234    is_closed: AtomicBool,
235    sender_count: AtomicUsize,
236    closed_waker: AtomicWaker,
237    channel: DriverChannel<D>,
238}
239
240impl<D> Shared<D> {
241    fn new(channel: DriverChannel<D>) -> Self {
242        Self {
243            is_closed: AtomicBool::new(false),
244            sender_count: AtomicUsize::new(1),
245            closed_waker: AtomicWaker::new(),
246            channel,
247        }
248    }
249
250    fn close(&self) {
251        self.is_closed.store(true, Ordering::Relaxed);
252        self.closed_waker.wake();
253    }
254}
255/// The sender side of a [`DriverChannel`].
256pub struct DriverSender<D> {
257    shared: Arc<Shared<D>>,
258}
259
260impl<D> Drop for DriverSender<D> {
261    fn drop(&mut self) {
262        let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
263        if senders == 1 {
264            self.shared.close();
265        }
266    }
267}
268
269impl<D> Clone for DriverSender<D> {
270    fn clone(&self) -> Self {
271        self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
272        Self { shared: self.shared.clone() }
273    }
274}
275
276/// The receiver side of a [`DriverChannel`].
277pub struct DriverReceiver<D> {
278    shared: Arc<Shared<D>>,
279}
280
281impl<D: OnDispatcher> fidl_next::protocol::Transport for DriverChannel<D> {
282    type Error = Status;
283
284    fn split(self) -> (Self::Sender, Self::Receiver) {
285        let shared = Arc::new(Shared::new(self));
286        let sender = DriverSender { shared: shared.clone() };
287        let receiver = DriverReceiver { shared };
288        (sender, receiver)
289    }
290
291    type Sender = DriverSender<D>;
292
293    type SendBuffer = SendBuffer;
294
295    type SendFutureState = SendBuffer;
296
297    fn acquire(_sender: &Self::Sender) -> Self::SendBuffer {
298        SendBuffer::new()
299    }
300
301    fn close(sender: &Self::Sender) {
302        sender.shared.close();
303    }
304
305    type Receiver = DriverReceiver<D>;
306
307    type RecvFutureState = DriverRecvState;
308
309    type RecvBuffer = RecvBuffer;
310
311    fn begin_send(_sender: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
312        buffer
313    }
314
315    fn poll_send(
316        mut buffer: std::pin::Pin<&mut Self::SendFutureState>,
317        _cx: &mut std::task::Context<'_>,
318        sender: &Self::Sender,
319    ) -> std::task::Poll<Result<(), Self::Error>> {
320        let arena = Arena::new();
321        let message = Message::new_with(arena, |arena| {
322            let data = arena.insert_slice(&buffer.data);
323            let handles = buffer.handles.split_off(0);
324            let handles = arena.insert_from_iter(handles.into_iter());
325            (Some(data), Some(handles))
326        });
327        Poll::Ready(sender.shared.channel.channel.write(message))
328    }
329
330    fn begin_recv(receiver: &mut Self::Receiver) -> Self::RecvFutureState {
331        // SAFETY: The `receiver` owns the channel we're using here and will be the same
332        // receiver given to `poll_recv`, so must outlive the state object we're constructing.
333        let state =
334            unsafe { ReadMessageState::new(receiver.shared.channel.channel.driver_handle()) };
335        DriverRecvState(state)
336    }
337
338    fn poll_recv(
339        mut future: std::pin::Pin<&mut Self::RecvFutureState>,
340        cx: &mut std::task::Context<'_>,
341        receiver: &mut Self::Receiver,
342    ) -> std::task::Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
343        use std::task::Poll::*;
344        match future.as_mut().0.poll_with_dispatcher(cx, receiver.shared.channel.dispatcher.clone())
345        {
346            Ready(Ok(Some(buffer))) => {
347                let buffer = buffer.map_data(|_, data| {
348                    let bytes = data.len();
349                    assert_eq!(
350                        0,
351                        bytes % size_of::<Chunk>(),
352                        "Received driver channel buffer was not a multiple of {} bytes",
353                        size_of::<Chunk>()
354                    );
355                    // SAFETY: we verified that the size of the message we received was the correct
356                    // multiple of chunks and we know that the data pointer is otherwise valid and
357                    // from the correct arena by construction.
358                    let new_box = unsafe {
359                        let ptr = ArenaBox::into_ptr(data).cast();
360                        ArenaBox::new(NonNull::slice_from_raw_parts(
361                            ptr,
362                            bytes / size_of::<Chunk>(),
363                        ))
364                    };
365                    new_box
366                });
367
368                Ready(Ok(Some(RecvBuffer { buffer, data_offset: 0, handle_offset: 0 })))
369            }
370            Ready(Ok(None)) => Ready(Ok(None)),
371            Ready(Err(err)) => Ready(Err(err)),
372            Pending => {
373                receiver.shared.closed_waker.register(cx.waker());
374                if receiver.shared.is_closed.load(Ordering::Relaxed) {
375                    return Poll::Ready(Ok(None));
376                }
377                Pending
378            }
379        }
380    }
381}
382
383#[cfg(test)]
384mod test {
385    use fidl_next::{Client, ClientEnd, Responder, Server, ServerEnd, ServerSender};
386    use fidl_next_fuchsia_examples_gizmo::device::{GetEvent, GetHardwareId};
387    use fidl_next_fuchsia_examples_gizmo::{
388        Device, DeviceClientHandler, DeviceClientSender, DeviceGetEventResponse,
389        DeviceGetHardwareIdResponse, DeviceServerHandler,
390    };
391    use fuchsia_async::OnSignals;
392    use zx::{AsHandleRef, Event, HandleBased, Signals};
393
394    use super::*;
395    use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
396    use fdf_env::test::spawn_in_driver;
397
398    struct DeviceServer;
399    impl DeviceServerHandler<DriverChannel> for DeviceServer {
400        fn get_hardware_id(
401            &mut self,
402            sender: &ServerSender<DriverChannel, Device>,
403            responder: Responder<GetHardwareId>,
404        ) {
405            let sender = sender.clone();
406            CurrentDispatcher
407                .spawn_task(async move {
408                    responder
409                        .respond(
410                            &sender,
411                            Result::<_, i32>::Ok(DeviceGetHardwareIdResponse { response: 4004 }),
412                        )
413                        .unwrap()
414                        .await
415                        .unwrap();
416                })
417                .unwrap();
418        }
419
420        fn get_event(
421            &mut self,
422            sender: &ServerSender<DriverChannel, Device>,
423            responder: Responder<GetEvent>,
424        ) {
425            let sender = sender.clone();
426            let event = Event::create();
427            event.signal_handle(Signals::empty(), Signals::USER_0).unwrap();
428            let response = DeviceGetEventResponse { event: event.into_handle() };
429            CurrentDispatcher
430                .spawn_task(async move {
431                    responder.respond(&sender, response).unwrap().await.unwrap();
432                })
433                .unwrap();
434        }
435    }
436
437    struct DeviceClient;
438    impl DeviceClientHandler<DriverChannel> for DeviceClient {}
439
440    #[test]
441    fn driver_fidl_server() {
442        spawn_in_driver("driver fidl server", async {
443            let (server_chan, client_chan) = Channel::<[Chunk]>::create();
444            let client_end = ClientEnd::from_untyped(DriverChannel::new(client_chan));
445            let server_end: ServerEnd<_, Device> =
446                ServerEnd::from_untyped(DriverChannel::new(server_chan));
447            let mut client = Client::new(client_end);
448            let mut server = Server::new(server_end);
449            let client_sender = client.sender().clone();
450
451            CurrentDispatcher
452                .spawn_task(async move {
453                    server.run(DeviceServer).await.unwrap();
454                    println!("server task finished");
455                })
456                .unwrap();
457            CurrentDispatcher
458                .spawn_task(async move {
459                    client.run(DeviceClient).await.unwrap();
460                    println!("client task finished");
461                })
462                .unwrap();
463
464            {
465                let res = client_sender.get_hardware_id().unwrap().await.unwrap();
466                let hardware_id = res.unwrap();
467                assert_eq!(hardware_id.response, 4004);
468            }
469
470            {
471                let res = client_sender.get_event().unwrap().await.unwrap();
472                let event = Event::from_handle(res.event.take());
473
474                // wait for the event on a fuchsia_async executor
475                let mut executor = fuchsia_async::LocalExecutor::new();
476                let signalled =
477                    executor.run_singlethreaded(OnSignals::new(event, Signals::USER_0)).unwrap();
478                assert_eq!(Signals::USER_0, signalled);
479            }
480        });
481    }
482}