netstack3_base/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Trait definition for matchers.
6
7use alloc::format;
8use alloc::string::String;
9use core::convert::Infallible as Never;
10use core::fmt::Debug;
11use core::num::NonZeroU64;
12use core::ops::RangeInclusive;
13
14use bitflags::bitflags;
15use derivative::Derivative;
16use net_types::ip::{IpAddr, IpAddress, Ipv4Addr, Ipv6Addr, Subnet};
17
18use crate::{InspectableValue, Inspector, Mark, MarkDomain, MarkStorage, Marks};
19
20/// Trait defining required types for matchers provided by bindings.
21///
22/// Allows rules that match on device class to be installed, storing the
23/// [`MatcherBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
24/// Core to have Bindings provide the type since it is platform-specific.
25pub trait MatcherBindingsTypes {
26    /// The device class type for devices installed in the netstack.
27    type DeviceClass: Clone + Debug;
28    /// The type used to represent a custom filter matcher.
29    ///
30    /// `Clone` is required because filter validator clones rules when
31    /// installing them.
32    type BindingsPacketMatcher: Clone + Debug + InspectableValue;
33}
34
35/// Common pattern to define a matcher for a metadata input `T`.
36///
37/// Used in matching engines like filtering and routing rules.
38pub trait Matcher<T> {
39    /// Returns whether the provided value matches.
40    fn matches(&self, actual: &T) -> bool;
41
42    /// Returns whether the provided value is set and matches.
43    fn required_matches(&self, actual: Option<&T>) -> bool {
44        actual.map_or(false, |actual| self.matches(actual))
45    }
46}
47
48/// Implement `Matcher` for optional matchers, so that if a matcher is left
49/// unspecified, it matches all inputs by default.
50impl<T, O> Matcher<T> for Option<O>
51where
52    O: Matcher<T>,
53{
54    fn matches(&self, actual: &T) -> bool {
55        self.as_ref().map_or(true, |expected| expected.matches(actual))
56    }
57
58    fn required_matches(&self, actual: Option<&T>) -> bool {
59        self.as_ref().map_or(true, |expected| expected.required_matches(actual))
60    }
61}
62
63/// Matcher that matches IP addresses in a subnet.
64#[derive(Debug, Copy, Clone, PartialEq, Eq)]
65pub struct SubnetMatcher<A: IpAddress>(pub Subnet<A>);
66
67impl<A: IpAddress> Matcher<A> for SubnetMatcher<A> {
68    fn matches(&self, actual: &A) -> bool {
69        let Self(matcher) = self;
70        matcher.contains(actual)
71    }
72}
73
74/// A matcher for network interfaces.
75#[derive(Clone, Derivative, PartialEq, Eq)]
76#[derivative(Debug)]
77pub enum InterfaceMatcher<DeviceClass> {
78    /// The ID of the interface as assigned by the netstack.
79    Id(NonZeroU64),
80    /// Match based on name.
81    Name(String),
82    /// The device class of the interface.
83    DeviceClass(DeviceClass),
84}
85
86impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
87    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
88        match self {
89            InterfaceMatcher::Id(id) => inspector.record_string(name, format!("Id({})", id.get())),
90            InterfaceMatcher::Name(iface_name) => {
91                inspector.record_string(name, format!("Name({iface_name})"))
92            }
93            InterfaceMatcher::DeviceClass(class) => {
94                inspector.record_debug(name, format!("Class({class:?})"))
95            }
96        };
97    }
98}
99
100/// Allows code to match on properties of an interface (ID, name, and device
101/// class) without Netstack3 Core (or Bindings, in the case of the device class)
102/// having to specifically expose that state.
103pub trait InterfaceProperties<DeviceClass> {
104    /// Returns whether the provided ID matches the interface.
105    fn id_matches(&self, id: &NonZeroU64) -> bool;
106
107    /// Returns whether the provided name matches the interface.
108    fn name_matches(&self, name: &str) -> bool;
109
110    /// Returns whether the provided device class matches the interface.
111    fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
112}
113
114impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
115    for InterfaceMatcher<DeviceClass>
116{
117    fn matches(&self, actual: &I) -> bool {
118        match self {
119            InterfaceMatcher::Id(id) => actual.id_matches(id),
120            InterfaceMatcher::Name(name) => actual.name_matches(name),
121            InterfaceMatcher::DeviceClass(device_class) => {
122                actual.device_class_matches(device_class)
123            }
124        }
125    }
126}
127
128/// Matcher for the bound device of locally generated traffic.
129#[derive(Debug, Clone, PartialEq, Eq)]
130pub enum BoundInterfaceMatcher<DeviceClass> {
131    /// The packet is bound to a device which is matched by the matcher.
132    Bound(InterfaceMatcher<DeviceClass>),
133    /// There is no bound device.
134    Unbound,
135}
136
137impl<'a, DeviceClass, D: InterfaceProperties<DeviceClass>> Matcher<Option<&'a D>>
138    for BoundInterfaceMatcher<DeviceClass>
139{
140    fn matches(&self, actual: &Option<&'a D>) -> bool {
141        match self {
142            BoundInterfaceMatcher::Bound(matcher) => matcher.required_matches(actual.as_deref()),
143            BoundInterfaceMatcher::Unbound => actual.is_none(),
144        }
145    }
146}
147
148impl<DeviceClass: Debug> InspectableValue for BoundInterfaceMatcher<DeviceClass> {
149    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
150        match self {
151            BoundInterfaceMatcher::Unbound => inspector.record_str(name, "Unbound"),
152            BoundInterfaceMatcher::Bound(interface) => {
153                inspector.record_inspectable_value(name, interface)
154            }
155        }
156    }
157}
158
159/// A matcher to the socket mark.
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub enum MarkMatcher {
162    /// Matches a packet if it is unmarked.
163    Unmarked,
164    /// The packet carries a mark that is in the range after masking.
165    Marked {
166        /// The mask to apply.
167        mask: u32,
168        /// Start of the range, inclusive.
169        start: u32,
170        /// End of the range, inclusive.
171        end: u32,
172        /// Inverts the meaning of the match.
173        invert: bool,
174    },
175}
176
177impl Matcher<Mark> for MarkMatcher {
178    fn matches(&self, Mark(actual): &Mark) -> bool {
179        match self {
180            MarkMatcher::Unmarked => actual.is_none(),
181            MarkMatcher::Marked { mask, start, end, invert } => {
182                let val = actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)));
183
184                if *invert { !val } else { val }
185            }
186        }
187    }
188}
189
190/// A matcher for the mark in a specific domain..
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub struct MarkInDomainMatcher {
193    /// The domain of the mark to match.
194    pub domain: MarkDomain,
195    /// The matcher for the mark.
196    pub matcher: MarkMatcher,
197}
198
199/// The 2 mark matchers a rule can specify. All non-none markers must match.
200#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
201pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
202
203impl MarkMatchers {
204    /// Creates [`MarkMatcher`]s from an iterator of `(MarkDomain, MarkMatcher)`.
205    ///
206    /// An unspecified domain will not have a matcher.
207    ///
208    /// # Panics
209    ///
210    /// Panics if the same domain is specified more than once.
211    pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
212        MarkMatchers(MarkStorage::new(matchers))
213    }
214
215    /// Returns an iterator over the mark matchers of all domains.
216    pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
217        let Self(storage) = self;
218        storage.iter()
219    }
220}
221
222impl Matcher<Marks> for MarkMatchers {
223    fn matches(&self, actual: &Marks) -> bool {
224        let Self(matchers) = self;
225        matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
226    }
227}
228
229/// A matcher for a socket's cookie.
230pub struct SocketCookieMatcher {
231    /// The cookie to check against.
232    pub cookie: u64,
233    /// Invert the matching criterion (i.e. if the socket cookie isn't the same,
234    /// it matches).
235    pub invert: bool,
236}
237
238impl Matcher<u64> for SocketCookieMatcher {
239    fn matches(&self, actual: &u64) -> bool {
240        let val = *actual == self.cookie;
241        if self.invert { !val } else { val }
242    }
243}
244
245/// A matcher for transport-layer port numbers.
246#[derive(Clone, Debug)]
247pub struct PortMatcher {
248    /// The range of port numbers in which the tested port number must fall.
249    pub range: RangeInclusive<u16>,
250    /// Whether to check for an "inverse" or "negative" match (in which case,
251    /// if the matcher criteria do *not* apply, it *is* considered a match, and
252    /// vice versa).
253    pub invert: bool,
254}
255
256impl Matcher<u16> for PortMatcher {
257    fn matches(&self, actual: &u16) -> bool {
258        let Self { range, invert } = self;
259        range.contains(actual) ^ *invert
260    }
261}
262
263bitflags! {
264    /// A matcher for TCP state machine state.
265    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
266    pub struct TcpStateMatcher: u32 {
267        /// The TCP ESTABLISHED state.
268        const ESTABLISHED = 1 << 0;
269        /// The TCP SYN_SENT state.
270        const SYN_SENT = 1 << 1;
271        /// The TCP SYN_RECV state.
272        const SYN_RECV = 1 << 2;
273        /// The TCP FIN_WAIT1 state.
274        const FIN_WAIT1 = 1 << 3;
275        /// The TCP FIN_WAIT2 state.
276        const FIN_WAIT2 = 1 << 4;
277        /// The TCP TIME_WAIT state.
278        const TIME_WAIT = 1 << 5;
279        /// The TCP CLOSE state.
280        const CLOSE = 1 << 6;
281        /// The TCP CLOSE_WAIT state.
282        const CLOSE_WAIT = 1 << 7;
283        /// The TCP LAST_ACK state.
284        const LAST_ACK = 1 << 8;
285        /// The TCP LISTEN state.
286        const LISTEN = 1 << 9;
287        /// The TCP CLOSING state.
288        const CLOSING = 1 << 10;
289    }
290}
291
292impl Matcher<TcpSocketState> for TcpStateMatcher {
293    fn matches(&self, actual: &TcpSocketState) -> bool {
294        self.contains(actual.matcher_flag())
295    }
296}
297
298/// Represents the state of a TCP socket's state machine.
299#[derive(Debug, Copy, Clone, PartialEq, Eq)]
300#[allow(missing_docs)]
301pub enum TcpSocketState {
302    Established,
303    SynSent,
304    SynRecv,
305    FinWait1,
306    FinWait2,
307    TimeWait,
308    Close,
309    CloseWait,
310    LastAck,
311    Listen,
312    Closing,
313}
314
315impl TcpSocketState {
316    fn matcher_flag(&self) -> TcpStateMatcher {
317        match self {
318            TcpSocketState::Established => TcpStateMatcher::ESTABLISHED,
319            TcpSocketState::SynSent => TcpStateMatcher::SYN_SENT,
320            TcpSocketState::SynRecv => TcpStateMatcher::SYN_RECV,
321            TcpSocketState::FinWait1 => TcpStateMatcher::FIN_WAIT1,
322            TcpSocketState::FinWait2 => TcpStateMatcher::FIN_WAIT2,
323            TcpSocketState::TimeWait => TcpStateMatcher::TIME_WAIT,
324            TcpSocketState::Close => TcpStateMatcher::CLOSE,
325            TcpSocketState::CloseWait => TcpStateMatcher::CLOSE_WAIT,
326            TcpSocketState::LastAck => TcpStateMatcher::LAST_ACK,
327            TcpSocketState::Listen => TcpStateMatcher::LISTEN,
328            TcpSocketState::Closing => TcpStateMatcher::CLOSING,
329        }
330    }
331}
332
333/// Allows code to match on properties of a TCP socket without Netstack3 Core
334/// having to specifically expose that state.
335pub trait TcpSocketProperties {
336    /// Returns whether the socket's source port is matched by the matcher.
337    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
338
339    /// Returns whether the socket's destination port is matched by the matcher.
340    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
341
342    /// Returns whether the socket's TCP state is matched by the matcher.
343    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool;
344}
345
346impl TcpSocketProperties for Never {
347    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
348        unimplemented!()
349    }
350
351    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
352        unimplemented!()
353    }
354
355    fn state_matches(&self, _matcher: &TcpStateMatcher) -> bool {
356        unimplemented!()
357    }
358}
359
360impl<T> TcpSocketProperties for &T
361where
362    T: TcpSocketProperties,
363{
364    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
365        (*self).src_port_matches(matcher)
366    }
367
368    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
369        (*self).dst_port_matches(matcher)
370    }
371
372    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
373        (*self).state_matches(matcher)
374    }
375}
376
377/// The top-level matcher for TCP sockets.
378pub enum TcpSocketMatcher {
379    /// Match any TCP socket without further constraints.
380    Empty,
381    /// Match on the source port.
382    SrcPort(PortMatcher),
383    /// Match on the destination port.
384    DstPort(PortMatcher),
385    /// Match on the state of the TCP state machine.
386    State(TcpStateMatcher),
387}
388
389impl<T: TcpSocketProperties> Matcher<T> for TcpSocketMatcher {
390    fn matches(&self, actual: &T) -> bool {
391        match self {
392            TcpSocketMatcher::Empty => true,
393            TcpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
394            TcpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
395            TcpSocketMatcher::State(matcher) => actual.state_matches(matcher),
396        }
397    }
398}
399
400bitflags! {
401    /// A matcher for UDP states.
402    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
403    pub struct UdpStateMatcher: u32 {
404        /// The UDP socket is bound but not connected.
405        const BOUND = 1 << 0;
406        /// The UDP socket is explicitly connected.
407        const CONNECTED = 1 << 1;
408    }
409}
410
411impl Matcher<UdpSocketState> for UdpStateMatcher {
412    fn matches(&self, actual: &UdpSocketState) -> bool {
413        self.contains(actual.matcher_flag())
414    }
415}
416
417/// Represents the state of a UDP socket.
418#[derive(Debug, Copy, Clone, PartialEq, Eq)]
419pub enum UdpSocketState {
420    /// The socket is bound to a local address and (maybe) port.
421    Bound,
422    /// The socket is connected to a remote peer and has a full 4-tuple.
423    Connected,
424}
425
426impl UdpSocketState {
427    fn matcher_flag(&self) -> UdpStateMatcher {
428        match self {
429            UdpSocketState::Bound => UdpStateMatcher::BOUND,
430            UdpSocketState::Connected => UdpStateMatcher::CONNECTED,
431        }
432    }
433}
434
435/// Allows code to match on properties of a UDP socket without Netstack3 Core
436/// having to specifically expose that state.
437pub trait UdpSocketProperties {
438    /// Returns whether the socket's source port is matched by the matcher.
439    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
440
441    /// Returns whether the socket's destination port is matched by the matcher.
442    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
443
444    /// Returns whether the socket's UDP state is matched by the matcher.
445    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool;
446}
447
448impl UdpSocketProperties for Never {
449    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
450        unimplemented!()
451    }
452
453    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
454        unimplemented!()
455    }
456
457    fn state_matches(&self, _matcher: &UdpStateMatcher) -> bool {
458        unimplemented!()
459    }
460}
461
462impl<U> UdpSocketProperties for &U
463where
464    U: UdpSocketProperties,
465{
466    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
467        (*self).src_port_matches(matcher)
468    }
469
470    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
471        (*self).dst_port_matches(matcher)
472    }
473
474    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
475        (*self).state_matches(matcher)
476    }
477}
478
479/// The top-level matcher for UDP sockets.
480pub enum UdpSocketMatcher {
481    /// Match any UDP socket without further constraints.
482    Empty,
483    /// Match the source port.
484    SrcPort(PortMatcher),
485    /// Match the destination port.
486    DstPort(PortMatcher),
487    /// Match the UDP state.
488    State(UdpStateMatcher),
489}
490
491impl<T: UdpSocketProperties> Matcher<T> for UdpSocketMatcher {
492    fn matches(&self, actual: &T) -> bool {
493        match self {
494            UdpSocketMatcher::Empty => true,
495            UdpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
496            UdpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
497            UdpSocketMatcher::State(matcher) => actual.state_matches(matcher),
498        }
499    }
500}
501
502/// Provides optional access to TCP socket properties.
503pub trait MaybeSocketTransportProperties {
504    /// The type that encapsulates TCP socket properties.
505    type TcpProps<'a>: TcpSocketProperties
506    where
507        Self: 'a;
508
509    /// The type that encapsulates UDP socket properties.
510    type UdpProps<'a>: UdpSocketProperties
511    where
512        Self: 'a;
513
514    /// Returns TCP socket properties if the socket is a TCP socket.
515    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>>;
516
517    /// Returns UDP socket properties if the socket is a UDP socket.
518    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>>;
519}
520
521impl MaybeSocketTransportProperties for Never {
522    type TcpProps<'a>
523        = Never
524    where
525        Self: 'a;
526
527    type UdpProps<'a>
528        = Never
529    where
530        Self: 'a;
531
532    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
533        unimplemented!()
534    }
535
536    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
537        unimplemented!()
538    }
539}
540
541/// A matcher for the transport protocol of a socket.
542pub enum SocketTransportProtocolMatcher {
543    /// Match against a TCP socket.
544    Tcp(TcpSocketMatcher),
545    /// Match against a UDP socket.
546    Udp(UdpSocketMatcher),
547}
548
549impl<T: MaybeSocketTransportProperties> Matcher<T> for SocketTransportProtocolMatcher {
550    fn matches(&self, actual: &T) -> bool {
551        match self {
552            SocketTransportProtocolMatcher::Tcp(tcp_matcher) => {
553                actual.tcp_socket_properties().map_or(false, |props| tcp_matcher.matches(props))
554            }
555            SocketTransportProtocolMatcher::Udp(udp_matcher) => {
556                actual.udp_socket_properties().map_or(false, |props| udp_matcher.matches(props))
557            }
558        }
559    }
560}
561
562/// A matcher for IP addresses.
563#[derive(Clone, Derivative)]
564#[derivative(Debug)]
565pub enum AddressMatcherType<A: IpAddress> {
566    /// A subnet that must contain the address.
567    #[derivative(Debug = "transparent")]
568    Subnet(SubnetMatcher<A>),
569    /// An inclusive range of IP addresses that must contain the address.
570    Range(RangeInclusive<A>),
571}
572
573impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
574    fn matches(&self, actual: &A) -> bool {
575        match self {
576            Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
577            Self::Range(range) => range.contains(actual),
578        }
579    }
580}
581
582/// A matcher for IP addresses.
583#[derive(Clone, Debug)]
584pub struct AddressMatcher<A: IpAddress> {
585    /// The type of the address matcher.
586    pub matcher: AddressMatcherType<A>,
587    /// Whether to check for an "inverse" or "negative" match (in which case,
588    /// if the matcher criteria do *not* apply, it *is* considered a match, and
589    /// vice versa).
590    pub invert: bool,
591}
592
593impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
594    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
595        let AddressMatcher { matcher, invert } = self;
596
597        inspector.record_child(name, |inspector| {
598            inspector.record_bool("invert", *invert);
599            match matcher {
600                AddressMatcherType::Subnet(SubnetMatcher(subnet)) => {
601                    inspector.record_display("subnet", subnet)
602                }
603                AddressMatcherType::Range(range) => {
604                    inspector.record_display("start", range.start());
605                    inspector.record_display("end", range.end());
606                }
607            }
608        })
609    }
610}
611
612impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
613    fn matches(&self, addr: &A) -> bool {
614        let Self { matcher, invert } = self;
615        matcher.matches(addr) ^ *invert
616    }
617}
618
619/// An address matcher that matches any IP version as specified at runtime.
620pub enum AddressMatcherEither {
621    /// The top-level IPv4 address matcher.
622    V4(AddressMatcher<Ipv4Addr>),
623    /// The top-level IPv6 address matcher.
624    V6(AddressMatcher<Ipv6Addr>),
625}
626
627impl Matcher<IpAddr> for AddressMatcherEither {
628    fn matches(&self, addr: &IpAddr) -> bool {
629        match self {
630            AddressMatcherEither::V4(matcher) => match addr {
631                IpAddr::V4(addr) => matcher.matches(addr),
632                IpAddr::V6(_) => false,
633            },
634            AddressMatcherEither::V6(matcher) => match addr {
635                IpAddr::V4(_) => false,
636                IpAddr::V6(addr) => matcher.matches(addr),
637            },
638        }
639    }
640}
641
642/// Allows code to match on properties of a socket without Netstack3 Core
643/// having to specifically expose that state.
644pub trait IpSocketProperties<DeviceClass> {
645    /// Returns whether the provided IP version matches the socket.
646    fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool;
647
648    /// Returns whether the provided address matcher matches the socket's source
649    /// address.
650    fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
651
652    /// Returns whether the provided address matcher matches the socket's
653    /// destination address.
654    fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
655
656    /// Returns whether the transport protocol matches the socket's
657    /// transport-layer information.
658    fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool;
659
660    /// Returns whether the provided interface matcher matches the socket's
661    /// bound interface, if present.
662    fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<DeviceClass>) -> bool;
663
664    /// Returns whether the provided cookie matcher matches the socket's cookie.
665    fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool;
666
667    /// Returns whether the provided mark matcher matches the corresponding mark.
668    fn mark_matches(&self, matcher: &MarkInDomainMatcher) -> bool;
669}
670
671/// The top-level matcher for IP sockets.
672pub enum IpSocketMatcher<DeviceClass> {
673    /// Matches the socket's address family.
674    Family(net_types::ip::IpVersion),
675    /// Matches the socket's source address.
676    SrcAddr(AddressMatcherEither),
677    /// Matches the socket's destination address.
678    DstAddr(AddressMatcherEither),
679    /// Matches the socket's transport protocol.
680    Proto(SocketTransportProtocolMatcher),
681    /// Matches the socket's bound interface.
682    BoundInterface(BoundInterfaceMatcher<DeviceClass>),
683    /// Matches the socket's cookie.
684    Cookie(SocketCookieMatcher),
685    /// Matches the socket's mark.
686    Mark(MarkInDomainMatcher),
687}
688
689impl<DeviceClass, S: IpSocketProperties<DeviceClass>> Matcher<S> for IpSocketMatcher<DeviceClass> {
690    fn matches(&self, actual: &S) -> bool {
691        match self {
692            IpSocketMatcher::Family(family) => actual.family_matches(family),
693            IpSocketMatcher::SrcAddr(addr) => actual.src_addr_matches(addr),
694            IpSocketMatcher::DstAddr(addr) => actual.dst_addr_matches(addr),
695            IpSocketMatcher::Proto(proto) => actual.transport_protocol_matches(proto),
696            IpSocketMatcher::BoundInterface(iface) => actual.bound_interface_matches(iface),
697            IpSocketMatcher::Cookie(cookie) => actual.cookie_matches(cookie),
698            IpSocketMatcher::Mark(mark) => actual.mark_matches(mark),
699        }
700    }
701}
702
703/// Allows code to take an opaque matcher that works on IP sockets without
704/// needing to know the type(s) of the underlying matcher(s).
705pub trait IpSocketPropertiesMatcher<DeviceClass> {
706    /// Whether the matcher matches `actual`.
707    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool;
708}
709
710impl<DeviceClass> IpSocketPropertiesMatcher<DeviceClass> for IpSocketMatcher<DeviceClass> {
711    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool {
712        self.matches(actual)
713    }
714}
715
716impl<DeviceClass> IpSocketPropertiesMatcher<DeviceClass> for [IpSocketMatcher<DeviceClass>] {
717    fn matches_ip_socket<S: IpSocketProperties<DeviceClass>>(&self, actual: &S) -> bool {
718        self.iter().all(|matcher| matcher.matches(actual))
719    }
720}
721
722#[cfg(any(test, feature = "testutils"))]
723pub(crate) mod testutil {
724    use alloc::string::String;
725    use core::num::NonZeroU64;
726
727    use crate::matchers::InterfaceProperties;
728    use crate::testutil::{FakeDeviceClass, FakeStrongDeviceId, FakeWeakDeviceId};
729    use crate::{DeviceIdentifier, StrongDeviceIdentifier};
730
731    /// A fake device ID for testing matchers.
732    #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
733    #[allow(missing_docs)]
734    pub struct FakeMatcherDeviceId {
735        pub id: NonZeroU64,
736        pub name: String,
737        pub class: FakeDeviceClass,
738    }
739
740    impl FakeMatcherDeviceId {
741        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary WLAN interface.
742        ///
743        /// The interface returned will always be identical.
744        pub fn wlan_interface() -> FakeMatcherDeviceId {
745            FakeMatcherDeviceId {
746                id: NonZeroU64::new(1).unwrap(),
747                name: String::from("wlan"),
748                class: FakeDeviceClass::Wlan,
749            }
750        }
751
752        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary Ethernet interface.
753        ///
754        /// The interface returned will always be identical.
755        pub fn ethernet_interface() -> FakeMatcherDeviceId {
756            FakeMatcherDeviceId {
757                id: NonZeroU64::new(2).unwrap(),
758                name: String::from("eth"),
759                class: FakeDeviceClass::Ethernet,
760            }
761        }
762    }
763
764    impl StrongDeviceIdentifier for FakeMatcherDeviceId {
765        type Weak = FakeWeakDeviceId<Self>;
766
767        fn downgrade(&self) -> Self::Weak {
768            FakeWeakDeviceId(self.clone())
769        }
770    }
771
772    impl DeviceIdentifier for FakeMatcherDeviceId {
773        fn is_loopback(&self) -> bool {
774            false
775        }
776    }
777
778    impl FakeStrongDeviceId for FakeMatcherDeviceId {
779        fn is_alive(&self) -> bool {
780            true
781        }
782    }
783
784    impl PartialEq<FakeWeakDeviceId<FakeMatcherDeviceId>> for FakeMatcherDeviceId {
785        fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeMatcherDeviceId>) -> bool {
786            self == other
787        }
788    }
789
790    impl InterfaceProperties<FakeDeviceClass> for FakeMatcherDeviceId {
791        fn id_matches(&self, id: &NonZeroU64) -> bool {
792            &self.id == id
793        }
794
795        fn name_matches(&self, name: &str) -> bool {
796            &self.name == name
797        }
798
799        fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
800            &self.class == class
801        }
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use ip_test_macro::ip_test;
808    use net_types::Witness;
809    use net_types::ip::{Ip, IpVersion, Ipv4, Ipv6};
810    use test_case::test_case;
811
812    use super::*;
813    use crate::testutil::{FakeDeviceClass, FakeMatcherDeviceId, TestIpExt};
814
815    /// Only matches `true`.
816    #[derive(Debug)]
817    struct TrueMatcher;
818
819    impl Matcher<bool> for TrueMatcher {
820        fn matches(&self, actual: &bool) -> bool {
821            *actual
822        }
823    }
824
825    #[test]
826    fn test_optional_matcher_optional_value() {
827        assert!(TrueMatcher.matches(&true));
828        assert!(!TrueMatcher.matches(&false));
829
830        assert!(TrueMatcher.required_matches(Some(&true)));
831        assert!(!TrueMatcher.required_matches(Some(&false)));
832        assert!(!TrueMatcher.required_matches(None));
833
834        assert!(Some(TrueMatcher).matches(&true));
835        assert!(!Some(TrueMatcher).matches(&false));
836        assert!(None::<TrueMatcher>.matches(&true));
837        assert!(None::<TrueMatcher>.matches(&false));
838
839        assert!(Some(TrueMatcher).required_matches(Some(&true)));
840        assert!(!Some(TrueMatcher).required_matches(Some(&false)));
841        assert!(!Some(TrueMatcher).required_matches(None));
842        assert!(None::<TrueMatcher>.required_matches(Some(&true)));
843        assert!(None::<TrueMatcher>.required_matches(Some(&false)));
844        assert!(None::<TrueMatcher>.required_matches(None));
845    }
846
847    #[test_case(
848        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
849        FakeMatcherDeviceId::wlan_interface() => true
850    )]
851    #[test_case(
852        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
853        FakeMatcherDeviceId::ethernet_interface() => false
854    )]
855    #[test_case(
856        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
857        FakeMatcherDeviceId::wlan_interface() => true
858    )]
859    #[test_case(
860        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
861        FakeMatcherDeviceId::ethernet_interface() => false
862    )]
863    #[test_case(
864        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
865        FakeMatcherDeviceId::wlan_interface() => true
866    )]
867    #[test_case(
868        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
869        FakeMatcherDeviceId::ethernet_interface() => false
870    )]
871    fn interface_matcher(
872        matcher: InterfaceMatcher<FakeDeviceClass>,
873        device: FakeMatcherDeviceId,
874    ) -> bool {
875        matcher.matches(&device)
876    }
877
878    #[test_case(BoundInterfaceMatcher::Unbound, None => true)]
879    #[test_case(
880        BoundInterfaceMatcher::Unbound,
881        Some(FakeMatcherDeviceId::wlan_interface()) => false
882    )]
883    #[test_case(
884        BoundInterfaceMatcher::Bound(
885            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
886        ),
887        None => false
888    )]
889    #[test_case(
890        BoundInterfaceMatcher::Bound(
891            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
892        ),
893        Some(FakeMatcherDeviceId::wlan_interface()) => true
894    )]
895    #[test_case(
896        BoundInterfaceMatcher::Bound(
897            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
898        ),
899        Some(FakeMatcherDeviceId::ethernet_interface()) => false
900    )]
901    #[test_case(
902        BoundInterfaceMatcher::Bound(
903            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
904        ),
905        None => false
906    )]
907    #[test_case(
908        BoundInterfaceMatcher::Bound(
909            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
910        ),
911        Some(FakeMatcherDeviceId::wlan_interface()) => true
912    )]
913    #[test_case(
914        BoundInterfaceMatcher::Bound(
915            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
916        ),
917        Some(FakeMatcherDeviceId::ethernet_interface()) => false
918    )]
919    #[test_case(
920        BoundInterfaceMatcher::Bound(
921            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
922        ),
923        None => false
924    )]
925    #[test_case(
926        BoundInterfaceMatcher::Bound(
927            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
928        ),
929        Some(FakeMatcherDeviceId::wlan_interface()) => true
930    )]
931    #[test_case(
932        BoundInterfaceMatcher::Bound(
933            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
934        ),
935        Some(FakeMatcherDeviceId::ethernet_interface()) => false
936    )]
937    fn bound_interface_matcher(
938        matcher: BoundInterfaceMatcher<FakeDeviceClass>,
939        device: Option<FakeMatcherDeviceId>,
940    ) -> bool {
941        matcher.matches(&device.as_ref())
942    }
943
944    #[ip_test(I)]
945    fn subnet_matcher<I: Ip + TestIpExt>() {
946        let matcher = SubnetMatcher(I::TEST_ADDRS.subnet);
947        assert!(matcher.matches(&I::TEST_ADDRS.local_ip));
948        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
949    }
950
951    #[test_case(MarkMatcher::Unmarked, Mark(None) => true; "unmarked matches none")]
952    #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false; "unmarked does not match some")]
953    #[test_case(MarkMatcher::Marked {
954        mask: 1,
955        start: 0,
956        end: 0,
957        invert: false,
958    }, Mark(None) => false; "marked does not match none")]
959    #[test_case(MarkMatcher::Marked {
960        mask: 1,
961        start: 0,
962        end: 0,
963        invert: false,
964    }, Mark(Some(0)) => true; "marked 0 mask 1 matches 0")]
965    #[test_case(MarkMatcher::Marked {
966        mask: 1,
967        start: 0,
968        end: 0,
969        invert: false,
970    }, Mark(Some(1)) => false; "marked 0 mask 1 does not match 1")]
971    #[test_case(MarkMatcher::Marked {
972        mask: 1,
973        start: 0,
974        end: 0,
975        invert: false,
976    }, Mark(Some(2)) => true; "marked 0 mask 1 matches 2")]
977    #[test_case(MarkMatcher::Marked {
978        mask: 1,
979        start: 0,
980        end: 0,
981        invert: false,
982    }, Mark(Some(3)) => false; "marked 0 mask 1 does not match 3")]
983    #[test_case(MarkMatcher::Marked {
984        mask: !0,
985        start: 0,
986        end: 10,
987        invert: true,
988    }, Mark(Some(5)) => false; "marked invert no match in range")]
989    #[test_case(MarkMatcher::Marked {
990        mask: !0,
991        start: 0,
992        end: 10,
993        invert: true,
994    }, Mark(Some(11)) => true; "marked invert matches out of range")]
995    fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
996        matcher.matches(&mark)
997    }
998
999    #[test_case(
1000        MarkMatchers::new(
1001            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1002            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1003        ),
1004        Marks::new([]) => true;
1005        "all unmarked matches empty"
1006    )]
1007    #[test_case(
1008        MarkMatchers::new(
1009            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1010            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1011        ),
1012        Marks::new([(MarkDomain::Mark1, 1)]) => false;
1013        "all unmarked does not match mark1"
1014    )]
1015    #[test_case(
1016        MarkMatchers::new(
1017            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1018            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1019        ),
1020        Marks::new([(MarkDomain::Mark2, 1)]) => false;
1021        "all unmarked does not match mark2"
1022    )]
1023    #[test_case(
1024        MarkMatchers::new(
1025            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1026            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1027        ),
1028        Marks::new([
1029            (MarkDomain::Mark1, 1),
1030            (MarkDomain::Mark2, 1),
1031        ]) => false;
1032        "all unmarked does not match mark1 and mark2"
1033    )]
1034    #[test_case(
1035        MarkMatchers::new(
1036            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1037            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1038        ),
1039        Marks::new([(MarkDomain::Mark1, 1)]) => true;
1040        "mark1 marked matches"
1041    )]
1042    #[test_case(
1043        MarkMatchers::new(
1044            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1045            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1046        ),
1047        Marks::new([(MarkDomain::Mark1, 2)]) => false;
1048        "mark1 marked no match"
1049    )]
1050    #[test_case(
1051        MarkMatchers::new(
1052            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1053            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1054        ),
1055        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 2)]) => true;
1056        "all marked matches"
1057    )]
1058    #[test_case(
1059        MarkMatchers::new(
1060            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1061            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1062        ),
1063        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 3)]) => false;
1064        "all marked no match mark2"
1065    )]
1066    fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
1067        matchers.matches(&marks)
1068    }
1069
1070    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 123 => true)]
1071    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 456 => false)]
1072    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 123 => false)]
1073    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 456 => true)]
1074    fn socket_cookie_matcher(matcher: SocketCookieMatcher, actual: u64) -> bool {
1075        matcher.matches(&actual)
1076    }
1077
1078    #[test_case(PortMatcher { range: 10..=20, invert: false }, 9 => false)]
1079    #[test_case(PortMatcher { range: 10..=20, invert: false }, 10 => true)]
1080    #[test_case(PortMatcher { range: 10..=20, invert: false }, 15 => true)]
1081    #[test_case(PortMatcher { range: 10..=20, invert: false }, 20 => true)]
1082    #[test_case(PortMatcher { range: 10..=20, invert: false }, 21 => false)]
1083    #[test_case(PortMatcher { range: 10..=20, invert: true }, 9 => true)]
1084    #[test_case(PortMatcher { range: 10..=20, invert: true }, 10 => false)]
1085    #[test_case(PortMatcher { range: 10..=20, invert: true }, 15 => false)]
1086    #[test_case(PortMatcher { range: 10..=20, invert: true }, 20 => false)]
1087    #[test_case(PortMatcher { range: 10..=20, invert: true }, 21 => true)]
1088    fn port_matcher(matcher: PortMatcher, actual: u16) -> bool {
1089        matcher.matches(&actual)
1090    }
1091
1092    struct FakeTcpSocket {
1093        src_port: u16,
1094        dst_port: u16,
1095        state: TcpSocketState,
1096    }
1097
1098    impl MaybeSocketTransportProperties for FakeTcpSocket {
1099        type TcpProps<'a>
1100            = Self
1101        where
1102            Self: 'a;
1103
1104        type UdpProps<'a>
1105            = Never
1106        where
1107            Self: 'a;
1108
1109        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1110            Some(self)
1111        }
1112
1113        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1114            None
1115        }
1116    }
1117
1118    impl TcpSocketProperties for FakeTcpSocket {
1119        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1120            matcher.matches(&self.src_port)
1121        }
1122
1123        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1124            matcher.matches(&self.dst_port)
1125        }
1126
1127        fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
1128            matcher.matches(&self.state)
1129        }
1130    }
1131
1132    struct FakeUdpSocket {
1133        src_port: u16,
1134        dst_port: u16,
1135        state: UdpSocketState,
1136    }
1137
1138    impl MaybeSocketTransportProperties for FakeUdpSocket {
1139        type TcpProps<'a>
1140            = Never
1141        where
1142            Self: 'a;
1143
1144        type UdpProps<'a>
1145            = Self
1146        where
1147            Self: 'a;
1148
1149        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1150            None
1151        }
1152
1153        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1154            Some(self)
1155        }
1156    }
1157
1158    impl UdpSocketProperties for FakeUdpSocket {
1159        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1160            matcher.matches(&self.src_port)
1161        }
1162
1163        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1164            matcher.matches(&self.dst_port)
1165        }
1166
1167        fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
1168            matcher.matches(&self.state)
1169        }
1170    }
1171
1172    struct FakeIpSocket<I, T>
1173    where
1174        I: TestIpExt,
1175        T: MaybeSocketTransportProperties,
1176    {
1177        src_ip: I::Addr,
1178        dst_ip: I::Addr,
1179        proto: T,
1180        intf: Option<FakeMatcherDeviceId>,
1181        cookie: u64,
1182        marks: Marks,
1183    }
1184
1185    impl<I, T> MaybeSocketTransportProperties for FakeIpSocket<I, T>
1186    where
1187        I: TestIpExt,
1188        T: MaybeSocketTransportProperties,
1189    {
1190        type TcpProps<'a>
1191            = T::TcpProps<'a>
1192        where
1193            Self: 'a;
1194
1195        type UdpProps<'a>
1196            = T::UdpProps<'a>
1197        where
1198            Self: 'a;
1199
1200        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1201            self.proto.tcp_socket_properties()
1202        }
1203
1204        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1205            self.proto.udp_socket_properties()
1206        }
1207    }
1208
1209    impl<I, T> IpSocketProperties<FakeDeviceClass> for FakeIpSocket<I, T>
1210    where
1211        I: TestIpExt,
1212        T: MaybeSocketTransportProperties,
1213    {
1214        fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool {
1215            *family == I::VERSION
1216        }
1217
1218        fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1219            addr.matches(&self.src_ip.into())
1220        }
1221
1222        fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1223            addr.matches(&self.dst_ip.into())
1224        }
1225
1226        fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool {
1227            matcher.matches(self)
1228        }
1229
1230        fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<FakeDeviceClass>) -> bool {
1231            iface.matches(&self.intf.as_ref())
1232        }
1233
1234        fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool {
1235            cookie.matches(&self.cookie)
1236        }
1237
1238        fn mark_matches(&self, matcher: &MarkInDomainMatcher) -> bool {
1239            matcher.matcher.matches(self.marks.get(matcher.domain))
1240        }
1241    }
1242
1243    #[test_case(
1244        TcpSocketMatcher::Empty,
1245        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1246        "empty matcher"
1247    )]
1248    #[test_case(
1249        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1250        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1251        "src_port match"
1252    )]
1253    #[test_case(
1254        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1255        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => false;
1256        "src_port no match"
1257    )]
1258    #[test_case(
1259        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1260        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1261        "src_port invert no match"
1262    )]
1263    #[test_case(
1264        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1265        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => true;
1266        "src_port invert match"
1267    )]
1268    #[test_case(
1269        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1270        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1271        "dst_port match"
1272    )]
1273    #[test_case(
1274        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1275        FakeTcpSocket { src_port: 80, dst_port: 12346, state: TcpSocketState::Established } => false;
1276        "dst_port no match"
1277    )]
1278    #[test_case(
1279        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED),
1280        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1281        "state match"
1282    )]
1283    #[test_case(
1284        TcpSocketMatcher::State(TcpStateMatcher::SYN_SENT),
1285        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1286        "state no match"
1287    )]
1288    #[test_case(
1289        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1290        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1291        "state multi match established"
1292    )]
1293    #[test_case(
1294        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1295        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::SynSent } => true;
1296        "state multi match syn_sent"
1297    )]
1298    #[test_case(
1299        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1300        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::FinWait1 } => false;
1301        "state multi no match"
1302    )]
1303    fn tcp_socket_matcher(matcher: TcpSocketMatcher, socket: FakeTcpSocket) -> bool {
1304        matcher.matches(&socket)
1305    }
1306
1307    #[test_case(
1308        UdpSocketMatcher::Empty,
1309        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1310        "empty matcher"
1311    )]
1312    #[test_case(
1313        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1314        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1315        "src_port match"
1316    )]
1317    #[test_case(
1318        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1319        FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound } => false;
1320        "src_port no match"
1321    )]
1322    #[test_case(
1323        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1324        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1325        "dst_port match"
1326    )]
1327    #[test_case(
1328        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1329        FakeUdpSocket { src_port: 53, dst_port: 12346, state: UdpSocketState::Bound } => false;
1330        "dst_port no match"
1331    )]
1332    #[test_case(
1333        UdpSocketMatcher::State(UdpStateMatcher::BOUND),
1334        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1335        "state match bound"
1336    )]
1337    #[test_case(
1338        UdpSocketMatcher::State(UdpStateMatcher::CONNECTED),
1339        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => false;
1340        "state no match connected"
1341    )]
1342    #[test_case(
1343        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1344        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1345        "state multi match bound"
1346    )]
1347    #[test_case(
1348        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1349        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Connected } => true;
1350        "state multi match connected"
1351    )]
1352    fn udp_socket_matcher(matcher: UdpSocketMatcher, socket: FakeUdpSocket) -> bool {
1353        matcher.matches(&socket)
1354    }
1355
1356    #[ip_test(I)]
1357    #[test_case(
1358        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1359        FakeIpSocket {
1360            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1361            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1362            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1363            cookie: 0,
1364            intf: None,
1365            marks: Marks::default(),
1366        } => true;
1367        "tcp empty"
1368    )]
1369    #[test_case(
1370        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1371        FakeIpSocket {
1372            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1373            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1374            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1375            cookie: 0,
1376            intf: None,
1377            marks: Marks::default(),
1378        } => false;
1379        "tcp empty no match udp"
1380    )]
1381    #[test_case(
1382        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1383        FakeIpSocket {
1384            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1385            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1386            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1387            cookie: 0,
1388            intf: None,
1389            marks: Marks::default(),
1390        } => false;
1391        "udp empty no match tcp"
1392    )]
1393    #[test_case(
1394        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1395        FakeIpSocket {
1396            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1397            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1398            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1399            cookie: 0,
1400            intf: None,
1401            marks: Marks::default(),
1402        } => true;
1403        "udp empty"
1404    )]
1405    #[test_case(
1406        IpSocketMatcher::Proto(
1407            SocketTransportProtocolMatcher::Tcp(
1408                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1409            )
1410        ),
1411        FakeIpSocket {
1412            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1413            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1414            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1415            cookie: 0,
1416            intf: None,
1417            marks: Marks::default(),
1418        } => true;
1419        "tcp src_port match"
1420    )]
1421    #[test_case(
1422        IpSocketMatcher::Proto(
1423            SocketTransportProtocolMatcher::Tcp(
1424                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1425            )
1426        ),
1427        FakeIpSocket {
1428            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1429            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1430            proto: FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established },
1431            cookie: 0,
1432            intf: None,
1433            marks: Marks::default(),
1434        } => false;
1435        "tcp src_port no match"
1436    )]
1437    #[test_case(
1438        IpSocketMatcher::Proto(
1439            SocketTransportProtocolMatcher::Udp(
1440                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1441            )
1442        ),
1443        FakeIpSocket {
1444            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1445            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1446            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1447            cookie: 0,
1448            intf: None,
1449            marks: Marks::default(),
1450        } => true;
1451        "udp src_port match"
1452    )]
1453    #[test_case(
1454        IpSocketMatcher::Proto(
1455            SocketTransportProtocolMatcher::Udp(
1456                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1457            )
1458        ),
1459        FakeIpSocket {
1460            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1461            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1462            proto: FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound },
1463            cookie: 0,
1464            intf: None,
1465            marks: Marks::default(),
1466        } => false;
1467        "udp src_port no match"
1468    )]
1469    #[test_case(
1470        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1471        FakeIpSocket {
1472            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1473            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1474            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1475            cookie: 123,
1476            intf: None,
1477            marks: Marks::default(),
1478        } => true;
1479        "cookie match"
1480    )]
1481    #[test_case(
1482        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1483        FakeIpSocket {
1484            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1485            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1486            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1487            cookie: 456,
1488            intf: None,
1489            marks: Marks::default(),
1490        } => false;
1491        "cookie no match"
1492    )]
1493    #[test_case(
1494        IpSocketMatcher::Mark(MarkInDomainMatcher {
1495            domain: MarkDomain::Mark1,
1496            matcher: MarkMatcher::Unmarked,
1497        }),
1498        FakeIpSocket {
1499            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1500            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1501            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1502            cookie: 0,
1503            intf: None,
1504            marks: Marks::default(),
1505        } => true;
1506        "mark1 unmarked match"
1507    )]
1508    #[test_case(
1509        IpSocketMatcher::Mark(MarkInDomainMatcher {
1510            domain: MarkDomain::Mark1,
1511            matcher: MarkMatcher::Unmarked,
1512        }),
1513        FakeIpSocket {
1514            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1515            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1516            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1517            cookie: 0,
1518            intf: None,
1519            marks: Marks::new([(MarkDomain::Mark1, 1)]),
1520        } => false;
1521        "mark1 unmarked no match"
1522    )]
1523    #[test_case(
1524        IpSocketMatcher::Mark(MarkInDomainMatcher {
1525            domain: MarkDomain::Mark2,
1526            matcher: MarkMatcher::Unmarked,
1527        }),
1528        FakeIpSocket {
1529            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1530            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1531            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1532            cookie: 0,
1533            intf: None,
1534            marks: Marks::default(),
1535        } => true;
1536        "mark2 unmarked match"
1537    )]
1538    #[test_case(
1539        IpSocketMatcher::Mark(MarkInDomainMatcher {
1540            domain: MarkDomain::Mark2,
1541            matcher: MarkMatcher::Unmarked,
1542        }),
1543        FakeIpSocket {
1544            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1545            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1546            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1547            cookie: 0,
1548            intf: None,
1549            marks: Marks::new([(MarkDomain::Mark2, 1)]),
1550        } => false;
1551        "mark2 unmarked no match"
1552    )]
1553    #[test_case(
1554        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1555            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1556        )),
1557        FakeIpSocket {
1558            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1559            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1560            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1561            cookie: 0,
1562            intf: Some(FakeMatcherDeviceId::wlan_interface()),
1563            marks: Marks::default(),
1564        } => true;
1565        "bound_interface match"
1566    )]
1567    #[test_case(
1568        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1569            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1570        )),
1571        FakeIpSocket {
1572            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1573            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1574            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1575            cookie: 0,
1576            intf: Some(FakeMatcherDeviceId::ethernet_interface()),
1577            marks: Marks::default(),
1578        } => false;
1579        "bound_interface no match"
1580    )]
1581    fn ip_socket_matcher<I: TestIpExt, T: MaybeSocketTransportProperties>(
1582        matcher: IpSocketMatcher<FakeDeviceClass>,
1583        socket: FakeIpSocket<I, T>,
1584    ) -> bool {
1585        matcher.matches(&socket)
1586    }
1587
1588    #[ip_test(I)]
1589    fn address_matcher_type<I: TestIpExt>() {
1590        let local_ip = I::TEST_ADDRS.local_ip.get();
1591        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1592
1593        let matcher = AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet));
1594        assert!(matcher.matches(&local_ip));
1595        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1596
1597        let matcher = AddressMatcherType::Range(local_ip..=remote_ip);
1598        assert!(matcher.matches(&local_ip));
1599        assert!(matcher.matches(&remote_ip));
1600        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1601    }
1602
1603    #[ip_test(I)]
1604    fn address_matcher<I: TestIpExt>() {
1605        let local_ip = I::TEST_ADDRS.local_ip.get();
1606        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1607
1608        let matcher = AddressMatcher {
1609            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1610            invert: false,
1611        };
1612        assert!(matcher.matches(&local_ip));
1613        assert!(matcher.matches(&remote_ip));
1614        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1615
1616        let matcher = AddressMatcher {
1617            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1618            invert: true,
1619        };
1620        assert!(!matcher.matches(&local_ip));
1621        assert!(!matcher.matches(&remote_ip));
1622        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1623
1624        let matcher = AddressMatcher {
1625            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1626            invert: false,
1627        };
1628        assert!(matcher.matches(&local_ip));
1629        assert!(matcher.matches(&remote_ip));
1630        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1631
1632        let matcher = AddressMatcher {
1633            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1634            invert: true,
1635        };
1636        assert!(!matcher.matches(&local_ip));
1637        assert!(!matcher.matches(&remote_ip));
1638        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1639    }
1640
1641    #[test]
1642    fn agnostic_address_matcher() {
1643        let v4_addr = IpAddr::V4(Ipv4Addr::new([192, 0, 2, 1]));
1644        let v6_addr = IpAddr::V6(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 1]));
1645
1646        let v4_subnet = Subnet::new(Ipv4Addr::new([192, 0, 2, 0]), 24).unwrap();
1647        let v6_subnet = Subnet::new(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0]), 32).unwrap();
1648
1649        let v4_matcher = AddressMatcherEither::V4(AddressMatcher {
1650            matcher: AddressMatcherType::Subnet(SubnetMatcher(v4_subnet)),
1651            invert: false,
1652        });
1653        assert!(v4_matcher.matches(&v4_addr));
1654        assert!(!v4_matcher.matches(&v6_addr));
1655
1656        let v6_matcher = AddressMatcherEither::V6(AddressMatcher {
1657            matcher: AddressMatcherType::Subnet(SubnetMatcher(v6_subnet)),
1658            invert: false,
1659        });
1660        assert!(!v6_matcher.matches(&v4_addr));
1661        assert!(v6_matcher.matches(&v6_addr));
1662    }
1663
1664    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => true; "v4 family matcher on v4 socket")]
1665    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => false; "v6 family matcher on v4 socket")]
1666    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1667        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1668        invert: false,
1669    })) => true; "src_addr match")]
1670    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1671        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1672        invert: false,
1673    })) => false; "src_addr no match")]
1674    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1675        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1676        invert: false,
1677    })) => true; "dst_addr match")]
1678    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1679        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1680        invert: false,
1681    })) => false; "dst_addr no match")]
1682    fn ip_socket_matcher_test_v4(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1683        let socket = FakeIpSocket::<Ipv4, _> {
1684            src_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1685            dst_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1686            proto: FakeTcpSocket {
1687                src_port: 80,
1688                dst_port: 12345,
1689                state: TcpSocketState::Established,
1690            },
1691            cookie: 0,
1692            intf: None,
1693            marks: Marks::default(),
1694        };
1695        matcher.matches(&socket)
1696    }
1697
1698    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => false; "v4 family matcher on v6 socket")]
1699    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => true; "v6 family matcher on v6 socket")]
1700    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1701        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1702        invert: false,
1703    })) => true; "src_addr match v6")]
1704    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1705        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1706        invert: false,
1707    })) => false; "src_addr no match v6")]
1708    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1709        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1710        invert: false,
1711    })) => true; "dst_addr match v6")]
1712    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1713        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1714        invert: false,
1715    })) => false; "dst_addr no match v6")]
1716    fn ip_socket_matcher_test_v6(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1717        let socket = FakeIpSocket::<Ipv6, _> {
1718            src_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1719            dst_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1720            proto: FakeTcpSocket {
1721                src_port: 80,
1722                dst_port: 12345,
1723                state: TcpSocketState::Established,
1724            },
1725            cookie: 0,
1726            intf: None,
1727            marks: Marks::default(),
1728        };
1729        matcher.matches(&socket)
1730    }
1731}