fidl_next_protocol/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::sync::atomic::{AtomicBool, Ordering};
10use core::task::{Context, Poll};
11use std::sync::{Arc, Mutex};
12
13use fidl_next_codec::{Encode, EncodeError, EncoderExt};
14
15use crate::lockers::Lockers;
16use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
17
18use super::lockers::LockerError;
19
20struct Shared<T: Transport> {
21    is_closed: AtomicBool,
22    responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> Shared<T> {
26    fn new() -> Self {
27        Self { is_closed: AtomicBool::new(false), responses: Mutex::new(Lockers::new()) }
28    }
29}
30
31/// A sender for a client endpoint.
32pub struct ClientSender<T: Transport> {
33    shared: Arc<Shared<T>>,
34    sender: T::Sender,
35}
36
37impl<T: Transport> ClientSender<T> {
38    /// Closes the channel from the client end.
39    pub fn close(&self) {
40        T::close(&self.sender);
41    }
42
43    /// Send a request.
44    pub fn send_one_way<M>(
45        &self,
46        ordinal: u64,
47        request: M,
48    ) -> Result<SendFuture<'_, T>, EncodeError>
49    where
50        M: Encode<T::SendBuffer>,
51    {
52        self.send_message(0, ordinal, request)
53    }
54
55    /// Send a request and await for a response.
56    pub fn send_two_way<M>(
57        &self,
58        ordinal: u64,
59        request: M,
60    ) -> Result<ResponseFuture<'_, T>, EncodeError>
61    where
62        M: Encode<T::SendBuffer>,
63    {
64        let index = self.shared.responses.lock().unwrap().alloc(ordinal);
65
66        // Send with txid = index + 1 because indices start at 0.
67        match self.send_message(index + 1, ordinal, request) {
68            Ok(future) => Ok(ResponseFuture {
69                shared: &self.shared,
70                index,
71                state: ResponseFutureState::Sending(future),
72            }),
73            Err(e) => {
74                self.shared.responses.lock().unwrap().free(index);
75                Err(e)
76            }
77        }
78    }
79
80    fn send_message<M>(
81        &self,
82        txid: u32,
83        ordinal: u64,
84        message: M,
85    ) -> Result<SendFuture<'_, T>, EncodeError>
86    where
87        M: Encode<T::SendBuffer>,
88    {
89        let mut buffer = T::acquire(&self.sender);
90        encode_header::<T>(&mut buffer, txid, ordinal)?;
91        buffer.encode_next(message)?;
92        Ok(T::send(&self.sender, buffer))
93    }
94}
95
96impl<T: Transport> Clone for ClientSender<T> {
97    fn clone(&self) -> Self {
98        Self { shared: self.shared.clone(), sender: self.sender.clone() }
99    }
100}
101
102enum ResponseFutureState<'a, T: Transport> {
103    Sending(SendFuture<'a, T>),
104    Receiving,
105    // We store the completion state locally so that we can free the locker during poll, instead of
106    // waiting until the future is dropped.
107    Completed,
108}
109
110/// A future for a request pending a response.
111pub struct ResponseFuture<'a, T: Transport> {
112    shared: &'a Shared<T>,
113    index: u32,
114    state: ResponseFutureState<'a, T>,
115}
116
117impl<T: Transport> Drop for ResponseFuture<'_, T> {
118    fn drop(&mut self) {
119        let mut responses = self.shared.responses.lock().unwrap();
120        match self.state {
121            // SAFETY: The future was canceled before it could be sent. The transaction ID was never
122            // used, so it's safe to immediately reuse.
123            ResponseFutureState::Sending(_) => responses.free(self.index),
124            ResponseFutureState::Receiving => {
125                if responses.get(self.index).unwrap().cancel() {
126                    responses.free(self.index);
127                }
128            }
129            // We already freed the slot when we completed.
130            ResponseFutureState::Completed => (),
131        }
132    }
133}
134
135impl<T: Transport> ResponseFuture<'_, T> {
136    fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
137        if self.shared.is_closed.load(Ordering::Relaxed) {
138            self.state = ResponseFutureState::Completed;
139            return Poll::Ready(Err(None));
140        }
141
142        let mut responses = self.shared.responses.lock().unwrap();
143        if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
144            responses.free(self.index);
145            self.state = ResponseFutureState::Completed;
146            Poll::Ready(Ok(ready))
147        } else {
148            Poll::Pending
149        }
150    }
151}
152
153impl<T: Transport> Future for ResponseFuture<'_, T> {
154    type Output = Result<T::RecvBuffer, Option<T::Error>>;
155
156    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157        // SAFETY: We treat the state as pinned as long as it is sending.
158        let this = unsafe { Pin::into_inner_unchecked(self) };
159
160        match &mut this.state {
161            ResponseFutureState::Sending(future) => {
162                // SAFETY: Because the state is sending, we always treat its future as pinned.
163                let pinned = unsafe { Pin::new_unchecked(future) };
164                match pinned.poll(cx) {
165                    // The send has not completed yet. Leave the state as sending.
166                    Poll::Pending => Poll::Pending,
167                    Poll::Ready(Ok(())) => {
168                        // The send completed successfully. Change the state to receiving and poll
169                        // for receiving.
170                        this.state = ResponseFutureState::Receiving;
171                        this.poll_receiving(cx)
172                    }
173                    Poll::Ready(Err(e)) => {
174                        // The send completed unsuccessfully. We can safely free the cell and set
175                        // our state to completed.
176
177                        this.shared.responses.lock().unwrap().free(this.index);
178                        this.state = ResponseFutureState::Completed;
179                        Poll::Ready(Err(Some(e)))
180                    }
181                }
182            }
183            ResponseFutureState::Receiving => this.poll_receiving(cx),
184            // We could reach here if this future is polled after completion, but that's not
185            // supposed to happen.
186            ResponseFutureState::Completed => unreachable!(),
187        }
188    }
189}
190
191/// A type which handles incoming events for a client.
192pub trait ClientHandler<T: Transport> {
193    /// Handles a received client event.
194    ///
195    /// The client cannot handle more messages until `on_event` completes. If `on_event` may block,
196    /// perform asynchronous work, or take a long time to process a message, it should offload work
197    /// to an async task.
198    fn on_event(&mut self, sender: &ClientSender<T>, ordinal: u64, buffer: T::RecvBuffer);
199}
200
201/// A client for an endpoint.
202///
203/// It must be actively polled to receive events and two-way message responses.
204pub struct Client<T: Transport> {
205    sender: ClientSender<T>,
206    receiver: T::Receiver,
207}
208
209impl<T: Transport> Client<T> {
210    /// Creates a new client from a transport.
211    pub fn new(transport: T) -> Self {
212        let (sender, receiver) = transport.split();
213        let shared = Arc::new(Shared::new());
214        Self { sender: ClientSender { shared, sender }, receiver }
215    }
216
217    /// Returns the sender for the client.
218    pub fn sender(&self) -> &ClientSender<T> {
219        &self.sender
220    }
221
222    /// Runs the client with the provided handler.
223    pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
224    where
225        H: ClientHandler<T>,
226    {
227        let result = self.run_to_completion(&mut handler).await;
228        self.sender.shared.is_closed.store(true, Ordering::Relaxed);
229        self.sender.shared.responses.lock().unwrap().wake_all();
230
231        result
232    }
233
234    /// Runs the client with the [`IgnoreEvents`] handler.
235    pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
236        self.run(IgnoreEvents).await
237    }
238
239    async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
240    where
241        H: ClientHandler<T>,
242    {
243        while let Some(mut buffer) =
244            T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
245        {
246            let (txid, ordinal) =
247                decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
248            if txid == 0 {
249                handler.on_event(&self.sender, ordinal, buffer);
250            } else {
251                let mut responses = self.sender.shared.responses.lock().unwrap();
252                let locker = responses
253                    .get(txid - 1)
254                    .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
255
256                match locker.write(ordinal, buffer) {
257                    // Reader didn't cancel
258                    Ok(false) => (),
259                    // Reader canceled, we can drop the entry
260                    Ok(true) => responses.free(txid - 1),
261                    Err(LockerError::NotWriteable) => {
262                        return Err(ProtocolError::UnrequestedResponse(txid));
263                    }
264                    Err(LockerError::MismatchedOrdinal { expected, actual }) => {
265                        return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
266                    }
267                }
268            }
269        }
270
271        self.sender.close();
272
273        Ok(())
274    }
275}
276
277/// A client handler which ignores any incoming events.
278pub struct IgnoreEvents;
279
280impl<T: Transport> ClientHandler<T> for IgnoreEvents {
281    fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
282}