reachability_core/
fetch.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use anyhow::{format_err, Context};
6use async_trait::async_trait;
7use fuchsia_async::net::TcpStream;
8use fuchsia_async::TimeoutExt;
9
10use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
11use log::{info, warn};
12use std::net;
13
14const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
15
16fn http_request(path: &str, host: &str) -> String {
17    [
18        &format!("HEAD {path} HTTP/1.1"),
19        &format!("host: {host}"),
20        "connection: close",
21        "user-agent: fuchsia reachability probe",
22    ]
23    .join("\r\n")
24        + "\r\n\r\n"
25}
26
27async fn fetch<FA: FetchAddr + std::marker::Sync>(
28    interface_name: &str,
29    host: &str,
30    path: &str,
31    addr: &FA,
32) -> anyhow::Result<u16> {
33    let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
34    let addr = addr.as_socket_addr();
35    let socket = socket2::Socket::new(
36        match addr {
37            net::SocketAddr::V4(_) => socket2::Domain::IPV4,
38            net::SocketAddr::V6(_) => socket2::Domain::IPV6,
39        },
40        socket2::Type::STREAM,
41        Some(socket2::Protocol::TCP),
42    )
43    .context("while constructing socket")?;
44    socket.bind_device(Some(interface_name.as_bytes()))?;
45    let mut stream = TcpStream::connect_from_raw(socket, addr)
46        .context("while constructing tcp stream")?
47        .map_err(|e| format_err!("Opening TcpStream connection failed: {e:?}"))
48        .on_timeout(timeout, || Err(format_err!("Opening TcpStream timed out")))
49        .await?;
50    let message = http_request(path, host);
51    stream
52        .write_all(message.as_bytes())
53        .map_err(|e| format_err!("Writing to TcpStream failed: {e:?}"))
54        .on_timeout(timeout, || Err(format_err!("Writing data to TcpStream timed out")))
55        .await?;
56
57    let mut bytes = Vec::new();
58    let _: usize = stream
59        .read_to_end(&mut bytes)
60        .map_err(|e| format_err!("Reading response from TcpStream failed: {e:?}"))
61        .on_timeout(timeout, || Err(format_err!("Reading response from TcpStream timed out")))
62        .await?;
63    let resp = String::from_utf8(bytes)?;
64    let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
65    if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
66        if !http.starts_with("HTTP/") {
67            return Err(format_err!("Response header malformed: {first_line}"));
68        }
69        Ok(code.parse().map_err(|e| format_err!("While parsing status code: {e:?}"))?)
70    } else {
71        Err(format_err!("Response header malformed: {first_line}"))
72    }
73}
74
75pub trait FetchAddr {
76    fn as_socket_addr(&self) -> net::SocketAddr;
77}
78
79impl FetchAddr for net::SocketAddr {
80    fn as_socket_addr(&self) -> net::SocketAddr {
81        *self
82    }
83}
84
85impl FetchAddr for net::IpAddr {
86    fn as_socket_addr(&self) -> net::SocketAddr {
87        net::SocketAddr::from((*self, 80))
88    }
89}
90
91#[async_trait]
92pub trait Fetch {
93    async fn fetch<FA: FetchAddr + std::marker::Sync>(
94        &self,
95        interface_name: &str,
96        host: &str,
97        path: &str,
98        addr: &FA,
99    ) -> Option<u16>;
100}
101
102pub struct Fetcher;
103
104#[async_trait]
105impl Fetch for Fetcher {
106    async fn fetch<FA: FetchAddr + std::marker::Sync>(
107        &self,
108        interface_name: &str,
109        host: &str,
110        path: &str,
111        addr: &FA,
112    ) -> Option<u16> {
113        let r = fetch(interface_name, host, path, addr).await;
114        match r {
115            Ok(code) => Some(code),
116            Err(e) => {
117                // Check to see if the error is due to the host/network being
118                // unreachable. In that case, this error is likely unconcerning
119                // and signifies a network may not have connectivity across
120                // one of the IP protocols, which can be common for home
121                // network configurations.
122                if let Some(io_error) = e.downcast_ref::<std::io::Error>() {
123                    if io_error.raw_os_error() == Some(libc::ENETUNREACH)
124                        || io_error.raw_os_error() == Some(libc::EHOSTUNREACH)
125                    {
126                        info!("error while fetching {host}{path}: {e:?}");
127                        return None;
128                    }
129                }
130                warn!("error while fetching {host}{path}: {e:?}");
131                None
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use super::*;
140
141    use std::net::{Ipv4Addr, SocketAddr};
142    use std::pin::pin;
143
144    use fuchsia_async::net::TcpListener;
145    use fuchsia_async::{self as fasync};
146    use futures::future::Fuse;
147    use futures::io::BufReader;
148    use futures::{AsyncBufReadExt, FutureExt, StreamExt};
149    use test_case::test_case;
150
151    fn server(
152        code: u16,
153    ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
154        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
155        let listener = TcpListener::bind(&addr).context("binding TCP")?;
156        let addr = listener.local_addr()?;
157
158        let server_fut = async move {
159            let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
160            let mut incoming = listener.accept_stream();
161            if let Some(result) = incoming
162                .next()
163                .on_timeout(timeout, || panic!("timeout waiting for connection"))
164                .await
165            {
166                let (stream, _addr) = result.expect("accept incoming TCP connection");
167                let mut stream = BufReader::new(stream);
168                let mut request = Vec::new();
169                loop {
170                    let mut s = String::new();
171                    let _: usize = stream
172                        .read_line(&mut s)
173                        .on_timeout(timeout, || panic!("timeout waiting to read data"))
174                        .await
175                        .expect("read data");
176                    if s == "\r\n" {
177                        break;
178                    }
179                    request.push(s.trim().to_string());
180                }
181                let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
182                stream
183                    .write_all(data.as_bytes())
184                    .on_timeout(timeout, || panic!("timeout waiting to write response"))
185                    .await
186                    .expect("reply to request");
187                request
188            } else {
189                Vec::new()
190            }
191        }
192        .fuse();
193
194        Ok((addr, server_fut))
195    }
196
197    #[test_case("http://reachability.test/", 200; "base path 200")]
198    #[test_case("http://reachability.test/path/", 200; "sub path 200")]
199    #[test_case("http://reachability.test/", 400; "base path 400")]
200    #[test_case("http://reachability.test/path/", 400; "sub path 400")]
201    #[fasync::run_singlethreaded(test)]
202    async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
203        let url = url::Url::parse(url_str)?;
204        let (addr, server_fut) = server(code)?;
205        let domain = url.host().expect("no host").to_string();
206        let path = url.path().to_string();
207
208        let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
209
210        let mut server_fut = pin!(server_fut);
211
212        let mut request = None;
213        let result = loop {
214            futures::select! {
215                req = server_fut => request = Some(req),
216                result = fetch_fut => break result
217            };
218        };
219
220        assert!(result.is_ok(), "Expected OK, got: {result:?}");
221        assert_eq!(result.ok(), Some(code));
222        let request = request.expect("no request body");
223        assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
224        assert!(request.contains(&format!("host: {domain}")));
225
226        Ok(())
227    }
228}