1use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24 WriteEnabled,
25 WriteDisabled,
26}
27
28impl SocketDisposition {
29 fn proto(self) -> proto::SocketDisposition {
31 match self {
32 SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33 SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34 }
35 }
36}
37
38impl Socket {
39 pub fn read<'a>(&self, buf: &'a mut [u8]) -> impl Future<Output = Result<usize, Error>> + 'a {
41 let client = Arc::downgrade(&self.0.client());
42 let handle = self.0.proto();
43
44 futures::future::poll_fn(move |ctx| {
45 client
46 .upgrade()
47 .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
48 .poll_socket(handle, ctx, buf)
49 })
50 }
51
52 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
54 let data = bytes.to_vec();
55 let len = bytes.len();
56 let hid = self.0.proto();
57
58 let client = self.0.client();
59 client
60 .transaction(
61 ordinals::WRITE_SOCKET,
62 proto::SocketWriteSocketRequest { handle: hid, data },
63 move |x| Responder::WriteSocket(x),
64 )
65 .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
66 }
67
68 pub fn set_socket_disposition(
70 &self,
71 disposition: Option<SocketDisposition>,
72 disposition_peer: Option<SocketDisposition>,
73 ) -> impl Future<Output = Result<(), Error>> {
74 let disposition =
75 disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
76 let disposition_peer = disposition_peer
77 .map(SocketDisposition::proto)
78 .unwrap_or(proto::SocketDisposition::NoChange);
79 let client = self.0.client();
80 let handle = self.0.proto();
81 client.transaction(
82 ordinals::SET_SOCKET_DISPOSITION,
83 proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
84 Responder::SetSocketDisposition,
85 )
86 }
87
88 pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
95 self.0.client().start_socket_streaming(self.0.proto())?;
96
97 let a = Arc::new(self);
98 let b = Arc::clone(&a);
99
100 Ok((SocketReadStream(a), SocketWriter(b)))
101 }
102}
103
104pub struct SocketWriter(Arc<Socket>);
106
107impl SocketWriter {
108 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
110 self.0.write_all(bytes)
111 }
112}
113
114pub struct SocketReadStream(Arc<Socket>);
116
117impl SocketReadStream {
118 pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
120 self.0.read(buf).await
121 }
122}
123
124impl Drop for SocketReadStream {
125 fn drop(&mut self) {
126 if let Some(client) = self.0.0.client.upgrade() {
127 client.stop_socket_streaming(self.0.0.proto());
128 }
129 }
130}
131
132fn async_read_poll_socket(
135 client: Arc<crate::Client>,
136 proto: proto::HandleId,
137 cx: &mut Context<'_>,
138 buf: &mut [u8],
139) -> Poll<std::io::Result<usize>> {
140 let res = ready!(client.poll_socket(proto, cx, buf)).or_else(|e| match e {
141 Error::FDomain(proto::Error::TargetError(e))
142 if e == zx_status::Status::PEER_CLOSED.into_raw() =>
143 {
144 Ok(0)
145 }
146 other => Err(std::io::Error::other(other)),
147 });
148 Poll::Ready(res)
149}
150
151impl futures::AsyncRead for Socket {
152 fn poll_read(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut [u8],
156 ) -> Poll<std::io::Result<usize>> {
157 let client = self.0.client();
158 async_read_poll_socket(client, self.0.proto(), cx, buf)
159 }
160}
161
162impl futures::AsyncRead for &Socket {
163 fn poll_read(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &mut [u8],
167 ) -> Poll<std::io::Result<usize>> {
168 let client = self.0.client();
169 async_read_poll_socket(client, self.0.proto(), cx, buf)
170 }
171}
172
173impl futures::AsyncWrite for Socket {
174 fn poll_write(
175 self: Pin<&mut Self>,
176 _cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<std::io::Result<usize>> {
179 let _ = self.write_all(buf);
180 Poll::Ready(Ok(buf.len()))
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
184 Poll::Ready(Ok(()))
185 }
186
187 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
188 self.0 = Handle::invalid();
189 Poll::Ready(Ok(()))
190 }
191}