reachability_core/
fetch.rs1use anyhow::{format_err, Context};
6use async_trait::async_trait;
7use fuchsia_async::net::TcpStream;
8use fuchsia_async::TimeoutExt;
9
10use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
11use log::{info, warn};
12use std::net;
13
14const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
15
16fn http_request(path: &str, host: &str) -> String {
17 [
18 &format!("HEAD {path} HTTP/1.1"),
19 &format!("host: {host}"),
20 "connection: close",
21 "user-agent: fuchsia reachability probe",
22 ]
23 .join("\r\n")
24 + "\r\n\r\n"
25}
26
27async fn fetch<FA: FetchAddr + std::marker::Sync>(
28 interface_name: &str,
29 host: &str,
30 path: &str,
31 addr: &FA,
32) -> anyhow::Result<u16> {
33 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
34 let addr = addr.as_socket_addr();
35 let socket = socket2::Socket::new(
36 match addr {
37 net::SocketAddr::V4(_) => socket2::Domain::IPV4,
38 net::SocketAddr::V6(_) => socket2::Domain::IPV6,
39 },
40 socket2::Type::STREAM,
41 Some(socket2::Protocol::TCP),
42 )
43 .context("while constructing socket")?;
44 socket.bind_device(Some(interface_name.as_bytes()))?;
45 let mut stream = TcpStream::connect_from_raw(socket, addr)
46 .context("while constructing tcp stream")?
47 .map_err(|e| format_err!("Opening TcpStream connection failed: {e:?}"))
48 .on_timeout(timeout, || Err(format_err!("Opening TcpStream timed out")))
49 .await?;
50 let message = http_request(path, host);
51 stream
52 .write_all(message.as_bytes())
53 .map_err(|e| format_err!("Writing to TcpStream failed: {e:?}"))
54 .on_timeout(timeout, || Err(format_err!("Writing data to TcpStream timed out")))
55 .await?;
56
57 let mut bytes = Vec::new();
58 let _: usize = stream
59 .read_to_end(&mut bytes)
60 .map_err(|e| format_err!("Reading response from TcpStream failed: {e:?}"))
61 .on_timeout(timeout, || Err(format_err!("Reading response from TcpStream timed out")))
62 .await?;
63 let resp = String::from_utf8(bytes)?;
64 let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
65 if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
66 if !http.starts_with("HTTP/") {
67 return Err(format_err!("Response header malformed: {first_line}"));
68 }
69 Ok(code.parse().map_err(|e| format_err!("While parsing status code: {e:?}"))?)
70 } else {
71 Err(format_err!("Response header malformed: {first_line}"))
72 }
73}
74
75pub trait FetchAddr {
76 fn as_socket_addr(&self) -> net::SocketAddr;
77}
78
79impl FetchAddr for net::SocketAddr {
80 fn as_socket_addr(&self) -> net::SocketAddr {
81 *self
82 }
83}
84
85impl FetchAddr for net::IpAddr {
86 fn as_socket_addr(&self) -> net::SocketAddr {
87 net::SocketAddr::from((*self, 80))
88 }
89}
90
91#[async_trait]
92pub trait Fetch {
93 async fn fetch<FA: FetchAddr + std::marker::Sync>(
94 &self,
95 interface_name: &str,
96 host: &str,
97 path: &str,
98 addr: &FA,
99 ) -> Option<u16>;
100}
101
102pub struct Fetcher;
103
104#[async_trait]
105impl Fetch for Fetcher {
106 async fn fetch<FA: FetchAddr + std::marker::Sync>(
107 &self,
108 interface_name: &str,
109 host: &str,
110 path: &str,
111 addr: &FA,
112 ) -> Option<u16> {
113 let r = fetch(interface_name, host, path, addr).await;
114 match r {
115 Ok(code) => Some(code),
116 Err(e) => {
117 if let Some(io_error) = e.downcast_ref::<std::io::Error>() {
123 if io_error.raw_os_error() == Some(libc::ENETUNREACH)
124 || io_error.raw_os_error() == Some(libc::EHOSTUNREACH)
125 {
126 info!("error while fetching {host}{path}: {e:?}");
127 return None;
128 }
129 }
130 warn!("error while fetching {host}{path}: {e:?}");
131 None
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use super::*;
140
141 use std::net::{Ipv4Addr, SocketAddr};
142 use std::pin::pin;
143
144 use fuchsia_async::net::TcpListener;
145 use fuchsia_async::{self as fasync};
146 use futures::future::Fuse;
147 use futures::io::BufReader;
148 use futures::{AsyncBufReadExt, FutureExt, StreamExt};
149 use test_case::test_case;
150
151 fn server(
152 code: u16,
153 ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
154 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
155 let listener = TcpListener::bind(&addr).context("binding TCP")?;
156 let addr = listener.local_addr()?;
157
158 let server_fut = async move {
159 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
160 let mut incoming = listener.accept_stream();
161 if let Some(result) = incoming
162 .next()
163 .on_timeout(timeout, || panic!("timeout waiting for connection"))
164 .await
165 {
166 let (stream, _addr) = result.expect("accept incoming TCP connection");
167 let mut stream = BufReader::new(stream);
168 let mut request = Vec::new();
169 loop {
170 let mut s = String::new();
171 let _: usize = stream
172 .read_line(&mut s)
173 .on_timeout(timeout, || panic!("timeout waiting to read data"))
174 .await
175 .expect("read data");
176 if s == "\r\n" {
177 break;
178 }
179 request.push(s.trim().to_string());
180 }
181 let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
182 stream
183 .write_all(data.as_bytes())
184 .on_timeout(timeout, || panic!("timeout waiting to write response"))
185 .await
186 .expect("reply to request");
187 request
188 } else {
189 Vec::new()
190 }
191 }
192 .fuse();
193
194 Ok((addr, server_fut))
195 }
196
197 #[test_case("http://reachability.test/", 200; "base path 200")]
198 #[test_case("http://reachability.test/path/", 200; "sub path 200")]
199 #[test_case("http://reachability.test/", 400; "base path 400")]
200 #[test_case("http://reachability.test/path/", 400; "sub path 400")]
201 #[fasync::run_singlethreaded(test)]
202 async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
203 let url = url::Url::parse(url_str)?;
204 let (addr, server_fut) = server(code)?;
205 let domain = url.host().expect("no host").to_string();
206 let path = url.path().to_string();
207
208 let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
209
210 let mut server_fut = pin!(server_fut);
211
212 let mut request = None;
213 let result = loop {
214 futures::select! {
215 req = server_fut => request = Some(req),
216 result = fetch_fut => break result
217 };
218 };
219
220 assert!(result.is_ok(), "Expected OK, got: {result:?}");
221 assert_eq!(result.ok(), Some(code));
222 let request = request.expect("no request body");
223 assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
224 assert!(request.contains(&format!("host: {domain}")));
225
226 Ok(())
227 }
228}