fidl_next_protocol/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::sync::atomic::{AtomicBool, Ordering};
10use core::task::{Context, Poll};
11use std::sync::{Arc, Mutex};
12
13use fidl_next_codec::{Encode, EncodeError, EncoderExt};
14
15use crate::lockers::Lockers;
16use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
17
18use super::lockers::LockerError;
19
20struct Shared<T: Transport> {
21    is_closed: AtomicBool,
22    responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> Shared<T> {
26    fn new() -> Self {
27        Self { is_closed: AtomicBool::new(false), responses: Mutex::new(Lockers::new()) }
28    }
29}
30
31/// A sender for a client endpoint.
32pub struct ClientSender<T: Transport> {
33    shared: Arc<Shared<T>>,
34    sender: T::Sender,
35}
36
37impl<T: Transport> ClientSender<T> {
38    /// Closes the channel from the client end.
39    pub fn close(&self) {
40        T::close(&self.sender);
41    }
42
43    /// Send a request.
44    pub fn send_one_way<M>(
45        &self,
46        ordinal: u64,
47        request: M,
48    ) -> Result<SendFuture<'_, T>, EncodeError>
49    where
50        M: Encode<T::SendBuffer>,
51    {
52        self.send_message(0, ordinal, request)
53    }
54
55    /// Send a request and await for a response.
56    pub fn send_two_way<M>(
57        &self,
58        ordinal: u64,
59        request: M,
60    ) -> Result<ResponseFuture<'_, T>, EncodeError>
61    where
62        M: Encode<T::SendBuffer>,
63    {
64        let index = self.shared.responses.lock().unwrap().alloc(ordinal);
65
66        // Send with txid = index + 1 because indices start at 0.
67        match self.send_message(index + 1, ordinal, request) {
68            Ok(future) => Ok(ResponseFuture {
69                shared: &self.shared,
70                index,
71                state: ResponseFutureState::Sending(future),
72            }),
73            Err(e) => {
74                self.shared.responses.lock().unwrap().free(index);
75                Err(e)
76            }
77        }
78    }
79
80    fn send_message<M>(
81        &self,
82        txid: u32,
83        ordinal: u64,
84        message: M,
85    ) -> Result<SendFuture<'_, T>, EncodeError>
86    where
87        M: Encode<T::SendBuffer>,
88    {
89        let mut buffer = T::acquire(&self.sender);
90        encode_header::<T>(&mut buffer, txid, ordinal)?;
91        buffer.encode_next(message)?;
92        Ok(T::send(&self.sender, buffer))
93    }
94}
95
96impl<T: Transport> Clone for ClientSender<T> {
97    fn clone(&self) -> Self {
98        Self { shared: self.shared.clone(), sender: self.sender.clone() }
99    }
100}
101
102enum ResponseFutureState<'a, T: Transport> {
103    Sending(SendFuture<'a, T>),
104    Receiving,
105    // We store the completion state locally so that we can free the locker during poll, instead of
106    // waiting until the future is dropped.
107    Completed,
108}
109
110/// A future for a request pending a response.
111pub struct ResponseFuture<'a, T: Transport> {
112    shared: &'a Shared<T>,
113    index: u32,
114    state: ResponseFutureState<'a, T>,
115}
116
117impl<T: Transport> Drop for ResponseFuture<'_, T> {
118    fn drop(&mut self) {
119        let mut responses = self.shared.responses.lock().unwrap();
120        match self.state {
121            // SAFETY: The future was canceled before it could be sent. The transaction ID was never
122            // used, so it's safe to immediately reuse.
123            ResponseFutureState::Sending(_) => responses.free(self.index),
124            ResponseFutureState::Receiving => {
125                if responses.get(self.index).unwrap().cancel() {
126                    responses.free(self.index);
127                }
128            }
129            // We already freed the slot when we completed.
130            ResponseFutureState::Completed => (),
131        }
132    }
133}
134
135impl<T: Transport> ResponseFuture<'_, T> {
136    fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
137        if self.shared.is_closed.load(Ordering::Relaxed) {
138            self.state = ResponseFutureState::Completed;
139            return Poll::Ready(Err(None));
140        }
141
142        let mut responses = self.shared.responses.lock().unwrap();
143        if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
144            responses.free(self.index);
145            self.state = ResponseFutureState::Completed;
146            Poll::Ready(Ok(ready))
147        } else {
148            Poll::Pending
149        }
150    }
151}
152
153impl<T: Transport> Future for ResponseFuture<'_, T> {
154    type Output = Result<T::RecvBuffer, Option<T::Error>>;
155
156    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157        // SAFETY: We treat the state as pinned as long as it is sending.
158        let this = unsafe { Pin::into_inner_unchecked(self) };
159
160        match &mut this.state {
161            ResponseFutureState::Sending(future) => {
162                // SAFETY: Because the state is sending, we always treat its future as pinned.
163                let pinned = unsafe { Pin::new_unchecked(future) };
164                match pinned.poll(cx) {
165                    // The send has not completed yet. Leave the state as sending.
166                    Poll::Pending => Poll::Pending,
167                    Poll::Ready(Ok(())) => {
168                        // The send completed successfully. Change the state to receiving and poll
169                        // for receiving.
170                        this.state = ResponseFutureState::Receiving;
171                        this.poll_receiving(cx)
172                    }
173                    Poll::Ready(Err(e)) => {
174                        // The send completed unsuccessfully. We can safely free the cell and set
175                        // our state to completed.
176
177                        this.shared.responses.lock().unwrap().free(this.index);
178                        this.state = ResponseFutureState::Completed;
179                        Poll::Ready(Err(Some(e)))
180                    }
181                }
182            }
183            ResponseFutureState::Receiving => this.poll_receiving(cx),
184            // We could reach here if this future is polled after completion, but that's not
185            // supposed to happen.
186            ResponseFutureState::Completed => unreachable!(),
187        }
188    }
189}
190
191/// A type which handles incoming events for a client.
192pub trait ClientHandler<T: Transport> {
193    /// Handles a received client event.
194    ///
195    /// The client cannot handle more messages until `on_event` completes. If `on_event` should
196    /// handle requests in parallel, it should spawn a new async task and return.
197    fn on_event(
198        &mut self,
199        sender: &ClientSender<T>,
200        ordinal: u64,
201        buffer: T::RecvBuffer,
202    ) -> impl Future<Output = ()> + Send;
203}
204
205/// A client for an endpoint.
206///
207/// It must be actively polled to receive events and two-way message responses.
208pub struct Client<T: Transport> {
209    sender: ClientSender<T>,
210    receiver: T::Receiver,
211}
212
213impl<T: Transport> Client<T> {
214    /// Creates a new client from a transport.
215    pub fn new(transport: T) -> Self {
216        let (sender, receiver) = transport.split();
217        let shared = Arc::new(Shared::new());
218        Self { sender: ClientSender { shared, sender }, receiver }
219    }
220
221    /// Returns the sender for the client.
222    pub fn sender(&self) -> &ClientSender<T> {
223        &self.sender
224    }
225
226    /// Runs the client with the provided handler.
227    pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
228    where
229        H: ClientHandler<T>,
230    {
231        let result = self.run_to_completion(&mut handler).await;
232        self.sender.shared.is_closed.store(true, Ordering::Relaxed);
233        self.sender.shared.responses.lock().unwrap().wake_all();
234
235        result
236    }
237
238    /// Runs the client with the [`IgnoreEvents`] handler.
239    pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
240        self.run(IgnoreEvents).await
241    }
242
243    async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
244    where
245        H: ClientHandler<T>,
246    {
247        while let Some(mut buffer) =
248            T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
249        {
250            let (txid, ordinal) =
251                decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
252            if txid == 0 {
253                handler.on_event(&self.sender, ordinal, buffer).await;
254            } else {
255                let mut responses = self.sender.shared.responses.lock().unwrap();
256                let locker = responses
257                    .get(txid - 1)
258                    .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
259
260                match locker.write(ordinal, buffer) {
261                    // Reader didn't cancel
262                    Ok(false) => (),
263                    // Reader canceled, we can drop the entry
264                    Ok(true) => responses.free(txid - 1),
265                    Err(LockerError::NotWriteable) => {
266                        return Err(ProtocolError::UnrequestedResponse(txid));
267                    }
268                    Err(LockerError::MismatchedOrdinal { expected, actual }) => {
269                        return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
270                    }
271                }
272            }
273        }
274
275        self.sender.close();
276
277        Ok(())
278    }
279}
280
281/// A client handler which ignores any incoming events.
282pub struct IgnoreEvents;
283
284impl<T: Transport> ClientHandler<T> for IgnoreEvents {
285    async fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
286}