Skip to main content

fdomain_client/
lib.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fidl_fuchsia_fdomain as proto;
6use fidl_message::TransactionHeader;
7use fuchsia_async as _;
8use fuchsia_sync::Mutex;
9use futures::FutureExt;
10use futures::channel::oneshot::Sender as OneshotSender;
11use futures::stream::Stream as StreamTrait;
12use std::collections::{HashMap, VecDeque};
13use std::convert::Infallible;
14use std::future::Future;
15use std::num::NonZeroU32;
16use std::pin::Pin;
17use std::sync::{Arc, LazyLock, Weak};
18use std::task::{Context, Poll, Waker, ready};
19
20mod channel;
21mod event;
22mod event_pair;
23mod handle;
24mod responder;
25mod socket;
26
27#[cfg(test)]
28mod test;
29
30pub mod fidl;
31pub mod fidl_next;
32
33use responder::Responder;
34
35pub use channel::{
36    AnyHandle, Channel, ChannelMessageStream, ChannelWriter, HandleInfo, HandleOp, MessageBuf,
37};
38pub use event::Event;
39pub use event_pair::Eventpair as EventPair;
40pub use handle::unowned::Unowned;
41pub use handle::{
42    AsHandleRef, Handle, HandleBased, HandleRef, NullableHandle, OnFDomainSignals, Peered,
43};
44pub use proto::{Error as FDomainError, WriteChannelError, WriteSocketError};
45pub use socket::{Socket, SocketDisposition, SocketReadStream, SocketWriter};
46
47// Unsupported handle types.
48#[rustfmt::skip]
49pub use Handle as Clock;
50#[rustfmt::skip]
51pub use Handle as Exception;
52#[rustfmt::skip]
53pub use Handle as Fifo;
54#[rustfmt::skip]
55pub use Handle as Iob;
56#[rustfmt::skip]
57pub use Handle as Job;
58#[rustfmt::skip]
59pub use Handle as Process;
60#[rustfmt::skip]
61pub use Handle as Resource;
62#[rustfmt::skip]
63pub use Handle as Stream;
64#[rustfmt::skip]
65pub use Handle as Thread;
66#[rustfmt::skip]
67pub use Handle as Vmar;
68#[rustfmt::skip]
69pub use Handle as Vmo;
70#[rustfmt::skip]
71pub use Handle as Counter;
72
73use proto::f_domain_ordinals as ordinals;
74
75fn write_fdomain_error(error: &FDomainError, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76    match error {
77        FDomainError::TargetError(e) => {
78            let e = zx_status::Status::from_raw(*e);
79            write!(f, "Target-side error {e}")
80        }
81        FDomainError::BadHandleId(proto::BadHandleId { id }) => {
82            write!(f, "Tried to use invalid handle id {id}")
83        }
84        FDomainError::WrongHandleType(proto::WrongHandleType { expected, got }) => write!(
85            f,
86            "Tried to use handle as {expected:?} but target reported handle was of type {got:?}"
87        ),
88        FDomainError::StreamingReadInProgress(proto::StreamingReadInProgress {}) => {
89            write!(f, "Handle is occupied delivering streaming reads")
90        }
91        FDomainError::NoReadInProgress(proto::NoReadInProgress {}) => {
92            write!(f, "No streaming read was in progress")
93        }
94        FDomainError::NewHandleIdOutOfRange(proto::NewHandleIdOutOfRange { id }) => {
95            write!(
96                f,
97                "Tried to create a handle with id {id}, which is outside the valid range for client handles"
98            )
99        }
100        FDomainError::NewHandleIdReused(proto::NewHandleIdReused { id, same_call }) => {
101            if *same_call {
102                write!(f, "Tried to create two or more new handles with the same id {id}")
103            } else {
104                write!(
105                    f,
106                    "Tried to create a new handle with id {id}, which is already the id of an existing handle"
107                )
108            }
109        }
110        FDomainError::WroteToSelf(proto::WroteToSelf {}) => {
111            write!(f, "Tried to write a channel into itself")
112        }
113        FDomainError::ClosedDuringRead(proto::ClosedDuringRead {}) => {
114            write!(f, "Handle closed while being read")
115        }
116        _ => todo!(),
117    }
118}
119
120/// Result type alias.
121pub type Result<T, E = Error> = std::result::Result<T, E>;
122
123/// Error type emitted by FDomain operations.
124#[derive(Clone)]
125pub enum Error {
126    SocketWrite(WriteSocketError),
127    ChannelWrite(WriteChannelError),
128    FDomain(FDomainError),
129    Protocol(::fidl::Error),
130    ProtocolObjectTypeIncompatible,
131    ProtocolRightsIncompatible,
132    ProtocolSignalsIncompatible,
133    ProtocolStreamEventIncompatible,
134    Transport(Option<Arc<std::io::Error>>),
135    ConnectionMismatch,
136    StreamingAborted,
137}
138
139impl std::fmt::Display for Error {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        match self {
142            Self::SocketWrite(proto::WriteSocketError { error, wrote }) => {
143                write!(f, "While writing socket (after {wrote} bytes written successfully): ")?;
144                write_fdomain_error(error, f)
145            }
146            Self::ChannelWrite(proto::WriteChannelError::Error(error)) => {
147                write!(f, "While writing channel: ")?;
148                write_fdomain_error(error, f)
149            }
150            Self::ChannelWrite(proto::WriteChannelError::OpErrors(errors)) => {
151                write!(f, "Couldn't write all handles into a channel:")?;
152                for (pos, error) in
153                    errors.iter().enumerate().filter_map(|(num, x)| x.as_ref().map(|y| (num, &**y)))
154                {
155                    write!(f, "\n  Handle in position {pos}: ")?;
156                    write_fdomain_error(error, f)?;
157                }
158                Ok(())
159            }
160            Self::ProtocolObjectTypeIncompatible => {
161                write!(f, "The FDomain protocol does not recognize an object type")
162            }
163            Self::ProtocolRightsIncompatible => {
164                write!(f, "The FDomain protocol does not recognize some rights")
165            }
166            Self::ProtocolSignalsIncompatible => {
167                write!(f, "The FDomain protocol does not recognize some signals")
168            }
169            Self::ProtocolStreamEventIncompatible => {
170                write!(f, "The FDomain protocol does not recognize a received streaming IO event")
171            }
172            Self::FDomain(e) => write_fdomain_error(e, f),
173            Self::Protocol(e) => write!(f, "Protocol error: {e}"),
174            Self::Transport(Some(e)) => write!(f, "Transport error: {e:?}"),
175            Self::Transport(None) => write!(f, "Transport closed"),
176            Self::ConnectionMismatch => {
177                write!(f, "Tried to use an FDomain handle from a different connection")
178            }
179            Self::StreamingAborted => write!(f, "This channel is no longer streaming"),
180        }
181    }
182}
183
184impl std::fmt::Debug for Error {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        match self {
187            Self::SocketWrite(e) => f.debug_tuple("SocketWrite").field(e).finish(),
188            Self::ChannelWrite(e) => f.debug_tuple("ChannelWrite").field(e).finish(),
189            Self::FDomain(e) => f.debug_tuple("FDomain").field(e).finish(),
190            Self::Protocol(e) => f.debug_tuple("Protocol").field(e).finish(),
191            Self::Transport(e) => f.debug_tuple("Transport").field(e).finish(),
192            Self::ProtocolObjectTypeIncompatible => write!(f, "ProtocolObjectTypeIncompatible "),
193            Self::ProtocolRightsIncompatible => write!(f, "ProtocolRightsIncompatible "),
194            Self::ProtocolSignalsIncompatible => write!(f, "ProtocolSignalsIncompatible "),
195            Self::ProtocolStreamEventIncompatible => write!(f, "ProtocolStreamEventIncompatible"),
196            Self::ConnectionMismatch => write!(f, "ConnectionMismatch"),
197            Self::StreamingAborted => write!(f, "StreamingAborted"),
198        }
199    }
200}
201
202impl std::error::Error for Error {}
203
204impl From<FDomainError> for Error {
205    fn from(other: FDomainError) -> Self {
206        Self::FDomain(other)
207    }
208}
209
210impl From<::fidl::Error> for Error {
211    fn from(other: ::fidl::Error) -> Self {
212        Self::Protocol(other)
213    }
214}
215
216impl From<WriteSocketError> for Error {
217    fn from(other: WriteSocketError) -> Self {
218        Self::SocketWrite(other)
219    }
220}
221
222impl From<WriteChannelError> for Error {
223    fn from(other: WriteChannelError) -> Self {
224        Self::ChannelWrite(other)
225    }
226}
227
228/// An error emitted internally by the client. Similar to [`Error`] but does not
229/// contain several variants which are irrelevant in the contexts where it is
230/// used.
231#[derive(Clone)]
232enum InnerError {
233    Protocol(::fidl::Error),
234    ProtocolStreamEventIncompatible,
235    Transport(Option<Arc<std::io::Error>>),
236}
237
238impl From<InnerError> for Error {
239    fn from(other: InnerError) -> Self {
240        match other {
241            InnerError::Protocol(p) => Error::Protocol(p),
242            InnerError::ProtocolStreamEventIncompatible => Error::ProtocolStreamEventIncompatible,
243            InnerError::Transport(t) => Error::Transport(t),
244        }
245    }
246}
247
248impl From<::fidl::Error> for InnerError {
249    fn from(other: ::fidl::Error) -> Self {
250        InnerError::Protocol(other)
251    }
252}
253
254// TODO(399717689) Figure out if we could just use AsyncRead/Write instead of a special trait.
255/// Implemented by objects which provide a transport over which we can speak the
256/// FDomain protocol.
257///
258/// The implementer must provide two things:
259/// 1) An incoming stream of messages presented as `Vec<u8>`. This is provided
260///    via the `Stream` trait, which this trait requires.
261/// 2) A way to send messages. This is provided by implementing the
262///    `poll_send_message` method.
263pub trait FDomainTransport: StreamTrait<Item = Result<Box<[u8]>, std::io::Error>> + Send {
264    /// Attempt to send a message asynchronously. Messages should be sent so
265    /// that they arrive at the target in order.
266    fn poll_send_message(
267        self: Pin<&mut Self>,
268        msg: &[u8],
269        ctx: &mut Context<'_>,
270    ) -> Poll<Result<(), Option<std::io::Error>>>;
271
272    /// Optional debug information outlet.
273    fn debug_fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        Ok(())
275    }
276
277    /// Whether `debug_fmt` does anything.
278    fn has_debug_fmt(&self) -> bool {
279        false
280    }
281}
282
283/// Wrapper for an `FDomainTransport` implementer that:
284/// 1) Provides a queue for outgoing messages so we need not have an await point
285///    when we submit a message.
286/// 2) Drops the transport on error, then returns the last observed error for
287///    all future operations.
288enum Transport {
289    Transport(Pin<Box<dyn FDomainTransport>>, VecDeque<Box<[u8]>>, Vec<Waker>),
290    Error(InnerError),
291}
292
293impl Transport {
294    /// Get the failure mode of the transport if it has failed.
295    fn error(&self) -> Option<InnerError> {
296        match self {
297            Transport::Transport(_, _, _) => None,
298            Transport::Error(inner_error) => Some(inner_error.clone()),
299        }
300    }
301
302    /// Enqueue a message to be sent on this transport.
303    fn push_msg(&mut self, msg: Box<[u8]>) -> Result<(), InnerError> {
304        match self {
305            Transport::Transport(_, v, w) => {
306                v.push_back(msg);
307                w.drain(..).for_each(Waker::wake);
308                Ok(())
309            }
310            Transport::Error(e) => Err(e.clone()),
311        }
312    }
313
314    /// Push messages in the send queue out through the transport.
315    fn poll_send_messages(&mut self, ctx: &mut Context<'_>) -> Poll<InnerError> {
316        match self {
317            Transport::Error(e) => Poll::Ready(e.clone()),
318            Transport::Transport(t, v, w) => {
319                while let Some(msg) = v.front() {
320                    match t.as_mut().poll_send_message(msg, ctx) {
321                        Poll::Ready(Ok(())) => {
322                            v.pop_front();
323                        }
324                        Poll::Ready(Err(e)) => {
325                            let e = e.map(Arc::new);
326                            *self = Transport::Error(InnerError::Transport(e.clone()));
327                            return Poll::Ready(InnerError::Transport(e));
328                        }
329                        Poll::Pending => return Poll::Pending,
330                    }
331                }
332
333                if v.is_empty() {
334                    w.push(ctx.waker().clone());
335                } else {
336                    ctx.waker().wake_by_ref();
337                }
338                Poll::Pending
339            }
340        }
341    }
342
343    /// Get the next incoming message from the transport.
344    fn poll_next(&mut self, ctx: &mut Context<'_>) -> Poll<Result<Box<[u8]>, InnerError>> {
345        match self {
346            Transport::Error(e) => Poll::Ready(Err(e.clone())),
347            Transport::Transport(t, _, _) => match ready!(t.as_mut().poll_next(ctx)) {
348                Some(Ok(x)) => Poll::Ready(Ok(x)),
349                Some(Err(e)) => {
350                    let e = Arc::new(e);
351                    *self = Transport::Error(InnerError::Transport(Some(Arc::clone(&e))));
352                    Poll::Ready(Err(InnerError::Transport(Some(e))))
353                }
354                Option::None => Poll::Ready(Err(InnerError::Transport(None))),
355            },
356        }
357    }
358}
359
360impl Drop for Transport {
361    fn drop(&mut self) {
362        if let Transport::Transport(_, _, wakers) = self {
363            wakers.drain(..).for_each(Waker::wake);
364        }
365    }
366}
367
368/// State of a socket that is or has been read from.
369struct SocketReadState {
370    wakers: Vec<Waker>,
371    queued: VecDeque<Result<proto::SocketData, Error>>,
372    read_request_pending: bool,
373    is_streaming: bool,
374}
375
376impl SocketReadState {
377    /// Handle an incoming message, which is either a channel streaming event or
378    /// response to a `ChannelRead` request.
379    fn handle_incoming_message(&mut self, msg: Result<proto::SocketData, Error>) -> Vec<Waker> {
380        self.queued.push_back(msg);
381        std::mem::replace(&mut self.wakers, Vec::new())
382    }
383}
384
385/// State of a channel that is or has been read from.
386struct ChannelReadState {
387    wakers: Vec<Waker>,
388    queued: VecDeque<Result<proto::ChannelMessage, Error>>,
389    read_request_pending: bool,
390    is_streaming: bool,
391}
392
393impl ChannelReadState {
394    /// Handle an incoming message, which is either a channel streaming event or
395    /// response to a `ChannelRead` request.
396    fn handle_incoming_message(&mut self, msg: Result<proto::ChannelMessage, Error>) -> Vec<Waker> {
397        self.queued.push_back(msg);
398        std::mem::replace(&mut self.wakers, Vec::new())
399    }
400}
401
402/// Lock-protected interior of `Client`
403struct ClientInner {
404    transport: Transport,
405    transactions: HashMap<NonZeroU32, responder::Responder>,
406    channel_read_states: HashMap<proto::HandleId, ChannelReadState>,
407    socket_read_states: HashMap<proto::HandleId, SocketReadState>,
408    next_tx_id: u32,
409    waiting_to_close: Vec<proto::HandleId>,
410    waiting_to_close_waker: Waker,
411
412    /// There is a lock around `ClientInner`, and sometimes the FIDL bindings
413    /// give us wakers that want to do handle operations synchronously on wake,
414    /// which means we can double-take the lock if we wake a waker while we hold
415    /// it. This is a place to store wakers that we'd like to be woken as soon
416    /// as we're not holding that lock, to avoid these weird reentrancy issues.
417    wakers_to_wake: Vec<Waker>,
418}
419
420impl ClientInner {
421    /// Serialize and enqueue a new transaction, including header and transaction ID.
422    fn request<S: fidl_message::Body>(&mut self, ordinal: u64, request: S, responder: Responder) {
423        if ordinal != ordinals::CLOSE {
424            self.process_waiting_to_close();
425        }
426        let tx_id = self.next_tx_id;
427
428        let header = TransactionHeader::new(tx_id, ordinal, fidl_message::DynamicFlags::FLEXIBLE);
429        let msg = fidl_message::encode_message(header, request).expect("Could not encode request!");
430        self.next_tx_id += 1;
431        if let Err(e) = self.transport.push_msg(msg.into()) {
432            let _ = responder.handle(self, Err(e.into()));
433        } else {
434            assert!(
435                self.transactions.insert(tx_id.try_into().unwrap(), responder).is_none(),
436                "Allocated same tx id twice!"
437            );
438        }
439    }
440
441    fn process_waiting_to_close(&mut self) {
442        if !self.waiting_to_close.is_empty() {
443            let handles = std::mem::replace(&mut self.waiting_to_close, Vec::new());
444            // We've dropped the handle object. Nobody is going to wait to read
445            // the buffers anymore. This is a safe time to drop the read state.
446            for handle in &handles {
447                let _ = self.channel_read_states.remove(handle);
448                let _ = self.socket_read_states.remove(handle);
449            }
450            self.request(
451                ordinals::CLOSE,
452                proto::FDomainCloseRequest { handles },
453                Responder::Ignore,
454            );
455        }
456    }
457
458    /// Polls the underlying transport to ensure any incoming or outgoing
459    /// messages are processed as far as possible. Errors if the transport has failed.
460    fn try_poll_transport(
461        &mut self,
462        ctx: &mut Context<'_>,
463    ) -> Poll<Result<Infallible, InnerError>> {
464        self.process_waiting_to_close();
465
466        self.waiting_to_close_waker = ctx.waker().clone();
467
468        loop {
469            if let Poll::Ready(e) = self.transport.poll_send_messages(ctx) {
470                for mut state in std::mem::take(&mut self.socket_read_states).into_values() {
471                    state.queued.push_back(Err(Error::from(e.clone())));
472                    self.wakers_to_wake.extend(state.wakers);
473                }
474                for (_, mut state) in self.channel_read_states.drain() {
475                    state.queued.push_back(Err(Error::from(e.clone())));
476                    self.wakers_to_wake.extend(state.wakers);
477                }
478                return Poll::Ready(Err(e));
479            }
480            let Poll::Ready(result) = self.transport.poll_next(ctx) else {
481                return Poll::Pending;
482            };
483            let data = result?;
484            let (header, data) = match fidl_message::decode_transaction_header(&data) {
485                Ok(x) => x,
486                Err(e) => {
487                    self.transport = Transport::Error(InnerError::Protocol(e));
488                    continue;
489                }
490            };
491
492            let Some(tx_id) = NonZeroU32::new(header.tx_id) else {
493                match self.process_event(header, data) {
494                    Ok(wakers) => self.wakers_to_wake.extend(wakers),
495                    Err(e) => self.transport = Transport::Error(e),
496                }
497                continue;
498            };
499
500            let tx = self.transactions.remove(&tx_id).ok_or(::fidl::Error::InvalidResponseTxid)?;
501            match tx.handle(self, Ok((header, data))) {
502                Ok(x) => x,
503                Err(e) => {
504                    self.transport = Transport::Error(InnerError::Protocol(e));
505                    continue;
506                }
507            }
508        }
509    }
510
511    /// Process an incoming message that arose from an event rather than a transaction reply.
512    fn process_event(
513        &mut self,
514        header: TransactionHeader,
515        data: &[u8],
516    ) -> Result<Vec<Waker>, InnerError> {
517        match header.ordinal {
518            ordinals::ON_SOCKET_STREAMING_DATA => {
519                let msg = fidl_message::decode_message::<proto::SocketOnSocketStreamingDataRequest>(
520                    header, data,
521                )?;
522                let o =
523                    self.socket_read_states.entry(msg.handle).or_insert_with(|| SocketReadState {
524                        wakers: Vec::new(),
525                        queued: VecDeque::new(),
526                        is_streaming: false,
527                        read_request_pending: false,
528                    });
529                match msg.socket_message {
530                    proto::SocketMessage::Data(data) => Ok(o.handle_incoming_message(Ok(data))),
531                    proto::SocketMessage::Stopped(proto::AioStopped { error }) => {
532                        let ret = if let Some(error) = error {
533                            o.handle_incoming_message(Err(Error::FDomain(*error)))
534                        } else {
535                            Vec::new()
536                        };
537                        o.is_streaming = false;
538                        Ok(ret)
539                    }
540                    _ => Err(InnerError::ProtocolStreamEventIncompatible),
541                }
542            }
543            ordinals::ON_CHANNEL_STREAMING_DATA => {
544                let msg = fidl_message::decode_message::<
545                    proto::ChannelOnChannelStreamingDataRequest,
546                >(header, data)?;
547                let o = self.channel_read_states.entry(msg.handle).or_insert_with(|| {
548                    ChannelReadState {
549                        wakers: Vec::new(),
550                        queued: VecDeque::new(),
551                        is_streaming: false,
552                        read_request_pending: false,
553                    }
554                });
555                match msg.channel_sent {
556                    proto::ChannelSent::Message(data) => Ok(o.handle_incoming_message(Ok(data))),
557                    proto::ChannelSent::Stopped(proto::AioStopped { error }) => {
558                        let ret = if let Some(error) = error {
559                            o.handle_incoming_message(Err(Error::FDomain(*error)))
560                        } else {
561                            Vec::new()
562                        };
563                        o.is_streaming = false;
564                        Ok(ret)
565                    }
566                    _ => Err(InnerError::ProtocolStreamEventIncompatible),
567                }
568            }
569            _ => Err(::fidl::Error::UnknownOrdinal {
570                ordinal: header.ordinal,
571                protocol_name:
572                    <proto::FDomainMarker as ::fidl::endpoints::ProtocolMarker>::DEBUG_NAME,
573            }
574            .into()),
575        }
576    }
577
578    /// Polls the underlying transport to ensure any incoming or outgoing
579    /// messages are processed as far as possible. If a failure occurs, puts the
580    /// transport into an error state and fails all pending transactions.
581    fn poll_transport(&mut self, ctx: &mut Context<'_>) -> Poll<()> {
582        if let Poll::Ready(Err(e)) = self.try_poll_transport(ctx) {
583            for (_, v) in std::mem::take(&mut self.transactions) {
584                let _ = v.handle(self, Err(e.clone()));
585            }
586
587            Poll::Ready(())
588        } else {
589            Poll::Pending
590        }
591    }
592
593    /// Handles the response to a `SocketRead` protocol message.
594    pub(crate) fn handle_socket_read_response(
595        &mut self,
596        msg: Result<proto::SocketData, Error>,
597        id: proto::HandleId,
598    ) {
599        let state = self.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
600            wakers: Vec::new(),
601            queued: VecDeque::new(),
602            is_streaming: false,
603            read_request_pending: false,
604        });
605        let wakers = state.handle_incoming_message(msg);
606        self.wakers_to_wake.extend(wakers);
607        state.read_request_pending = false;
608    }
609
610    /// Handles the response to a `ChannelRead` protocol message.
611    pub(crate) fn handle_channel_read_response(
612        &mut self,
613        msg: Result<proto::ChannelMessage, Error>,
614        id: proto::HandleId,
615    ) {
616        let state = self.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
617            wakers: Vec::new(),
618            queued: VecDeque::new(),
619            is_streaming: false,
620            read_request_pending: false,
621        });
622        let wakers = state.handle_incoming_message(msg);
623        self.wakers_to_wake.extend(wakers);
624        state.read_request_pending = false;
625    }
626}
627
628impl Drop for ClientInner {
629    fn drop(&mut self) {
630        let responders = self.transactions.drain().map(|x| x.1).collect::<Vec<_>>();
631        for responder in responders {
632            let _ = responder.handle(self, Err(InnerError::Transport(None)));
633        }
634        for state in self.channel_read_states.values_mut() {
635            state.wakers.drain(..).for_each(Waker::wake);
636        }
637        for state in self.socket_read_states.values_mut() {
638            state.wakers.drain(..).for_each(Waker::wake);
639        }
640        self.waiting_to_close_waker.wake_by_ref();
641        self.wakers_to_wake.drain(..).for_each(Waker::wake);
642    }
643}
644
645/// Represents a connection to an FDomain.
646///
647/// The client is constructed by passing it a transport object which represents
648/// the raw connection to the remote FDomain. The `Client` wrapper then allows
649/// us to construct and use handles which behave similarly to their counterparts
650/// on a Fuchsia device.
651pub struct Client(pub(crate) Mutex<ClientInner>);
652
653impl std::fmt::Debug for Client {
654    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
655        let inner = self.0.lock();
656        match &inner.transport {
657            Transport::Transport(transport, ..) if transport.has_debug_fmt() => {
658                write!(f, "Client(")?;
659                transport.debug_fmt(f)?;
660                write!(f, ")")
661            }
662            Transport::Error(error) => {
663                let error = Error::from(error.clone());
664                write!(f, "Client(Failed: {error})")
665            }
666            _ => f.debug_tuple("Client").field(&"<transport>").finish(),
667        }
668    }
669}
670
671/// A client which is always disconnected. Handles that lose their clients
672/// connect to this client instead, which always returns a "Client Lost"
673/// transport failure.
674pub(crate) static DEAD_CLIENT: LazyLock<Arc<Client>> = LazyLock::new(|| {
675    Arc::new(Client(Mutex::new(ClientInner {
676        transport: Transport::Error(InnerError::Transport(None)),
677        transactions: HashMap::new(),
678        channel_read_states: HashMap::new(),
679        socket_read_states: HashMap::new(),
680        next_tx_id: 1,
681        waiting_to_close: Vec::new(),
682        waiting_to_close_waker: std::task::Waker::noop().clone(),
683        wakers_to_wake: Vec::new(),
684    })))
685});
686
687/// A wrapper around the FDomain client background future that ensures
688/// all pending transactions and reads are failed if the loop is dropped.
689///
690/// This prevents hangs when the transport is abruptly closed (e.g. during target reboot)
691/// by waking up any futures waiting for responses or data on channels/sockets.
692pub struct ClientLoop {
693    client: Weak<Client>,
694    fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
695}
696
697impl Future for ClientLoop {
698    type Output = ();
699    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
700        self.fut.as_mut().poll(cx)
701    }
702}
703
704impl Drop for ClientLoop {
705    fn drop(&mut self) {
706        let Some(client) = self.client.upgrade() else {
707            return;
708        };
709
710        let (channel_read_states, socket_read_states, deferred_wakers) = {
711            let mut inner = client.0.lock();
712            let transactions = std::mem::take(&mut inner.transactions);
713            log::debug!("ClientLoop dropped, failing {} transactions", transactions.len());
714            for (_, v) in transactions {
715                let _ = v.handle(&mut *inner, Err(InnerError::Transport(None)));
716            }
717
718            let channel_read_states = std::mem::take(&mut inner.channel_read_states);
719            let socket_read_states = std::mem::take(&mut inner.socket_read_states);
720
721            let deferred_wakers = std::mem::replace(&mut inner.wakers_to_wake, Vec::new());
722
723            (channel_read_states, socket_read_states, deferred_wakers)
724        };
725
726        log::debug!("Failing reads on {} channels", channel_read_states.len());
727        for (_, mut state) in channel_read_states {
728            state.queued.push_back(Err(Error::Transport(None)));
729            state.wakers.into_iter().for_each(Waker::wake);
730        }
731
732        log::debug!("Failing reads on {} sockets", socket_read_states.len());
733        for (_, mut state) in socket_read_states {
734            state.queued.push_back(Err(Error::Transport(None)));
735            state.wakers.into_iter().for_each(Waker::wake);
736        }
737
738        deferred_wakers.into_iter().for_each(Waker::wake);
739    }
740}
741
742impl Client {
743    /// Create a new FDomain client. The `transport` argument should contain the
744    /// established connection to the target, ready to communicate the FDomain
745    /// protocol.
746    ///
747    /// The second return item is a future that must be polled to keep
748    /// transactions running.
749    pub fn new(
750        transport: impl FDomainTransport + 'static,
751    ) -> (Arc<Self>, impl Future<Output = ()> + Send + 'static) {
752        let ret = Arc::new(Client(Mutex::new(ClientInner {
753            transport: Transport::Transport(Box::pin(transport), VecDeque::new(), Vec::new()),
754            transactions: HashMap::new(),
755            socket_read_states: HashMap::new(),
756            channel_read_states: HashMap::new(),
757            next_tx_id: 1,
758            waiting_to_close: Vec::new(),
759            waiting_to_close_waker: std::task::Waker::noop().clone(),
760            wakers_to_wake: Vec::new(),
761        })));
762
763        let client_weak = Arc::downgrade(&ret);
764        let fut = futures::future::poll_fn(move |ctx| {
765            let Some(client) = client_weak.upgrade() else {
766                return Poll::Ready(());
767            };
768
769            let (ret, deferred_wakers) = {
770                let mut inner = client.0.lock();
771                let ret = inner.poll_transport(ctx);
772                let deferred_wakers = std::mem::replace(&mut inner.wakers_to_wake, Vec::new());
773                (ret, deferred_wakers)
774            };
775            deferred_wakers.into_iter().for_each(Waker::wake);
776            ret
777        });
778
779        let client_loop = ClientLoop { client: Arc::downgrade(&ret), fut: Box::pin(fut) };
780
781        (ret, client_loop)
782    }
783
784    /// Get the namespace for the connected FDomain. Calling this more than once is an error.
785    pub async fn namespace(self: &Arc<Self>) -> Result<Channel, Error> {
786        let new_handle = self.new_hid();
787        self.transaction(
788            ordinals::GET_NAMESPACE,
789            proto::FDomainGetNamespaceRequest { new_handle },
790            Responder::Namespace,
791        )
792        .await?;
793        Ok(Channel(Handle { id: new_handle.id, client: Arc::downgrade(self) }))
794    }
795
796    /// Create a new channel in the connected FDomain.
797    pub fn create_channel(self: &Arc<Self>) -> (Channel, Channel) {
798        let id_a = self.new_hid();
799        let id_b = self.new_hid();
800        let fut = self.transaction(
801            ordinals::CREATE_CHANNEL,
802            proto::ChannelCreateChannelRequest { handles: [id_a, id_b] },
803            Responder::CreateChannel,
804        );
805
806        fuchsia_async::Task::spawn(async move {
807            if let Err(e) = fut.await {
808                log::debug!("FDomain channel creation failed: {e}");
809            }
810        })
811        .detach();
812
813        (
814            Channel(Handle { id: id_a.id, client: Arc::downgrade(self) }),
815            Channel(Handle { id: id_b.id, client: Arc::downgrade(self) }),
816        )
817    }
818
819    /// Creates client and server endpoints connected to by a channel.
820    pub fn create_endpoints<F: crate::fidl::ProtocolMarker>(
821        self: &Arc<Self>,
822    ) -> (crate::fidl::ClientEnd<F>, crate::fidl::ServerEnd<F>) {
823        let (client, server) = self.create_channel();
824        let client_end = crate::fidl::ClientEnd::<F>::new(client);
825        let server_end = crate::fidl::ServerEnd::new(server);
826        (client_end, server_end)
827    }
828
829    /// Creates a client proxy and a server endpoint connected by a channel.
830    pub fn create_proxy<F: crate::fidl::ProtocolMarker>(
831        self: &Arc<Self>,
832    ) -> (F::Proxy, crate::fidl::ServerEnd<F>) {
833        let (client_end, server_end) = self.create_endpoints::<F>();
834        (client_end.into_proxy(), server_end)
835    }
836
837    /// Creates a client proxy and a server request stream connected by a channel.
838    pub fn create_proxy_and_stream<F: crate::fidl::ProtocolMarker>(
839        self: &Arc<Self>,
840    ) -> (F::Proxy, F::RequestStream) {
841        let (client_end, server_end) = self.create_endpoints::<F>();
842        (client_end.into_proxy(), server_end.into_stream())
843    }
844
845    /// Creates a client end and a server request stream connected by a channel.
846    pub fn create_request_stream<F: crate::fidl::ProtocolMarker>(
847        self: &Arc<Self>,
848    ) -> (crate::fidl::ClientEnd<F>, F::RequestStream) {
849        let (client_end, server_end) = self.create_endpoints::<F>();
850        (client_end, server_end.into_stream())
851    }
852
853    /// Create a new socket in the connected FDomain.
854    fn create_socket(self: &Arc<Self>, options: proto::SocketType) -> (Socket, Socket) {
855        let id_a = self.new_hid();
856        let id_b = self.new_hid();
857        let fut = self.transaction(
858            ordinals::CREATE_SOCKET,
859            proto::SocketCreateSocketRequest { handles: [id_a, id_b], options },
860            Responder::CreateSocket,
861        );
862
863        fuchsia_async::Task::spawn(async move {
864            if let Err(e) = fut.await {
865                log::debug!("FDomain socket creation failed: {e}");
866            }
867        })
868        .detach();
869
870        (
871            Socket(Handle { id: id_a.id, client: Arc::downgrade(self) }),
872            Socket(Handle { id: id_b.id, client: Arc::downgrade(self) }),
873        )
874    }
875
876    /// Create a new streaming socket in the connected FDomain.
877    pub fn create_stream_socket(self: &Arc<Self>) -> (Socket, Socket) {
878        self.create_socket(proto::SocketType::Stream)
879    }
880
881    /// Create a new datagram socket in the connected FDomain.
882    pub fn create_datagram_socket(self: &Arc<Self>) -> (Socket, Socket) {
883        self.create_socket(proto::SocketType::Datagram)
884    }
885
886    /// Create a new event pair in the connected FDomain.
887    pub fn create_event_pair(self: &Arc<Self>) -> (EventPair, EventPair) {
888        let id_a = self.new_hid();
889        let id_b = self.new_hid();
890        let fut = self.transaction(
891            ordinals::CREATE_EVENT_PAIR,
892            proto::EventPairCreateEventPairRequest { handles: [id_a, id_b] },
893            Responder::CreateEventPair,
894        );
895
896        fuchsia_async::Task::spawn(async move {
897            if let Err(e) = fut.await {
898                log::debug!("FDomain event pair creation failed: {e}");
899            }
900        })
901        .detach();
902
903        (
904            EventPair(Handle { id: id_a.id, client: Arc::downgrade(self) }),
905            EventPair(Handle { id: id_b.id, client: Arc::downgrade(self) }),
906        )
907    }
908
909    /// Create a new event handle in the connected FDomain.
910    pub fn create_event(self: &Arc<Self>) -> Event {
911        let id = self.new_hid();
912        let fut = self.transaction(
913            ordinals::CREATE_EVENT,
914            proto::EventCreateEventRequest { handle: id },
915            Responder::CreateEvent,
916        );
917
918        fuchsia_async::Task::spawn(async move {
919            if let Err(e) = fut.await {
920                log::debug!("FDomain event creation failed: {e}");
921            }
922        })
923        .detach();
924
925        Event(Handle { id: id.id, client: Arc::downgrade(self) })
926    }
927
928    /// Allocate a new HID, which should be suitable for use with the connected FDomain.
929    pub(crate) fn new_hid(&self) -> proto::NewHandleId {
930        // TODO: On the target side we have to keep a table of these which means
931        // we can automatically detect collisions in the random value. On the
932        // client side we'd have to add a whole data structure just for that
933        // purpose. Should we?
934        proto::NewHandleId { id: rand::random::<u32>() >> 1 }
935    }
936
937    /// Create a future which sends a FIDL message to the connected FDomain and
938    /// waits for a response.
939    ///
940    /// Calling this method queues the transaction synchronously. Awaiting is
941    /// only necessary to wait for the response.
942    pub(crate) fn transaction<S: fidl_message::Body, R: 'static, F>(
943        self: &Arc<Self>,
944        ordinal: u64,
945        request: S,
946        f: F,
947    ) -> impl Future<Output = Result<R, Error>> + 'static + use<S, R, F>
948    where
949        F: Fn(OneshotSender<Result<R, Error>>) -> Responder,
950    {
951        let mut inner = self.0.lock();
952
953        let (sender, receiver) = futures::channel::oneshot::channel();
954        inner.request(ordinal, request, f(sender));
955        receiver.map(|x| x.expect("Oneshot went away without reply!"))
956    }
957
958    /// Start getting streaming events for socket reads.
959    pub(crate) fn start_socket_streaming(&self, id: proto::HandleId) -> Result<(), Error> {
960        let mut inner = self.0.lock();
961        if let Some(e) = inner.transport.error() {
962            return Err(e.into());
963        }
964
965        let state = inner.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
966            wakers: Vec::new(),
967            queued: VecDeque::new(),
968            is_streaming: false,
969            read_request_pending: false,
970        });
971
972        assert!(!state.is_streaming, "Initiated streaming twice!");
973        state.is_streaming = true;
974
975        inner.request(
976            ordinals::READ_SOCKET_STREAMING_START,
977            proto::SocketReadSocketStreamingStartRequest { handle: id },
978            Responder::Ignore,
979        );
980        Ok(())
981    }
982
983    /// Stop getting streaming events for socket reads. Doesn't return errors
984    /// because it's exclusively called in destructors where we have nothing to
985    /// do with them.
986    pub(crate) fn stop_socket_streaming(&self, id: proto::HandleId) {
987        let mut inner = self.0.lock();
988        if let Some(state) = inner.socket_read_states.get_mut(&id) {
989            if state.is_streaming {
990                state.is_streaming = false;
991                // TODO: Log?
992                let _ = inner.request(
993                    ordinals::READ_SOCKET_STREAMING_STOP,
994                    proto::ChannelReadChannelStreamingStopRequest { handle: id },
995                    Responder::Ignore,
996                );
997            }
998        }
999    }
1000
1001    /// Start getting streaming events for socket reads.
1002    pub(crate) fn start_channel_streaming(&self, id: proto::HandleId) -> Result<(), Error> {
1003        let mut inner = self.0.lock();
1004        if let Some(e) = inner.transport.error() {
1005            return Err(e.into());
1006        }
1007        let state = inner.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
1008            wakers: Vec::new(),
1009            queued: VecDeque::new(),
1010            is_streaming: false,
1011            read_request_pending: false,
1012        });
1013
1014        assert!(!state.is_streaming, "Initiated streaming twice!");
1015        state.is_streaming = true;
1016
1017        inner.request(
1018            ordinals::READ_CHANNEL_STREAMING_START,
1019            proto::ChannelReadChannelStreamingStartRequest { handle: id },
1020            Responder::Ignore,
1021        );
1022
1023        Ok(())
1024    }
1025
1026    /// Stop getting streaming events for socket reads. Doesn't return errors
1027    /// because it's exclusively called in destructors where we have nothing to
1028    /// do with them.
1029    pub(crate) fn stop_channel_streaming(&self, id: proto::HandleId) {
1030        let mut inner = self.0.lock();
1031        if let Some(state) = inner.channel_read_states.get_mut(&id) {
1032            if state.is_streaming {
1033                state.is_streaming = false;
1034                // TODO: Log?
1035                let _ = inner.request(
1036                    ordinals::READ_CHANNEL_STREAMING_STOP,
1037                    proto::ChannelReadChannelStreamingStopRequest { handle: id },
1038                    Responder::Ignore,
1039                );
1040            }
1041        }
1042    }
1043
1044    /// Execute a read from a channel.
1045    pub(crate) fn poll_socket(
1046        &self,
1047        id: proto::HandleId,
1048        ctx: &mut Context<'_>,
1049        out: &mut [u8],
1050    ) -> Poll<Result<usize, Error>> {
1051        let mut inner = self.0.lock();
1052        if let Some(error) = inner.transport.error() {
1053            return Poll::Ready(Err(error.into()));
1054        }
1055
1056        let state = inner.socket_read_states.entry(id).or_insert_with(|| SocketReadState {
1057            wakers: Vec::new(),
1058            queued: VecDeque::new(),
1059            is_streaming: false,
1060            read_request_pending: false,
1061        });
1062
1063        if let Some(got) = state.queued.front_mut() {
1064            match got.as_mut() {
1065                Ok(data) => {
1066                    let read_size = std::cmp::min(data.data.len(), out.len());
1067                    out[..read_size].copy_from_slice(&data.data[..read_size]);
1068
1069                    if data.data.len() > read_size && !data.is_datagram {
1070                        let _ = data.data.drain(..read_size);
1071                    } else {
1072                        let _ = state.queued.pop_front();
1073                    }
1074
1075                    return Poll::Ready(Ok(read_size));
1076                }
1077                Err(_) => {
1078                    let err = state.queued.pop_front().unwrap().unwrap_err();
1079                    return Poll::Ready(Err(err));
1080                }
1081            }
1082        } else if !state.wakers.iter().any(|x| ctx.waker().will_wake(x)) {
1083            state.wakers.push(ctx.waker().clone());
1084        }
1085
1086        if !state.read_request_pending && !state.is_streaming {
1087            inner.request(
1088                ordinals::READ_SOCKET,
1089                proto::SocketReadSocketRequest { handle: id, max_bytes: out.len() as u64 },
1090                Responder::ReadSocket(id),
1091            );
1092        }
1093
1094        Poll::Pending
1095    }
1096
1097    /// Execute a read from a channel.
1098    pub(crate) fn poll_channel(
1099        &self,
1100        id: proto::HandleId,
1101        ctx: &mut Context<'_>,
1102        for_stream: bool,
1103    ) -> Poll<Option<Result<proto::ChannelMessage, Error>>> {
1104        let mut inner = self.0.lock();
1105        if let Some(error) = inner.transport.error() {
1106            return Poll::Ready(Some(Err(error.into())));
1107        }
1108
1109        let state = inner.channel_read_states.entry(id).or_insert_with(|| ChannelReadState {
1110            wakers: Vec::new(),
1111            queued: VecDeque::new(),
1112            is_streaming: false,
1113            read_request_pending: false,
1114        });
1115
1116        if let Some(got) = state.queued.pop_front() {
1117            return Poll::Ready(Some(got));
1118        } else if for_stream && !state.is_streaming {
1119            return Poll::Ready(None);
1120        } else if !state.wakers.iter().any(|x| ctx.waker().will_wake(x)) {
1121            state.wakers.push(ctx.waker().clone());
1122        }
1123
1124        if !state.read_request_pending && !state.is_streaming {
1125            inner.request(
1126                ordinals::READ_CHANNEL,
1127                proto::ChannelReadChannelRequest { handle: id },
1128                Responder::ReadChannel(id),
1129            );
1130        }
1131
1132        Poll::Pending
1133    }
1134
1135    /// Check whether this channel is streaming
1136    pub(crate) fn channel_is_streaming(&self, id: proto::HandleId) -> bool {
1137        let inner = self.0.lock();
1138        let Some(state) = inner.channel_read_states.get(&id) else {
1139            return false;
1140        };
1141        state.is_streaming
1142    }
1143
1144    /// Check that all the given handles are safe to transfer through a channel
1145    /// e.g. that there's no chance of in-flight reads getting dropped.
1146    pub(crate) fn clear_handles_for_transfer(&self, handles: &proto::Handles) {
1147        let inner = self.0.lock();
1148        match handles {
1149            proto::Handles::Handles(handles) => {
1150                for handle in handles {
1151                    assert!(
1152                        !(inner.channel_read_states.contains_key(handle)
1153                            || inner.socket_read_states.contains_key(handle)),
1154                        "Tried to transfer handle after reading"
1155                    );
1156                }
1157            }
1158            proto::Handles::Dispositions(dispositions) => {
1159                for disposition in dispositions {
1160                    match &disposition.handle {
1161                        proto::HandleOp::Move_(handle) => assert!(
1162                            !(inner.channel_read_states.contains_key(handle)
1163                                || inner.socket_read_states.contains_key(handle)),
1164                            "Tried to transfer handle after reading"
1165                        ),
1166                        // Pretty sure this should be fine regardless of read state.
1167                        proto::HandleOp::Duplicate(_) => (),
1168                    }
1169                }
1170            }
1171        }
1172    }
1173}