openthread_fuchsia/backing/
trel.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::*;
6use crate::to_escaped_string::*;
7use anyhow::Context as _;
8use fidl::endpoints::create_endpoints;
9use fidl_fuchsia_net_mdns::*;
10use fuchsia_async::Task;
11use futures::stream::FusedStream;
12use openthread_sys::*;
13use ot::{PlatTrel as _, TrelCounters};
14use std::collections::HashMap;
15use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
16use std::task::{Context, Poll};
17
18pub(crate) struct TrelInstance {
19    socket: fasync::net::UdpSocket,
20    publication_responder: Option<Task<Result<(), anyhow::Error>>>,
21    instance_name: String,
22    peer_instance_sockaddr_map: HashMap<String, ot::SockAddr>,
23
24    #[allow(dead_code)] // This field must be kept around for <https://fxbug.dev/42182233>
25    subscriber: ServiceSubscriber2Proxy,
26
27    subscriber_request_stream: ServiceSubscriptionListenerRequestStream,
28
29    counters: RefCell<TrelCounters>,
30}
31
32// Converts an optional vector of strings to a single DNS-compatible string.
33fn flatten_txt(txt: Option<Vec<Vec<u8>>>) -> Vec<u8> {
34    let mut ret = vec![];
35
36    for mut txt in txt.iter().flat_map(|x| x.iter()).map(Vec::as_slice) {
37        if txt.len() > u8::MAX as usize {
38            // Limit the size of the records to 255 characters.
39            txt = &txt[0..(u8::MAX as usize) + 1];
40        }
41        ret.push(u8::try_from(txt.len()).unwrap());
42        ret.extend_from_slice(txt);
43    }
44
45    ret
46}
47
48/// Converts an iterator over [`fidl_fuchsia_net::SocketAddress`]es to a vector of
49/// [`ot::Ip6Address`]es and a port.
50fn process_addresses_from_socket_addresses<
51    T: IntoIterator<Item = fidl_fuchsia_net::SocketAddress>,
52>(
53    addresses: T,
54) -> (Vec<ot::Ip6Address>, Option<u16>) {
55    let mut ret_port: Option<u16> = None;
56    let mut addresses =
57        addresses
58            .into_iter()
59            .flat_map(|x| {
60                if let fidl_fuchsia_net::SocketAddress::Ipv6(
61                    fidl_fuchsia_net::Ipv6SocketAddress { address, port, .. },
62                ) = x
63                {
64                    let addr = ot::Ip6Address::from(address.addr);
65                    if ret_port.is_none() {
66                        ret_port = Some(port);
67                    } else if ret_port != Some(port) {
68                        warn!(
69                            tag = "trel";
70                            "mDNS service has multiple ports for the same service, {:?} != {:?}",
71                            ret_port.unwrap(),
72                            port
73                        );
74                    }
75                    if !ipv6addr_is_unicast_link_local(&addr) {
76                        return Some(addr);
77                    }
78                }
79                None
80            })
81            .collect::<Vec<_>>();
82    addresses.sort();
83    (addresses, ret_port)
84}
85
86/// Returns `true` if the address is a unicast address with link-local scope.
87///
88/// The official equivalent of this method is [`std::net::Ipv6Addr::is_unicast_link_local()`],
89/// however that method is [still experimental](https://github.com/rust-lang/rust/issues/27709).
90fn ipv6addr_is_unicast_link_local(addr: &std::net::Ipv6Addr) -> bool {
91    (addr.segments()[0] & 0xffc0) == 0xfe80
92}
93
94// Splits the TXT record into individual values.
95fn split_txt(txt: &[u8]) -> Vec<Vec<u8>> {
96    info!(tag = "trel"; "trel:split_txt: Splitting TXT record: {:?}", hex::encode(txt));
97    let txt =
98        ot::DnsTxtEntryIterator::try_new(txt).expect("can't parse TXT records from OpenThread");
99    txt.map(|x| x.expect("can't parse TXT records from OpenThread").to_vec()).collect::<Vec<_>>()
100}
101
102impl TrelInstance {
103    fn new(instance_name: String) -> Result<TrelInstance, anyhow::Error> {
104        let (client, server) = create_endpoints::<ServiceSubscriptionListenerMarker>();
105
106        let subscriber =
107            fuchsia_component::client::connect_to_protocol::<ServiceSubscriber2Marker>().unwrap();
108
109        subscriber
110            .subscribe_to_service(
111                ot::TREL_DNSSD_SERVICE_NAME_WITH_DOT,
112                &ServiceSubscriptionOptions { exclude_local: Some(true), ..Default::default() },
113                client,
114            )
115            .context("Unable to subscribe to TREL services")?;
116
117        Ok(TrelInstance {
118            socket: fasync::net::UdpSocket::bind(&SocketAddr::V6(SocketAddrV6::new(
119                Ipv6Addr::UNSPECIFIED,
120                0,
121                0,
122                0,
123            )))
124            .context("Unable to open TREL UDP socket")?,
125            publication_responder: None,
126            instance_name,
127            peer_instance_sockaddr_map: HashMap::default(),
128            subscriber,
129            subscriber_request_stream: server.into_stream(),
130            counters: RefCell::new(TrelCounters::default()),
131        })
132    }
133
134    fn port(&self) -> u16 {
135        self.socket.local_addr().unwrap().port()
136    }
137
138    fn register_service(&mut self, port: u16, txt: &[u8]) {
139        let txt = split_txt(txt);
140
141        let (client, server) = create_endpoints::<ServiceInstancePublicationResponder_Marker>();
142
143        let publisher =
144            fuchsia_component::client::connect_to_protocol::<ServiceInstancePublisherMarker>()
145                .unwrap();
146
147        let publish_init_future = publisher
148            .publish_service_instance(
149                ot::TREL_DNSSD_SERVICE_NAME_WITH_DOT,
150                self.instance_name.as_str(),
151                &ServiceInstancePublicationOptions::default(),
152                client,
153            )
154            .map(|x| -> Result<(), anyhow::Error> {
155                match x {
156                    Ok(Ok(x)) => Ok(x),
157                    Ok(Err(err)) => Err(anyhow::format_err!("{:?}", err)),
158                    Err(zx_err) => Err(zx_err.into()),
159                }
160            });
161
162        let publish_responder_future = server.into_stream().map_err(Into::into).try_for_each(
163            move |ServiceInstancePublicationResponder_Request::OnPublication {
164                      responder, ..
165                  }| {
166                let txt = txt.clone();
167                let _publisher = publisher.clone();
168                async move {
169                    responder
170                        .send(Ok(&ServiceInstancePublication {
171                            port: Some(port),
172                            text: Some(txt),
173                            ..Default::default()
174                        }))
175                        .map_err(Into::into)
176                }
177            },
178        );
179
180        let future =
181            futures::future::try_join(publish_init_future, publish_responder_future).map_ok(|_| ());
182
183        self.publication_responder = Some(fuchsia_async::Task::spawn(future));
184    }
185
186    pub fn handle_service_subscriber_request(
187        &mut self,
188        ot_instance: &ot::Instance,
189        service_subscriber_request: ServiceSubscriptionListenerRequest,
190    ) -> Result<(), anyhow::Error> {
191        match service_subscriber_request {
192            // A DNS-SD IPv6 service instance has been discovered.
193            ServiceSubscriptionListenerRequest::OnInstanceDiscovered {
194                instance:
195                    ServiceInstance {
196                        instance: Some(instance_name),
197                        addresses: Some(addresses),
198                        text_strings,
199                        ..
200                    },
201                responder,
202            } => {
203                let txt = flatten_txt(text_strings);
204
205                let (addresses, port) = process_addresses_from_socket_addresses(addresses);
206
207                info!(
208                    tag = "trel";
209                    "ServiceSubscriptionListenerRequest::OnInstanceDiscovered: [PII]({instance_name:?}) port:{port:?} addresses:{addresses:?}"
210                );
211
212                if let Some(address) = addresses.first() {
213                    let sockaddr = ot::SockAddr::new(*address, port.unwrap());
214
215                    self.peer_instance_sockaddr_map.insert(instance_name, sockaddr);
216
217                    let info = ot::PlatTrelPeerInfo::new(false, &txt, sockaddr);
218                    info!(tag = "trel"; "otPlatTrelHandleDiscoveredPeerInfo: Adding {:?}", info);
219                    ot_instance.plat_trel_handle_discovered_peer_info(&info);
220                };
221
222                responder.send().context("Unable to respond to OnInstanceDiscovered")?;
223            }
224
225            // A DNS-SD IPv6 service instance has changed.
226            ServiceSubscriptionListenerRequest::OnInstanceChanged {
227                instance:
228                    ServiceInstance {
229                        instance: Some(instance_name),
230                        addresses: Some(addresses),
231                        text_strings,
232                        ..
233                    },
234                responder,
235            } => {
236                let txt = flatten_txt(text_strings);
237                let (addresses, port) = process_addresses_from_socket_addresses(addresses);
238
239                info!(
240                    tag = "trel";
241                    "ServiceSubscriptionListenerRequest::OnInstanceChanged: [PII]({instance_name:?}) port:{port:?} addresses:{addresses:?}"
242                );
243
244                if let Some(address) = addresses.first() {
245                    let sockaddr = ot::SockAddr::new(*address, port.unwrap());
246
247                    if let Some(old_sockaddr) =
248                        self.peer_instance_sockaddr_map.insert(instance_name, sockaddr)
249                    {
250                        if old_sockaddr != sockaddr {
251                            // Remove old sockaddr with the same instance name
252                            let info_old = ot::PlatTrelPeerInfo::new(true, &[], old_sockaddr);
253                            info!(
254                                tag = "trel";
255                                "otPlatTrelHandleDiscoveredPeerInfo: Removing {:?}", info_old
256                            );
257                            ot_instance.plat_trel_handle_discovered_peer_info(&info_old);
258                        }
259
260                        let info = ot::PlatTrelPeerInfo::new(false, &txt, sockaddr);
261                        info!(
262                            tag = "trel";
263                            "otPlatTrelHandleDiscoveredPeerInfo: Updating {:?}", info
264                        );
265                        ot_instance.plat_trel_handle_discovered_peer_info(&info);
266                    }
267                };
268
269                responder.send().context("Unable to respond to OnInstanceChanged")?;
270            }
271
272            // A DNS-SD IPv6 service instance has been lost.
273            ServiceSubscriptionListenerRequest::OnInstanceLost { instance, responder, .. } => {
274                info!(
275                    tag = "trel";
276                    "ServiceSubscriptionListenerRequest::OnInstanceLost [PII]({instance:?})"
277                );
278                if let Some(sockaddr) = self.peer_instance_sockaddr_map.remove(&instance) {
279                    let info = ot::PlatTrelPeerInfo::new(true, &[], sockaddr);
280                    info!(tag = "trel"; "otPlatTrelHandleDiscoveredPeerInfo: Removing {:?}", info);
281                    ot_instance.plat_trel_handle_discovered_peer_info(&info);
282                }
283
284                responder.send().context("Unable to respond to OnInstanceLost")?;
285            }
286
287            ServiceSubscriptionListenerRequest::OnInstanceChanged { instance, responder } => {
288                warn!(
289                    tag = "trel";
290                    "ServiceSubscriptionListenerRequest::OnInstanceChanged: [PII]({instance:?})"
291                );
292                // Skip changes without an IPv6 address.
293                responder.send().context("Unable to respond to OnInstanceChanged")?;
294            }
295
296            ServiceSubscriptionListenerRequest::OnInstanceDiscovered {
297                instance,
298                responder,
299                ..
300            } => {
301                warn!(
302                    tag = "trel";
303                    "ServiceSubscriptionListenerRequest::OnInstanceDiscovered: [PII]({instance:?})"
304                );
305                // Skip discoveries without an IPv6 address.
306                responder.send().context("Unable to respond to OnInstanceDiscovered")?;
307            }
308
309            ServiceSubscriptionListenerRequest::OnQuery { resource_type, responder, .. } => {
310                info!(
311                    tag = "trel";
312                    "ServiceSubscriptionListenerRequest::OnQuery: {resource_type:?}"
313                );
314
315                // We don't care about queries.
316                responder.send().context("Unable to respond to OnQuery")?;
317            }
318        }
319        Ok(())
320    }
321
322    pub fn get_trel_counters(&self) -> *const otPlatTrelCounters {
323        self.counters.borrow().as_ot_ptr()
324    }
325
326    pub fn reset_trel_counters(&self) {
327        self.counters.borrow_mut().reset_counters()
328    }
329
330    /// Async entrypoint for I/O.
331    ///
332    /// This is explicitly not `mut` so that `on_trel_send` can be called reentrantly from here.
333    pub fn poll_io(&self, instance: &ot::Instance, cx: &mut Context<'_>) {
334        let mut buffer = [0u8; crate::UDP_PACKET_MAX_LENGTH];
335        loop {
336            match self.socket.async_recv_from(&mut buffer, cx) {
337                Poll::Ready(Ok((len, sockaddr))) => {
338                    let sockaddr: ot::SockAddr = sockaddr.as_socket_ipv6().unwrap().into();
339                    debug!(tag = "trel"; "Incoming {} byte TREL packet from {:?}", len, sockaddr);
340                    {
341                        let mut counters = self.counters.borrow_mut();
342                        counters.update_rx_bytes(len.try_into().unwrap());
343                        counters.update_rx_packets(1);
344                    }
345                    instance.plat_trel_handle_received(&buffer[..len], &sockaddr)
346                }
347                Poll::Ready(Err(err)) => {
348                    warn!(tag = "trel"; "Error receiving packet: {:?}", err);
349                    break;
350                }
351                _ => {
352                    break;
353                }
354            }
355        }
356    }
357
358    /// Async entrypoint for non-I/O
359    pub fn poll(&mut self, instance: &ot::Instance, cx: &mut Context<'_>) {
360        if let Some(task) = &mut self.publication_responder {
361            if let Poll::Ready(x) = task.poll_unpin(cx) {
362                warn!(
363                    tag = "trel";
364                    "TrelInstance: publication_responder finished unexpectedly: {:?}", x
365                );
366                self.publication_responder = None;
367            }
368        }
369
370        if !self.subscriber_request_stream.is_terminated() {
371            while let Poll::Ready(Some(event)) = self.subscriber_request_stream.poll_next_unpin(cx)
372            {
373                match event {
374                    Ok(event) => {
375                        if let Err(err) = self.handle_service_subscriber_request(instance, event) {
376                            error!(
377                                tag = "trel";
378                                "Error handling service subscriber request: {err:?}"
379                            );
380                        }
381                    }
382                    Err(err) => {
383                        error!(tag = "trel"; "subscriber_request_stream FIDL error: {:?}", err);
384                    }
385                }
386            }
387        }
388    }
389}
390
391impl PlatformBacking {
392    fn on_trel_enable(&self, instance: &ot::Instance) -> Result<u16, anyhow::Error> {
393        let mut trel = self.trel.borrow_mut();
394        if let Some(trel) = trel.as_ref() {
395            Ok(trel.port())
396        } else {
397            let instance_name = hex::encode(instance.get_extended_address().as_slice());
398            let trel_instance = TrelInstance::new(instance_name)?;
399            let port = trel_instance.port();
400            trel.replace(trel_instance);
401            Ok(port)
402        }
403    }
404
405    fn on_trel_disable(&self, _instance: &ot::Instance) {
406        self.trel.replace(None);
407    }
408
409    fn on_trel_register_service(&self, _instance: &ot::Instance, port: u16, txt: &[u8]) {
410        let mut trel = self.trel.borrow_mut();
411        if let Some(trel) = trel.as_mut() {
412            info!(
413                tag = "trel";
414                "otPlatTrelRegisterService: port:{} txt:{:?}",
415                port,
416                txt.to_escaped_string()
417            );
418            trel.register_service(port, txt);
419        } else {
420            debug!(tag = "trel"; "otPlatTrelRegisterService: TREL is disabled, cannot register.");
421        }
422    }
423
424    fn on_trel_send(&self, _instance: &ot::Instance, payload: &[u8], sockaddr: &ot::SockAddr) {
425        let trel = self.trel.borrow();
426        if let Some(trel) = trel.as_ref() {
427            let mut counters = trel.counters.borrow_mut();
428            debug!(tag = "trel"; "otPlatTrelSend: {:?} -> {}", sockaddr, hex::encode(payload));
429            match trel.socket.send_to(payload, (*sockaddr).into()).now_or_never() {
430                Some(Ok(_)) => {
431                    counters.update_tx_bytes(payload.len().try_into().unwrap());
432                    counters.update_tx_packets(1);
433                }
434                Some(Err(err)) => {
435                    counters.update_tx_failure(1);
436                    warn!(tag = "trel"; "otPlatTrelSend: send_to failed: {:?}", err);
437                }
438                None => {
439                    warn!(tag = "trel"; "otPlatTrelSend: send_to didn't finish immediately");
440                }
441            }
442        } else {
443            debug!(tag = "trel"; "otPlatTrelSend: TREL is disabled, cannot send.");
444        }
445    }
446}
447
448#[no_mangle]
449unsafe extern "C" fn otPlatTrelEnable(instance: *mut otInstance, port_ptr: *mut u16) {
450    match PlatformBacking::on_trel_enable(
451        // SAFETY: Must only be called from OpenThread thread,
452        PlatformBacking::as_ref(),
453        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
454        //         which is guaranteed by the caller.
455        ot::Instance::ref_from_ot_ptr(instance).unwrap(),
456    ) {
457        Ok(port) => {
458            info!(tag = "trel"; "otPlatTrelEnable: Ready on port {}", port);
459            *port_ptr = port;
460        }
461        Err(err) => {
462            warn!(tag = "trel"; "otPlatTrelEnable: Unable to start TREL: {:?}", err);
463        }
464    }
465}
466
467#[no_mangle]
468unsafe extern "C" fn otPlatTrelDisable(instance: *mut otInstance) {
469    PlatformBacking::on_trel_disable(
470        // SAFETY: Must only be called from OpenThread thread,
471        PlatformBacking::as_ref(),
472        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
473        //         which is guaranteed by the caller.
474        ot::Instance::ref_from_ot_ptr(instance).unwrap(),
475    );
476    info!(tag = "trel"; "otPlatTrelDisable: Closed.");
477}
478
479#[no_mangle]
480unsafe extern "C" fn otPlatTrelRegisterService(
481    instance: *mut otInstance,
482    port: u16,
483    txt_data: *const u8,
484    txt_len: u8,
485) {
486    PlatformBacking::on_trel_register_service(
487        // SAFETY: Must only be called from OpenThread thread,
488        PlatformBacking::as_ref(),
489        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
490        //         which is guaranteed by the caller.
491        ot::Instance::ref_from_ot_ptr(instance).unwrap(),
492        port,
493        // SAFETY: Caller guarantees either txt_data is valid or txt_len is zero.
494        std::slice::from_raw_parts(txt_data, txt_len.into()),
495    );
496}
497
498#[no_mangle]
499unsafe extern "C" fn otPlatTrelSend(
500    instance: *mut otInstance,
501    payload_data: *const u8,
502    payload_len: u16,
503    dest: *const otSockAddr,
504) {
505    PlatformBacking::on_trel_send(
506        // SAFETY: Must only be called from OpenThread thread,
507        PlatformBacking::as_ref(),
508        // SAFETY: `instance` must be a pointer to a valid `otInstance`,
509        //         which is guaranteed by the caller.
510        ot::Instance::ref_from_ot_ptr(instance).unwrap(),
511        // SAFETY: Caller guarantees either payload_data is valid or payload_len is zero.
512        std::slice::from_raw_parts(payload_data, payload_len.into()),
513        // SAFETY: Caller guarantees dest points to a valid otSockAddr.
514        ot::SockAddr::ref_from_ot_ptr(dest).unwrap(),
515    );
516}
517
518#[no_mangle]
519unsafe extern "C" fn otPlatTrelGetCounters(
520    _instance: *mut otInstance,
521) -> *const otPlatTrelCounters {
522    if let Some(trel) = PlatformBacking::as_ref().trel.borrow().as_ref() {
523        trel.get_trel_counters()
524    } else {
525        std::ptr::null()
526    }
527}
528
529#[no_mangle]
530unsafe extern "C" fn otPlatTrelNotifyPeerSocketAddressDifference(
531    _instance: *mut otsys::otInstance,
532    peer_sock_addr: &ot::SockAddr,
533    rx_sock_addr: &ot::SockAddr,
534) {
535    info!(tag = "trel"; "otPlatTrelNotifyPeerSocketAddressDifference: Not Implemented. peer_sock_addr {}, rx_sock_addr {}", peer_sock_addr, rx_sock_addr);
536}
537
538#[no_mangle]
539unsafe extern "C" fn otPlatTrelResetCounters(_instance: *mut otInstance) {
540    if let Some(trel) = PlatformBacking::as_ref().trel.borrow().as_ref() {
541        trel.reset_trel_counters()
542    }
543}
544
545#[cfg(test)]
546mod test {
547    use super::*;
548
549    #[test]
550    fn test_split_txt() {
551        assert_eq!(
552            split_txt(b"\x13xa=a7bfc4981f4e4d22\x13xp=029c6f4dbae059cb"),
553            vec![b"xa=a7bfc4981f4e4d22".to_vec(), b"xp=029c6f4dbae059cb".to_vec()]
554        );
555    }
556
557    #[test]
558    fn test_flatten_txt() {
559        assert_eq!(flatten_txt(None), vec![]);
560        assert_eq!(flatten_txt(Some(vec![])), vec![]);
561        assert_eq!(
562            flatten_txt(Some(vec![b"xa=a7bfc4981f4e4d22".to_vec()])),
563            b"\x13xa=a7bfc4981f4e4d22".to_vec()
564        );
565        assert_eq!(
566            flatten_txt(Some(vec![
567                b"xa=a7bfc4981f4e4d22".to_vec(),
568                b"xp=029c6f4dbae059cb".to_vec()
569            ])),
570            b"\x13xa=a7bfc4981f4e4d22\x13xp=029c6f4dbae059cb".to_vec()
571        );
572    }
573}