speedtest/
socket.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::num::{NonZeroU32, TryFromIntError};
6use std::time::{Duration, Instant};
7use std::u64;
8
9use flex_fuchsia_developer_ffx_speedtest as fspeedtest;
10use futures::AsyncReadExt;
11#[cfg(not(feature = "fdomain"))]
12use futures::AsyncWriteExt;
13#[cfg(feature = "fdomain")]
14use futures::StreamExt;
15use thiserror::Error;
16
17pub struct Transfer {
18    pub socket: flex_client::Socket,
19    pub params: TransferParams,
20}
21
22#[derive(Debug, Clone)]
23pub struct TransferParams {
24    pub data_len: NonZeroU32,
25    pub buffer_len: NonZeroU32,
26    #[cfg(feature = "fdomain")]
27    pub fdomain_params: FDomainTransferParams,
28}
29
30#[cfg(feature = "fdomain")]
31#[derive(Debug, Clone)]
32pub struct FDomainTransferParams {
33    pub streaming_read: bool,
34    pub writes_in_flight: NonZeroU32,
35}
36
37impl TryFrom<fspeedtest::TransferParams> for TransferParams {
38    type Error = TryFromIntError;
39    fn try_from(value: fspeedtest::TransferParams) -> Result<Self, Self::Error> {
40        let fspeedtest::TransferParams { len_bytes, buffer_bytes, __source_breaking } = value;
41        Ok(Self {
42            data_len: len_bytes.unwrap_or(fspeedtest::DEFAULT_TRANSFER_SIZE).try_into()?,
43            buffer_len: buffer_bytes.unwrap_or(fspeedtest::DEFAULT_BUFFER_SIZE).try_into()?,
44            #[cfg(feature = "fdomain")]
45            fdomain_params: FDomainTransferParams {
46                streaming_read: false,
47                writes_in_flight: NonZeroU32::new(1).unwrap(),
48            },
49        })
50    }
51}
52
53impl TryFrom<TransferParams> for fspeedtest::TransferParams {
54    type Error = TryFromIntError;
55    fn try_from(value: TransferParams) -> Result<Self, Self::Error> {
56        let TransferParams { data_len, buffer_len, .. } = value;
57        Ok(Self {
58            len_bytes: Some(data_len.try_into()?),
59            buffer_bytes: Some(buffer_len.try_into()?),
60            __source_breaking: fidl::marker::SourceBreaking,
61        })
62    }
63}
64
65#[derive(Debug)]
66pub struct Report {
67    pub duration: Duration,
68}
69
70impl From<Report> for fspeedtest::TransferReport {
71    fn from(value: Report) -> Self {
72        let Report { duration } = value;
73        Self {
74            duration_nsec: Some(duration.as_nanos().try_into().unwrap_or(u64::MAX)),
75            __source_breaking: fidl::marker::SourceBreaking,
76        }
77    }
78}
79
80#[derive(Error, Debug)]
81#[error("missing mandatory field")]
82pub struct MissingFieldError;
83
84impl TryFrom<fspeedtest::TransferReport> for Report {
85    type Error = MissingFieldError;
86
87    fn try_from(value: fspeedtest::TransferReport) -> Result<Self, Self::Error> {
88        let fspeedtest::TransferReport { duration_nsec, __source_breaking } = value;
89        Ok(Self { duration: Duration::from_nanos(duration_nsec.ok_or(MissingFieldError)?) })
90    }
91}
92
93#[derive(Error, Debug)]
94pub enum TransferError {
95    #[error(transparent)]
96    IntConversion(#[from] TryFromIntError),
97    #[error(transparent)]
98    Io(#[from] std::io::Error),
99    #[error(transparent)]
100    FDomain(#[from] fdomain_client::Error),
101    #[error("remote hung up before terminating transfer")]
102    Hangup,
103}
104
105enum ReadSocket {
106    Normal(flex_client::AsyncSocket),
107    #[cfg(feature = "fdomain")]
108    Stream(flex_client::SocketReadStream),
109}
110
111impl ReadSocket {
112    fn from_socket(socket: flex_client::AsyncSocket, stream: bool) -> Result<Self, TransferError> {
113        #[cfg(feature = "fdomain")]
114        if stream {
115            let (socket, _) = socket.stream()?;
116            return Ok(ReadSocket::Stream(socket));
117        }
118
119        debug_assert!(!stream);
120        Ok(ReadSocket::Normal(socket))
121    }
122
123    async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TransferError> {
124        let bytes = match self {
125            ReadSocket::Normal(s) => s.read(buf).await?,
126            #[cfg(feature = "fdomain")]
127            ReadSocket::Stream(s) => s.read(buf).await?,
128        };
129        Ok(bytes)
130    }
131}
132
133impl Transfer {
134    #[cfg(not(feature = "fdomain"))]
135    pub async fn send(self) -> Result<Report, TransferError> {
136        let Self { socket, params: TransferParams { data_len, buffer_len } } = self;
137        let mut socket = flex_client::socket_to_async(socket);
138        let buffer_len = usize::try_from(buffer_len.get())?;
139        let mut data_len = usize::try_from(data_len.get())?;
140        let buffer = vec![0xAA; buffer_len];
141        let start = Instant::now();
142        while data_len != 0 {
143            let send = buffer_len.min(data_len);
144            let written = socket.write(&buffer[..send]).await?;
145            data_len -= written;
146        }
147        let end = Instant::now();
148        Ok(Report { duration: end - start })
149    }
150
151    #[cfg(feature = "fdomain")]
152    pub async fn send(self) -> Result<Report, TransferError> {
153        let Self {
154            socket,
155            params:
156                TransferParams {
157                    data_len,
158                    buffer_len,
159                    fdomain_params: FDomainTransferParams { writes_in_flight, .. },
160                },
161        } = self;
162        let buffer_len = usize::try_from(buffer_len.get())?;
163        let mut data_len = usize::try_from(data_len.get())?;
164        let buffer = vec![0xAA; buffer_len];
165        let start = Instant::now();
166
167        let mut stream = futures::stream::iter(std::iter::from_fn(|| {
168            if data_len == 0 {
169                return None;
170            }
171
172            let send = buffer_len.min(data_len);
173            data_len -= send;
174            Some(socket.write_all(&buffer[..send]))
175        }))
176        .buffered(writes_in_flight.get() as usize);
177
178        while let Some(res) = stream.next().await {
179            let _: () = res?;
180        }
181
182        let end = Instant::now();
183        Ok(Report { duration: end - start })
184    }
185
186    pub async fn receive(self) -> Result<Report, TransferError> {
187        let Self {
188            socket,
189            params:
190                TransferParams {
191                    data_len,
192                    buffer_len,
193                    #[cfg(feature = "fdomain")]
194                        fdomain_params: FDomainTransferParams { streaming_read, .. },
195                },
196        } = self;
197        #[cfg(not(feature = "fdomain"))]
198        let streaming_read = false;
199        let mut socket =
200            ReadSocket::from_socket(flex_client::socket_to_async(socket), streaming_read)?;
201        let buffer_len = usize::try_from(buffer_len.get())?;
202        let mut data_len = usize::try_from(data_len.get())?;
203        let mut buffer = vec![0x00; buffer_len];
204        let start = Instant::now();
205
206        while data_len != 0 {
207            let recv = buffer_len.min(data_len);
208            let recv = socket.read(&mut buffer[..recv]).await?;
209            if recv == 0 {
210                return Err(TransferError::Hangup);
211            }
212            data_len -= recv;
213        }
214        let end = Instant::now();
215        Ok(Report { duration: end - start })
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use super::*;
222
223    use assert_matches::assert_matches;
224
225    #[fuchsia::test]
226    async fn receive_hangup() {
227        #[cfg(feature = "fdomain")]
228        let client = fdomain_local::local_client(|| Err(zx_status::Status::NOT_SUPPORTED));
229        #[cfg(not(feature = "fdomain"))]
230        let client = fidl::endpoints::ZirconClient;
231        let (socket, _) = client.create_stream_socket();
232        let result = Transfer {
233            socket,
234            params: TransferParams {
235                data_len: NonZeroU32::new(10).unwrap(),
236                buffer_len: NonZeroU32::new(100).unwrap(),
237                #[cfg(feature = "fdomain")]
238                fdomain_params: FDomainTransferParams {
239                    streaming_read: false,
240                    writes_in_flight: NonZeroU32::new(1).unwrap(),
241                },
242            },
243        }
244        .receive()
245        .await;
246
247        assert_matches!(result, Err(TransferError::Hangup));
248    }
249}