fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Constrained, Encode, EncodeError, EncoderExt};
12use pin_project::{pin_project, pinned_drop};
13
14use crate::concurrency::sync::{Arc, Mutex};
15use crate::endpoints::connection::{Connection, ORDINAL_EPITAPH};
16use crate::endpoints::lockers::{LockerError, Lockers};
17use crate::{ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header};
18
19struct ClientInner<T: Transport> {
20    connection: Connection<T>,
21    responses: Mutex<Lockers<T::RecvBuffer>>,
22}
23
24impl<T: Transport> ClientInner<T> {
25    fn new(shared: T::Shared) -> Self {
26        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
27    }
28}
29
30/// A client endpoint.
31pub struct Client<T: Transport> {
32    inner: Arc<ClientInner<T>>,
33}
34
35impl<T: Transport> Drop for Client<T> {
36    fn drop(&mut self) {
37        if Arc::strong_count(&self.inner) == 2 {
38            // This was the last reference to the connection other than the one
39            // in the dispatcher itself. Stop the connection.
40            self.close();
41        }
42    }
43}
44
45impl<T: Transport> Client<T> {
46    /// Closes the channel from the client end.
47    pub fn close(&self) {
48        self.inner.connection.stop();
49    }
50
51    /// Send a request.
52    pub fn send_one_way<M>(
53        &self,
54        ordinal: u64,
55        request: M,
56    ) -> Result<SendFuture<'_, T>, EncodeError>
57    where
58        M: Encode<T::SendBuffer>,
59        M::Encoded: Constrained<Constraint = ()>,
60    {
61        self.send_message(0, ordinal, request)
62    }
63
64    /// Send a request and await for a response.
65    pub fn send_two_way<M>(
66        &self,
67        ordinal: u64,
68        request: M,
69    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
70    where
71        M: Encode<T::SendBuffer>,
72        M::Encoded: Constrained<Constraint = ()>,
73    {
74        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
75
76        // Send with txid = index + 1 because indices start at 0.
77        match self.send_message(index + 1, ordinal, request) {
78            Ok(send_future) => {
79                Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
80            }
81            Err(e) => {
82                self.inner.responses.lock().unwrap().free(index);
83                Err(e)
84            }
85        }
86    }
87
88    fn send_message<M>(
89        &self,
90        txid: u32,
91        ordinal: u64,
92        message: M,
93    ) -> Result<SendFuture<'_, T>, EncodeError>
94    where
95        M: Encode<T::SendBuffer>,
96        M::Encoded: Constrained<Constraint = ()>,
97    {
98        self.inner.connection.send_message(|buffer| {
99            encode_header::<T>(buffer, txid, ordinal)?;
100            buffer.encode_next(message, ())
101        })
102    }
103}
104
105impl<T: Transport> Clone for Client<T> {
106    fn clone(&self) -> Self {
107        Self { inner: self.inner.clone() }
108    }
109}
110
111/// A future for a pending response to a two-way message.
112pub struct TwoWayResponseFuture<'a, T: Transport> {
113    inner: &'a ClientInner<T>,
114    index: Option<u32>,
115}
116
117impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
118    fn drop(&mut self) {
119        // If `index` is `Some`, then we still need to free our locker.
120        if let Some(index) = self.index {
121            let mut responses = self.inner.responses.lock().unwrap();
122            if responses.get(index).unwrap().cancel() {
123                responses.free(index);
124            }
125        }
126    }
127}
128
129impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
130    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
131
132    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133        let this = Pin::into_inner(self);
134        let Some(index) = this.index else {
135            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
136        };
137
138        let mut responses = this.inner.responses.lock().unwrap();
139        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
140            Ok(ready)
141        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
142            Err(termination_reason)
143        } else {
144            return Poll::Pending;
145        };
146
147        responses.free(index);
148        this.index = None;
149        Poll::Ready(ready)
150    }
151}
152
153/// A future for a sending a two-way FIDL message.
154#[pin_project(PinnedDrop)]
155pub struct TwoWayRequestFuture<'a, T: Transport> {
156    inner: &'a ClientInner<T>,
157    index: Option<u32>,
158    #[pin]
159    send_future: SendFuture<'a, T>,
160}
161
162#[pinned_drop]
163impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
164    fn drop(self: Pin<&mut Self>) {
165        if let Some(index) = self.index {
166            let mut responses = self.inner.responses.lock().unwrap();
167
168            // The future was canceled before it could be sent. The transaction
169            // ID was never used, so it's safe to immediately reuse.
170            responses.free(index);
171        }
172    }
173}
174
175impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
176    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
177
178    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        let this = self.project();
180
181        let Some(index) = *this.index else {
182            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
183        };
184
185        let result = ready!(this.send_future.poll(cx));
186        *this.index = None;
187        if let Err(error) = result {
188            // The send failed. Free the locker and return an error.
189            this.inner.responses.lock().unwrap().free(index);
190            Poll::Ready(Err(error))
191        } else {
192            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
193        }
194    }
195}
196
197/// A type which handles incoming events for a client.
198pub trait ClientHandler<T: Transport> {
199    /// Handles a received client event, returning the appropriate flow control
200    /// to perform.
201    ///
202    /// The client cannot handle more messages until `on_event` completes. If
203    /// `on_event` should handle requests in parallel, it should spawn a new
204    /// async task and return.
205    fn on_event(
206        &mut self,
207        ordinal: u64,
208        buffer: T::RecvBuffer,
209    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
210}
211
212/// A dispatcher for a client endpoint.
213///
214/// A client dispatcher receives all of the incoming messages and dispatches them to the client
215/// handler and two-way futures. It acts as the message pump for the client.
216///
217/// The dispatcher must be actively polled to receive events and two-way message responses. If the
218/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
219/// and two-way message futures will not receive their responses.
220pub struct ClientDispatcher<T: Transport> {
221    inner: Arc<ClientInner<T>>,
222    exclusive: T::Exclusive,
223    is_terminated: bool,
224}
225
226impl<T: Transport> Drop for ClientDispatcher<T> {
227    fn drop(&mut self) {
228        if !self.is_terminated {
229            // SAFETY: We checked that the connection has not been terminated.
230            unsafe {
231                self.terminate(ProtocolError::Stopped);
232            }
233        }
234    }
235}
236
237impl<T: Transport> ClientDispatcher<T> {
238    /// Creates a new client from a transport.
239    pub fn new(transport: T) -> Self {
240        let (shared, exclusive) = transport.split();
241        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
242    }
243
244    /// # Safety
245    ///
246    /// The connection must not yet be terminated.
247    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
248        // SAFETY: We checked that the connection has not been terminated.
249        unsafe {
250            self.inner.connection.terminate(error);
251        }
252        self.inner.responses.lock().unwrap().wake_all();
253    }
254
255    /// Returns a client for the dispatcher.
256    ///
257    /// When the last `Client` is dropped, the dispatcher will be stopped.
258    pub fn client(&self) -> Client<T> {
259        Client { inner: self.inner.clone() }
260    }
261
262    /// Runs the client with the provided handler.
263    pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
264    where
265        H: ClientHandler<T>,
266    {
267        // We may assume that the connection has not been terminated because
268        // connections are only terminated by `run` and `drop`. Neither of those
269        // could have been called before this method because `run` consumes
270        // `self` and `drop` is only ever called once.
271
272        let error = loop {
273            // SAFETY: The connection has not been terminated.
274            let result = unsafe { self.run_one(&mut handler).await };
275            if let Err(error) = result {
276                break error;
277            }
278        };
279
280        // SAFETY: The connection has not been terminated.
281        unsafe {
282            self.terminate(error.clone());
283        }
284        self.is_terminated = true;
285
286        match error {
287            // We consider clients to have finished successfully only if they
288            // stop themselves manually.
289            ProtocolError::Stopped => Ok(handler),
290
291            // Otherwise, the client finished with an error.
292            _ => Err(error),
293        }
294    }
295
296    /// # Safety
297    ///
298    /// The connection must not be terminated.
299    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
300    where
301        H: ClientHandler<T>,
302    {
303        // SAFETY: The caller guaranteed that the connection is not terminated.
304        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
305
306        let (txid, ordinal) =
307            decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
308
309        if ordinal == ORDINAL_EPITAPH {
310            let epitaph =
311                decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
312            return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
313        } else if txid == 0 {
314            handler.on_event(ordinal, buffer).await?;
315        } else {
316            let mut responses = self.inner.responses.lock().unwrap();
317            let locker = responses
318                .get(txid - 1)
319                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
320
321            match locker.write(ordinal, buffer) {
322                // Reader didn't cancel
323                Ok(false) => (),
324                // Reader canceled, we can drop the entry
325                Ok(true) => responses.free(txid - 1),
326                Err(LockerError::NotWriteable) => {
327                    return Err(ProtocolError::UnrequestedResponse { txid });
328                }
329                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
330                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
331                }
332            }
333        }
334
335        Ok(())
336    }
337
338    /// Runs the client with the [`IgnoreEvents`] handler.
339    pub async fn run_client(self) -> Result<(), ProtocolError<T::Error>> {
340        self.run(IgnoreEvents).await.map(|_| ())
341    }
342}
343
344/// A client handler which ignores any incoming events.
345pub struct IgnoreEvents;
346
347impl<T: Transport> ClientHandler<T> for IgnoreEvents {
348    async fn on_event(&mut self, _: u64, _: T::RecvBuffer) -> Result<(), ProtocolError<T::Error>> {
349        Ok(())
350    }
351}