1mod stream;
8#[cfg(test)]
9mod test_util;
10
11use std::cmp::Ordering;
12use std::collections::{HashMap, HashSet};
13
14use fidl_fuchsia_net::SocketAddress;
15use fidl_fuchsia_net_name::{
16 DhcpDnsServerSource, Dhcpv6DnsServerSource, DnsServerSource, DnsServer_, NdpDnsServerSource,
17 SocketProxyDnsServerSource, StaticDnsServerSource,
18};
19
20pub use self::stream::*;
21
22pub const DEFAULT_DNS_PORT: u16 = 53;
24
25#[derive(Default)]
27pub struct DnsServers {
28 default: Vec<DnsServer_>,
32
33 netstack: Vec<DnsServer_>,
35
36 dhcpv4: HashMap<u64, Vec<DnsServer_>>,
38
39 dhcpv6: HashMap<u64, Vec<DnsServer_>>,
41
42 ndp: HashMap<u64, Vec<DnsServer_>>,
44
45 socketproxy: Vec<DnsServer_>,
47}
48
49impl DnsServers {
50 pub fn set_servers_from_source(
54 &mut self,
55 source: DnsServersUpdateSource,
56 servers: Vec<DnsServer_>,
57 ) {
58 let Self { default, netstack, dhcpv4, dhcpv6, ndp, socketproxy } = self;
59
60 match source {
61 DnsServersUpdateSource::Default => *default = servers,
62 DnsServersUpdateSource::Netstack => *netstack = servers,
63 DnsServersUpdateSource::Dhcpv4 { interface_id } => {
64 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
67 dhcpv4.remove(&interface_id)
68 } else {
69 dhcpv4.insert(interface_id, servers)
70 };
71 }
72 DnsServersUpdateSource::Dhcpv6 { interface_id } => {
73 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
76 dhcpv6.remove(&interface_id)
77 } else {
78 dhcpv6.insert(interface_id, servers)
79 };
80 }
81 DnsServersUpdateSource::Ndp { interface_id } => {
82 let _: Option<Vec<DnsServer_>> = if servers.is_empty() {
85 ndp.remove(&interface_id)
86 } else {
87 ndp.insert(interface_id, servers)
88 };
89 }
90 DnsServersUpdateSource::SocketProxy => *socketproxy = servers,
91 }
92 }
93
94 pub fn consolidated(&self) -> Vec<SocketAddress> {
115 self.consolidate_filter_map(|x| x.address)
116 }
117
118 pub fn consolidated_dns_servers(&self) -> Vec<DnsServer_> {
137 self.consolidate_filter_map(|x| Some(x))
138 }
139
140 fn consolidate_filter_map<T, F: Fn(DnsServer_) -> Option<T>>(&self, f: F) -> Vec<T> {
144 let Self { default, netstack, dhcpv4, dhcpv6, ndp, socketproxy } = self;
145 let mut servers = netstack
146 .iter()
147 .chain(socketproxy)
148 .chain(dhcpv4.values().flatten())
149 .chain(ndp.values().flatten())
150 .chain(dhcpv6.values().flatten())
151 .cloned()
152 .collect::<Vec<_>>();
153 let () = servers.sort_by(Self::ordering);
159 let () = servers.extend(default.clone());
162 let mut addresses = HashSet::new();
163 let () = servers.retain(move |s| addresses.insert(s.address));
164 servers.into_iter().filter_map(f).collect()
165 }
166
167 fn ordering(a: &DnsServer_, b: &DnsServer_) -> Ordering {
183 let ordering = |source| match source {
184 Some(&DnsServerSource::SocketProxy(SocketProxyDnsServerSource {
185 source_interface: _,
186 ..
187 })) => 0,
188 Some(&DnsServerSource::Dhcp(DhcpDnsServerSource { source_interface: _, .. })) => 1,
189 Some(&DnsServerSource::Ndp(NdpDnsServerSource { source_interface: _, .. })) => 2,
190 Some(&DnsServerSource::Dhcpv6(Dhcpv6DnsServerSource {
191 source_interface: _, ..
192 })) => 3,
193 Some(&DnsServerSource::StaticSource(StaticDnsServerSource { .. })) => 4,
194 Some(&DnsServerSource::__SourceBreaking { .. }) | None => 5,
195 };
196 let a = ordering(a.source.as_ref());
197 let b = ordering(b.source.as_ref());
198 std::cmp::Ord::cmp(&a, &b)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::test_util::constants::*;
206
207 #[test]
208 fn deduplicate_within_source() {
209 let servers = DnsServers {
211 default: vec![ndp_server(), ndp_server()],
212 netstack: vec![ndp_server(), static_server(), ndp_server(), static_server()],
213 dhcpv4: [
217 (DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1(), dhcpv4_server2()]),
218 (DHCPV4_SERVER2_INTERFACE_ID, vec![dhcpv4_server1(), dhcpv4_server2()]),
219 ]
220 .into_iter()
221 .collect(),
222 dhcpv6: [
223 (DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1(), dhcpv6_server2()]),
224 (DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server1(), dhcpv6_server2()]),
225 ]
226 .into_iter()
227 .collect(),
228 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server(), ndp_server()])]
229 .into_iter()
230 .collect(),
231 socketproxy: vec![socketproxy_server1(), socketproxy_server2(), socketproxy_server1()],
232 };
233 assert_eq!(
237 servers.consolidated(),
238 vec![
239 SOCKETPROXY_SOURCE_SOCKADDR1,
240 SOCKETPROXY_SOURCE_SOCKADDR2,
241 DHCPV4_SOURCE_SOCKADDR1,
242 DHCPV4_SOURCE_SOCKADDR2,
243 NDP_SOURCE_SOCKADDR,
244 DHCPV6_SOURCE_SOCKADDR1,
245 DHCPV6_SOURCE_SOCKADDR2,
246 STATIC_SOURCE_SOCKADDR,
247 ],
248 );
249 }
250
251 #[test]
252 fn default_low_prio() {
253 let servers = DnsServers {
257 default: vec![
258 static_server(),
259 dhcpv4_server1(),
260 dhcpv6_server1(),
261 socketproxy_server1(),
262 ],
263 netstack: vec![static_server()],
264 dhcpv4: [
265 (DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1()]),
266 (DHCPV4_SERVER2_INTERFACE_ID, vec![dhcpv4_server1()]),
267 ]
268 .into_iter()
269 .collect(),
270 dhcpv6: [
271 (DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1()]),
272 (DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server2()]),
273 ]
274 .into_iter()
275 .collect(),
276 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server()])].into_iter().collect(),
277 socketproxy: vec![socketproxy_server1()],
278 };
279 let mut got = servers.consolidated();
282 let mut got = got.drain(..);
283 let want_socketproxy = [SOCKETPROXY_SOURCE_SOCKADDR1];
284 assert_eq!(
285 HashSet::from_iter(got.by_ref().take(want_socketproxy.len())),
286 HashSet::from(want_socketproxy),
287 );
288
289 let want_dhcpv4 = [DHCPV4_SOURCE_SOCKADDR1];
290 assert_eq!(
291 HashSet::from_iter(got.by_ref().take(want_dhcpv4.len())),
292 HashSet::from(want_dhcpv4),
293 );
294
295 let want_ndp = [NDP_SOURCE_SOCKADDR];
296 assert_eq!(HashSet::from_iter(got.by_ref().take(want_ndp.len())), HashSet::from(want_ndp));
297
298 let want_dhcpv6 = [DHCPV6_SOURCE_SOCKADDR1, DHCPV6_SOURCE_SOCKADDR2];
299 assert_eq!(
300 HashSet::from_iter(got.by_ref().take(want_dhcpv6.len())),
301 HashSet::from(want_dhcpv6),
302 );
303
304 let want_rest = [STATIC_SOURCE_SOCKADDR];
305 assert_eq!(got.as_slice(), want_rest);
306 }
307
308 #[test]
309 fn deduplicate_across_sources() {
310 let dhcpv6_with_ndp_address = || DnsServer_ {
315 address: Some(NDP_SOURCE_SOCKADDR),
316 source: Some(DnsServerSource::Dhcpv6(Dhcpv6DnsServerSource {
317 source_interface: Some(DHCPV6_SERVER1_INTERFACE_ID),
318 ..Default::default()
319 })),
320 ..Default::default()
321 };
322 let mut dhcpv6 = HashMap::new();
323 assert_matches::assert_matches!(
324 dhcpv6.insert(
325 DHCPV6_SERVER1_INTERFACE_ID,
326 vec![dhcpv6_with_ndp_address(), dhcpv6_server1()]
327 ),
328 None
329 );
330 let mut servers = DnsServers {
331 default: vec![],
332 netstack: vec![dhcpv4_server1(), static_server()],
333 dhcpv4: [(DHCPV4_SERVER1_INTERFACE_ID, vec![dhcpv4_server1()])].into_iter().collect(),
334 dhcpv6: [(
335 DHCPV6_SERVER1_INTERFACE_ID,
336 vec![dhcpv6_with_ndp_address(), dhcpv6_server1()],
337 )]
338 .into_iter()
339 .collect(),
340 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_server(), dhcpv6_with_ndp_address()])]
341 .into_iter()
342 .collect(),
343 socketproxy: vec![],
344 };
345 let expected_servers =
346 vec![dhcpv4_server1(), ndp_server(), dhcpv6_server1(), static_server()];
347 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
348 let expected_sockaddrs = vec![
349 DHCPV4_SOURCE_SOCKADDR1,
350 NDP_SOURCE_SOCKADDR,
351 DHCPV6_SOURCE_SOCKADDR1,
352 STATIC_SOURCE_SOCKADDR,
353 ];
354 assert_eq!(servers.consolidated(), expected_sockaddrs);
355 servers.netstack = vec![dhcpv4_server1(), static_server(), dhcpv6_with_ndp_address()];
356 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
357 assert_eq!(servers.consolidated(), expected_sockaddrs);
358
359 let ndp_with_dhcpv6_sockaddr1 = || DnsServer_ {
362 address: Some(DHCPV6_SOURCE_SOCKADDR1),
363 source: Some(DnsServerSource::Ndp(NdpDnsServerSource {
364 source_interface: Some(NDP_SERVER_INTERFACE_ID),
365 ..Default::default()
366 })),
367 ..Default::default()
368 };
369
370 let mut dhcpv6 = HashMap::new();
371 assert_matches::assert_matches!(
372 dhcpv6.insert(DHCPV6_SERVER1_INTERFACE_ID, vec![dhcpv6_server1()]),
373 None
374 );
375 assert_matches::assert_matches!(
376 dhcpv6.insert(DHCPV6_SERVER2_INTERFACE_ID, vec![dhcpv6_server2()]),
377 None
378 );
379 let mut servers = DnsServers {
380 default: vec![],
381 netstack: vec![static_server()],
382 dhcpv4: Default::default(),
383 dhcpv6,
384 ndp: [(NDP_SERVER_INTERFACE_ID, vec![ndp_with_dhcpv6_sockaddr1()])]
385 .into_iter()
386 .collect(),
387 socketproxy: vec![],
388 };
389 let expected_servers = vec![ndp_with_dhcpv6_sockaddr1(), dhcpv6_server2(), static_server()];
390 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
391 let expected_sockaddrs =
392 vec![DHCPV6_SOURCE_SOCKADDR1, DHCPV6_SOURCE_SOCKADDR2, STATIC_SOURCE_SOCKADDR];
393 assert_eq!(servers.consolidated(), expected_sockaddrs);
394 servers.netstack = vec![static_server(), ndp_with_dhcpv6_sockaddr1()];
395 assert_eq!(servers.consolidate_filter_map(Some), expected_servers);
396 assert_eq!(servers.consolidated(), expected_sockaddrs);
397 }
398
399 #[test]
400 fn test_dns_servers_ordering() {
401 assert_eq!(
402 DnsServers::ordering(&socketproxy_server1(), &socketproxy_server1()),
403 Ordering::Equal
404 );
405 assert_eq!(DnsServers::ordering(&ndp_server(), &ndp_server()), Ordering::Equal);
406 assert_eq!(DnsServers::ordering(&dhcpv4_server1(), &dhcpv4_server1()), Ordering::Equal);
407 assert_eq!(DnsServers::ordering(&dhcpv6_server1(), &dhcpv6_server1()), Ordering::Equal);
408 assert_eq!(DnsServers::ordering(&static_server(), &static_server()), Ordering::Equal);
409 assert_eq!(
410 DnsServers::ordering(&unspecified_source_server(), &unspecified_source_server()),
411 Ordering::Equal
412 );
413
414 let mut servers = vec![
415 unspecified_source_server(),
416 dhcpv6_server1(),
417 dhcpv4_server1(),
418 static_server(),
419 ndp_server(),
420 socketproxy_server1(),
421 ];
422 servers.sort_by(DnsServers::ordering);
423 assert_eq!(
424 servers,
425 vec![
426 socketproxy_server1(),
427 dhcpv4_server1(),
428 ndp_server(),
429 dhcpv6_server1(),
430 static_server(),
431 unspecified_source_server(),
432 ]
433 );
434 }
435}