Skip to main content

reachability_core/
fetch.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use async_trait::async_trait;
6use fuchsia_async::TimeoutExt;
7use fuchsia_async::net::TcpStream;
8
9use futures::{AsyncReadExt, AsyncWriteExt, TryFutureExt};
10use std::net;
11
12const FETCH_TIMEOUT: zx::MonotonicDuration = zx::MonotonicDuration::from_seconds(10);
13
14#[derive(Debug, thiserror::Error)]
15pub enum FetchError {
16    #[error("failed to create socket")]
17    CreateSocket(#[source] std::io::Error),
18    #[error("failed to bind socket to device {interface_name}")]
19    BindSocket {
20        interface_name: String,
21        #[source]
22        err: std::io::Error,
23    },
24    #[error("failed to open TCP stream")]
25    ConnectTcpStream(#[source] std::io::Error),
26    #[error("timed out connecting TCP stream")]
27    ConnectTcpStreamTimeout,
28    #[error("failed to write to TCP stream")]
29    WriteTcpStream(#[source] std::io::Error),
30    #[error("timed out writing data to TCP stream")]
31    WriteTcpStreamTimeout,
32    #[error("failed to read response from TCP stream")]
33    ReadTcpStream(#[source] std::io::Error),
34    #[error("timed out reading response from TCP stream")]
35    ReadTcpStreamTimeout,
36    #[error("failed to parse string from UTF-8 bytes")]
37    ParseUtf8(#[from] std::string::FromUtf8Error),
38    #[error("response header malformed: {first_line}")]
39    MalformedHeader { first_line: String },
40    #[error("failed to parse status code")]
41    ParseStatusCode(#[source] std::num::ParseIntError),
42}
43
44impl FetchError {
45    /// Returns a short, simplified string describing the error.
46    /// Each string should only take at most 14 characters, else the row labels for the
47    /// `fetch_results` time series in internal visualization tool would be truncated.
48    pub fn short_name(&self) -> String {
49        let (name, os_err) = match self {
50            Self::CreateSocket(err) => ("CreateSock", err.raw_os_error()),
51            Self::BindSocket { err, .. } => ("BindSock", err.raw_os_error()),
52            Self::ConnectTcpStream(err) => ("ConnectTcp", err.raw_os_error()),
53            Self::ConnectTcpStreamTimeout => return "ConnTcpTimeout".to_string(),
54            Self::WriteTcpStream(err) => ("WriteTcp", err.raw_os_error()),
55            Self::WriteTcpStreamTimeout => return "WriteTcpTOut".to_string(),
56            Self::ReadTcpStream(err) => ("ReadTcp", err.raw_os_error()),
57            Self::ReadTcpStreamTimeout => return "ReadTcpTimeout".to_string(),
58            Self::ParseUtf8(_) => return "ParseUtf8".to_string(),
59            Self::MalformedHeader { .. } => return "BadHeader".to_string(),
60            Self::ParseStatusCode(_) => return "ParseStatus".to_string(),
61        };
62
63        if let Some(code) = os_err { format!("{name}_{code}") } else { name.to_string() }
64    }
65}
66
67pub(crate) fn fetch_result_short_name(result: &Result<u16, FetchError>) -> String {
68    match result {
69        Ok(code) => format!("Completed_{}", code),
70        Err(e) => format!("e_{}", e.short_name()),
71    }
72}
73
74fn http_request(path: &str, host: &str) -> String {
75    [
76        &format!("HEAD {path} HTTP/1.1"),
77        &format!("host: {host}"),
78        "connection: close",
79        "user-agent: fuchsia reachability probe",
80    ]
81    .join("\r\n")
82        + "\r\n\r\n"
83}
84
85async fn fetch<FA: FetchAddr + std::marker::Sync>(
86    interface_name: &str,
87    host: &str,
88    path: &str,
89    addr: &FA,
90) -> Result<u16, FetchError> {
91    let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
92    let addr = addr.as_socket_addr();
93    let socket = socket2::Socket::new(
94        match addr {
95            net::SocketAddr::V4(_) => socket2::Domain::IPV4,
96            net::SocketAddr::V6(_) => socket2::Domain::IPV6,
97        },
98        socket2::Type::STREAM,
99        Some(socket2::Protocol::TCP),
100    )
101    .map_err(FetchError::CreateSocket)?;
102    socket.bind_device(Some(interface_name.as_bytes())).map_err(|err| FetchError::BindSocket {
103        interface_name: interface_name.to_string(),
104        err,
105    })?;
106    let mut stream = TcpStream::connect_from_raw(socket, addr)
107        .map_err(FetchError::ConnectTcpStream)?
108        .map_err(FetchError::ConnectTcpStream)
109        .on_timeout(timeout, || Err(FetchError::ConnectTcpStreamTimeout))
110        .await?;
111    let message = http_request(path, host);
112    stream
113        .write_all(message.as_bytes())
114        .map_err(FetchError::WriteTcpStream)
115        .on_timeout(timeout, || Err(FetchError::WriteTcpStreamTimeout))
116        .await?;
117
118    let mut bytes = Vec::new();
119    let _: usize = stream
120        .read_to_end(&mut bytes)
121        .map_err(FetchError::ReadTcpStream)
122        .on_timeout(timeout, || Err(FetchError::ReadTcpStreamTimeout))
123        .await?;
124    let resp = String::from_utf8(bytes)?;
125    let first_line = resp.split("\r\n").next().expect("split always returns at least one item");
126    if let [http, code, ..] = first_line.split(' ').collect::<Vec<_>>().as_slice() {
127        if !http.starts_with("HTTP/") {
128            return Err(FetchError::MalformedHeader { first_line: first_line.to_string() });
129        }
130        Ok(code.parse().map_err(FetchError::ParseStatusCode)?)
131    } else {
132        Err(FetchError::MalformedHeader { first_line: first_line.to_string() })
133    }
134}
135
136pub trait FetchAddr {
137    fn as_socket_addr(&self) -> net::SocketAddr;
138}
139
140impl FetchAddr for net::SocketAddr {
141    fn as_socket_addr(&self) -> net::SocketAddr {
142        *self
143    }
144}
145
146impl FetchAddr for net::IpAddr {
147    fn as_socket_addr(&self) -> net::SocketAddr {
148        net::SocketAddr::from((*self, 80))
149    }
150}
151
152#[async_trait]
153pub trait Fetch {
154    async fn fetch<FA: FetchAddr + std::marker::Sync>(
155        &self,
156        interface_name: &str,
157        host: &str,
158        path: &str,
159        addr: &FA,
160    ) -> Result<u16, FetchError>;
161}
162
163pub struct Fetcher;
164
165#[async_trait]
166impl Fetch for Fetcher {
167    async fn fetch<FA: FetchAddr + std::marker::Sync>(
168        &self,
169        interface_name: &str,
170        host: &str,
171        path: &str,
172        addr: &FA,
173    ) -> Result<u16, FetchError> {
174        fetch(interface_name, host, path, addr).await
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use super::*;
181
182    use anyhow::Context;
183    use std::net::{Ipv4Addr, SocketAddr};
184    use std::pin::pin;
185
186    use fuchsia_async::net::TcpListener;
187    use fuchsia_async::{self as fasync};
188    use futures::future::Fuse;
189    use futures::io::BufReader;
190    use futures::{AsyncBufReadExt, FutureExt, StreamExt};
191    use test_case::test_case;
192
193    fn server(
194        code: u16,
195    ) -> anyhow::Result<(SocketAddr, Fuse<impl futures::Future<Output = Vec<String>>>)> {
196        let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0);
197        let listener = TcpListener::bind(&addr).context("binding TCP")?;
198        let addr = listener.local_addr()?;
199
200        let server_fut = async move {
201            let timeout = zx::MonotonicInstant::after(FETCH_TIMEOUT);
202            let mut incoming = listener.accept_stream();
203            if let Some(result) = incoming
204                .next()
205                .on_timeout(timeout, || panic!("timeout waiting for connection"))
206                .await
207            {
208                let (stream, _addr) = result.expect("accept incoming TCP connection");
209                let mut stream = BufReader::new(stream);
210                let mut request = Vec::new();
211                loop {
212                    let mut s = String::new();
213                    let _: usize = stream
214                        .read_line(&mut s)
215                        .on_timeout(timeout, || panic!("timeout waiting to read data"))
216                        .await
217                        .expect("read data");
218                    if s == "\r\n" {
219                        break;
220                    }
221                    request.push(s.trim().to_string());
222                }
223                let data = format!("HTTP/1.1 {} OK\r\n\r\n", code);
224                stream
225                    .write_all(data.as_bytes())
226                    .on_timeout(timeout, || panic!("timeout waiting to write response"))
227                    .await
228                    .expect("reply to request");
229                request
230            } else {
231                Vec::new()
232            }
233        }
234        .fuse();
235
236        Ok((addr, server_fut))
237    }
238
239    #[test_case("http://reachability.test/", 200; "base path 200")]
240    #[test_case("http://reachability.test/path/", 200; "sub path 200")]
241    #[test_case("http://reachability.test/", 400; "base path 400")]
242    #[test_case("http://reachability.test/path/", 400; "sub path 400")]
243    #[fasync::run_singlethreaded(test)]
244    async fn test_fetch(url_str: &'static str, code: u16) -> anyhow::Result<()> {
245        let url = url::Url::parse(url_str)?;
246        let (addr, server_fut) = server(code)?;
247        let domain = url.host().expect("no host").to_string();
248        let path = url.path().to_string();
249
250        let mut fetch_fut = pin!(fetch("", &domain, &path, &addr).fuse());
251
252        let mut server_fut = pin!(server_fut);
253
254        let mut request = None;
255        let result = loop {
256            futures::select! {
257                req = server_fut => request = Some(req),
258                result = fetch_fut => break result
259            };
260        };
261
262        assert!(result.is_ok(), "Expected OK, got: {result:?}");
263        assert_eq!(result.ok(), Some(code));
264        let request = request.expect("no request body");
265        assert!(request.contains(&format!("HEAD {path} HTTP/1.1")));
266        assert!(request.contains(&format!("host: {domain}")));
267
268        Ok(())
269    }
270}