Skip to main content

reachability_core/
fetch.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use async_trait::async_trait;
6use fuchsia_async::TimeoutExt;
7use fuchsia_async::net::TcpStream;
8
9use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
10use log::log;
11use std::net;
12
13const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
14
15#[derive(Debug, thiserror::Error)]
16pub enum FetchError {
17    #[error("failed to create socket")]
18    CreateSocket(#[source] std::io::Error),
19    #[error("failed to bind socket to device {interface_name}")]
20    BindSocket {
21        interface_name: String,
22        #[source]
23        err: std::io::Error,
24    },
25    #[error("failed to open TCP stream")]
26    ConnectTcpStream(#[source] std::io::Error),
27    #[error("timed out connecting TCP stream")]
28    ConnectTcpStreamTimeout,
29    #[error("failed to write to TCP stream")]
30    WriteTcpStream(#[source] std::io::Error),
31    #[error("timed out writing data to TCP stream")]
32    WriteTcpStreamTimeout,
33    #[error("failed to read response from TCP stream")]
34    ReadTcpStream(#[source] std::io::Error),
35    #[error("timed out reading response from TCP stream")]
36    ReadTcpStreamTimeout,
37    #[error("failed to parse string from UTF-8 bytes")]
38    ParseUtf8(#[from] std::string::FromUtf8Error),
39    #[error("response header malformed: {first_line}")]
40    MalformedHeader { first_line: String },
41    #[error("failed to parse status code")]
42    ParseStatusCode(#[source] std::num::ParseIntError),
43}
44
45impl FetchError {
46    /// Returns a short, simplified string describing the error.
47    /// Each string should only take at most 14 characters, else the row labels for the
48    /// `fetch_results` time series in internal visualization tool would be truncated.
49    pub fn short_name(&self) -> String {
50        let (name, os_err) = match self {
51            Self::CreateSocket(err) => ("CreateSock", err.raw_os_error()),
52            Self::BindSocket { err, .. } => ("BindSock", err.raw_os_error()),
53            Self::ConnectTcpStream(err) => ("ConnectTcp", err.raw_os_error()),
54            Self::ConnectTcpStreamTimeout => return "ConnTcpTimeout".to_string(),
55            Self::WriteTcpStream(err) => ("WriteTcp", err.raw_os_error()),
56            Self::WriteTcpStreamTimeout => return "WriteTcpTOut".to_string(),
57            Self::ReadTcpStream(err) => ("ReadTcp", err.raw_os_error()),
58            Self::ReadTcpStreamTimeout => return "ReadTcpTimeout".to_string(),
59            Self::ParseUtf8(_) => return "ParseUtf8".to_string(),
60            Self::MalformedHeader { .. } => return "BadHeader".to_string(),
61            Self::ParseStatusCode(_) => return "ParseStatus".to_string(),
62        };
63
64        if let Some(code) = os_err { format!("{name}_{code}") } else { name.to_string() }
65    }
66}
67
68pub(crate) fn fetch_result_short_name(result: &Result<u16, FetchError>) -> String {
69    match result {
70        Ok(code) => format!("Completed_{}", code),
71        Err(e) => format!("e_{}", e.short_name()),
72    }
73}
74
75fn http_request(path: &str, host: &str) -> String {
76    [
77        &format!("HEAD {path} HTTP/1.1"),
78        &format!("host: {host}"),
79        "connection: close",
80        "user-agent: fuchsia reachability probe",
81    ]
82    .join("\r\n")
83        + "\r\n\r\n"
84}
85
86async fn fetch<FA: FetchAddr + std::marker::Sync>(
87    interface_name: &str,
88    host: &str,
89    path: &str,
90    addr: &FA,
91) -> Result<u16, FetchError> {
92    let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
93    let addr = addr.as_socket_addr();
94    let socket = socket2::Socket::new(
95        match addr {
96            net::SocketAddr::V4(_) => socket2::Domain::IPV4,
97            net::SocketAddr::V6(_) => socket2::Domain::IPV6,
98        },
99        socket2::Type::STREAM,
100        Some(socket2::Protocol::TCP),
101    )
102    .map_err(FetchError::CreateSocket)?;
103    socket.bind_device(Some(interface_name.as_bytes())).map_err(|err| FetchError::BindSocket {
104        interface_name: interface_name.to_string(),
105        err,
106    })?;
107    let mut stream = TcpStream::connect_from_raw(socket, addr)
108        .map_err(FetchError::ConnectTcpStream)?
109        .map_err(FetchError::ConnectTcpStream)
110        .on_timeout(timeout, || Err(FetchError::ConnectTcpStreamTimeout))
111        .await?;
112    let message = http_request(path, host);
113    stream
114        .write_all(message.as_bytes())
115        .map_err(FetchError::WriteTcpStream)
116        .on_timeout(timeout, || Err(FetchError::WriteTcpStreamTimeout))
117        .await?;
118
119    let mut bytes = Vec::new();
120    let _: usize = stream
121        .read_to_end(&mut bytes)
122        .map_err(FetchError::ReadTcpStream)
123        .on_timeout(timeout, || Err(FetchError::ReadTcpStreamTimeout))
124        .await?;
125    let resp = String::from_utf8(bytes)?;
126    let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
127    if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
128        if !http.starts_with("HTTP/") {
129            return Err(FetchError::MalformedHeader { first_line: first_line.to_string() });
130        }
131        Ok(code.parse().map_err(FetchError::ParseStatusCode)?)
132    } else {
133        Err(FetchError::MalformedHeader { first_line: first_line.to_string() })
134    }
135}
136
137pub trait FetchAddr {
138    fn as_socket_addr(&self) -> net::SocketAddr;
139}
140
141impl FetchAddr for net::SocketAddr {
142    fn as_socket_addr(&self) -> net::SocketAddr {
143        *self
144    }
145}
146
147impl FetchAddr for net::IpAddr {
148    fn as_socket_addr(&self) -> net::SocketAddr {
149        net::SocketAddr::from((*self, 80))
150    }
151}
152
153#[async_trait]
154pub trait Fetch {
155    async fn fetch<FA: FetchAddr + std::marker::Sync>(
156        &self,
157        interface_name: &str,
158        host: &str,
159        path: &str,
160        addr: &FA,
161    ) -> Result<u16, FetchError>;
162}
163
164pub struct Fetcher;
165
166#[async_trait]
167impl Fetch for Fetcher {
168    async fn fetch<FA: FetchAddr + std::marker::Sync>(
169        &self,
170        interface_name: &str,
171        host: &str,
172        path: &str,
173        addr: &FA,
174    ) -> Result<u16, FetchError> {
175        let r = fetch(interface_name, host, path, addr).await;
176        match r {
177            Ok(code) => Ok(code),
178            Err(e) => {
179                // Check to see if the error is due to the host/network being
180                // unreachable. In that case, this error is likely unconcerning
181                // and signifies a network may not have connectivity across
182                // one of the IP protocols, which can be common for home
183                // network configurations.
184                let level = if let Some(io_error) = std::error::Error::source(&e)
185                    .and_then(|err| err.downcast_ref::<std::io::Error>())
186                    && (io_error.raw_os_error() == Some(libc::ENETUNREACH)
187                        || io_error.raw_os_error() == Some(libc::EHOSTUNREACH))
188                {
189                    log::Level::Info
190                } else {
191                    log::Level::Warn
192                };
193
194                let addr = addr.as_socket_addr();
195                log!(
196                    level,
197                    "err fetching {host}{path} from {addr} \
198                     through {interface_name}: {e:#}"
199                );
200
201                Err(e)
202            }
203        }
204    }
205}
206
207#[cfg(test)]
208mod test {
209    use super::*;
210
211    use anyhow::Context;
212    use std::net::{Ipv4Addr, SocketAddr};
213    use std::pin::pin;
214
215    use fuchsia_async::net::TcpListener;
216    use fuchsia_async::{self as fasync};
217    use futures::future::Fuse;
218    use futures::io::BufReader;
219    use futures::{AsyncBufReadExt, FutureExt, StreamExt};
220    use test_case::test_case;
221
222    fn server(
223        code: u16,
224    ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
225        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
226        let listener = TcpListener::bind(&addr).context("binding TCP")?;
227        let addr = listener.local_addr()?;
228
229        let server_fut = async move {
230            let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
231            let mut incoming = listener.accept_stream();
232            if let Some(result) = incoming
233                .next()
234                .on_timeout(timeout, || panic!("timeout waiting for connection"))
235                .await
236            {
237                let (stream, _addr) = result.expect("accept incoming TCP connection");
238                let mut stream = BufReader::new(stream);
239                let mut request = Vec::new();
240                loop {
241                    let mut s = String::new();
242                    let _: usize = stream
243                        .read_line(&mut s)
244                        .on_timeout(timeout, || panic!("timeout waiting to read data"))
245                        .await
246                        .expect("read data");
247                    if s == "\r\n" {
248                        break;
249                    }
250                    request.push(s.trim().to_string());
251                }
252                let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
253                stream
254                    .write_all(data.as_bytes())
255                    .on_timeout(timeout, || panic!("timeout waiting to write response"))
256                    .await
257                    .expect("reply to request");
258                request
259            } else {
260                Vec::new()
261            }
262        }
263        .fuse();
264
265        Ok((addr, server_fut))
266    }
267
268    #[test_case("http://reachability.test/", 200; "base path 200")]
269    #[test_case("http://reachability.test/path/", 200; "sub path 200")]
270    #[test_case("http://reachability.test/", 400; "base path 400")]
271    #[test_case("http://reachability.test/path/", 400; "sub path 400")]
272    #[fasync::run_singlethreaded(test)]
273    async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
274        let url = url::Url::parse(url_str)?;
275        let (addr, server_fut) = server(code)?;
276        let domain = url.host().expect("no host").to_string();
277        let path = url.path().to_string();
278
279        let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
280
281        let mut server_fut = pin!(server_fut);
282
283        let mut request = None;
284        let result = loop {
285            futures::select! {
286                req = server_fut => request = Some(req),
287                result = fetch_fut => break result
288            };
289        };
290
291        assert!(result.is_ok(), "Expected OK, got: {result:?}");
292        assert_eq!(result.ok(), Some(code));
293        let request = request.expect("no request body");
294        assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
295        assert!(request.contains(&format!("host: {domain}")));
296
297        Ok(())
298    }
299}