1use crate::handle::handle_type;
6use crate::responder::Responder;
7use crate::{Error, Handle, ordinals};
8use fidl_fuchsia_fdomain as proto;
9use futures::FutureExt;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll, ready};
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
17pub struct Socket(pub(crate) Handle);
18
19handle_type!(Socket SOCKET peered);
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum SocketDisposition {
24 WriteEnabled,
25 WriteDisabled,
26}
27
28impl SocketDisposition {
29 fn proto(self) -> proto::SocketDisposition {
31 match self {
32 SocketDisposition::WriteEnabled => proto::SocketDisposition::WriteEnabled,
33 SocketDisposition::WriteDisabled => proto::SocketDisposition::WriteDisabled,
34 }
35 }
36}
37
38impl Socket {
39 pub fn fdomain_read<'a>(
41 &self,
42 buf: &'a mut [u8],
43 ) -> impl Future<Output = Result<usize, Error>> + 'a {
44 let client = Arc::downgrade(&self.0.client());
45 let handle = self.0.proto();
46
47 futures::future::poll_fn(move |ctx| {
48 client
49 .upgrade()
50 .unwrap_or_else(|| Arc::clone(&crate::DEAD_CLIENT))
51 .poll_socket(handle, ctx, buf)
52 })
53 }
54
55 pub fn poll_socket(&self, ctx: &mut Context<'_>, out: &mut [u8]) -> Poll<Result<usize, Error>> {
59 let client = self.0.client();
60 client.poll_socket(self.0.proto(), ctx, out)
61 }
62
63 pub fn fdomain_write_all(
65 &self,
66 bytes: &[u8],
67 ) -> impl Future<Output = Result<(), Error>> + use<> {
68 let data = bytes.to_vec();
69 let len = bytes.len();
70 let hid = self.0.proto();
71
72 let client = self.0.client();
73 client
74 .transaction(
75 ordinals::WRITE_SOCKET,
76 proto::SocketWriteSocketRequest { handle: hid, data },
77 move |x| Responder::WriteSocket(x),
78 )
79 .map(move |x| x.map(|y| assert!(y.wrote as usize == len)))
80 }
81
82 pub fn set_socket_disposition(
84 &self,
85 disposition: Option<SocketDisposition>,
86 disposition_peer: Option<SocketDisposition>,
87 ) -> impl Future<Output = Result<(), Error>> {
88 let disposition =
89 disposition.map(SocketDisposition::proto).unwrap_or(proto::SocketDisposition::NoChange);
90 let disposition_peer = disposition_peer
91 .map(SocketDisposition::proto)
92 .unwrap_or(proto::SocketDisposition::NoChange);
93 let client = self.0.client();
94 let handle = self.0.proto();
95 client.transaction(
96 ordinals::SET_SOCKET_DISPOSITION,
97 proto::SocketSetSocketDispositionRequest { handle, disposition, disposition_peer },
98 Responder::SetSocketDisposition,
99 )
100 }
101
102 pub fn stream(self) -> Result<(SocketReadStream, SocketWriter), Error> {
109 self.0.client().start_socket_streaming(self.0.proto())?;
110
111 let a = Arc::new(self);
112 let b = Arc::clone(&a);
113
114 Ok((SocketReadStream(a), SocketWriter(b)))
115 }
116}
117
118pub struct SocketWriter(Arc<Socket>);
120
121impl SocketWriter {
122 pub fn write_all(&self, bytes: &[u8]) -> impl Future<Output = Result<(), Error>> {
124 self.0.fdomain_write_all(bytes)
125 }
126}
127
128pub struct SocketReadStream(Arc<Socket>);
130
131impl SocketReadStream {
132 pub async fn fdomain_read(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
134 self.0.fdomain_read(buf).await
135 }
136}
137
138impl futures::AsyncRead for SocketReadStream {
139 fn poll_read(
140 self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 buf: &mut [u8],
143 ) -> Poll<std::io::Result<usize>> {
144 convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
145 }
146}
147
148impl futures::AsyncRead for &SocketReadStream {
149 fn poll_read(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &mut [u8],
153 ) -> Poll<std::io::Result<usize>> {
154 convert_poll_res_to_async_read(self.0.poll_socket(cx, buf))
155 }
156}
157
158impl Drop for SocketReadStream {
159 fn drop(&mut self) {
160 if let Some(client) = self.0.0.client.upgrade() {
161 client.stop_socket_streaming(self.0.0.proto());
162 }
163 }
164}
165
166fn convert_poll_res_to_async_read(
169 poll_res: Poll<Result<usize, Error>>,
170) -> Poll<std::io::Result<usize>> {
171 let res = ready!(poll_res).or_else(|e| match e {
172 Error::FDomain(proto::Error::TargetError(e))
173 if e == zx_status::Status::PEER_CLOSED.into_raw() =>
174 {
175 Ok(0)
176 }
177 other => Err(std::io::Error::other(other)),
178 });
179 Poll::Ready(res)
180}
181
182impl futures::AsyncRead for Socket {
183 fn poll_read(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 buf: &mut [u8],
187 ) -> Poll<std::io::Result<usize>> {
188 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
189 }
190}
191
192impl futures::AsyncRead for &Socket {
193 fn poll_read(
194 self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 buf: &mut [u8],
197 ) -> Poll<std::io::Result<usize>> {
198 convert_poll_res_to_async_read(self.poll_socket(cx, buf))
199 }
200}
201
202impl futures::AsyncWrite for Socket {
203 fn poll_write(
204 self: Pin<&mut Self>,
205 _cx: &mut Context<'_>,
206 buf: &[u8],
207 ) -> Poll<std::io::Result<usize>> {
208 let _ = self.fdomain_write_all(buf);
209 Poll::Ready(Ok(buf.len()))
210 }
211
212 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
213 Poll::Ready(Ok(()))
214 }
215
216 fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
217 self.0 = Handle::invalid();
218 Poll::Ready(Ok(()))
219 }
220}