fdomain_client/
socket.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15/// A socket in a remote FDomain.
16#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21/// Disposition of a socket.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24    WriteEnabled,
25    WriteDisabled,
26}
27
28impl SocketDisposition {
29    /// Convert to a proto::SocketDisposition
30    fn proto(self) -> proto::SocketDisposition {
31        match self {
32            SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33            SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34        }
35    }
36}
37
38impl Socket {
39    /// Read up to the given buffer's length from the socket.
40    pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41        let client = Arc::downgrade(&self.0.client());
42        let handle = self.0.proto();
43
44        futures::future::poll_fn(move |ctx| {
45            client
46                .upgrade()
47                .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
48                .poll_socket(handle, ctx, buf)
49        })
50    }
51
52    /// Write all of the given data to the socket.
53    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
54        let data = bytes.to_vec();
55        let len = bytes.len();
56        let hid = self.0.proto();
57
58        let client = self.0.client();
59        client
60            .transaction(
61                ordinals::WRITE_SOCKET,
62                proto::SocketWriteSocketRequest { handle: hid, data },
63                move |x| Responder::WriteSocket(x),
64            )
65            .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
66    }
67
68    /// Set the disposition of this socket and/or its peer.
69    pub fn set_socket_disposition(
70        &self,
71        disposition: Option<SocketDisposition>,
72        disposition_peer: Option<SocketDisposition>,
73    ) -> impl Future<Output = Result<(), Error>> {
74        let disposition =
75            disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
76        let disposition_peer = disposition_peer
77            .map(SocketDisposition::proto)
78            .unwrap_or(proto::SocketDisposition::NoChange);
79        let client = self.0.client();
80        let handle = self.0.proto();
81        client.transaction(
82            ordinals::SET_SOCKET_DISPOSITION,
83            proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
84            Responder::SetSocketDisposition,
85        )
86    }
87
88    /// Split this socket into a streaming reader and a writer. This is more
89    /// efficient on the read side if you intend to consume all of the data from
90    /// the socket. However it will prevent you from transferring the handle in
91    /// the future. It also means data will build up in the buffer, so it may
92    /// lead to memory issues if you don't intend to use the data from the
93    /// socket as fast as it comes.
94    pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
95        self.0.client().start_socket_streaming(self.0.proto())?;
96
97        let a = Arc::new(self);
98        let b = Arc::clone(&a);
99
100        Ok((SocketReadStream(a), SocketWriter(b)))
101    }
102}
103
104/// A write-only handle to a socket.
105pub struct SocketWriter(Arc<Socket>);
106
107impl SocketWriter {
108    /// Write all of the given data to the socket.
109    pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
110        self.0.write_all(bytes)
111    }
112}
113
114/// A stream of data issuing from a socket.
115pub struct SocketReadStream(Arc<Socket>);
116
117impl SocketReadStream {
118    /// Read from the socket into the supplied buffer. Returns the number of bytes read.
119    pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
120        self.0.read(buf).await
121    }
122}
123
124impl Drop for SocketReadStream {
125    fn drop(&mut self) {
126        if let Some(client) = self.0.0.client.upgrade() {
127            client.stop_socket_streaming(self.0.0.proto());
128        }
129    }
130}
131
132/// Wrapper for [`Client::poll_socket`] that adapts the return value semantics
133/// to what Unix prescribes, and what `futures::io` thus prescribes.
134fn async_read_poll_socket(
135    client: Arc<crate::Client>,
136    proto: proto::HandleId,
137    cx: &mut Context<'_>,
138    buf: &mut [u8],
139) -> Poll<std::io::Result<usize>> {
140    let res = ready!(client.poll_socket(proto, cx, buf)).or_else(|e| match e {
141        Error::FDomain(proto::Error::TargetError(e))
142            if e == zx_status::Status::PEER_CLOSED.into_raw() =>
143        {
144            Ok(0)
145        }
146        other => Err(std::io::Error::other(other)),
147    });
148    Poll::Ready(res)
149}
150
151impl futures::AsyncRead for Socket {
152    fn poll_read(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut [u8],
156    ) -> Poll<std::io::Result<usize>> {
157        let client = self.0.client();
158        async_read_poll_socket(client, self.0.proto(), cx, buf)
159    }
160}
161
162impl futures::AsyncRead for &Socket {
163    fn poll_read(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        buf: &mut [u8],
167    ) -> Poll<std::io::Result<usize>> {
168        let client = self.0.client();
169        async_read_poll_socket(client, self.0.proto(), cx, buf)
170    }
171}
172
173impl futures::AsyncWrite for Socket {
174    fn poll_write(
175        self: Pin<&mut Self>,
176        _cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<std::io::Result<usize>> {
179        let _ = self.write_all(buf);
180        Poll::Ready(Ok(buf.len()))
181    }
182
183    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
184        Poll::Ready(Ok(()))
185    }
186
187    fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
188        self.0 = Handle::invalid();
189        Poll::Ready(Ok(()))
190    }
191}