1use async_trait::async_trait;
6use fuchsia_async::TimeoutExt;
7use fuchsia_async::net::TcpStream;
8
9use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
10use std::net;
11
12const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
13
14#[derive(Debug, thiserror::Error)]
15pub enum FetchError {
16 #[error("failed to create socket")]
17 CreateSocket(#[source] std::io::Error),
18 #[error("failed to bind socket to device {interface_name}")]
19 BindSocket {
20 interface_name: String,
21 #[source]
22 err: std::io::Error,
23 },
24 #[error("failed to open TCP stream")]
25 ConnectTcpStream(#[source] std::io::Error),
26 #[error("timed out connecting TCP stream")]
27 ConnectTcpStreamTimeout,
28 #[error("failed to write to TCP stream")]
29 WriteTcpStream(#[source] std::io::Error),
30 #[error("timed out writing data to TCP stream")]
31 WriteTcpStreamTimeout,
32 #[error("failed to read response from TCP stream")]
33 ReadTcpStream(#[source] std::io::Error),
34 #[error("timed out reading response from TCP stream")]
35 ReadTcpStreamTimeout,
36 #[error("failed to parse string from UTF-8 bytes")]
37 ParseUtf8(#[from] std::string::FromUtf8Error),
38 #[error("response header malformed: {first_line}")]
39 MalformedHeader { first_line: String },
40 #[error("failed to parse status code")]
41 ParseStatusCode(#[source] std::num::ParseIntError),
42}
43
44impl FetchError {
45 pub fn short_name(&self) -> String {
49 let (name, os_err) = match self {
50 Self::CreateSocket(err) => ("CreateSock", err.raw_os_error()),
51 Self::BindSocket { err, .. } => ("BindSock", err.raw_os_error()),
52 Self::ConnectTcpStream(err) => ("ConnectTcp", err.raw_os_error()),
53 Self::ConnectTcpStreamTimeout => return "ConnTcpTimeout".to_string(),
54 Self::WriteTcpStream(err) => ("WriteTcp", err.raw_os_error()),
55 Self::WriteTcpStreamTimeout => return "WriteTcpTOut".to_string(),
56 Self::ReadTcpStream(err) => ("ReadTcp", err.raw_os_error()),
57 Self::ReadTcpStreamTimeout => return "ReadTcpTimeout".to_string(),
58 Self::ParseUtf8(_) => return "ParseUtf8".to_string(),
59 Self::MalformedHeader { .. } => return "BadHeader".to_string(),
60 Self::ParseStatusCode(_) => return "ParseStatus".to_string(),
61 };
62
63 if let Some(code) = os_err { format!("{name}_{code}") } else { name.to_string() }
64 }
65}
66
67pub(crate) fn fetch_result_short_name(result: &Result<u16, FetchError>) -> String {
68 match result {
69 Ok(code) => format!("Completed_{}", code),
70 Err(e) => format!("e_{}", e.short_name()),
71 }
72}
73
74fn http_request(path: &str, host: &str) -> String {
75 [
76 &format!("HEAD {path} HTTP/1.1"),
77 &format!("host: {host}"),
78 "connection: close",
79 "user-agent: fuchsia reachability probe",
80 ]
81 .join("\r\n")
82 + "\r\n\r\n"
83}
84
85async fn fetch<FA: FetchAddr + std::marker::Sync>(
86 interface_name: &str,
87 host: &str,
88 path: &str,
89 addr: &FA,
90) -> Result<u16, FetchError> {
91 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
92 let addr = addr.as_socket_addr();
93 let socket = socket2::Socket::new(
94 match addr {
95 net::SocketAddr::V4(_) => socket2::Domain::IPV4,
96 net::SocketAddr::V6(_) => socket2::Domain::IPV6,
97 },
98 socket2::Type::STREAM,
99 Some(socket2::Protocol::TCP),
100 )
101 .map_err(FetchError::CreateSocket)?;
102 socket.bind_device(Some(interface_name.as_bytes())).map_err(|err| FetchError::BindSocket {
103 interface_name: interface_name.to_string(),
104 err,
105 })?;
106 let mut stream = TcpStream::connect_from_raw(socket, addr)
107 .map_err(FetchError::ConnectTcpStream)?
108 .map_err(FetchError::ConnectTcpStream)
109 .on_timeout(timeout, || Err(FetchError::ConnectTcpStreamTimeout))
110 .await?;
111 let message = http_request(path, host);
112 stream
113 .write_all(message.as_bytes())
114 .map_err(FetchError::WriteTcpStream)
115 .on_timeout(timeout, || Err(FetchError::WriteTcpStreamTimeout))
116 .await?;
117
118 let mut bytes = Vec::new();
119 let _: usize = stream
120 .read_to_end(&mut bytes)
121 .map_err(FetchError::ReadTcpStream)
122 .on_timeout(timeout, || Err(FetchError::ReadTcpStreamTimeout))
123 .await?;
124 let resp = String::from_utf8(bytes)?;
125 let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
126 if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
127 if !http.starts_with("HTTP/") {
128 return Err(FetchError::MalformedHeader { first_line: first_line.to_string() });
129 }
130 Ok(code.parse().map_err(FetchError::ParseStatusCode)?)
131 } else {
132 Err(FetchError::MalformedHeader { first_line: first_line.to_string() })
133 }
134}
135
136pub trait FetchAddr {
137 fn as_socket_addr(&self) -> net::SocketAddr;
138}
139
140impl FetchAddr for net::SocketAddr {
141 fn as_socket_addr(&self) -> net::SocketAddr {
142 *self
143 }
144}
145
146impl FetchAddr for net::IpAddr {
147 fn as_socket_addr(&self) -> net::SocketAddr {
148 net::SocketAddr::from((*self, 80))
149 }
150}
151
152#[async_trait]
153pub trait Fetch {
154 async fn fetch<FA: FetchAddr + std::marker::Sync>(
155 &self,
156 interface_name: &str,
157 host: &str,
158 path: &str,
159 addr: &FA,
160 ) -> Result<u16, FetchError>;
161}
162
163pub struct Fetcher;
164
165#[async_trait]
166impl Fetch for Fetcher {
167 async fn fetch<FA: FetchAddr + std::marker::Sync>(
168 &self,
169 interface_name: &str,
170 host: &str,
171 path: &str,
172 addr: &FA,
173 ) -> Result<u16, FetchError> {
174 fetch(interface_name, host, path, addr).await
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use super::*;
181
182 use anyhow::Context;
183 use std::net::{Ipv4Addr, SocketAddr};
184 use std::pin::pin;
185
186 use fuchsia_async::net::TcpListener;
187 use fuchsia_async::{self as fasync};
188 use futures::future::Fuse;
189 use futures::io::BufReader;
190 use futures::{AsyncBufReadExt, FutureExt, StreamExt};
191 use test_case::test_case;
192
193 fn server(
194 code: u16,
195 ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
196 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
197 let listener = TcpListener::bind(&addr).context("binding TCP")?;
198 let addr = listener.local_addr()?;
199
200 let server_fut = async move {
201 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
202 let mut incoming = listener.accept_stream();
203 if let Some(result) = incoming
204 .next()
205 .on_timeout(timeout, || panic!("timeout waiting for connection"))
206 .await
207 {
208 let (stream, _addr) = result.expect("accept incoming TCP connection");
209 let mut stream = BufReader::new(stream);
210 let mut request = Vec::new();
211 loop {
212 let mut s = String::new();
213 let _: usize = stream
214 .read_line(&mut s)
215 .on_timeout(timeout, || panic!("timeout waiting to read data"))
216 .await
217 .expect("read data");
218 if s == "\r\n" {
219 break;
220 }
221 request.push(s.trim().to_string());
222 }
223 let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
224 stream
225 .write_all(data.as_bytes())
226 .on_timeout(timeout, || panic!("timeout waiting to write response"))
227 .await
228 .expect("reply to request");
229 request
230 } else {
231 Vec::new()
232 }
233 }
234 .fuse();
235
236 Ok((addr, server_fut))
237 }
238
239 #[test_case("http://reachability.test/", 200; "base path 200")]
240 #[test_case("http://reachability.test/path/", 200; "sub path 200")]
241 #[test_case("http://reachability.test/", 400; "base path 400")]
242 #[test_case("http://reachability.test/path/", 400; "sub path 400")]
243 #[fasync::run_singlethreaded(test)]
244 async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
245 let url = url::Url::parse(url_str)?;
246 let (addr, server_fut) = server(code)?;
247 let domain = url.host().expect("no host").to_string();
248 let path = url.path().to_string();
249
250 let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
251
252 let mut server_fut = pin!(server_fut);
253
254 let mut request = None;
255 let result = loop {
256 futures::select! {
257 req = server_fut => request = Some(req),
258 result = fetch_fut => break result
259 };
260 };
261
262 assert!(result.is_ok(), "Expected OK, got: {result:?}");
263 assert_eq!(result.ok(), Some(code));
264 let request = request.expect("no request body");
265 assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
266 assert!(request.contains(&format!("host: {domain}")));
267
268 Ok(())
269 }
270}