openthread_fuchsia/backing/
resolver.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::*;
6use anyhow::Context as _;
7use fidl_fuchsia_net_name::DnsServerWatcherMarker;
8use fuchsia_sync::Mutex;
9use openthread::ot::DnsUpstream;
10use openthread_sys::*;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13use std::net::{Ipv6Addr, SocketAddr};
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker};
16
17const MAX_DNS_RESPONSE_SIZE: usize = 2048;
18
19struct DnsUpstreamQueryRefWrapper(&'static ot::PlatDnsUpstreamQuery);
20
21impl DnsUpstreamQueryRefWrapper {
22    fn as_ptr(&self) -> *const ot::PlatDnsUpstreamQuery {
23        std::ptr::from_ref(self.0)
24    }
25}
26
27impl PartialEq for DnsUpstreamQueryRefWrapper {
28    fn eq(&self, other: &Self) -> bool {
29        self.as_ptr().eq(&other.as_ptr())
30    }
31}
32
33impl Eq for DnsUpstreamQueryRefWrapper {}
34
35impl Hash for DnsUpstreamQueryRefWrapper {
36    fn hash<H>(&self, state: &mut H)
37    where
38        H: Hasher,
39    {
40        self.as_ptr().hash(state);
41    }
42}
43
44struct Transaction {
45    // This field is set but never accessed directly, so we need to silence this warning
46    // so that we can compile.
47    #[allow(unused)]
48    // The task that performs socket poll and forwards the DNS reply from socket.
49    task: fasync::Task<Result<(), anyhow::Error>>,
50    // Receive the DNS reply from the `task` which stores the corresponding sender.
51    receiver: fmpsc::UnboundedReceiver<(DnsUpstreamQueryRefWrapper, Vec<u8>)>,
52}
53
54struct LocalDnsServerList {
55    // A local copy of the DNS server list
56    dns_server_list: Vec<fidl_fuchsia_net_name::DnsServer_>,
57    // This field is set but never accessed directly, so we need to silence this warning
58    // so that we can compile
59    #[allow(unused)]
60    // The task that awaits on the DNS server list change.
61    task: fasync::Task<Result<(), anyhow::Error>>,
62    // Receive the DNS server list from the `task` which stores the corresponding sender.
63    receiver: fmpsc::UnboundedReceiver<Vec<fidl_fuchsia_net_name::DnsServer_>>,
64}
65
66pub(crate) struct Resolver {
67    // The Map that uses the `DnsUpstreamQueryRefWrapper` as key to quickly locate the Transaction
68    transactions_map: Arc<Mutex<HashMap<DnsUpstreamQueryRefWrapper, Transaction>>>,
69    // Maintains a local DNS record for immediately sending out the DNS upstream query.
70    local_dns_record: RefCell<Option<LocalDnsServerList>>,
71    waker: Cell<Option<Waker>>,
72}
73
74impl Resolver {
75    pub fn new() -> Resolver {
76        if let Ok(proxy) =
77            fuchsia_component::client::connect_to_protocol::<DnsServerWatcherMarker>()
78        {
79            let (mut sender, receiver) = fmpsc::unbounded();
80
81            // Create a future that await for the latest DNS server list, and forward it to the
82            // corresponding receiver. The future is executed in the task in `LocalDnsServerList`.
83            let dns_list_watcher_fut = async move {
84                loop {
85                    let vec = proxy.watch_servers().await?;
86                    info!(tag = "resolver"; "getting latest DNS server list: {:?}", vec);
87                    if let Err(e) = sender.send(vec).await {
88                        warn!(
89                            tag = "resolver";
90                            "error when sending out latest dns list to process_poll, {:?}", e
91                        );
92                    }
93                }
94            };
95            Resolver {
96                transactions_map: Default::default(),
97                waker: Default::default(),
98                local_dns_record: RefCell::new(Some(LocalDnsServerList {
99                    dns_server_list: Vec::new(),
100                    task: fuchsia_async::Task::spawn(dns_list_watcher_fut),
101                    receiver,
102                })),
103            }
104        } else {
105            warn!(
106                tag = "resolver";
107                "failed to connect to `DnsServerWatcherMarker`, \
108                         DNS upstream query will not be supported"
109            );
110            Resolver {
111                transactions_map: Arc::new(Mutex::new(HashMap::new())),
112                waker: Cell::new(None),
113                local_dns_record: RefCell::new(None),
114            }
115        }
116    }
117
118    pub fn process_poll_resolver(&self, instance: &ot::Instance, cx: &mut Context<'_>) {
119        // Update the waker so that we can later signal when we need to be polled again
120        self.waker.replace(Some(cx.waker().clone()));
121
122        // Poll the DNS server list task
123        if let Some(local_dns_record) = self.local_dns_record.borrow_mut().as_mut() {
124            while let Poll::Ready(Some(dns_server_list)) =
125                local_dns_record.receiver.poll_next_unpin(cx)
126            {
127                // DNS server watcher proxy returns the new DNS server list when something changed
128                // in netstack. The outdated list should be replaced by the new one.
129                local_dns_record.dns_server_list = dns_server_list;
130            }
131        }
132
133        let mut remove_key_vec = Vec::new();
134        // Poll the socket in each transaction. If a response is ready, forward it to the OpenThread
135        // and remove the corresponding transaction.
136        for (_, transaction) in self.transactions_map.lock().iter_mut() {
137            while let Poll::Ready(Some((context, message_vec))) =
138                transaction.receiver.poll_next_unpin(cx)
139            {
140                if let Ok(mut message) =
141                    ot::Message::udp_new(instance, None).context("cannot create UDP message")
142                {
143                    match message.append(&message_vec) {
144                        Ok(_) => {
145                            instance.plat_dns_upstream_query_done(context.0, message);
146                        }
147                        Err(e) => {
148                            warn!(tag = "resolver"; "failed to append to `ot::Message`: {}", e);
149                        }
150                    }
151                } else {
152                    warn!(
153                        tag = "resolver";
154                        "failed to create `ot::Message`, drop the upstream DNS response"
155                    );
156                }
157                remove_key_vec.push(context);
158            }
159        }
160
161        // cancel the transaction
162        for key in remove_key_vec {
163            self.transactions_map.lock().remove(&key);
164        }
165    }
166
167    fn on_start_dns_upstream_query<'a>(
168        &self,
169        _instance: &ot::Instance,
170        thread_context: &'static ot::PlatDnsUpstreamQuery,
171        dns_query: &ot::Message<'_>,
172    ) {
173        let sockaddr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 53);
174        let socket = match fuchsia_async::net::UdpSocket::bind(&sockaddr) {
175            Ok(socket) => socket,
176            Err(_) => {
177                warn!(
178                    tag = "resolver";
179                    "on_start_dns_upstream_query() failed to create UDP socket, ignoring the query"
180                );
181                return;
182            }
183        };
184
185        let query_bytes = dns_query.to_vec();
186
187        // Get the DNS server list, and send out the query to all the available DNS servers.
188        if let Some(local_dns_record) = self.local_dns_record.borrow().as_ref() {
189            for dns_server in &local_dns_record.dns_server_list {
190                if let Some(address) = dns_server.address {
191                    match address {
192                        fidl_fuchsia_net::SocketAddress::Ipv4(ipv4_sock_addr) => {
193                            let sock_addr = SocketAddr::new(
194                                std::net::IpAddr::V4(std::net::Ipv4Addr::from(
195                                    ipv4_sock_addr.address.addr,
196                                )),
197                                ipv4_sock_addr.port,
198                            );
199                            info!(
200                                tag = "resolver";
201                                "sending DNS query to IPv4 server {}", sock_addr
202                            );
203                            if let Some(Err(e)) =
204                                socket.send_to(&query_bytes, sock_addr).now_or_never()
205                            {
206                                warn!(
207                                    tag = "resolver";
208                                    "Failed to send DNS query to IPv4 server {}: {}", sock_addr, e
209                                );
210                            }
211                        }
212                        fidl_fuchsia_net::SocketAddress::Ipv6(ipv6_sock_addr) => {
213                            let sock_addr = SocketAddr::new(
214                                std::net::IpAddr::V6(std::net::Ipv6Addr::from(
215                                    ipv6_sock_addr.address.addr,
216                                )),
217                                ipv6_sock_addr.port,
218                            );
219
220                            info!(
221                                tag = "resolver";
222                                "sending DNS query to IPv6 server {}", sock_addr
223                            );
224                            if let Some(Err(e)) =
225                                socket.send_to(&query_bytes, sock_addr).now_or_never()
226                            {
227                                warn!(
228                                    tag = "resolver";
229                                    "Failed to send DNS query to IPv6 server {}: {}", sock_addr, e
230                                );
231                            }
232                        }
233                    }
234                }
235            }
236
237            let (mut sender, receiver) = fmpsc::unbounded();
238
239            // Create a poll_fn for the socket that can be await on
240            let receive_from_fut = futures::future::poll_fn(move |cx| {
241                let mut buffer = [0u8; MAX_DNS_RESPONSE_SIZE];
242                match socket.async_recv_from(&mut buffer, cx) {
243                    Poll::Ready(Ok((len, sockaddr))) => {
244                        let message = buffer[..len].to_vec();
245                        Poll::Ready(Ok((message, sockaddr)))
246                    }
247                    Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
248                    Poll::Pending => Poll::Pending,
249                }
250            });
251
252            // Create a future that forward the DNS reply from socket to process_poll
253            let fut = async move {
254                let (message_vec, sockaddr) =
255                    receive_from_fut.await.context("error receiving from dns upstream socket")?;
256
257                info!(
258                    tag = "resolver";
259                    "Incoming {} bytes DNS response from {:?}",
260                    message_vec.len(),
261                    sockaddr
262                );
263                if let Err(e) =
264                    sender.send((DnsUpstreamQueryRefWrapper(thread_context), message_vec)).await
265                {
266                    warn!(
267                        tag = "resolver";
268                        "error when sending out dns upstream reply to process_poll, {:?}", e
269                    );
270                }
271                Ok(())
272            };
273
274            // Socket and the sender is owned by the task now
275            let transaction = Transaction { task: fuchsia_async::Task::spawn(fut), receiver };
276
277            self.transactions_map
278                .lock()
279                .insert(DnsUpstreamQueryRefWrapper(thread_context), transaction);
280        } else {
281            warn!(
282                tag = "resolver";
283                "on_start_dns_upstream_query() failed to get local_dns_record, ignoring the query"
284            );
285        }
286
287        // Trigger the waker so that our poll method gets called by the executor
288        self.waker.replace(None).and_then(|waker| {
289            waker.wake();
290            Some(())
291        });
292    }
293
294    // Cancel the pending query
295    fn on_cancel_dns_upstream_query(
296        &self,
297        _instance: &ot::Instance,
298        thread_context: &'static ot::PlatDnsUpstreamQuery,
299    ) {
300        if let None =
301            self.transactions_map.lock().remove(&DnsUpstreamQueryRefWrapper(thread_context))
302        {
303            warn!(
304                tag = "resolver";
305                "on_cancel_dns_upstream_query() target transaction not presented for remove, ignoring"
306            );
307        }
308    }
309}
310
311#[no_mangle]
312unsafe extern "C" fn otPlatDnsStartUpstreamQuery(
313    a_instance: *mut otInstance,
314    a_txn: *mut otPlatDnsUpstreamQuery,
315    a_query: *const otMessage,
316) {
317    Resolver::on_start_dns_upstream_query(
318        &PlatformBacking::as_ref().resolver,
319        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
320        //         which is guaranteed by the caller.
321        ot::Instance::ref_from_ot_ptr(a_instance).unwrap(),
322        // SAFETY: no dereference is happening in fuchsia platform side
323        ot::PlatDnsUpstreamQuery::mut_from_ot_mut_ptr(a_txn).unwrap(),
324        // SAFETY: caller ensures the dns query is valid
325        ot::Message::ref_from_ot_ptr(a_query as *mut otMessage).unwrap(),
326    )
327}
328
329#[no_mangle]
330unsafe extern "C" fn otPlatDnsCancelUpstreamQuery(
331    a_instance: *mut otInstance,
332    a_txn: *mut otPlatDnsUpstreamQuery,
333) {
334    Resolver::on_cancel_dns_upstream_query(
335        &PlatformBacking::as_ref().resolver,
336        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
337        //         which is guaranteed by the caller.
338        ot::Instance::ref_from_ot_ptr(a_instance).unwrap(),
339        // SAFETY: no dereference is happening in fuchsia platform side
340        ot::PlatDnsUpstreamQuery::mut_from_ot_mut_ptr(a_txn).unwrap(),
341    )
342}