socket_proxy/
dns_watcher.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements the fuchsia.netpol.socketproxy.DnsServerWatcher service.
6
7use anyhow::{Context, Error};
8use fidl::endpoints::{ControlHandle as _, RequestStream as _, Responder as _};
9use fidl_fuchsia_netpol_socketproxy::{self as fnp_socketproxy, DnsServerList};
10use fuchsia_inspect_derive::{IValue, Inspect, Unit};
11use futures::channel::mpsc;
12use futures::lock::Mutex;
13use futures::{StreamExt, TryStreamExt};
14use log::{info, warn};
15use std::sync::Arc;
16
17#[derive(Unit, Debug, Default)]
18struct DnsServerWatcherState {
19    #[inspect(skip)]
20    server_list: Vec<DnsServerList>,
21    #[inspect(skip)]
22    last_sent: Option<Vec<DnsServerList>>,
23    #[inspect(skip)]
24    queued_responder: Option<fnp_socketproxy::DnsServerWatcherWatchServersResponder>,
25
26    updates_seen: u32,
27    updates_sent: u32,
28}
29
30/// A wrapper around the fuchsia.netpol.socketproxy.DnsServerWatcher service
31/// that tracks when a DnsServerList update needs to be sent.
32#[derive(Inspect, Debug, Clone)]
33pub(crate) struct DnsServerWatcher {
34    #[inspect(forward)]
35    state: Arc<Mutex<IValue<DnsServerWatcherState>>>,
36    dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
37}
38
39impl DnsServerWatcher {
40    /// Create a new DnsServerWatcher.
41    pub(crate) fn new(
42        dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
43    ) -> Self {
44        Self { dns_rx, state: Default::default() }
45    }
46
47    /// Runs the fuchsia.netpol.socketproxy.DnsServerWatcher service.
48    pub(crate) async fn run<'a>(
49        &self,
50        stream: fnp_socketproxy::DnsServerWatcherRequestStream,
51    ) -> Result<(), Error> {
52        let mut state = match self.state.try_lock() {
53            Some(o) => o,
54            None => {
55                warn!("Only one connection to DnsServerWatcher is allowed at a time");
56                stream.control_handle().shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
57                return Ok(());
58            }
59        };
60        let mut dns_rx = self.dns_rx.lock().await;
61        info!("Starting fuchsia.netpol.socketproxy.DnsServerWatcher server");
62        let mut stream = stream.map(|result| result.context("failed request")).fuse();
63
64        loop {
65            futures::select! {
66                request = stream.try_next() => match request? {
67                    Some(fnp_socketproxy::DnsServerWatcherRequest::WatchServers { responder }) => {
68                        let mut state = state.as_mut();
69                        if state.queued_responder.is_some() {
70                            warn!("Only one call to watch server may be active at once");
71                            responder
72                                .control_handle()
73                                .shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
74                        } else {
75                            state.queued_responder = Some(responder);
76                            state.maybe_respond()?;
77                        }
78                    },
79                     Some(fnp_socketproxy::DnsServerWatcherRequest::CheckPresence { responder }) => {
80                        // This is a no-op method, so ignore any errors.
81                        let _: Result<(), fidl::Error> = responder.send();
82                    }
83                    None => {}
84                },
85                dns_update = dns_rx.select_next_some() => {
86                    let mut state = state.as_mut();
87                    state.updates_seen += 1;
88                    state.server_list = dns_update;
89                    state.maybe_respond()?;
90                }
91            }
92        }
93    }
94}
95
96impl DnsServerWatcherState {
97    fn maybe_respond(&mut self) -> Result<(), Error> {
98        if self.last_sent.as_ref() != Some(&self.server_list) {
99            if let Some(responder) = self.queued_responder.take() {
100                info!("Sending DNS update to client: {}", self.server_list.len());
101                responder.send(&self.server_list)?;
102                self.updates_sent += 1;
103                self.last_sent = Some(self.server_list.clone());
104            }
105        }
106        Ok(())
107    }
108}
109
110#[cfg(test)]
111mod test {
112    use super::*;
113    use assert_matches::assert_matches;
114    use diagnostics_assertions::assert_data_tree;
115    use fuchsia_component::server::ServiceFs;
116    use fuchsia_component_test::{
117        Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
118    };
119    use fuchsia_inspect_derive::WithInspect;
120    use futures::channel::mpsc::{Receiver, Sender};
121    use futures::SinkExt as _;
122    use pretty_assertions::assert_eq;
123
124    enum IncomingService {
125        DnsServerWatcher(fnp_socketproxy::DnsServerWatcherRequestStream),
126    }
127
128    async fn run_registry(
129        handles: LocalComponentHandles,
130        dns_rx: Arc<Mutex<Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
131    ) -> Result<(), Error> {
132        let mut fs = ServiceFs::new();
133        let _ = fs.dir("svc").add_fidl_service(IncomingService::DnsServerWatcher);
134        let _ = fs.serve_connection(handles.outgoing_dir)?;
135
136        let watcher = DnsServerWatcher::new(dns_rx)
137            .with_inspect(fuchsia_inspect::component::inspector().root(), "dns_watcher")?;
138
139        fs.for_each_concurrent(0, |IncomingService::DnsServerWatcher(stream)| {
140            let watcher = watcher.clone();
141            async move {
142                watcher
143                    .run(stream)
144                    .await
145                    .context("Failed to serve request stream")
146                    .unwrap_or_else(|e| eprintln!("Error encountered: {e:?}"))
147            }
148        })
149        .await;
150
151        Ok(())
152    }
153
154    async fn setup_test(
155    ) -> Result<(RealmInstance, Sender<Vec<fnp_socketproxy::DnsServerList>>), Error> {
156        let builder = RealmBuilder::new().await?;
157        let (dns_tx, dns_rx) = mpsc::channel(1);
158        let dns_rx = Arc::new(Mutex::new(dns_rx));
159        let registry = builder
160            .add_local_child(
161                "dns_watcher",
162                {
163                    let dns_rx = dns_rx.clone();
164                    move |handles: LocalComponentHandles| {
165                        Box::pin(run_registry(handles, dns_rx.clone()))
166                    }
167                },
168                ChildOptions::new(),
169            )
170            .await?;
171
172        builder
173            .add_route(
174                Route::new()
175                    .capability(Capability::protocol::<fnp_socketproxy::DnsServerWatcherMarker>())
176                    .from(&registry)
177                    .to(Ref::parent()),
178            )
179            .await?;
180
181        let realm = builder.build().await?;
182
183        Ok((realm, dns_tx))
184    }
185
186    #[fuchsia::test]
187    async fn test_normal_operation() -> Result<(), Error> {
188        let (realm, mut dns_tx) = setup_test().await?;
189
190        let dns_server_watcher = realm
191            .root
192            .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
193            .context("While connecting to DnsServerWatcher")?;
194
195        // Initial watch should return immediately
196        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
197
198        // Send a new DNS update
199        let (send_result, watch_result) = futures::future::join(
200            dns_tx.send(vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]),
201            dns_server_watcher.watch_servers(),
202        )
203        .await;
204
205        assert_matches!(send_result, Ok(()));
206        assert_eq!(
207            watch_result?,
208            vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]
209        );
210
211        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
212            dns_watcher: {
213                updates_seen: 1u64,
214                updates_sent: 2u64,
215            },
216        });
217
218        Ok(())
219    }
220
221    #[fuchsia::test]
222    async fn test_duplicate_list() -> Result<(), Error> {
223        let (realm, mut dns_tx) = setup_test().await?;
224        let dns_server_watcher = realm
225            .root
226            .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
227            .context("While connecting to DnsServerWatcher")?;
228
229        // Initial watch should return immediately
230        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
231
232        let server_list = vec![DnsServerList { source_network_id: Some(0), ..Default::default() }];
233
234        let mut dns_tx2 = dns_tx.clone();
235        let mut dns_tx3 = dns_tx.clone();
236        let (watch_result, s1, s2, s3) = futures::join!(
237            dns_server_watcher.watch_servers(),
238            dns_tx.send(server_list.clone()),
239            dns_tx2.send(server_list.clone()),
240            dns_tx3.send(server_list.clone()),
241        );
242
243        assert_matches!(s1, Ok(()));
244        assert_matches!(s2, Ok(()));
245        assert_matches!(s3, Ok(()));
246        assert_eq!(watch_result?, server_list);
247
248        // Send a new (distinct) DNS update
249        let (send_result, watch_result) = futures::future::join(
250            dns_tx.send(vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]),
251            dns_server_watcher.watch_servers(),
252        )
253        .await;
254        assert_matches!(send_result, Ok(()));
255
256        // We expect that this watch should get the new server list, not one of
257        // the old duplicate ones.
258        assert_eq!(
259            watch_result?,
260            vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]
261        );
262
263        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
264            dns_watcher: {
265                updates_seen: 4u64,
266                updates_sent: 3u64,
267            },
268        });
269
270        Ok(())
271    }
272
273    #[fuchsia::test]
274    async fn test_duplicate_watch() -> Result<(), Error> {
275        let (realm, _dns_tx) = setup_test().await?;
276
277        let dns_server_watcher = realm
278            .root
279            .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
280            .context("While connecting to DnsServerWatcher")?;
281
282        // Initial watch should return immediately
283        assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
284
285        let watch1 = dns_server_watcher.watch_servers();
286        let watch2 = dns_server_watcher.watch_servers();
287
288        // Two simultaneous calls to watch_servers is invalid and will cause the
289        // watcher channel to be closed.
290        assert_matches!(
291            futures::future::join(watch1, watch2).await,
292            (
293                Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. }),
294                Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. })
295            )
296        );
297
298        assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
299            dns_watcher: {
300                updates_seen: 0u64,
301                updates_sent: 1u64,
302            },
303        });
304
305        Ok(())
306    }
307}