Skip to main content

fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15/// A socket in a remote FDomain.
16#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21/// Disposition of a socket.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24    WriteEnabled,
25    WriteDisabled,
26}
27
28impl SocketDisposition {
29    /// Convert to a proto::SocketDisposition
30    fn proto(self) -> proto::SocketDisposition {
31        match self {
32            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34        }
35    }
36}
37
38impl Socket {
39    /// Read up to the given buffer's length from the socket.
40    pub fn fdomain_read<'a>(
41        &self,
42        buf: &'a mut [u8],
43    ) -> impl Future<Output = Result<usize, Error>> + 'a {
44        let client = Arc::downgrade(&self.0.client());
45        let handle = self.0.proto();
46
47        futures::future::poll_fn(move |ctx| {
48            client
49                .upgrade()
50                .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
51                .poll_socket(handle, ctx, buf)
52        })
53    }
54
55    /// Polls on reading this socket. Not to be confused with `AsyncRead::poll_read` which has a
56    /// different method of error reporting. That will handle errors, whereas this will return them
57    /// directly.
58    pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
59        let client = self.0.client();
60        client.poll_socket(self.0.proto(), ctx, out)
61    }
62
63    /// Write all of the given data to the socket.
64    pub fn fdomain_write_all(
65        &self,
66        bytes: &[u8],
67    ) -> impl Future<Output = Result<(), Error>> + use<> {
68        let data = bytes.to_vec();
69        let len = bytes.len();
70        let hid = self.0.proto();
71
72        let client = self.0.client();
73        client
74            .transaction(
75                ordinals::WRITE_SOCKET,
76                proto::SocketWriteSocketRequest { handle: hid, data },
77                move |x| Responder::WriteSocket(x),
78            )
79            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
80    }
81
82    /// Set the disposition of this socket and/or its peer.
83    pub fn set_socket_disposition(
84        &self,
85        disposition: Option<SocketDisposition>,
86        disposition_peer: Option<SocketDisposition>,
87    ) -> impl Future<Output = Result<(), Error>> {
88        let disposition =
89            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
90        let disposition_peer = disposition_peer
91            .map(SocketDisposition::proto)
92            .unwrap_or(proto::SocketDisposition::NoChange);
93        let client = self.0.client();
94        let handle = self.0.proto();
95        client.transaction(
96            ordinals::SET_SOCKET_DISPOSITION,
97            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
98            Responder::SetSocketDisposition,
99        )
100    }
101
102    /// Split this socket into a streaming reader and a writer. This is more
103    /// efficient on the read side if you intend to consume all of the data from
104    /// the socket. However it will prevent you from transferring the handle in
105    /// the future. It also means data will build up in the buffer, so it may
106    /// lead to memory issues if you don't intend to use the data from the
107    /// socket as fast as it comes.
108    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
109        self.0.client().start_socket_streaming(self.0.proto())?;
110
111        let a = Arc::new(self);
112        let b = Arc::clone(&a);
113
114        Ok((SocketReadStream(a), SocketWriter(b)))
115    }
116}
117
118/// A write-only handle to a socket.
119pub struct SocketWriter(Arc<Socket>);
120
121impl SocketWriter {
122    /// Write all of the given data to the socket.
123    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
124        self.0.fdomain_write_all(bytes)
125    }
126}
127
128/// A stream of data issuing from a socket.
129pub struct SocketReadStream(Arc<Socket>);
130
131impl SocketReadStream {
132    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
133    pub async fn fdomain_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
134        self.0.fdomain_read(buf).await
135    }
136}
137
138impl futures::AsyncRead for SocketReadStream {
139    fn poll_read(
140        self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        buf: &mut [u8],
143    ) -> Poll<std::io::Result<usize>> {
144        convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
145    }
146}
147
148impl futures::AsyncRead for &SocketReadStream {
149    fn poll_read(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &mut [u8],
153    ) -> Poll<std::io::Result<usize>> {
154        convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
155    }
156}
157
158impl Drop for SocketReadStream {
159    fn drop(&mut self) {
160        if let Some(client) = self.0.0.client.upgrade() {
161            client.stop_socket_streaming(self.0.0.proto());
162        }
163    }
164}
165
166/// Wrapper for [`Client::poll_socket`] that adapts the return value semantics
167/// to what Unix prescribes, and what `futures::io` thus prescribes.
168fn convert_poll_res_to_async_read(
169    poll_res: Poll<Result<usize, Error>>,
170) -> Poll<std::io::Result<usize>> {
171    let res = ready!(poll_res).or_else(|e| match e {
172        Error::FDomain(proto::Error::TargetError(e))
173            if e == zx_status::Status::PEER_CLOSED.into_raw() =>
174        {
175            Ok(0)
176        }
177        other => Err(std::io::Error::other(other)),
178    });
179    Poll::Ready(res)
180}
181
182impl futures::AsyncRead for Socket {
183    fn poll_read(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186        buf: &mut [u8],
187    ) -> Poll<std::io::Result<usize>> {
188        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
189    }
190}
191
192impl futures::AsyncRead for &Socket {
193    fn poll_read(
194        self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        buf: &mut [u8],
197    ) -> Poll<std::io::Result<usize>> {
198        convert_poll_res_to_async_read(self.poll_socket(cx, buf))
199    }
200}
201
202impl futures::AsyncWrite for Socket {
203    fn poll_write(
204        self: Pin<&mut Self>,
205        _cx: &mut Context<'_>,
206        buf: &[u8],
207    ) -> Poll<std::io::Result<usize>> {
208        let _ = self.fdomain_write_all(buf);
209        Poll::Ready(Ok(buf.len()))
210    }
211
212    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
213        Poll::Ready(Ok(()))
214    }
215
216    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
217        self.0 = Handle::invalid();
218        Poll::Ready(Ok(()))
219    }
220}