fidl_next_protocol/
mpsc.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! A basic [`Transport`] implementation based on MPSC channels.
6
7use core::fmt;
8use core::marker::PhantomData;
9use core::mem::take;
10use core::pin::Pin;
11use core::ptr::NonNull;
12use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use core::task::{Context, Poll};
14use std::sync::{mpsc, Arc};
15
16use fidl_next_codec::decoder::InternalHandleDecoder;
17use fidl_next_codec::{Chunk, DecodeError, Decoder, CHUNK_SIZE};
18use futures::task::AtomicWaker;
19
20use crate::Transport;
21
22struct SharedEnd {
23    sender_count: AtomicUsize,
24    send_waker: AtomicWaker,
25}
26
27struct Shared {
28    is_closed: AtomicBool,
29    ends: [SharedEnd; 2],
30}
31
32impl Shared {
33    fn close(&self) {
34        let was_closed = self.is_closed.swap(true, Ordering::Relaxed);
35        if !was_closed {
36            for end in &self.ends {
37                end.send_waker.wake();
38            }
39        }
40    }
41}
42
43/// A paired mpsc transport.
44pub struct Mpsc {
45    sender: Sender,
46    receiver: mpsc::Receiver<Vec<Chunk>>,
47}
48
49impl Mpsc {
50    /// Creates two mpscs which can communicate with each other.
51    pub fn new() -> (Self, Self) {
52        let shared = Arc::new(Shared {
53            is_closed: AtomicBool::new(false),
54            ends: [
55                SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
56                SharedEnd { sender_count: AtomicUsize::new(1), send_waker: AtomicWaker::new() },
57            ],
58        });
59        let (a_send, a_recv) = mpsc::channel();
60        let (b_send, b_recv) = mpsc::channel();
61        (
62            Mpsc {
63                sender: Sender { shared: shared.clone(), end: 0, sender: a_send },
64                receiver: b_recv,
65            },
66            Mpsc { sender: Sender { shared, end: 1, sender: b_send }, receiver: a_recv },
67        )
68    }
69}
70
71/// The error type for paired mpsc transports.
72#[derive(Debug)]
73pub enum Error {
74    /// The mpsc was closed.
75    Closed,
76}
77
78impl fmt::Display for Error {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::Closed => write!(f, "the mpsc was closed"),
82        }
83    }
84}
85
86impl core::error::Error for Error {}
87
88/// The send end of a paired mpsc transport.
89pub struct Sender {
90    shared: Arc<Shared>,
91    end: usize,
92    sender: mpsc::Sender<Vec<Chunk>>,
93}
94
95impl Clone for Sender {
96    fn clone(&self) -> Self {
97        self.shared.ends[self.end].sender_count.fetch_add(1, Ordering::Relaxed);
98        Self { shared: self.shared.clone(), end: self.end, sender: self.sender.clone() }
99    }
100}
101
102impl Drop for Sender {
103    fn drop(&mut self) {
104        let senders = self.shared.ends[self.end].sender_count.fetch_sub(1, Ordering::Relaxed);
105        if senders == 1 {
106            self.shared.close();
107        }
108    }
109}
110
111/// The send future for a paired mpsc transport.
112pub struct SendFutureState {
113    buffer: Vec<Chunk>,
114}
115
116/// The receive end of a paired mpsc transport.
117pub struct Receiver {
118    shared: Arc<Shared>,
119    end: usize,
120    receiver: mpsc::Receiver<Vec<Chunk>>,
121}
122
123/// The receive future for a paired mpsc transport.
124pub struct RecvFutureState {
125    _phantom: PhantomData<()>,
126}
127
128/// A received message buffer.
129pub struct RecvBuffer {
130    chunks: Vec<Chunk>,
131    chunks_taken: usize,
132}
133
134impl InternalHandleDecoder for RecvBuffer {
135    fn __internal_take_handles(&mut self, _: usize) -> Result<(), DecodeError> {
136        Err(DecodeError::InsufficientHandles)
137    }
138
139    fn __internal_handles_remaining(&self) -> usize {
140        0
141    }
142}
143
144unsafe impl Decoder for RecvBuffer {
145    fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
146        if count > self.chunks.len() - self.chunks_taken {
147            return Err(DecodeError::InsufficientData);
148        }
149
150        let chunks = unsafe { self.chunks.as_mut_ptr().add(self.chunks_taken) };
151        self.chunks_taken += count;
152
153        unsafe { Ok(NonNull::new_unchecked(chunks)) }
154    }
155
156    fn commit(&mut self) {
157        // No resources to take, so commit is a no-op
158    }
159
160    fn finish(&self) -> Result<(), DecodeError> {
161        if self.chunks_taken != self.chunks.len() {
162            return Err(DecodeError::ExtraBytes {
163                num_extra: (self.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
164            });
165        }
166
167        Ok(())
168    }
169}
170
171impl Transport for Mpsc {
172    type Error = Error;
173
174    fn split(self) -> (Self::Sender, Self::Receiver) {
175        let receiver = Receiver {
176            shared: self.sender.shared.clone(),
177            end: self.sender.end,
178            receiver: self.receiver,
179        };
180        (self.sender, receiver)
181    }
182
183    type Sender = Sender;
184    type SendBuffer = Vec<Chunk>;
185    type SendFutureState = SendFutureState;
186
187    fn acquire(_: &Self::Sender) -> Self::SendBuffer {
188        Vec::new()
189    }
190
191    fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
192        SendFutureState { buffer }
193    }
194
195    fn poll_send(
196        mut future_state: Pin<&mut SendFutureState>,
197        _: &mut Context<'_>,
198        sender: &Self::Sender,
199    ) -> Poll<Result<(), Error>> {
200        if sender.shared.is_closed.load(Ordering::Relaxed) {
201            return Poll::Ready(Err(Error::Closed));
202        }
203
204        let chunks = take(&mut future_state.buffer);
205        match sender.sender.send(chunks) {
206            Ok(()) => {
207                sender.shared.ends[sender.end].send_waker.wake();
208                Poll::Ready(Ok(()))
209            }
210            Err(_) => Poll::Ready(Err(Error::Closed)),
211        }
212    }
213
214    fn close(sender: &Self::Sender) {
215        sender.shared.close();
216    }
217
218    type Receiver = Receiver;
219    type RecvFutureState = RecvFutureState;
220    type RecvBuffer = RecvBuffer;
221
222    fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
223        RecvFutureState { _phantom: PhantomData }
224    }
225
226    fn poll_recv(
227        _: Pin<&mut Self::RecvFutureState>,
228        cx: &mut Context<'_>,
229        receiver: &mut Self::Receiver,
230    ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
231        if receiver.shared.is_closed.load(Ordering::Relaxed) {
232            return Poll::Ready(Ok(None));
233        }
234
235        receiver.shared.ends[1 - receiver.end].send_waker.register(cx.waker());
236        match receiver.receiver.try_recv() {
237            Ok(chunks) => Poll::Ready(Ok(Some(RecvBuffer { chunks, chunks_taken: 0 }))),
238            Err(mpsc::TryRecvError::Empty) => Poll::Pending,
239            Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Ok(None)),
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use fuchsia_async as fasync;
247
248    use super::Mpsc;
249    use crate::testing::*;
250
251    #[fasync::run_singlethreaded(test)]
252    async fn close_on_drop() {
253        let (client_end, server_end) = Mpsc::new();
254        test_close_on_drop(client_end, server_end).await;
255    }
256
257    #[fasync::run_singlethreaded(test)]
258    async fn send_receive() {
259        let (client_end, server_end) = Mpsc::new();
260        test_one_way(client_end, server_end).await;
261    }
262
263    #[fasync::run_singlethreaded(test)]
264    async fn two_way() {
265        let (client_end, server_end) = Mpsc::new();
266        test_two_way(client_end, server_end).await;
267    }
268
269    #[fasync::run_singlethreaded(test)]
270    async fn multiple_two_way() {
271        let (client_end, server_end) = Mpsc::new();
272        test_multiple_two_way(client_end, server_end).await;
273    }
274
275    #[fasync::run_singlethreaded(test)]
276    async fn event() {
277        let (client_end, server_end) = Mpsc::new();
278        test_event(client_end, server_end).await;
279    }
280}