1use std::num::{NonZeroU32, TryFromIntError};
6use std::time::{Duration, Instant};
7use std::u64;
8
9use flex_fuchsia_developer_ffx_speedtest as fspeedtest;
10use futures::AsyncReadExt;
11#[cfg(not(feature = "fdomain"))]
12use futures::AsyncWriteExt;
13#[cfg(feature = "fdomain")]
14use futures::StreamExt;
15use thiserror::Error;
16
17pub struct Transfer {
18 pub socket: flex_client::Socket,
19 pub params: TransferParams,
20}
21
22#[derive(Debug, Clone)]
23pub struct TransferParams {
24 pub data_len: NonZeroU32,
25 pub buffer_len: NonZeroU32,
26 #[cfg(feature = "fdomain")]
27 pub fdomain_params: FDomainTransferParams,
28}
29
30#[cfg(feature = "fdomain")]
31#[derive(Debug, Clone)]
32pub struct FDomainTransferParams {
33 pub streaming_read: bool,
34 pub writes_in_flight: NonZeroU32,
35}
36
37impl TryFrom<fspeedtest::TransferParams> for TransferParams {
38 type Error = TryFromIntError;
39 fn try_from(value: fspeedtest::TransferParams) -> Result<Self, Self::Error> {
40 let fspeedtest::TransferParams { len_bytes, buffer_bytes, __source_breaking } = value;
41 Ok(Self {
42 data_len: len_bytes.unwrap_or(fspeedtest::DEFAULT_TRANSFER_SIZE).try_into()?,
43 buffer_len: buffer_bytes.unwrap_or(fspeedtest::DEFAULT_BUFFER_SIZE).try_into()?,
44 #[cfg(feature = "fdomain")]
45 fdomain_params: FDomainTransferParams {
46 streaming_read: false,
47 writes_in_flight: NonZeroU32::new(1).unwrap(),
48 },
49 })
50 }
51}
52
53impl TryFrom<TransferParams> for fspeedtest::TransferParams {
54 type Error = TryFromIntError;
55 fn try_from(value: TransferParams) -> Result<Self, Self::Error> {
56 let TransferParams { data_len, buffer_len, .. } = value;
57 Ok(Self {
58 len_bytes: Some(data_len.try_into()?),
59 buffer_bytes: Some(buffer_len.try_into()?),
60 __source_breaking: fidl::marker::SourceBreaking,
61 })
62 }
63}
64
65#[derive(Debug)]
66pub struct Report {
67 pub duration: Duration,
68}
69
70impl From<Report> for fspeedtest::TransferReport {
71 fn from(value: Report) -> Self {
72 let Report { duration } = value;
73 Self {
74 duration_nsec: Some(duration.as_nanos().try_into().unwrap_or(u64::MAX)),
75 __source_breaking: fidl::marker::SourceBreaking,
76 }
77 }
78}
79
80#[derive(Error, Debug)]
81#[error("missing mandatory field")]
82pub struct MissingFieldError;
83
84impl TryFrom<fspeedtest::TransferReport> for Report {
85 type Error = MissingFieldError;
86
87 fn try_from(value: fspeedtest::TransferReport) -> Result<Self, Self::Error> {
88 let fspeedtest::TransferReport { duration_nsec, __source_breaking } = value;
89 Ok(Self { duration: Duration::from_nanos(duration_nsec.ok_or(MissingFieldError)?) })
90 }
91}
92
93#[derive(Error, Debug)]
94pub enum TransferError {
95 #[error(transparent)]
96 IntConversion(#[from] TryFromIntError),
97 #[error(transparent)]
98 Io(#[from] std::io::Error),
99 #[error(transparent)]
100 FDomain(#[from] fdomain_client::Error),
101 #[error("remote hung up before terminating transfer")]
102 Hangup,
103}
104
105enum ReadSocket {
106 Normal(flex_client::AsyncSocket),
107 #[cfg(feature = "fdomain")]
108 Stream(flex_client::SocketReadStream),
109}
110
111impl ReadSocket {
112 fn from_socket(socket: flex_client::AsyncSocket, stream: bool) -> Result<Self, TransferError> {
113 #[cfg(feature = "fdomain")]
114 if stream {
115 let (socket, _) = socket.stream()?;
116 return Ok(ReadSocket::Stream(socket));
117 }
118
119 debug_assert!(!stream);
120 Ok(ReadSocket::Normal(socket))
121 }
122
123 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TransferError> {
124 let bytes = match self {
125 ReadSocket::Normal(s) => s.read(buf).await?,
126 #[cfg(feature = "fdomain")]
127 ReadSocket::Stream(s) => s.read(buf).await?,
128 };
129 Ok(bytes)
130 }
131}
132
133impl Transfer {
134 #[cfg(not(feature = "fdomain"))]
135 pub async fn send(self) -> Result<Report, TransferError> {
136 let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
137 let mut socket = flex_client::socket_to_async(socket);
138 let buffer_len = usize::try_from(buffer_len.get())?;
139 let mut data_len = usize::try_from(data_len.get())?;
140 let buffer = vec![0xAA; buffer_len];
141 let start = Instant::now();
142 while data_len != 0 {
143 let send = buffer_len.min(data_len);
144 let written = socket.write(&buffer[..send]).await?;
145 data_len -= written;
146 }
147 let end = Instant::now();
148 Ok(Report { duration: end - start })
149 }
150
151 #[cfg(feature = "fdomain")]
152 pub async fn send(self) -> Result<Report, TransferError> {
153 let Self {
154 socket,
155 params:
156 TransferParams {
157 data_len,
158 buffer_len,
159 fdomain_params: FDomainTransferParams { writes_in_flight, .. },
160 },
161 } = self;
162 let buffer_len = usize::try_from(buffer_len.get())?;
163 let mut data_len = usize::try_from(data_len.get())?;
164 let buffer = vec![0xAA; buffer_len];
165 let start = Instant::now();
166
167 let mut stream = futures::stream::iter(std::iter::from_fn(|| {
168 if data_len == 0 {
169 return None;
170 }
171
172 let send = buffer_len.min(data_len);
173 data_len -= send;
174 Some(socket.write_all(&buffer[..send]))
175 }))
176 .buffered(writes_in_flight.get() as usize);
177
178 while let Some(res) = stream.next().await {
179 let _: () = res?;
180 }
181
182 let end = Instant::now();
183 Ok(Report { duration: end - start })
184 }
185
186 pub async fn receive(self) -> Result<Report, TransferError> {
187 let Self {
188 socket,
189 params:
190 TransferParams {
191 data_len,
192 buffer_len,
193 #[cfg(feature = "fdomain")]
194 fdomain_params: FDomainTransferParams { streaming_read, .. },
195 },
196 } = self;
197 #[cfg(not(feature = "fdomain"))]
198 let streaming_read = false;
199 let mut socket =
200 ReadSocket::from_socket(flex_client::socket_to_async(socket), streaming_read)?;
201 let buffer_len = usize::try_from(buffer_len.get())?;
202 let mut data_len = usize::try_from(data_len.get())?;
203 let mut buffer = vec![0x00; buffer_len];
204 let start = Instant::now();
205
206 while data_len != 0 {
207 let recv = buffer_len.min(data_len);
208 let recv = socket.read(&mut buffer[..recv]).await?;
209 if recv == 0 {
210 return Err(TransferError::Hangup);
211 }
212 data_len -= recv;
213 }
214 let end = Instant::now();
215 Ok(Report { duration: end - start })
216 }
217}
218
219#[cfg(test)]
220mod test {
221 use super::*;
222
223 use assert_matches::assert_matches;
224
225 #[fuchsia::test]
226 async fn receive_hangup() {
227 #[cfg(feature = "fdomain")]
228 let client = fdomain_local::local_client(|| Err(zx_status::Status::NOT_SUPPORTED));
229 #[cfg(not(feature = "fdomain"))]
230 let client = fidl::endpoints::ZirconClient;
231 let (socket, _) = client.create_stream_socket();
232 let result = Transfer {
233 socket,
234 params: TransferParams {
235 data_len: NonZeroU32::new(10).unwrap(),
236 buffer_len: NonZeroU32::new(100).unwrap(),
237 #[cfg(feature = "fdomain")]
238 fdomain_params: FDomainTransferParams {
239 streaming_read: false,
240 writes_in_flight: NonZeroU32::new(1).unwrap(),
241 },
242 },
243 }
244 .receive()
245 .await;
246
247 assert_matches!(result, Err(TransferError::Hangup));
248 }
249}