fidl_next_protocol/fuchsia/
channel.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! A transport implementation which uses Zircon channels.
6
7use core::mem::replace;
8use core::pin::Pin;
9use core::ptr::NonNull;
10use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use core::task::{Context, Poll};
12use std::sync::Arc;
13
14use fidl_next_codec::decoder::InternalHandleDecoder;
15use fidl_next_codec::encoder::InternalHandleEncoder;
16use fidl_next_codec::fuchsia::{HandleDecoder, HandleEncoder};
17use fidl_next_codec::{Chunk, DecodeError, Decoder, EncodeError, Encoder, CHUNK_SIZE};
18use fuchsia_async::{RWHandle, ReadableHandle as _};
19use futures::task::AtomicWaker;
20use zx::sys::{
21    zx_channel_read, zx_channel_write, zx_handle_t, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_PEER_CLOSED,
22    ZX_ERR_SHOULD_WAIT, ZX_OK,
23};
24use zx::{AsHandleRef as _, Channel, Handle, HandleBased, Status};
25
26use crate::{NonBlockingTransport, Transport};
27
28struct Shared {
29    is_closed: AtomicBool,
30    sender_count: AtomicUsize,
31    closed_waker: AtomicWaker,
32    channel: RWHandle<Channel>,
33    // TODO: recycle send/recv buffers to reduce allocations
34}
35
36impl Shared {
37    fn new(channel: Channel) -> Self {
38        Self {
39            is_closed: AtomicBool::new(false),
40            sender_count: AtomicUsize::new(1),
41            closed_waker: AtomicWaker::new(),
42            channel: RWHandle::new(channel),
43        }
44    }
45
46    fn close(&self) {
47        self.is_closed.store(true, Ordering::Relaxed);
48        self.closed_waker.wake();
49    }
50}
51
52/// A channel sender.
53pub struct Sender {
54    shared: Arc<Shared>,
55}
56
57impl Drop for Sender {
58    fn drop(&mut self) {
59        let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
60        if senders == 1 {
61            self.shared.close();
62        }
63    }
64}
65
66impl Clone for Sender {
67    fn clone(&self) -> Self {
68        self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
69        Self { shared: self.shared.clone() }
70    }
71}
72
73/// A channel buffer.
74#[derive(Default)]
75pub struct Buffer {
76    handles: Vec<Handle>,
77    chunks: Vec<Chunk>,
78}
79
80impl Buffer {
81    /// New buffer.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Retrieve the handles.
87    pub fn handles(&self) -> &[Handle] {
88        &self.handles
89    }
90
91    /// Retrieve the bytes.
92    pub fn bytes(&self) -> Vec<u8> {
93        self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
94    }
95}
96
97impl InternalHandleEncoder for Buffer {
98    #[inline]
99    fn __internal_handle_count(&self) -> usize {
100        self.handles.len()
101    }
102}
103
104impl Encoder for Buffer {
105    #[inline]
106    fn bytes_written(&self) -> usize {
107        Encoder::bytes_written(&self.chunks)
108    }
109
110    #[inline]
111    fn write_zeroes(&mut self, len: usize) {
112        Encoder::write_zeroes(&mut self.chunks, len)
113    }
114
115    #[inline]
116    fn write(&mut self, bytes: &[u8]) {
117        Encoder::write(&mut self.chunks, bytes)
118    }
119
120    #[inline]
121    fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
122        Encoder::rewrite(&mut self.chunks, pos, bytes)
123    }
124}
125
126impl HandleEncoder for Buffer {
127    fn push_handle(&mut self, handle: Handle) -> Result<(), EncodeError> {
128        self.handles.push(handle);
129        Ok(())
130    }
131
132    fn handles_pushed(&self) -> usize {
133        self.handles.len()
134    }
135}
136
137/// The state for a channel send future.
138pub struct SendFutureState {
139    buffer: Buffer,
140}
141
142/// A channel receiver.
143pub struct Receiver {
144    shared: Arc<Shared>,
145}
146
147/// The state for a channel receive future.
148pub struct RecvFutureState {
149    buffer: Option<Buffer>,
150}
151
152/// A channel receive buffer.
153pub struct RecvBuffer {
154    buffer: Buffer,
155    chunks_taken: usize,
156    handles_taken: usize,
157}
158
159unsafe impl Decoder for RecvBuffer {
160    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
161        if count > self.buffer.chunks.len() - self.chunks_taken {
162            return Err(DecodeError::InsufficientData);
163        }
164
165        let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
166        self.chunks_taken += count;
167
168        unsafe { Ok(NonNull::new_unchecked(chunks)) }
169    }
170
171    fn commit(&mut self) {
172        for handle in &mut self.buffer.handles[0..self.handles_taken] {
173            // This handle was taken. To commit the current changes, we need to forget it.
174            let _ = replace(handle, Handle::invalid()).into_raw();
175        }
176    }
177
178    fn finish(&self) -> Result<(), DecodeError> {
179        if self.chunks_taken != self.buffer.chunks.len() {
180            return Err(DecodeError::ExtraBytes {
181                num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
182            });
183        }
184
185        if self.handles_taken != self.buffer.handles.len() {
186            return Err(DecodeError::ExtraHandles {
187                num_extra: self.buffer.handles.len() - self.handles_taken,
188            });
189        }
190
191        Ok(())
192    }
193}
194
195impl InternalHandleDecoder for RecvBuffer {
196    fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
197        if count > self.buffer.handles.len() - self.handles_taken {
198            return Err(DecodeError::InsufficientHandles);
199        }
200
201        for i in self.handles_taken..self.handles_taken + count {
202            let handle = replace(&mut self.buffer.handles[i], Handle::invalid());
203            drop(handle);
204        }
205        self.handles_taken += count;
206
207        Ok(())
208    }
209
210    fn __internal_handles_remaining(&self) -> usize {
211        self.buffer.handles.len() - self.handles_taken
212    }
213}
214
215impl HandleDecoder for RecvBuffer {
216    fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
217        if self.handles_taken >= self.buffer.handles.len() {
218            return Err(DecodeError::InsufficientHandles);
219        }
220
221        let handle = self.buffer.handles[self.handles_taken].raw_handle();
222        self.handles_taken += 1;
223
224        Ok(handle)
225    }
226
227    fn handles_remaining(&mut self) -> usize {
228        self.buffer.handles.len() - self.handles_taken
229    }
230}
231
232impl Transport for Channel {
233    type Error = Status;
234
235    fn split(self) -> (Self::Sender, Self::Receiver) {
236        let shared = Arc::new(Shared::new(self));
237        (Sender { shared: shared.clone() }, Receiver { shared })
238    }
239
240    type Sender = Sender;
241    type SendBuffer = Buffer;
242    type SendFutureState = SendFutureState;
243
244    fn acquire(_: &Self::Sender) -> Self::SendBuffer {
245        Buffer::new()
246    }
247
248    fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
249        SendFutureState { buffer }
250    }
251
252    fn poll_send(
253        future_state: Pin<&mut Self::SendFutureState>,
254        _: &mut Context<'_>,
255        sender: &Self::Sender,
256    ) -> Poll<Result<(), Self::Error>> {
257        Poll::Ready(Self::send_immediately(future_state.get_mut(), sender))
258    }
259
260    fn close(sender: &Self::Sender) {
261        sender.shared.close();
262    }
263
264    type Receiver = Receiver;
265    type RecvFutureState = RecvFutureState;
266    type RecvBuffer = RecvBuffer;
267
268    fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
269        RecvFutureState { buffer: Some(Buffer::new()) }
270    }
271
272    fn poll_recv(
273        mut future_state: Pin<&mut Self::RecvFutureState>,
274        cx: &mut Context<'_>,
275        receiver: &mut Self::Receiver,
276    ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
277        let buffer = future_state.buffer.as_mut().unwrap();
278
279        let mut actual_bytes = 0;
280        let mut actual_handles = 0;
281
282        loop {
283            let result = unsafe {
284                zx_channel_read(
285                    receiver.shared.channel.get_ref().raw_handle(),
286                    0,
287                    buffer.chunks.as_mut_ptr().cast(),
288                    buffer.handles.as_mut_ptr().cast(),
289                    (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
290                    buffer.handles.capacity() as u32,
291                    &mut actual_bytes,
292                    &mut actual_handles,
293                )
294            };
295
296            match result {
297                ZX_OK => {
298                    unsafe {
299                        buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
300                        buffer.handles.set_len(actual_handles as usize);
301                    }
302                    return Poll::Ready(Ok(Some(RecvBuffer {
303                        buffer: future_state.buffer.take().unwrap(),
304                        chunks_taken: 0,
305                        handles_taken: 0,
306                    })));
307                }
308                ZX_ERR_PEER_CLOSED => return Poll::Ready(Ok(None)),
309                ZX_ERR_BUFFER_TOO_SMALL => {
310                    let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
311                    buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
312                    buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
313                }
314                ZX_ERR_SHOULD_WAIT => {
315                    if matches!(receiver.shared.channel.need_readable(cx)?, Poll::Pending) {
316                        receiver.shared.closed_waker.register(cx.waker());
317                        if receiver.shared.is_closed.load(Ordering::Relaxed) {
318                            return Poll::Ready(Ok(None));
319                        }
320                        return Poll::Pending;
321                    }
322                }
323                raw => return Poll::Ready(Err(Status::from_raw(raw))),
324            }
325        }
326    }
327}
328
329impl NonBlockingTransport for Channel {
330    fn send_immediately(
331        future_state: &mut Self::SendFutureState,
332        sender: &Self::Sender,
333    ) -> Result<(), Self::Error> {
334        let result = unsafe {
335            zx_channel_write(
336                sender.shared.channel.get_ref().raw_handle(),
337                0,
338                future_state.buffer.chunks.as_ptr().cast::<u8>(),
339                (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
340                future_state.buffer.handles.as_ptr().cast(),
341                future_state.buffer.handles.len() as u32,
342            )
343        };
344
345        if result == ZX_OK {
346            // Handles were written to the channel, so we must not drop them.
347            unsafe {
348                future_state.buffer.handles.set_len(0);
349            }
350            Ok(())
351        } else {
352            Err(Status::from_raw(result))
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use core::mem::MaybeUninit;
360
361    use fidl_next_codec::fuchsia::{HandleDecoder, HandleEncoder, WireHandle};
362    use fidl_next_codec::{
363        munge, Decode, DecodeError, DecoderExt as _, Encodable, Encode, EncodeError,
364        EncoderExt as _, Slot, WireString, ZeroPadding,
365    };
366    use fuchsia_async as fasync;
367    use zx::{AsHandleRef, Channel, Handle, HandleBased as _, Instant, Signals, WaitResult};
368
369    use crate::fuchsia::channel::{Buffer, RecvBuffer};
370    use crate::testing::{
371        test_close_on_drop, test_event, test_multiple_two_way, test_one_way, test_two_way,
372    };
373    use crate::{Client, Responder, Server, ServerHandler, ServerSender, Transport};
374
375    #[fasync::run_singlethreaded(test)]
376    async fn close_on_drop() {
377        let (client_end, server_end) = Channel::create();
378        test_close_on_drop(client_end, server_end).await;
379    }
380
381    #[fasync::run_singlethreaded(test)]
382    async fn one_way() {
383        let (client_end, server_end) = Channel::create();
384        test_one_way(client_end, server_end).await;
385    }
386
387    #[fasync::run_singlethreaded(test)]
388    async fn two_way() {
389        let (client_end, server_end) = Channel::create();
390        test_two_way(client_end, server_end).await;
391    }
392
393    #[fasync::run_singlethreaded(test)]
394    async fn multiple_two_way() {
395        let (client_end, server_end) = Channel::create();
396        test_multiple_two_way(client_end, server_end).await;
397    }
398
399    #[fasync::run_singlethreaded(test)]
400    async fn event() {
401        let (client_end, server_end) = Channel::create();
402        test_event(client_end, server_end).await;
403    }
404
405    struct HandleAndBoolean {
406        handle: Handle,
407        boolean: bool,
408    }
409
410    #[derive(Debug)]
411    #[repr(C)]
412    struct WireHandleAndBoolean {
413        handle: WireHandle,
414        boolean: bool,
415    }
416
417    unsafe impl ZeroPadding for WireHandleAndBoolean {
418        fn zero_padding(out: &mut MaybeUninit<Self>) {
419            unsafe {
420                out.as_mut_ptr().write_bytes(0, 1);
421            }
422        }
423    }
424
425    impl Encodable for HandleAndBoolean {
426        type Encoded = WireHandleAndBoolean;
427    }
428
429    unsafe impl<E: HandleEncoder + ?Sized> Encode<E> for HandleAndBoolean {
430        fn encode(
431            self,
432            encoder: &mut E,
433            out: &mut MaybeUninit<Self::Encoded>,
434        ) -> Result<(), EncodeError> {
435            munge!(let Self::Encoded { handle, boolean } = out);
436            self.handle.encode(encoder, handle)?;
437            self.boolean.encode(encoder, boolean)?;
438            Ok(())
439        }
440    }
441
442    unsafe impl<D: HandleDecoder + ?Sized> Decode<D> for WireHandleAndBoolean {
443        fn decode(slot: Slot<'_, Self>, decoder: &mut D) -> Result<(), DecodeError> {
444            munge!(let Self { handle, boolean } = slot);
445            Decode::decode(handle, decoder)?;
446            Decode::decode(boolean, decoder)?;
447            Ok(())
448        }
449    }
450
451    #[test]
452    fn partial_decode_drops_handles() {
453        let (encode_end, check_end) = Channel::create();
454
455        let mut buffer = Buffer::new();
456        buffer
457            .encode_next(HandleAndBoolean { handle: encode_end.into_handle(), boolean: false })
458            .expect("encoding should succeed");
459        // Modify the buffer so that the boolean value is invalid
460        *buffer.chunks[0] |= 0x00000002_00000000;
461
462        let mut recv_buffer = RecvBuffer { buffer, chunks_taken: 0, handles_taken: 0 };
463        (&mut recv_buffer)
464            .decode_prefix::<WireHandleAndBoolean>()
465            .expect_err("decoding an invalid boolean should fail");
466
467        // Decoding failed, so the handle should still be in the buffer.
468        assert_eq!(
469            check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
470            WaitResult::TimedOut(Signals::CHANNEL_WRITABLE),
471        );
472
473        drop(recv_buffer);
474
475        // The handle should have been dropped with the buffer.
476        assert_eq!(
477            check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
478            WaitResult::Ok(Signals::CHANNEL_PEER_CLOSED),
479        );
480    }
481
482    #[test]
483    fn complete_decode_moves_handles() {
484        let (encode_end, check_end) = Channel::create();
485
486        let mut buffer = Buffer::new();
487        buffer
488            .encode_next(HandleAndBoolean { handle: encode_end.into_handle(), boolean: false })
489            .expect("encoding should succeed");
490
491        let recv_buffer = RecvBuffer { buffer, chunks_taken: 0, handles_taken: 0 };
492        let decoded =
493            recv_buffer.decode::<WireHandleAndBoolean>().expect("decoding should succeed");
494
495        // The handle should remain un-signaled after successful decoding.
496        assert_eq!(
497            check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
498            WaitResult::TimedOut(Signals::CHANNEL_WRITABLE),
499        );
500
501        drop(decoded.handle.take());
502
503        // Now the handle should be signaled.
504        assert_eq!(
505            check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
506            WaitResult::Ok(Signals::CHANNEL_PEER_CLOSED),
507        );
508
509        drop(decoded);
510    }
511
512    #[fasync::run_singlethreaded(test)]
513    async fn one_way_nonblocking() {
514        let (client_end, server_end) = Channel::create();
515        struct TestServer;
516
517        impl<T: Transport> ServerHandler<T> for TestServer {
518            fn on_one_way(&mut self, _: &ServerSender<T>, ordinal: u64, buffer: T::RecvBuffer) {
519                assert_eq!(ordinal, 42);
520                let message = buffer.decode::<WireString>().expect("failed to decode request");
521                assert_eq!(&**message, "Hello world");
522            }
523
524            fn on_two_way(&mut self, _: &ServerSender<T>, _: u64, _: T::RecvBuffer, _: Responder) {
525                panic!("unexpected two-way message");
526            }
527        }
528
529        let mut client = Client::new(client_end);
530        let client_sender = client.sender().clone();
531        let client_task = fasync::Task::spawn(async move { client.run_sender().await });
532        let mut server = Server::new(server_end);
533        let server_task = fasync::Task::spawn(async move { server.run(TestServer).await });
534
535        client_sender
536            .send_one_way(42, "Hello world")
537            .expect("client failed to encode request")
538            .send_immediately()
539            .expect("client failed to send request");
540        client_sender.close();
541        drop(client_sender);
542
543        client_task.await.expect("client encountered an error");
544        server_task.await.expect("server encountered an error");
545    }
546}