1use anyhow::{Context, Error};
8use fidl::endpoints::{ControlHandle as _, RequestStream as _, Responder as _};
9use fidl_fuchsia_netpol_socketproxy::{self as fnp_socketproxy, DnsServerList};
10use fuchsia_inspect_derive::{IValue, Inspect, Unit};
11use futures::channel::mpsc;
12use futures::lock::Mutex;
13use futures::{StreamExt, TryStreamExt};
14use log::{info, warn};
15use std::sync::Arc;
16
17#[derive(Unit, Debug, Default)]
18struct DnsServerWatcherState {
19 #[inspect(skip)]
20 server_list: Vec<DnsServerList>,
21 #[inspect(skip)]
22 last_sent: Option<Vec<DnsServerList>>,
23 #[inspect(skip)]
24 queued_responder: Option<fnp_socketproxy::DnsServerWatcherWatchServersResponder>,
25
26 updates_seen: u32,
27 updates_sent: u32,
28}
29
30#[derive(Inspect, Debug, Clone)]
33pub(crate) struct DnsServerWatcher {
34 #[inspect(forward)]
35 state: Arc<Mutex<IValue<DnsServerWatcherState>>>,
36 dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
37}
38
39impl DnsServerWatcher {
40 pub(crate) fn new(
42 dns_rx: Arc<Mutex<mpsc::Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
43 ) -> Self {
44 Self { dns_rx, state: Default::default() }
45 }
46
47 pub(crate) async fn run<'a>(
49 &self,
50 stream: fnp_socketproxy::DnsServerWatcherRequestStream,
51 ) -> Result<(), Error> {
52 let mut state = match self.state.try_lock() {
53 Some(o) => o,
54 None => {
55 warn!("Only one connection to DnsServerWatcher is allowed at a time");
56 stream.control_handle().shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
57 return Ok(());
58 }
59 };
60 let mut dns_rx = self.dns_rx.lock().await;
61 info!("Starting fuchsia.netpol.socketproxy.DnsServerWatcher server");
62 let mut stream = stream.map(|result| result.context("failed request")).fuse();
63
64 loop {
65 futures::select! {
66 request = stream.try_next() => match request? {
67 Some(fnp_socketproxy::DnsServerWatcherRequest::WatchServers { responder }) => {
68 let mut state = state.as_mut();
69 if state.queued_responder.is_some() {
70 warn!("Only one call to watch server may be active at once");
71 responder
72 .control_handle()
73 .shutdown_with_epitaph(fidl::Status::ACCESS_DENIED);
74 } else {
75 state.queued_responder = Some(responder);
76 state.maybe_respond()?;
77 }
78 },
79 Some(fnp_socketproxy::DnsServerWatcherRequest::CheckPresence { responder }) => {
80 let _: Result<(), fidl::Error> = responder.send();
82 }
83 None => {}
84 },
85 dns_update = dns_rx.select_next_some() => {
86 let mut state = state.as_mut();
87 state.updates_seen += 1;
88 state.server_list = dns_update;
89 state.maybe_respond()?;
90 }
91 }
92 }
93 }
94}
95
96impl DnsServerWatcherState {
97 fn maybe_respond(&mut self) -> Result<(), Error> {
98 if self.last_sent.as_ref() != Some(&self.server_list) {
99 if let Some(responder) = self.queued_responder.take() {
100 info!("Sending DNS update to client: {}", self.server_list.len());
101 responder.send(&self.server_list)?;
102 self.updates_sent += 1;
103 self.last_sent = Some(self.server_list.clone());
104 }
105 }
106 Ok(())
107 }
108}
109
110#[cfg(test)]
111mod test {
112 use super::*;
113 use assert_matches::assert_matches;
114 use diagnostics_assertions::assert_data_tree;
115 use fuchsia_component::server::ServiceFs;
116 use fuchsia_component_test::{
117 Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
118 };
119 use fuchsia_inspect_derive::WithInspect;
120 use futures::channel::mpsc::{Receiver, Sender};
121 use futures::SinkExt as _;
122 use pretty_assertions::assert_eq;
123
124 enum IncomingService {
125 DnsServerWatcher(fnp_socketproxy::DnsServerWatcherRequestStream),
126 }
127
128 async fn run_registry(
129 handles: LocalComponentHandles,
130 dns_rx: Arc<Mutex<Receiver<Vec<fnp_socketproxy::DnsServerList>>>>,
131 ) -> Result<(), Error> {
132 let mut fs = ServiceFs::new();
133 let _ = fs.dir("svc").add_fidl_service(IncomingService::DnsServerWatcher);
134 let _ = fs.serve_connection(handles.outgoing_dir)?;
135
136 let watcher = DnsServerWatcher::new(dns_rx)
137 .with_inspect(fuchsia_inspect::component::inspector().root(), "dns_watcher")?;
138
139 fs.for_each_concurrent(0, |IncomingService::DnsServerWatcher(stream)| {
140 let watcher = watcher.clone();
141 async move {
142 watcher
143 .run(stream)
144 .await
145 .context("Failed to serve request stream")
146 .unwrap_or_else(|e| eprintln!("Error encountered: {e:?}"))
147 }
148 })
149 .await;
150
151 Ok(())
152 }
153
154 async fn setup_test(
155 ) -> Result<(RealmInstance, Sender<Vec<fnp_socketproxy::DnsServerList>>), Error> {
156 let builder = RealmBuilder::new().await?;
157 let (dns_tx, dns_rx) = mpsc::channel(1);
158 let dns_rx = Arc::new(Mutex::new(dns_rx));
159 let registry = builder
160 .add_local_child(
161 "dns_watcher",
162 {
163 let dns_rx = dns_rx.clone();
164 move |handles: LocalComponentHandles| {
165 Box::pin(run_registry(handles, dns_rx.clone()))
166 }
167 },
168 ChildOptions::new(),
169 )
170 .await?;
171
172 builder
173 .add_route(
174 Route::new()
175 .capability(Capability::protocol::<fnp_socketproxy::DnsServerWatcherMarker>())
176 .from(®istry)
177 .to(Ref::parent()),
178 )
179 .await?;
180
181 let realm = builder.build().await?;
182
183 Ok((realm, dns_tx))
184 }
185
186 #[fuchsia::test]
187 async fn test_normal_operation() -> Result<(), Error> {
188 let (realm, mut dns_tx) = setup_test().await?;
189
190 let dns_server_watcher = realm
191 .root
192 .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
193 .context("While connecting to DnsServerWatcher")?;
194
195 assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
197
198 let (send_result, watch_result) = futures::future::join(
200 dns_tx.send(vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]),
201 dns_server_watcher.watch_servers(),
202 )
203 .await;
204
205 assert_matches!(send_result, Ok(()));
206 assert_eq!(
207 watch_result?,
208 vec![DnsServerList { source_network_id: Some(0), ..Default::default() }]
209 );
210
211 assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
212 dns_watcher: {
213 updates_seen: 1u64,
214 updates_sent: 2u64,
215 },
216 });
217
218 Ok(())
219 }
220
221 #[fuchsia::test]
222 async fn test_duplicate_list() -> Result<(), Error> {
223 let (realm, mut dns_tx) = setup_test().await?;
224 let dns_server_watcher = realm
225 .root
226 .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
227 .context("While connecting to DnsServerWatcher")?;
228
229 assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
231
232 let server_list = vec![DnsServerList { source_network_id: Some(0), ..Default::default() }];
233
234 let mut dns_tx2 = dns_tx.clone();
235 let mut dns_tx3 = dns_tx.clone();
236 let (watch_result, s1, s2, s3) = futures::join!(
237 dns_server_watcher.watch_servers(),
238 dns_tx.send(server_list.clone()),
239 dns_tx2.send(server_list.clone()),
240 dns_tx3.send(server_list.clone()),
241 );
242
243 assert_matches!(s1, Ok(()));
244 assert_matches!(s2, Ok(()));
245 assert_matches!(s3, Ok(()));
246 assert_eq!(watch_result?, server_list);
247
248 let (send_result, watch_result) = futures::future::join(
250 dns_tx.send(vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]),
251 dns_server_watcher.watch_servers(),
252 )
253 .await;
254 assert_matches!(send_result, Ok(()));
255
256 assert_eq!(
259 watch_result?,
260 vec![DnsServerList { source_network_id: Some(1), ..Default::default() }]
261 );
262
263 assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
264 dns_watcher: {
265 updates_seen: 4u64,
266 updates_sent: 3u64,
267 },
268 });
269
270 Ok(())
271 }
272
273 #[fuchsia::test]
274 async fn test_duplicate_watch() -> Result<(), Error> {
275 let (realm, _dns_tx) = setup_test().await?;
276
277 let dns_server_watcher = realm
278 .root
279 .connect_to_protocol_at_exposed_dir::<fnp_socketproxy::DnsServerWatcherMarker>()
280 .context("While connecting to DnsServerWatcher")?;
281
282 assert_eq!(dns_server_watcher.watch_servers().await?, vec![]);
284
285 let watch1 = dns_server_watcher.watch_servers();
286 let watch2 = dns_server_watcher.watch_servers();
287
288 assert_matches!(
291 futures::future::join(watch1, watch2).await,
292 (
293 Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. }),
294 Err(fidl::Error::ClientChannelClosed { status: fidl::Status::ACCESS_DENIED, .. })
295 )
296 );
297
298 assert_data_tree!(fuchsia_inspect::component::inspector(), root: {
299 dns_watcher: {
300 updates_seen: 0u64,
301 updates_sent: 1u64,
302 },
303 });
304
305 Ok(())
306 }
307}