reachability_core/
fetch.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use anyhow::{format_err, Context};
6use async_trait::async_trait;
7use fuchsia_async::net::TcpStream;
8use fuchsia_async::TimeoutExt;
9
10use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
11use log::warn;
12use std::net;
13
14const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
15
16fn http_request(path: &str, host: &str) -> String {
17    [
18        &format!("HEAD {path} HTTP/1.1"),
19        &format!("host: {host}"),
20        "connection: close",
21        "user-agent: fuchsia reachability probe",
22    ]
23    .join("\r\n")
24        + "\r\n\r\n"
25}
26
27async fn fetch<FA: FetchAddr + std::marker::Sync>(
28    interface_name: &str,
29    host: &str,
30    path: &str,
31    addr: &FA,
32) -> anyhow::Result<u16> {
33    let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
34    let addr = addr.as_socket_addr();
35    let socket = socket2::Socket::new(
36        match addr {
37            net::SocketAddr::V4(_) => socket2::Domain::IPV4,
38            net::SocketAddr::V6(_) => socket2::Domain::IPV6,
39        },
40        socket2::Type::STREAM,
41        Some(socket2::Protocol::TCP),
42    )
43    .context("while constructing socket")?;
44    socket.bind_device(Some(interface_name.as_bytes()))?;
45    let mut stream = TcpStream::connect_from_raw(socket, addr)
46        .context("while constructing tcp stream")?
47        .map_err(|e| format_err!("Opening TcpStream connection failed: {e:?}"))
48        .on_timeout(timeout, || Err(format_err!("Opening TcpStream timed out")))
49        .await?;
50    let message = http_request(path, host);
51    stream
52        .write_all(message.as_bytes())
53        .map_err(|e| format_err!("Writing to TcpStream failed: {e:?}"))
54        .on_timeout(timeout, || Err(format_err!("Writing data to TcpStream timed out")))
55        .await?;
56
57    let mut bytes = Vec::new();
58    let _: usize = stream
59        .read_to_end(&mut bytes)
60        .map_err(|e| format_err!("Reading response from TcpStream failed: {e:?}"))
61        .on_timeout(timeout, || Err(format_err!("Reading response from TcpStream timed out")))
62        .await?;
63    let resp = String::from_utf8(bytes)?;
64    let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
65    if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
66        if !http.starts_with("HTTP/") {
67            return Err(format_err!("Response header malformed: {first_line}"));
68        }
69        Ok(code.parse().map_err(|e| format_err!("While parsing status code: {e:?}"))?)
70    } else {
71        Err(format_err!("Response header malformed: {first_line}"))
72    }
73}
74
75pub trait FetchAddr {
76    fn as_socket_addr(&self) -> net::SocketAddr;
77}
78
79impl FetchAddr for net::SocketAddr {
80    fn as_socket_addr(&self) -> net::SocketAddr {
81        *self
82    }
83}
84
85impl FetchAddr for net::IpAddr {
86    fn as_socket_addr(&self) -> net::SocketAddr {
87        net::SocketAddr::from((*self, 80))
88    }
89}
90
91#[async_trait]
92pub trait Fetch {
93    async fn fetch<FA: FetchAddr + std::marker::Sync>(
94        &self,
95        interface_name: &str,
96        host: &str,
97        path: &str,
98        addr: &FA,
99    ) -> Option<u16>;
100}
101
102pub struct Fetcher;
103
104#[async_trait]
105impl Fetch for Fetcher {
106    async fn fetch<FA: FetchAddr + std::marker::Sync>(
107        &self,
108        interface_name: &str,
109        host: &str,
110        path: &str,
111        addr: &FA,
112    ) -> Option<u16> {
113        let r = fetch(interface_name, host, path, addr).await;
114        match r {
115            Ok(code) => Some(code),
116            Err(e) => {
117                warn!("error while fetching {host}{path}: {e:?}");
118                None
119            }
120        }
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    use std::net::{Ipv4Addr, SocketAddr};
129    use std::pin::pin;
130
131    use fuchsia_async::net::TcpListener;
132    use fuchsia_async::{self as fasync};
133    use futures::future::Fuse;
134    use futures::io::BufReader;
135    use futures::{AsyncBufReadExt, FutureExt, StreamExt};
136    use test_case::test_case;
137
138    fn server(
139        code: u16,
140    ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
141        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
142        let listener = TcpListener::bind(&addr).context("binding TCP")?;
143        let addr = listener.local_addr()?;
144
145        let server_fut = async move {
146            let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
147            let mut incoming = listener.accept_stream();
148            if let Some(result) = incoming
149                .next()
150                .on_timeout(timeout, || panic!("timeout waiting for connection"))
151                .await
152            {
153                let (stream, _addr) = result.expect("accept incoming TCP connection");
154                let mut stream = BufReader::new(stream);
155                let mut request = Vec::new();
156                loop {
157                    let mut s = String::new();
158                    let _: usize = stream
159                        .read_line(&mut s)
160                        .on_timeout(timeout, || panic!("timeout waiting to read data"))
161                        .await
162                        .expect("read data");
163                    if s == "\r\n" {
164                        break;
165                    }
166                    request.push(s.trim().to_string());
167                }
168                let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
169                stream
170                    .write_all(data.as_bytes())
171                    .on_timeout(timeout, || panic!("timeout waiting to write response"))
172                    .await
173                    .expect("reply to request");
174                request
175            } else {
176                Vec::new()
177            }
178        }
179        .fuse();
180
181        Ok((addr, server_fut))
182    }
183
184    #[test_case("http://reachability.test/", 200; "base path 200")]
185    #[test_case("http://reachability.test/path/", 200; "sub path 200")]
186    #[test_case("http://reachability.test/", 400; "base path 400")]
187    #[test_case("http://reachability.test/path/", 400; "sub path 400")]
188    #[fasync::run_singlethreaded(test)]
189    async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
190        let url = url::Url::parse(url_str)?;
191        let (addr, server_fut) = server(code)?;
192        let domain = url.host().expect("no host").to_string();
193        let path = url.path().to_string();
194
195        let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
196
197        let mut server_fut = pin!(server_fut);
198
199        let mut request = None;
200        let result = loop {
201            futures::select! {
202                req = server_fut => request = Some(req),
203                result = fetch_fut => break result
204            };
205        };
206
207        assert!(result.is_ok(), "Expected OK, got: {result:?}");
208        assert_eq!(result.ok(), Some(code));
209        let request = request.expect("no request body");
210        assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
211        assert!(request.contains(&format!("host: {domain}")));
212
213        Ok(())
214    }
215}