Skip to main content

fidl_next_protocol/endpoints/
client.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! FIDL protocol clients.
6
7use core::future::Future;
8use core::mem::ManuallyDrop;
9use core::pin::Pin;
10use core::ptr;
11use core::task::{Context, Poll, ready};
12
13use fidl_constants::EPITAPH_ORDINAL;
14use fidl_next_codec::{
15    AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire, wire,
16};
17use fuchsia_loom::sync::{Arc, Mutex};
18use pin_project::{pin_project, pinned_drop};
19
20use crate::endpoints::connection::{Connection, SendFutureState};
21use crate::endpoints::lockers::{LockerError, Lockers};
22use crate::wire::{Epitaph, MessageHeader};
23use crate::{
24    Flexibility, FrameworkError, Message, NonBlockingTransport, ProtocolError, SendFuture,
25    Transport,
26};
27
28const ORD_FRAMEWORK_ERR: u64 = 3;
29const ERR_UNKNOWN_METHOD: i32 = FrameworkError::UnknownMethod as i32;
30
31struct ClientInner<T: Transport> {
32    connection: Connection<T>,
33    responses: Mutex<Lockers<Message<T>>>,
34}
35
36impl<T: Transport> ClientInner<T> {
37    fn new(shared: T::Shared) -> Self {
38        Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
39    }
40}
41
42/// A client endpoint.
43pub struct Client<T: Transport> {
44    inner: Arc<ClientInner<T>>,
45}
46
47impl<T: Transport> Drop for Client<T> {
48    fn drop(&mut self) {
49        if Arc::strong_count(&self.inner) == 2 {
50            // This was the last reference to the connection other than the one
51            // in the dispatcher itself. Stop the connection.
52            self.close();
53        }
54    }
55}
56
57impl<T: Transport> Client<T> {
58    /// Closes the channel from the client end.
59    pub fn close(&self) {
60        self.inner.connection.stop();
61    }
62
63    /// Send a request.
64    pub fn send_one_way<W>(
65        &self,
66        ordinal: u64,
67        flexibility: Flexibility,
68        request: impl Encode<W, T::SendBuffer>,
69    ) -> Result<SendFuture<'_, T>, EncodeError>
70    where
71        W: Wire<Constraint = ()>,
72    {
73        Ok(SendFuture::from_raw_parts(
74            &self.inner.connection,
75            self.send_message_raw(0, ordinal, flexibility, request)?,
76        ))
77    }
78
79    /// Send a request and await for a response.
80    pub fn send_two_way<W>(
81        &self,
82        ordinal: u64,
83        flexibility: Flexibility,
84        request: impl Encode<W, T::SendBuffer>,
85    ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
86    where
87        W: Wire<Constraint = ()>,
88    {
89        let index = self.inner.responses.lock().unwrap().alloc(ordinal);
90
91        // Send with txid = index + 1 because indices start at 0.
92        match self.send_message_raw(index + 1, ordinal, flexibility, request) {
93            Ok(state) => Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), state }),
94            Err(e) => {
95                self.inner.responses.lock().unwrap().free(index);
96                Err(e)
97            }
98        }
99    }
100
101    fn send_message_raw<W>(
102        &self,
103        txid: u32,
104        ordinal: u64,
105        flexibility: Flexibility,
106        message: impl Encode<W, T::SendBuffer>,
107    ) -> Result<SendFutureState<T>, EncodeError>
108    where
109        W: Wire<Constraint = ()>,
110    {
111        self.inner.connection.send_message_raw(|buffer| {
112            buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
113            buffer.encode_next(message)
114        })
115    }
116}
117
118impl<T: Transport> Clone for Client<T> {
119    fn clone(&self) -> Self {
120        Self { inner: self.inner.clone() }
121    }
122}
123
124/// A future for a pending response to a two-way message.
125pub struct TwoWayResponseFuture<'a, T: Transport> {
126    inner: &'a ClientInner<T>,
127    index: Option<u32>,
128}
129
130impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
131    fn drop(&mut self) {
132        // If `index` is `Some`, then we still need to free our locker.
133        if let Some(index) = self.index {
134            let mut responses = self.inner.responses.lock().unwrap();
135            if responses.get(index).unwrap().cancel() {
136                responses.free(index);
137            }
138        }
139    }
140}
141
142fn handle_flexible_response<T: Transport>(
143    mut message: Message<T>,
144) -> Result<Message<T>, ProtocolError<T::Error>> {
145    if matches!(message.header().flexibility(), Flexibility::Flexible) {
146        let mut decoder = message.as_decoder();
147        let mut union = decoder
148            .take_slot::<wire::Union>()
149            .map_err(|_| ProtocolError::InvalidFlexibleResponse)?;
150        let ordinal = wire::Union::encoded_ordinal(union.as_mut());
151        if ordinal == ORD_FRAMEWORK_ERR {
152            wire::Union::decode_as::<_, wire::Int32>(union.as_mut(), &mut decoder, ())
153                .map_err(|_| ProtocolError::InvalidFlexibleResponse)?;
154            let union = unsafe { union.deref_unchecked() };
155            let error = unsafe { union.get().read_unchecked::<wire::Int32>() };
156            match error.0 {
157                ERR_UNKNOWN_METHOD => {
158                    return Err(ProtocolError::FrameworkError(FrameworkError::UnknownMethod));
159                }
160                ordinal => return Err(ProtocolError::UnknownFrameworkError { ordinal }),
161            }
162        }
163    }
164
165    Ok(message)
166}
167
168impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
169    type Output = Result<Message<T>, ProtocolError<T::Error>>;
170
171    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172        let this = Pin::into_inner(self);
173        let Some(index) = this.index else {
174            panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
175        };
176
177        let mut responses = this.inner.responses.lock().unwrap();
178        let message = if let Some(message) = responses.get(index).unwrap().read(cx.waker()) {
179            handle_flexible_response(message)
180        } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
181            Err(termination_reason)
182        } else {
183            return Poll::Pending;
184        };
185
186        responses.free(index);
187        this.index = None;
188
189        Poll::Ready(message)
190    }
191}
192
193/// A future for a sending a two-way FIDL message.
194#[pin_project(PinnedDrop)]
195pub struct TwoWayRequestFuture<'a, T: Transport> {
196    inner: &'a ClientInner<T>,
197    index: Option<u32>,
198    #[pin]
199    state: SendFutureState<T>,
200}
201
202#[pinned_drop]
203impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
204    fn drop(self: Pin<&mut Self>) {
205        if let Some(index) = self.index {
206            let mut responses = self.inner.responses.lock().unwrap();
207
208            // The future was canceled before it could be sent. The transaction
209            // ID was never used, so it's safe to immediately reuse.
210            responses.free(index);
211        }
212    }
213}
214
215impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
216    type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
217
218    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219        let this = self.project();
220
221        let Some(index) = *this.index else {
222            panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
223        };
224
225        let result = ready!(this.state.poll_send(cx, &this.inner.connection));
226        *this.index = None;
227        if let Err(error) = result {
228            // The send failed. Free the locker and return an error.
229            this.inner.responses.lock().unwrap().free(index);
230            Poll::Ready(Err(error))
231        } else {
232            Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
233        }
234    }
235}
236
237impl<'a, T: NonBlockingTransport> TwoWayRequestFuture<'a, T> {
238    /// Completes the send operation synchronously and without blocking.
239    ///
240    /// Using this method prevents transports from applying backpressure. Prefer
241    /// awaiting when possible to allow for backpressure.
242    ///
243    /// Because failed sends return immediately, `send_immediately` may observe
244    /// transport closure prematurely. This can manifest as this method
245    /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
246    /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
247    /// correctness.
248    pub fn send_immediately(self) -> Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>> {
249        let inner = self.inner;
250        let index = self.index;
251        let state = unsafe { ptr::read(&ManuallyDrop::new(self).state) };
252        if let Err(e) = state.send_immediately(&inner.connection) {
253            inner.responses.lock().unwrap().free(index.unwrap());
254            return Err(e);
255        }
256
257        Ok(TwoWayResponseFuture { inner, index })
258    }
259}
260
261/// A type which handles incoming events for a local client.
262///
263/// This is a variant of [`ClientHandler`] that does not require implementing
264/// `Send` and only supports local-thread executors.
265pub trait LocalClientHandler<T: Transport> {
266    /// Handles a received client event.
267    ///
268    /// See [`ClientHandler::on_event`] for more information.
269    fn on_event(
270        &mut self,
271        message: Message<T>,
272    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
273}
274
275/// A type which handles incoming events for a client.
276pub trait ClientHandler<T: Transport>: Send {
277    /// Handles a received client event.
278    ///
279    /// The client cannot handle more messages until `on_event` completes. If
280    /// `on_event` may block, or would perform asynchronous work that takes a
281    /// long time, it should offload work to an async task and return.
282    fn on_event(
283        &mut self,
284        message: Message<T>,
285    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
286}
287
288/// An adapter for a [`ClientHandler`] which implements [`LocalClientHandler`].
289#[repr(transparent)]
290pub struct ClientHandlerToLocalAdapter<H>(H);
291
292impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
293where
294    T: Transport,
295    H: ClientHandler<T>,
296{
297    #[inline]
298    fn on_event(
299        &mut self,
300        message: Message<T>,
301    ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
302        self.0.on_event(message)
303    }
304}
305
306/// A dispatcher for a client endpoint.
307///
308/// A client dispatcher receives all of the incoming messages and dispatches them to the client
309/// handler and two-way futures. It acts as the message pump for the client.
310///
311/// The dispatcher must be actively polled to receive events and two-way message responses. If the
312/// dispatcher is not [`run`](ClientDispatcher::run) concurrently, then events will not be received
313/// and two-way message futures will not receive their responses.
314pub struct ClientDispatcher<T: Transport> {
315    inner: Arc<ClientInner<T>>,
316    exclusive: T::Exclusive,
317    is_terminated: bool,
318}
319
320impl<T: Transport> Drop for ClientDispatcher<T> {
321    fn drop(&mut self) {
322        if !self.is_terminated {
323            // SAFETY: We checked that the connection has not been terminated.
324            unsafe {
325                self.terminate(ProtocolError::Stopped);
326            }
327        }
328    }
329}
330
331impl<T: Transport> ClientDispatcher<T> {
332    /// Creates a new client from a transport.
333    pub fn new(transport: T) -> Self {
334        let (shared, exclusive) = transport.split();
335        Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
336    }
337
338    /// # Safety
339    ///
340    /// The connection must not yet be terminated.
341    unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
342        // SAFETY: We checked that the connection has not been terminated.
343        unsafe {
344            self.inner.connection.terminate(error);
345        }
346        self.inner.responses.lock().unwrap().wake_all();
347    }
348
349    /// Returns a client for the dispatcher.
350    ///
351    /// When the last `Client` is dropped, the dispatcher will be stopped.
352    pub fn client(&self) -> Client<T> {
353        Client { inner: self.inner.clone() }
354    }
355
356    /// Runs the client with the provided handler.
357    pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
358    where
359        H: ClientHandler<T>,
360    {
361        // The bounds on `H` prove that the future returned by `run_local` is
362        // `Send`.
363        self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
364    }
365
366    /// Runs the client with the provided handler.
367    pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
368    where
369        H: LocalClientHandler<T>,
370    {
371        // We may assume that the connection has not been terminated because
372        // connections are only terminated by `run` and `drop`. Neither of those
373        // could have been called before this method because `run` consumes
374        // `self` and `drop` is only ever called once.
375
376        let error = loop {
377            // SAFETY: The connection has not been terminated.
378            let result = unsafe { self.run_one(&mut handler).await };
379            if let Err(error) = result {
380                break error;
381            }
382        };
383
384        // SAFETY: The connection has not been terminated.
385        unsafe {
386            self.terminate(error.clone());
387        }
388        self.is_terminated = true;
389
390        match error {
391            // We consider clients to have finished successfully only if they
392            // stop themselves manually.
393            ProtocolError::Stopped => Ok(handler),
394
395            // Otherwise, the client finished with an error.
396            _ => Err(error),
397        }
398    }
399
400    /// # Safety
401    ///
402    /// The connection must not be terminated.
403    async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
404    where
405        H: LocalClientHandler<T>,
406    {
407        // SAFETY: The caller guaranteed that the connection is not terminated.
408        let buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
409
410        // This expression is really awkward due to a limitation in rustc's
411        // liveness analysis for local variables. We need to avoid holding
412        // `decoder` across `.await`s because it may not be `Send` and tasks may
413        // migrate threads between polls. We should be able to just
414        // `drop(decoder)` before any `.await`, but rustc is overly conservative
415        // and still considers `decoder` as live at the `.await` for that
416        // analysis. The only way to convince rustc that `decoder` is not live
417        // at that await point is to keep the lexical scope containing `decoder`
418        // free of `.await`s.
419        //
420        // See https://github.com/rust-lang/rust/issues/63768 for more details.
421        let mut message =
422            Message::<T>::decode(buffer).map_err(ProtocolError::InvalidMessageHeader)?;
423        let txid = *message.header().txid;
424        let ordinal = *message.header().ordinal;
425
426        // Check if the ordinal is the epitaph so we can immediately decode and
427        // return it.
428        if ordinal == EPITAPH_ORDINAL {
429            let mut decoder = message.as_decoder();
430            let epitaph = decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
431            return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
432        }
433
434        if txid == 0 {
435            handler.on_event(message).await?;
436        } else {
437            let mut responses = self.inner.responses.lock().unwrap();
438            let locker = responses
439                .get(txid - 1)
440                .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
441
442            match locker.write(ordinal, message) {
443                // Reader didn't cancel
444                Ok(false) => (),
445                // Reader canceled, we can drop the entry
446                Ok(true) => responses.free(txid - 1),
447                Err(LockerError::NotWriteable) => {
448                    return Err(ProtocolError::UnrequestedResponse { txid });
449                }
450                Err(LockerError::MismatchedOrdinal { expected, actual }) => {
451                    return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
452                }
453            }
454        }
455
456        Ok(())
457    }
458}