fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{ordinals, Error, Handle};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::sync::Arc;
12
13/// A socket in a remote FDomain.
14#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
15pub struct Socket(pub(crate) Handle);
16
17handle_type!(Socket SOCKET peered);
18
19/// Disposition of a socket.
20#[derive(Copy, Clone, Debug, PartialEq, Eq)]
21pub enum SocketDisposition {
22    WriteEnabled,
23    WriteDisabled,
24}
25
26impl SocketDisposition {
27    /// Convert to a proto::SocketDisposition
28    fn proto(self) -> proto::SocketDisposition {
29        match self {
30            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
31            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
32        }
33    }
34}
35
36impl Socket {
37    /// Read up to the given buffer's length from the socket.
38    pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
39        let client = self.0.client();
40        let handle = self.0.proto();
41
42        futures::future::poll_fn(move |ctx| client.poll_socket(handle, ctx, buf))
43    }
44
45    /// Write all of the given data to the socket.
46    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
47        let data = bytes.to_vec();
48        let len = bytes.len();
49        let hid = self.0.proto();
50
51        let client = self.0.client();
52        client
53            .transaction(
54                ordinals::WRITE_SOCKET,
55                proto::SocketWriteSocketRequest { handle: hid, data },
56                move |x| Responder::WriteSocket(x),
57            )
58            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
59    }
60
61    /// Set the disposition of this socket and/or its peer.
62    pub fn set_socket_disposition(
63        &self,
64        disposition: Option<SocketDisposition>,
65        disposition_peer: Option<SocketDisposition>,
66    ) -> impl Future<Output = Result<(), Error>> {
67        let disposition =
68            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
69        let disposition_peer = disposition_peer
70            .map(SocketDisposition::proto)
71            .unwrap_or(proto::SocketDisposition::NoChange);
72        let client = self.0.client();
73        let handle = self.0.proto();
74        client.transaction(
75            ordinals::SET_SOCKET_DISPOSITION,
76            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
77            Responder::SetSocketDisposition,
78        )
79    }
80
81    /// Split this socket into a streaming reader and a writer. This is more
82    /// efficient on the read side if you intend to consume all of the data from
83    /// the socket. However it will prevent you from transferring the handle in
84    /// the future. It also means data will build up in the buffer, so it may
85    /// lead to memory issues if you don't intend to use the data from the
86    /// socket as fast as it comes.
87    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
88        self.0.client().start_socket_streaming(self.0.proto())?;
89
90        let a = Arc::new(self);
91        let b = Arc::clone(&a);
92
93        Ok((SocketReadStream(a), SocketWriter(b)))
94    }
95}
96
97/// A write-only handle to a socket.
98pub struct SocketWriter(Arc<Socket>);
99
100impl SocketWriter {
101    /// Write all of the given data to the socket.
102    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
103        self.0.write_all(bytes)
104    }
105}
106
107/// A stream of data issuing from a socket.
108pub struct SocketReadStream(Arc<Socket>);
109
110impl SocketReadStream {
111    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
112    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
113        self.0.read(buf).await
114    }
115}
116
117impl Drop for SocketReadStream {
118    fn drop(&mut self) {
119        if let Some(client) = self.0 .0.client.upgrade() {
120            client.stop_socket_streaming(self.0 .0.proto());
121        }
122    }
123}
124
125impl futures::AsyncRead for Socket {
126    fn poll_read(
127        self: std::pin::Pin<&mut Self>,
128        cx: &mut std::task::Context<'_>,
129        buf: &mut [u8],
130    ) -> std::task::Poll<std::io::Result<usize>> {
131        let client = self.0.client();
132        client.poll_socket(self.0.proto(), cx, buf).map_err(std::io::Error::other)
133    }
134}
135
136impl futures::AsyncWrite for Socket {
137    fn poll_write(
138        self: std::pin::Pin<&mut Self>,
139        _cx: &mut std::task::Context<'_>,
140        buf: &[u8],
141    ) -> std::task::Poll<std::io::Result<usize>> {
142        let _ = self.write_all(buf);
143        std::task::Poll::Ready(Ok(buf.len()))
144    }
145
146    fn poll_flush(
147        self: std::pin::Pin<&mut Self>,
148        _cx: &mut std::task::Context<'_>,
149    ) -> std::task::Poll<std::io::Result<()>> {
150        std::task::Poll::Ready(Ok(()))
151    }
152
153    fn poll_close(
154        mut self: std::pin::Pin<&mut Self>,
155        _cx: &mut std::task::Context<'_>,
156    ) -> std::task::Poll<std::io::Result<()>> {
157        self.0 = Handle::invalid();
158        std::task::Poll::Ready(Ok(()))
159    }
160}