fdomain_client/
channel.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Event, EventPair, Handle, OnFDomainSignals, Socket, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::future::Either;
10use futures::stream::Stream;
11use std::future::Future;
12use std::pin::Pin;
13use std::sync::{Arc, Weak};
14use std::task::{Context, Poll, ready};
15
16/// A channel in a remote FDomain.
17#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
18pub struct Channel(pub(crate) Handle);
19
20handle_type!(Channel CHANNEL peered);
21
22/// A message which has been read from a channel.
23#[derive(Debug)]
24pub struct MessageBuf {
25    pub bytes: Vec<u8>,
26    pub handles: Vec<HandleInfo>,
27}
28
29impl MessageBuf {
30    /// Create a new [`MessageBuf`]
31    pub fn new() -> Self {
32        MessageBuf { bytes: Vec::new(), handles: Vec::new() }
33    }
34
35    /// Get the components of this buffer separately.
36    pub fn split(self) -> (Vec<u8>, Vec<HandleInfo>) {
37        (self.bytes, self.handles)
38    }
39
40    /// Make sure this buffer has room for a certain number of bytes.
41    pub fn ensure_capacity_bytes(&mut self, bytes: usize) {
42        self.bytes.reserve(bytes);
43    }
44
45    /// Clear out the contents of this buffer.
46    pub fn clear(&mut self) {
47        self.bytes.clear();
48        self.handles.clear();
49    }
50
51    /// Get the byte content of this buffer.
52    pub fn bytes(&self) -> &[u8] {
53        self.bytes.as_slice()
54    }
55
56    /// Convert a proto ChannelMessage to a MessageBuf.
57    fn from_proto(client: &Arc<crate::Client>, message: proto::ChannelMessage) -> MessageBuf {
58        let proto::ChannelMessage { data, handles } = message;
59        MessageBuf {
60            bytes: data,
61            handles: handles
62                .into_iter()
63                .map(|info| {
64                    let handle = Handle { id: info.handle.id, client: Arc::downgrade(client) };
65                    HandleInfo {
66                        rights: info.rights,
67                        handle: AnyHandle::from_handle(handle, info.type_),
68                    }
69                })
70                .collect(),
71        }
72    }
73}
74
75/// A handle which has been read from a channel.
76#[derive(Debug)]
77pub struct HandleInfo {
78    pub handle: AnyHandle,
79    pub rights: fidl::Rights,
80}
81
82/// Sum type of all the handle types which can be read from a channel. Allows
83/// the user to learn the type of a handle after it has been read.
84#[derive(Debug)]
85pub enum AnyHandle {
86    Channel(Channel),
87    Socket(Socket),
88    Event(Event),
89    EventPair(EventPair),
90    Unknown(Handle, fidl::ObjectType),
91}
92
93impl AnyHandle {
94    /// Construct an `AnyHandle` from a `Handle` and an object type.
95    pub fn from_handle(handle: Handle, ty: fidl::ObjectType) -> AnyHandle {
96        match ty {
97            fidl::ObjectType::CHANNEL => AnyHandle::Channel(Channel(handle)),
98            fidl::ObjectType::SOCKET => AnyHandle::Socket(Socket(handle)),
99            fidl::ObjectType::EVENT => AnyHandle::Event(Event(handle)),
100            fidl::ObjectType::EVENTPAIR => AnyHandle::EventPair(EventPair(handle)),
101            _ => AnyHandle::Unknown(handle, ty),
102        }
103    }
104
105    /// Get an `AnyHandle` wrapping an invalid handle.
106    pub fn invalid() -> AnyHandle {
107        AnyHandle::Unknown(Handle::invalid(), fidl::ObjectType::NONE)
108    }
109
110    /// Get the object type for a handle.
111    pub fn object_type(&self) -> fidl::ObjectType {
112        match self {
113            AnyHandle::Channel(_) => fidl::ObjectType::CHANNEL,
114            AnyHandle::Socket(_) => fidl::ObjectType::SOCKET,
115            AnyHandle::Event(_) => fidl::ObjectType::EVENT,
116            AnyHandle::EventPair(_) => fidl::ObjectType::EVENTPAIR,
117            AnyHandle::Unknown(_, t) => *t,
118        }
119    }
120}
121
122impl From<AnyHandle> for Handle {
123    fn from(item: AnyHandle) -> Handle {
124        match item {
125            AnyHandle::Channel(h) => h.into(),
126            AnyHandle::Socket(h) => h.into(),
127            AnyHandle::Event(h) => h.into(),
128            AnyHandle::EventPair(h) => h.into(),
129            AnyHandle::Unknown(h, _) => h,
130        }
131    }
132}
133
134/// Operation to perform on a handle when writing it to a channel.
135pub enum HandleOp<'h> {
136    Move(Handle, fidl::Rights),
137    Duplicate(&'h Handle, fidl::Rights),
138}
139
140impl Channel {
141    /// Reads a message from the channel.
142    pub fn recv_msg(&self) -> impl Future<Output = Result<MessageBuf, Error>> + use<> {
143        let client = Arc::downgrade(&self.0.client());
144        let handle = self.0.proto();
145
146        futures::future::poll_fn(move |ctx| {
147            let client = client.upgrade().unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT));
148            client.poll_channel(handle, ctx, false).map(|x| {
149                x.expect("Got stream termination indication from non-streaming read!")
150                    .map(|x| MessageBuf::from_proto(&client, x))
151            })
152        })
153    }
154
155    /// Poll to try and read a channel message.
156    pub fn poll_read(&self, cx: &mut Context<'_>) -> Poll<Result<MessageBuf, Error>> {
157        let client = self.0.client();
158        let handle = self.0.proto();
159
160        client.poll_channel(handle, cx, false).map(|x| {
161            x.expect("Got stream termination indication from non-streaming read!")
162                .map(|x| MessageBuf::from_proto(&client, x))
163        })
164    }
165
166    /// Poll a channel for a message to read.
167    pub fn recv_from(&self, cx: &mut Context<'_>, buf: &mut MessageBuf) -> Poll<Result<(), Error>> {
168        let client = self.0.client();
169        match ready!(client.poll_channel(self.0.proto(), cx, false))
170            .expect("Got stream termination indication from non-streaming read!")
171        {
172            Ok(msg) => {
173                *buf = MessageBuf::from_proto(&client, msg);
174                Poll::Ready(Ok(()))
175            }
176            Err(e) => Poll::Ready(Err(e)),
177        }
178    }
179
180    /// Writes a message into the channel.
181    pub fn write(&self, bytes: &[u8], handles: Vec<Handle>) -> Result<(), Error> {
182        if bytes.len() > zx_types::ZX_CHANNEL_MAX_MSG_BYTES as usize
183            || handles.len() > zx_types::ZX_CHANNEL_MAX_MSG_HANDLES as usize
184        {
185            return Err(Error::FDomain(proto::Error::TargetError(
186                fidl::Status::OUT_OF_RANGE.into_raw(),
187            )));
188        }
189
190        let _ = self.write_inner(
191            bytes,
192            proto::Handles::Handles(handles.into_iter().map(|x| x.take_proto()).collect()),
193        );
194        Ok(())
195    }
196
197    /// Writes a message into the channel. Returns a future that will allow you
198    /// to wait for the write to move across the FDomain connection and return
199    /// with the result of the actual write call on the target.
200    pub fn fdomain_write(
201        &self,
202        bytes: &[u8],
203        handles: Vec<Handle>,
204    ) -> impl Future<Output = Result<(), Error>> + use<> {
205        if bytes.len() > zx_types::ZX_CHANNEL_MAX_MSG_BYTES as usize
206            || handles.len() > zx_types::ZX_CHANNEL_MAX_MSG_HANDLES as usize
207        {
208            Either::Left(async {
209                Err(Error::FDomain(proto::Error::TargetError(
210                    fidl::Status::OUT_OF_RANGE.into_raw(),
211                )))
212            })
213        } else {
214            Either::Right(self.write_inner(
215                bytes,
216                proto::Handles::Handles(handles.into_iter().map(|x| x.take_proto()).collect()),
217            ))
218        }
219    }
220
221    /// A future that returns when the channel is closed.
222    pub fn on_closed(&self) -> OnFDomainSignals {
223        OnFDomainSignals::new(&self.0, fidl::Signals::OBJECT_PEER_CLOSED)
224    }
225
226    /// Whether this handle is closed.
227    pub fn is_closed(&self) -> bool {
228        self.0.client.upgrade().is_none()
229    }
230
231    /// Writes a message into the channel. Optionally duplicates some of the
232    /// handles rather than consuming them, and can update the handle's rights
233    /// before sending.
234    pub fn fdomain_write_etc<'b>(
235        &self,
236        bytes: &[u8],
237        handles: Vec<HandleOp<'b>>,
238    ) -> impl Future<Output = Result<(), Error>> + use<'b> {
239        let handles = handles
240            .into_iter()
241            .map(|handle| match handle {
242                HandleOp::Move(x, rights) => {
243                    if Weak::ptr_eq(&x.client, &self.0.client) {
244                        Ok(proto::HandleDisposition {
245                            handle: proto::HandleOp::Move_(x.take_proto()),
246                            rights,
247                        })
248                    } else {
249                        Err(Error::ConnectionMismatch)
250                    }
251                }
252                HandleOp::Duplicate(x, rights) => {
253                    if Weak::ptr_eq(&x.client, &self.0.client) {
254                        Ok(proto::HandleDisposition {
255                            handle: proto::HandleOp::Duplicate(x.proto()),
256                            rights,
257                        })
258                    } else {
259                        Err(Error::ConnectionMismatch)
260                    }
261                }
262            })
263            .collect::<Result<Vec<_>, Error>>();
264
265        let handles = if handles
266            .as_ref()
267            .map(|x| x.len() > zx_types::ZX_CHANNEL_MAX_MSG_HANDLES as usize)
268            .unwrap_or(false)
269            || bytes.len() > zx_types::ZX_CHANNEL_MAX_MSG_BYTES as usize
270        {
271            Err(Error::FDomain(proto::Error::TargetError(fidl::Status::OUT_OF_RANGE.into_raw())))
272        } else {
273            handles
274        };
275
276        match handles {
277            Ok(handles) => {
278                Either::Left(self.write_inner(bytes, proto::Handles::Dispositions(handles)))
279            }
280            Err(e) => Either::Right(async move { Err(e) }),
281        }
282    }
283
284    /// Writes a message into the channel.
285    fn write_inner(
286        &self,
287        bytes: &[u8],
288        handles: proto::Handles,
289    ) -> impl Future<Output = Result<(), Error>> + use<> {
290        let data = bytes.to_vec();
291        let client = self.0.client();
292        let handle = self.0.proto();
293
294        client.clear_handles_for_transfer(&handles);
295        client.transaction(
296            ordinals::WRITE_CHANNEL,
297            proto::ChannelWriteChannelRequest { handle, data, handles },
298            move |x| Responder::WriteChannel(x),
299        )
300    }
301
302    /// Split this channel into a streaming reader and a writer. This is more
303    /// efficient on the read side if you intend to consume all of the messages
304    /// from the channel. However it will prevent you from transferring the
305    /// handle in the future. It also means messages will build up in the
306    /// buffer, so it may lead to memory issues if you don't intend to use the
307    /// messages from the channel as fast as they come.
308    pub fn stream(self) -> Result<(ChannelMessageStream, ChannelWriter), Error> {
309        self.0.client().start_channel_streaming(self.0.proto())?;
310
311        let a = Arc::new(self);
312        let b = Arc::clone(&a);
313
314        Ok((ChannelMessageStream(a), ChannelWriter(b)))
315    }
316}
317
318/// A write-only handle to a socket.
319#[derive(Debug, Clone)]
320pub struct ChannelWriter(Arc<Channel>);
321
322impl ChannelWriter {
323    /// Writes a message into the channel.
324    pub fn write(&self, bytes: &[u8], handles: Vec<Handle>) -> Result<(), Error> {
325        self.0.write(bytes, handles)
326    }
327
328    /// Writes a message into the channel. Returns a future that will allow you
329    /// to wait for the write to move across the FDomain connection and return
330    /// with the result of the actual write call on the target.
331    pub fn fdomain_write(
332        &self,
333        bytes: &[u8],
334        handles: Vec<Handle>,
335    ) -> impl Future<Output = Result<(), Error>> {
336        self.0.fdomain_write(bytes, handles)
337    }
338
339    /// Writes a message into the channel.
340    pub fn fdomain_write_etc<'b>(
341        &self,
342        bytes: &[u8],
343        handles: Vec<HandleOp<'b>>,
344    ) -> impl Future<Output = Result<(), Error>> + 'b {
345        self.0.fdomain_write_etc(bytes, handles)
346    }
347
348    /// Get a reference to the inner channel.
349    pub fn as_channel(&self) -> &Channel {
350        &*self.0
351    }
352}
353
354/// A stream of data issuing from a socket.
355#[derive(Debug)]
356pub struct ChannelMessageStream(Arc<Channel>);
357
358impl ChannelMessageStream {
359    /// Turn a `ChannelMessageStream` and its accompanying `ChannelWriter` back
360    /// into a `Channel`.
361    ///
362    /// # Panics
363    /// If this stream and the writer passed didn't come from the same call to
364    /// `Channel::stream`, or if there is more than one writer.
365    pub fn rejoin(mut self, writer: ChannelWriter) -> Channel {
366        assert!(Arc::ptr_eq(&self.0, &writer.0), "Tried to join stream with wrong writer!");
367        if let Some(client) = self.0.0.client.upgrade() {
368            client.stop_channel_streaming(self.0.0.proto())
369        }
370        std::mem::drop(writer);
371        let channel = std::mem::replace(&mut self.0, Arc::new(Channel(Handle::invalid())));
372        Arc::try_unwrap(channel).expect("Stream pointer no longer unique!")
373    }
374
375    /// Whether this stream is closed.
376    pub fn is_closed(&self) -> bool {
377        let client = self.0.0.client();
378
379        !client.channel_is_streaming(self.0.0.proto())
380    }
381
382    /// Get a reference to the inner channel.
383    pub fn as_channel(&self) -> &Channel {
384        &*self.0
385    }
386}
387
388impl Stream for ChannelMessageStream {
389    type Item = Result<MessageBuf, Error>;
390    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
391        let client = self.0.0.client();
392        client
393            .poll_channel(self.0.0.proto(), ctx, true)
394            .map(|x| x.map(|x| x.map(|x| MessageBuf::from_proto(&client, x))))
395    }
396}
397
398impl Drop for ChannelMessageStream {
399    fn drop(&mut self) {
400        if let Some(client) = self.0.0.client.upgrade() {
401            client.stop_channel_streaming(self.0.0.proto());
402        }
403    }
404}