1use async_trait::async_trait;
6use fuchsia_async::TimeoutExt;
7use fuchsia_async::net::TcpStream;
8
9use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
10use log::log;
11use std::net;
12
13const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
14
15#[derive(Debug, thiserror::Error)]
16pub enum FetchError {
17 #[error("failed to create socket")]
18 CreateSocket(#[source] std::io::Error),
19 #[error("failed to bind socket to device {interface_name}")]
20 BindSocket {
21 interface_name: String,
22 #[source]
23 err: std::io::Error,
24 },
25 #[error("failed to open TCP stream")]
26 ConnectTcpStream(#[source] std::io::Error),
27 #[error("timed out connecting TCP stream")]
28 ConnectTcpStreamTimeout,
29 #[error("failed to write to TCP stream")]
30 WriteTcpStream(#[source] std::io::Error),
31 #[error("timed out writing data to TCP stream")]
32 WriteTcpStreamTimeout,
33 #[error("failed to read response from TCP stream")]
34 ReadTcpStream(#[source] std::io::Error),
35 #[error("timed out reading response from TCP stream")]
36 ReadTcpStreamTimeout,
37 #[error("failed to parse string from UTF-8 bytes")]
38 ParseUtf8(#[from] std::string::FromUtf8Error),
39 #[error("response header malformed: {first_line}")]
40 MalformedHeader { first_line: String },
41 #[error("failed to parse status code")]
42 ParseStatusCode(#[source] std::num::ParseIntError),
43}
44
45impl FetchError {
46 pub fn short_name(&self) -> String {
50 let (name, os_err) = match self {
51 Self::CreateSocket(err) => ("CreateSock", err.raw_os_error()),
52 Self::BindSocket { err, .. } => ("BindSock", err.raw_os_error()),
53 Self::ConnectTcpStream(err) => ("ConnectTcp", err.raw_os_error()),
54 Self::ConnectTcpStreamTimeout => return "ConnTcpTimeout".to_string(),
55 Self::WriteTcpStream(err) => ("WriteTcp", err.raw_os_error()),
56 Self::WriteTcpStreamTimeout => return "WriteTcpTOut".to_string(),
57 Self::ReadTcpStream(err) => ("ReadTcp", err.raw_os_error()),
58 Self::ReadTcpStreamTimeout => return "ReadTcpTimeout".to_string(),
59 Self::ParseUtf8(_) => return "ParseUtf8".to_string(),
60 Self::MalformedHeader { .. } => return "BadHeader".to_string(),
61 Self::ParseStatusCode(_) => return "ParseStatus".to_string(),
62 };
63
64 if let Some(code) = os_err { format!("{name}_{code}") } else { name.to_string() }
65 }
66}
67
68pub(crate) fn fetch_result_short_name(result: &Result<u16, FetchError>) -> String {
69 match result {
70 Ok(code) => format!("Completed_{}", code),
71 Err(e) => format!("e_{}", e.short_name()),
72 }
73}
74
75fn http_request(path: &str, host: &str) -> String {
76 [
77 &format!("HEAD {path} HTTP/1.1"),
78 &format!("host: {host}"),
79 "connection: close",
80 "user-agent: fuchsia reachability probe",
81 ]
82 .join("\r\n")
83 + "\r\n\r\n"
84}
85
86async fn fetch<FA: FetchAddr + std::marker::Sync>(
87 interface_name: &str,
88 host: &str,
89 path: &str,
90 addr: &FA,
91) -> Result<u16, FetchError> {
92 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
93 let addr = addr.as_socket_addr();
94 let socket = socket2::Socket::new(
95 match addr {
96 net::SocketAddr::V4(_) => socket2::Domain::IPV4,
97 net::SocketAddr::V6(_) => socket2::Domain::IPV6,
98 },
99 socket2::Type::STREAM,
100 Some(socket2::Protocol::TCP),
101 )
102 .map_err(FetchError::CreateSocket)?;
103 socket.bind_device(Some(interface_name.as_bytes())).map_err(|err| FetchError::BindSocket {
104 interface_name: interface_name.to_string(),
105 err,
106 })?;
107 let mut stream = TcpStream::connect_from_raw(socket, addr)
108 .map_err(FetchError::ConnectTcpStream)?
109 .map_err(FetchError::ConnectTcpStream)
110 .on_timeout(timeout, || Err(FetchError::ConnectTcpStreamTimeout))
111 .await?;
112 let message = http_request(path, host);
113 stream
114 .write_all(message.as_bytes())
115 .map_err(FetchError::WriteTcpStream)
116 .on_timeout(timeout, || Err(FetchError::WriteTcpStreamTimeout))
117 .await?;
118
119 let mut bytes = Vec::new();
120 let _: usize = stream
121 .read_to_end(&mut bytes)
122 .map_err(FetchError::ReadTcpStream)
123 .on_timeout(timeout, || Err(FetchError::ReadTcpStreamTimeout))
124 .await?;
125 let resp = String::from_utf8(bytes)?;
126 let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
127 if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
128 if !http.starts_with("HTTP/") {
129 return Err(FetchError::MalformedHeader { first_line: first_line.to_string() });
130 }
131 Ok(code.parse().map_err(FetchError::ParseStatusCode)?)
132 } else {
133 Err(FetchError::MalformedHeader { first_line: first_line.to_string() })
134 }
135}
136
137pub trait FetchAddr {
138 fn as_socket_addr(&self) -> net::SocketAddr;
139}
140
141impl FetchAddr for net::SocketAddr {
142 fn as_socket_addr(&self) -> net::SocketAddr {
143 *self
144 }
145}
146
147impl FetchAddr for net::IpAddr {
148 fn as_socket_addr(&self) -> net::SocketAddr {
149 net::SocketAddr::from((*self, 80))
150 }
151}
152
153#[async_trait]
154pub trait Fetch {
155 async fn fetch<FA: FetchAddr + std::marker::Sync>(
156 &self,
157 interface_name: &str,
158 host: &str,
159 path: &str,
160 addr: &FA,
161 ) -> Result<u16, FetchError>;
162}
163
164pub struct Fetcher;
165
166#[async_trait]
167impl Fetch for Fetcher {
168 async fn fetch<FA: FetchAddr + std::marker::Sync>(
169 &self,
170 interface_name: &str,
171 host: &str,
172 path: &str,
173 addr: &FA,
174 ) -> Result<u16, FetchError> {
175 let r = fetch(interface_name, host, path, addr).await;
176 match r {
177 Ok(code) => Ok(code),
178 Err(e) => {
179 let level = if let Some(io_error) = std::error::Error::source(&e)
185 .and_then(|err| err.downcast_ref::<std::io::Error>())
186 && (io_error.raw_os_error() == Some(libc::ENETUNREACH)
187 || io_error.raw_os_error() == Some(libc::EHOSTUNREACH))
188 {
189 log::Level::Info
190 } else {
191 log::Level::Warn
192 };
193
194 let addr = addr.as_socket_addr();
195 log!(
196 level,
197 "err fetching {host}{path} from {addr} \
198 through {interface_name}: {e:#}"
199 );
200
201 Err(e)
202 }
203 }
204 }
205}
206
207#[cfg(test)]
208mod test {
209 use super::*;
210
211 use anyhow::Context;
212 use std::net::{Ipv4Addr, SocketAddr};
213 use std::pin::pin;
214
215 use fuchsia_async::net::TcpListener;
216 use fuchsia_async::{self as fasync};
217 use futures::future::Fuse;
218 use futures::io::BufReader;
219 use futures::{AsyncBufReadExt, FutureExt, StreamExt};
220 use test_case::test_case;
221
222 fn server(
223 code: u16,
224 ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
225 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
226 let listener = TcpListener::bind(&addr).context("binding TCP")?;
227 let addr = listener.local_addr()?;
228
229 let server_fut = async move {
230 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
231 let mut incoming = listener.accept_stream();
232 if let Some(result) = incoming
233 .next()
234 .on_timeout(timeout, || panic!("timeout waiting for connection"))
235 .await
236 {
237 let (stream, _addr) = result.expect("accept incoming TCP connection");
238 let mut stream = BufReader::new(stream);
239 let mut request = Vec::new();
240 loop {
241 let mut s = String::new();
242 let _: usize = stream
243 .read_line(&mut s)
244 .on_timeout(timeout, || panic!("timeout waiting to read data"))
245 .await
246 .expect("read data");
247 if s == "\r\n" {
248 break;
249 }
250 request.push(s.trim().to_string());
251 }
252 let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
253 stream
254 .write_all(data.as_bytes())
255 .on_timeout(timeout, || panic!("timeout waiting to write response"))
256 .await
257 .expect("reply to request");
258 request
259 } else {
260 Vec::new()
261 }
262 }
263 .fuse();
264
265 Ok((addr, server_fut))
266 }
267
268 #[test_case("http://reachability.test/", 200; "base path 200")]
269 #[test_case("http://reachability.test/path/", 200; "sub path 200")]
270 #[test_case("http://reachability.test/", 400; "base path 400")]
271 #[test_case("http://reachability.test/path/", 400; "sub path 400")]
272 #[fasync::run_singlethreaded(test)]
273 async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
274 let url = url::Url::parse(url_str)?;
275 let (addr, server_fut) = server(code)?;
276 let domain = url.host().expect("no host").to_string();
277 let path = url.path().to_string();
278
279 let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
280
281 let mut server_fut = pin!(server_fut);
282
283 let mut request = None;
284 let result = loop {
285 futures::select! {
286 req = server_fut => request = Some(req),
287 result = fetch_fut => break result
288 };
289 };
290
291 assert!(result.is_ok(), "Expected OK, got: {result:?}");
292 assert_eq!(result.ok(), Some(code));
293 let request = request.expect("no request body");
294 assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
295 assert!(request.contains(&format!("host: {domain}")));
296
297 Ok(())
298 }
299}