reachability_core/
fetch.rs1use anyhow::{format_err, Context};
6use async_trait::async_trait;
7use fuchsia_async::net::TcpStream;
8use fuchsia_async::TimeoutExt;
9
10use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
11use log::warn;
12use std::net;
13
14const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
15
16fn http_request(path: &str, host: &str) -> String {
17 [
18 &format!("HEAD {path} HTTP/1.1"),
19 &format!("host: {host}"),
20 "connection: close",
21 "user-agent: fuchsia reachability probe",
22 ]
23 .join("\r\n")
24 + "\r\n\r\n"
25}
26
27async fn fetch<FA: FetchAddr + std::marker::Sync>(
28 interface_name: &str,
29 host: &str,
30 path: &str,
31 addr: &FA,
32) -> anyhow::Result<u16> {
33 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
34 let addr = addr.as_socket_addr();
35 let socket = socket2::Socket::new(
36 match addr {
37 net::SocketAddr::V4(_) => socket2::Domain::IPV4,
38 net::SocketAddr::V6(_) => socket2::Domain::IPV6,
39 },
40 socket2::Type::STREAM,
41 Some(socket2::Protocol::TCP),
42 )
43 .context("while constructing socket")?;
44 socket.bind_device(Some(interface_name.as_bytes()))?;
45 let mut stream = TcpStream::connect_from_raw(socket, addr)
46 .context("while constructing tcp stream")?
47 .map_err(|e| format_err!("Opening TcpStream connection failed: {e:?}"))
48 .on_timeout(timeout, || Err(format_err!("Opening TcpStream timed out")))
49 .await?;
50 let message = http_request(path, host);
51 stream
52 .write_all(message.as_bytes())
53 .map_err(|e| format_err!("Writing to TcpStream failed: {e:?}"))
54 .on_timeout(timeout, || Err(format_err!("Writing data to TcpStream timed out")))
55 .await?;
56
57 let mut bytes = Vec::new();
58 let _: usize = stream
59 .read_to_end(&mut bytes)
60 .map_err(|e| format_err!("Reading response from TcpStream failed: {e:?}"))
61 .on_timeout(timeout, || Err(format_err!("Reading response from TcpStream timed out")))
62 .await?;
63 let resp = String::from_utf8(bytes)?;
64 let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
65 if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
66 if !http.starts_with("HTTP/") {
67 return Err(format_err!("Response header malformed: {first_line}"));
68 }
69 Ok(code.parse().map_err(|e| format_err!("While parsing status code: {e:?}"))?)
70 } else {
71 Err(format_err!("Response header malformed: {first_line}"))
72 }
73}
74
75pub trait FetchAddr {
76 fn as_socket_addr(&self) -> net::SocketAddr;
77}
78
79impl FetchAddr for net::SocketAddr {
80 fn as_socket_addr(&self) -> net::SocketAddr {
81 *self
82 }
83}
84
85impl FetchAddr for net::IpAddr {
86 fn as_socket_addr(&self) -> net::SocketAddr {
87 net::SocketAddr::from((*self, 80))
88 }
89}
90
91#[async_trait]
92pub trait Fetch {
93 async fn fetch<FA: FetchAddr + std::marker::Sync>(
94 &self,
95 interface_name: &str,
96 host: &str,
97 path: &str,
98 addr: &FA,
99 ) -> Option<u16>;
100}
101
102pub struct Fetcher;
103
104#[async_trait]
105impl Fetch for Fetcher {
106 async fn fetch<FA: FetchAddr + std::marker::Sync>(
107 &self,
108 interface_name: &str,
109 host: &str,
110 path: &str,
111 addr: &FA,
112 ) -> Option<u16> {
113 let r = fetch(interface_name, host, path, addr).await;
114 match r {
115 Ok(code) => Some(code),
116 Err(e) => {
117 warn!("error while fetching {host}{path}: {e:?}");
118 None
119 }
120 }
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 use std::net::{Ipv4Addr, SocketAddr};
129 use std::pin::pin;
130
131 use fuchsia_async::net::TcpListener;
132 use fuchsia_async::{self as fasync};
133 use futures::future::Fuse;
134 use futures::io::BufReader;
135 use futures::{AsyncBufReadExt, FutureExt, StreamExt};
136 use test_case::test_case;
137
138 fn server(
139 code: u16,
140 ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
141 let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
142 let listener = TcpListener::bind(&addr).context("binding TCP")?;
143 let addr = listener.local_addr()?;
144
145 let server_fut = async move {
146 let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
147 let mut incoming = listener.accept_stream();
148 if let Some(result) = incoming
149 .next()
150 .on_timeout(timeout, || panic!("timeout waiting for connection"))
151 .await
152 {
153 let (stream, _addr) = result.expect("accept incoming TCP connection");
154 let mut stream = BufReader::new(stream);
155 let mut request = Vec::new();
156 loop {
157 let mut s = String::new();
158 let _: usize = stream
159 .read_line(&mut s)
160 .on_timeout(timeout, || panic!("timeout waiting to read data"))
161 .await
162 .expect("read data");
163 if s == "\r\n" {
164 break;
165 }
166 request.push(s.trim().to_string());
167 }
168 let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
169 stream
170 .write_all(data.as_bytes())
171 .on_timeout(timeout, || panic!("timeout waiting to write response"))
172 .await
173 .expect("reply to request");
174 request
175 } else {
176 Vec::new()
177 }
178 }
179 .fuse();
180
181 Ok((addr, server_fut))
182 }
183
184 #[test_case("http://reachability.test/", 200; "base path 200")]
185 #[test_case("http://reachability.test/path/", 200; "sub path 200")]
186 #[test_case("http://reachability.test/", 400; "base path 400")]
187 #[test_case("http://reachability.test/path/", 400; "sub path 400")]
188 #[fasync::run_singlethreaded(test)]
189 async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
190 let url = url::Url::parse(url_str)?;
191 let (addr, server_fut) = server(code)?;
192 let domain = url.host().expect("no host").to_string();
193 let path = url.path().to_string();
194
195 let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
196
197 let mut server_fut = pin!(server_fut);
198
199 let mut request = None;
200 let result = loop {
201 futures::select! {
202 req = server_fut => request = Some(req),
203 result = fetch_fut => break result
204 };
205 };
206
207 assert!(result.is_ok(), "Expected OK, got: {result:?}");
208 assert_eq!(result.ok(), Some(code));
209 let request = request.expect("no request body");
210 assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
211 assert!(request.contains(&format!("host: {domain}")));
212
213 Ok(())
214 }
215}