Skip to main content

fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::mem::ManuallyDrop;
9use core::pin::Pin;
10use core::ptr;
11use core::task::{Context, Poll, ready};
12
13use fidl_constants::EPITAPH_ORDINAL;
14use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
15use pin_project::{pin_project, pinned_drop};
16
17use crate::concurrency::sync::{Arc, Mutex};
18use crate::endpoints::connection::{Connection, SendFutureState};
19use crate::endpoints::lockers::{LockerError, Lockers};
20use crate::wire::{Epitaph, MessageHeader};
21use crate::{Body, Flexibility, NonBlockingTransport, ProtocolError, SendFuture, Transport};
22
23struct ClientInner<T: Transport> {
24    connection: Connection<T>,
25    responses: Mutex<Lockers<Body<T>>>,
26}
27
28impl<T: Transport> ClientInner<T> {
29    fn new(shared: T::Shared) -> Self {
30        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
31    }
32}
33
34/// A client endpoint.
35pub struct Client<T: Transport> {
36    inner: Arc<ClientInner<T>>,
37}
38
39impl<T: Transport> Drop for Client<T> {
40    fn drop(&mut self) {
41        if Arc::strong_count(&self.inner) == 2 {
42            // This was the last reference to the connection other than the one
43            // in the dispatcher itself. Stop the connection.
44            self.close();
45        }
46    }
47}
48
49impl<T: Transport> Client<T> {
50    /// Closes the channel from the client end.
51    pub fn close(&self) {
52        self.inner.connection.stop();
53    }
54
55    /// Send a request.
56    pub fn send_one_way<W>(
57        &self,
58        ordinal: u64,
59        flexibility: Flexibility,
60        request: impl Encode<W, T::SendBuffer>,
61    ) -> Result<SendFuture<'_, T>, EncodeError>
62    where
63        W: Wire<Constraint = ()>,
64    {
65        Ok(SendFuture::from_raw_parts(
66            &self.inner.connection,
67            self.send_message_raw(0, ordinal, flexibility, request)?,
68        ))
69    }
70
71    /// Send a request and await for a response.
72    pub fn send_two_way<W>(
73        &self,
74        ordinal: u64,
75        flexibility: Flexibility,
76        request: impl Encode<W, T::SendBuffer>,
77    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
78    where
79        W: Wire<Constraint = ()>,
80    {
81        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
82
83        // Send with txid = index + 1 because indices start at 0.
84        match self.send_message_raw(index + 1, ordinal, flexibility, request) {
85            Ok(state) => Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), state }),
86            Err(e) => {
87                self.inner.responses.lock().unwrap().free(index);
88                Err(e)
89            }
90        }
91    }
92
93    fn send_message_raw<W>(
94        &self,
95        txid: u32,
96        ordinal: u64,
97        flexibility: Flexibility,
98        message: impl Encode<W, T::SendBuffer>,
99    ) -> Result<SendFutureState<T>, EncodeError>
100    where
101        W: Wire<Constraint = ()>,
102    {
103        self.inner.connection.send_message_raw(|buffer| {
104            buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
105            buffer.encode_next(message)
106        })
107    }
108}
109
110impl<T: Transport> Clone for Client<T> {
111    fn clone(&self) -> Self {
112        Self { inner: self.inner.clone() }
113    }
114}
115
116/// A future for a pending response to a two-way message.
117pub struct TwoWayResponseFuture<'a, T: Transport> {
118    inner: &'a ClientInner<T>,
119    index: Option<u32>,
120}
121
122impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
123    fn drop(&mut self) {
124        // If `index` is `Some`, then we still need to free our locker.
125        if let Some(index) = self.index {
126            let mut responses = self.inner.responses.lock().unwrap();
127            if responses.get(index).unwrap().cancel() {
128                responses.free(index);
129            }
130        }
131    }
132}
133
134impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
135    type Output = Result<Body<T>, ProtocolError<T::Error>>;
136
137    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138        let this = Pin::into_inner(self);
139        let Some(index) = this.index else {
140            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
141        };
142
143        let mut responses = this.inner.responses.lock().unwrap();
144        let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
145            Ok(ready)
146        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
147            Err(termination_reason)
148        } else {
149            return Poll::Pending;
150        };
151
152        responses.free(index);
153        this.index = None;
154        Poll::Ready(ready)
155    }
156}
157
158/// A future for a sending a two-way FIDL message.
159#[pin_project(PinnedDrop)]
160pub struct TwoWayRequestFuture<'a, T: Transport> {
161    inner: &'a ClientInner<T>,
162    index: Option<u32>,
163    #[pin]
164    state: SendFutureState<T>,
165}
166
167#[pinned_drop]
168impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
169    fn drop(self: Pin<&mut Self>) {
170        if let Some(index) = self.index {
171            let mut responses = self.inner.responses.lock().unwrap();
172
173            // The future was canceled before it could be sent. The transaction
174            // ID was never used, so it's safe to immediately reuse.
175            responses.free(index);
176        }
177    }
178}
179
180impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
181    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
182
183    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184        let this = self.project();
185
186        let Some(index) = *this.index else {
187            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
188        };
189
190        let result = ready!(this.state.poll_send(cx, &this.inner.connection));
191        *this.index = None;
192        if let Err(error) = result {
193            // The send failed. Free the locker and return an error.
194            this.inner.responses.lock().unwrap().free(index);
195            Poll::Ready(Err(error))
196        } else {
197            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
198        }
199    }
200}
201
202impl<'a, T: NonBlockingTransport> TwoWayRequestFuture<'a, T> {
203    /// Completes the send operation synchronously and without blocking.
204    ///
205    /// Using this method prevents transports from applying backpressure. Prefer
206    /// awaiting when possible to allow for backpressure.
207    ///
208    /// Because failed sends return immediately, `send_immediately` may observe
209    /// transport closure prematurely. This can manifest as this method
210    /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
211    /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
212    /// correctness.
213    pub fn send_immediately(self) -> Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>> {
214        let inner = self.inner;
215        let index = self.index;
216        let state = unsafe { ptr::read(&ManuallyDrop::new(self).state) };
217        if let Err(e) = state.send_immediately(&inner.connection) {
218            inner.responses.lock().unwrap().free(index.unwrap());
219            return Err(e);
220        }
221
222        Ok(TwoWayResponseFuture { inner, index })
223    }
224}
225
226/// A type which handles incoming events for a local client.
227///
228/// This is a variant of [`ClientHandler`] that does not require implementing
229/// `Send` and only supports local-thread executors.
230pub trait LocalClientHandler<T: Transport> {
231    /// Handles a received client event.
232    ///
233    /// See [`ClientHandler::on_event`] for more information.
234    fn on_event(
235        &mut self,
236        ordinal: u64,
237        flexibility: Flexibility,
238        body: Body<T>,
239    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
240}
241
242/// A type which handles incoming events for a client.
243pub trait ClientHandler<T: Transport>: Send {
244    /// Handles a received client event.
245    ///
246    /// The client cannot handle more messages until `on_event` completes. If
247    /// `on_event` may block, or would perform asynchronous work that takes a
248    /// long time, it should offload work to an async task and return.
249    fn on_event(
250        &mut self,
251        ordinal: u64,
252        flexibility: Flexibility,
253        body: Body<T>,
254    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
255}
256
257/// An adapter for a [`ClientHandler`] which implements [`LocalClientHandler`].
258#[repr(transparent)]
259pub struct ClientHandlerToLocalAdapter<H>(H);
260
261impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
262where
263    T: Transport,
264    H: ClientHandler<T>,
265{
266    #[inline]
267    fn on_event(
268        &mut self,
269        ordinal: u64,
270        flexibility: Flexibility,
271        body: Body<T>,
272    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
273        self.0.on_event(ordinal, flexibility, body)
274    }
275}
276
277/// A dispatcher for a client endpoint.
278///
279/// A client dispatcher receives all of the incoming messages and dispatches them to the client
280/// handler and two-way futures. It acts as the message pump for the client.
281///
282/// The dispatcher must be actively polled to receive events and two-way message responses. If the
283/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
284/// and two-way message futures will not receive their responses.
285pub struct ClientDispatcher<T: Transport> {
286    inner: Arc<ClientInner<T>>,
287    exclusive: T::Exclusive,
288    is_terminated: bool,
289}
290
291impl<T: Transport> Drop for ClientDispatcher<T> {
292    fn drop(&mut self) {
293        if !self.is_terminated {
294            // SAFETY: We checked that the connection has not been terminated.
295            unsafe {
296                self.terminate(ProtocolError::Stopped);
297            }
298        }
299    }
300}
301
302impl<T: Transport> ClientDispatcher<T> {
303    /// Creates a new client from a transport.
304    pub fn new(transport: T) -> Self {
305        let (shared, exclusive) = transport.split();
306        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
307    }
308
309    /// # Safety
310    ///
311    /// The connection must not yet be terminated.
312    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
313        // SAFETY: We checked that the connection has not been terminated.
314        unsafe {
315            self.inner.connection.terminate(error);
316        }
317        self.inner.responses.lock().unwrap().wake_all();
318    }
319
320    /// Returns a client for the dispatcher.
321    ///
322    /// When the last `Client` is dropped, the dispatcher will be stopped.
323    pub fn client(&self) -> Client<T> {
324        Client { inner: self.inner.clone() }
325    }
326
327    /// Runs the client with the provided handler.
328    pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
329    where
330        H: ClientHandler<T>,
331    {
332        // The bounds on `H` prove that the future returned by `run_local` is
333        // `Send`.
334        self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
335    }
336
337    /// Runs the client with the provided handler.
338    pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
339    where
340        H: LocalClientHandler<T>,
341    {
342        // We may assume that the connection has not been terminated because
343        // connections are only terminated by `run` and `drop`. Neither of those
344        // could have been called before this method because `run` consumes
345        // `self` and `drop` is only ever called once.
346
347        let error = loop {
348            // SAFETY: The connection has not been terminated.
349            let result = unsafe { self.run_one(&mut handler).await };
350            if let Err(error) = result {
351                break error;
352            }
353        };
354
355        // SAFETY: The connection has not been terminated.
356        unsafe {
357            self.terminate(error.clone());
358        }
359        self.is_terminated = true;
360
361        match error {
362            // We consider clients to have finished successfully only if they
363            // stop themselves manually.
364            ProtocolError::Stopped => Ok(handler),
365
366            // Otherwise, the client finished with an error.
367            _ => Err(error),
368        }
369    }
370
371    /// # Safety
372    ///
373    /// The connection must not be terminated.
374    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
375    where
376        H: LocalClientHandler<T>,
377    {
378        // SAFETY: The caller guaranteed that the connection is not terminated.
379        let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
380
381        // This expression is really awkward due to a limitation in rustc's
382        // liveness analysis for local variables. We need to avoid holding
383        // `decoder` across `.await`s because it may not be `Send` and tasks may
384        // migrate threads between polls. We should be able to just
385        // `drop(decoder)` before any `.await`, but rustc is overly conservative
386        // and still considers `decoder` as live at the `.await` for that
387        // analysis. The only way to convince rustc that `decoder` is not live
388        // at that await point is to keep the lexical scope containing `decoder`
389        // free of `.await`s.
390        //
391        // See https://github.com/rust-lang/rust/issues/63768 for more details.
392        let header = {
393            let mut decoder = buffer.as_decoder();
394
395            let header = decoder
396                .decode_prefix::<MessageHeader>()
397                .map_err(ProtocolError::InvalidMessageHeader)?;
398
399            // Check if the ordinal is the epitaph so we can immediately decode
400            // and return it. We do this before dropping `decoder` so that we
401            // don't have to re-acquire it and wrap it in `Body`.
402            if header.ordinal == EPITAPH_ORDINAL {
403                let epitaph =
404                    decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
405                return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
406            }
407
408            header
409        };
410
411        if header.txid == 0 {
412            handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
413        } else {
414            let mut responses = self.inner.responses.lock().unwrap();
415            let locker = responses
416                .get(*header.txid - 1)
417                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
418
419            match locker.write(*header.ordinal, Body::new(buffer)) {
420                // Reader didn't cancel
421                Ok(false) => (),
422                // Reader canceled, we can drop the entry
423                Ok(true) => responses.free(*header.txid - 1),
424                Err(LockerError::NotWriteable) => {
425                    return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
426                }
427                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
428                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
429                }
430            }
431        }
432
433        Ok(())
434    }
435}