netstack3_base/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Trait definition for matchers.
6
7use alloc::format;
8use alloc::string::String;
9use core::convert::Infallible as Never;
10use core::fmt::Debug;
11use core::num::NonZeroU64;
12use core::ops::RangeInclusive;
13
14use bitflags::bitflags;
15use derivative::Derivative;
16use net_types::ip::{IpAddr, IpAddress, Ipv4Addr, Ipv6Addr, Subnet};
17
18use crate::{InspectableValue, Inspector, Mark, MarkDomain, MarkStorage, Marks};
19
20/// Trait defining required types for matchers provided by bindings.
21///
22/// Allows rules that match on device class to be installed, storing the
23/// [`MatcherBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
24/// Core to have Bindings provide the type since it is platform-specific.
25pub trait MatcherBindingsTypes {
26    /// The device class type for devices installed in the netstack.
27    type DeviceClass: Clone + Debug;
28}
29
30/// Common pattern to define a matcher for a metadata input `T`.
31///
32/// Used in matching engines like filtering and routing rules.
33pub trait Matcher<T> {
34    /// Returns whether the provided value matches.
35    fn matches(&self, actual: &T) -> bool;
36
37    /// Returns whether the provided value is set and matches.
38    fn required_matches(&self, actual: Option<&T>) -> bool {
39        actual.map_or(false, |actual| self.matches(actual))
40    }
41}
42
43/// Implement `Matcher` for optional matchers, so that if a matcher is left
44/// unspecified, it matches all inputs by default.
45impl<T, O> Matcher<T> for Option<O>
46where
47    O: Matcher<T>,
48{
49    fn matches(&self, actual: &T) -> bool {
50        self.as_ref().map_or(true, |expected| expected.matches(actual))
51    }
52
53    fn required_matches(&self, actual: Option<&T>) -> bool {
54        self.as_ref().map_or(true, |expected| expected.required_matches(actual))
55    }
56}
57
58/// Matcher that matches IP addresses in a subnet.
59#[derive(Debug, Copy, Clone, PartialEq, Eq)]
60pub struct SubnetMatcher<A: IpAddress>(pub Subnet<A>);
61
62impl<A: IpAddress> Matcher<A> for SubnetMatcher<A> {
63    fn matches(&self, actual: &A) -> bool {
64        let Self(matcher) = self;
65        matcher.contains(actual)
66    }
67}
68
69/// A matcher for network interfaces.
70#[derive(Clone, Derivative, PartialEq, Eq)]
71#[derivative(Debug)]
72pub enum InterfaceMatcher<DeviceClass> {
73    /// The ID of the interface as assigned by the netstack.
74    Id(NonZeroU64),
75    /// Match based on name.
76    Name(String),
77    /// The device class of the interface.
78    DeviceClass(DeviceClass),
79}
80
81impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
82    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
83        match self {
84            InterfaceMatcher::Id(id) => inspector.record_string(name, format!("Id({})", id.get())),
85            InterfaceMatcher::Name(iface_name) => {
86                inspector.record_string(name, format!("Name({iface_name})"))
87            }
88            InterfaceMatcher::DeviceClass(class) => {
89                inspector.record_debug(name, format!("Class({class:?})"))
90            }
91        };
92    }
93}
94
95/// Allows code to match on properties of an interface (ID, name, and device
96/// class) without Netstack3 Core (or Bindings, in the case of the device class)
97/// having to specifically expose that state.
98pub trait InterfaceProperties<DeviceClass> {
99    /// Returns whether the provided ID matches the interface.
100    fn id_matches(&self, id: &NonZeroU64) -> bool;
101
102    /// Returns whether the provided name matches the interface.
103    fn name_matches(&self, name: &str) -> bool;
104
105    /// Returns whether the provided device class matches the interface.
106    fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
107}
108
109impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
110    for InterfaceMatcher<DeviceClass>
111{
112    fn matches(&self, actual: &I) -> bool {
113        match self {
114            InterfaceMatcher::Id(id) => actual.id_matches(id),
115            InterfaceMatcher::Name(name) => actual.name_matches(name),
116            InterfaceMatcher::DeviceClass(device_class) => {
117                actual.device_class_matches(device_class)
118            }
119        }
120    }
121}
122
123/// Matcher for the bound device of locally generated traffic.
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub enum BoundInterfaceMatcher<DeviceClass> {
126    /// The packet is bound to a device which is matched by the matcher.
127    Bound(InterfaceMatcher<DeviceClass>),
128    /// There is no bound device.
129    Unbound,
130}
131
132impl<'a, DeviceClass, D: InterfaceProperties<DeviceClass>> Matcher<Option<&'a D>>
133    for BoundInterfaceMatcher<DeviceClass>
134{
135    fn matches(&self, actual: &Option<&'a D>) -> bool {
136        match self {
137            BoundInterfaceMatcher::Bound(matcher) => matcher.required_matches(actual.as_deref()),
138            BoundInterfaceMatcher::Unbound => actual.is_none(),
139        }
140    }
141}
142
143impl<DeviceClass: Debug> InspectableValue for BoundInterfaceMatcher<DeviceClass> {
144    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
145        match self {
146            BoundInterfaceMatcher::Unbound => inspector.record_str(name, "Unbound"),
147            BoundInterfaceMatcher::Bound(interface) => {
148                inspector.record_inspectable_value(name, interface)
149            }
150        }
151    }
152}
153
154/// A matcher to the socket mark.
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub enum MarkMatcher {
157    /// Matches a packet if it is unmarked.
158    Unmarked,
159    /// The packet carries a mark that is in the range after masking.
160    Marked {
161        /// The mask to apply.
162        mask: u32,
163        /// Start of the range, inclusive.
164        start: u32,
165        /// End of the range, inclusive.
166        end: u32,
167        /// Inverts the meaning of the match.
168        invert: bool,
169    },
170}
171
172impl Matcher<Mark> for MarkMatcher {
173    fn matches(&self, Mark(actual): &Mark) -> bool {
174        match self {
175            MarkMatcher::Unmarked => actual.is_none(),
176            MarkMatcher::Marked { mask, start, end, invert } => {
177                let val = actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)));
178
179                if *invert { !val } else { val }
180            }
181        }
182    }
183}
184
185/// The 2 mark matchers a rule can specify. All non-none markers must match.
186#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
187pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
188
189impl MarkMatchers {
190    /// Creates [`MarkMatcher`]s from an iterator of `(MarkDomain, MarkMatcher)`.
191    ///
192    /// An unspecified domain will not have a matcher.
193    ///
194    /// # Panics
195    ///
196    /// Panics if the same domain is specified more than once.
197    pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
198        MarkMatchers(MarkStorage::new(matchers))
199    }
200
201    /// Returns an iterator over the mark matchers of all domains.
202    pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
203        let Self(storage) = self;
204        storage.iter()
205    }
206}
207
208impl Matcher<Marks> for MarkMatchers {
209    fn matches(&self, actual: &Marks) -> bool {
210        let Self(matchers) = self;
211        matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
212    }
213}
214
215/// A matcher for a socket's cookie.
216pub struct SocketCookieMatcher {
217    /// The cookie to check against.
218    pub cookie: u64,
219    /// Invert the matching criterion (i.e. if the socket cookie isn't the same,
220    /// it matches).
221    pub invert: bool,
222}
223
224impl Matcher<u64> for SocketCookieMatcher {
225    fn matches(&self, actual: &u64) -> bool {
226        let val = *actual == self.cookie;
227        if self.invert { !val } else { val }
228    }
229}
230
231/// A matcher for transport-layer port numbers.
232#[derive(Clone, Debug)]
233pub struct PortMatcher {
234    /// The range of port numbers in which the tested port number must fall.
235    pub range: RangeInclusive<u16>,
236    /// Whether to check for an "inverse" or "negative" match (in which case,
237    /// if the matcher criteria do *not* apply, it *is* considered a match, and
238    /// vice versa).
239    pub invert: bool,
240}
241
242impl Matcher<u16> for PortMatcher {
243    fn matches(&self, actual: &u16) -> bool {
244        let Self { range, invert } = self;
245        range.contains(actual) ^ *invert
246    }
247}
248
249bitflags! {
250    /// A matcher for TCP state machine state.
251    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
252    pub struct TcpStateMatcher: u32 {
253        /// The TCP ESTABLISHED state.
254        const ESTABLISHED = 1 << 0;
255        /// The TCP SYN_SENT state.
256        const SYN_SENT = 1 << 1;
257        /// The TCP SYN_RECV state.
258        const SYN_RECV = 1 << 2;
259        /// The TCP FIN_WAIT1 state.
260        const FIN_WAIT1 = 1 << 3;
261        /// The TCP FIN_WAIT2 state.
262        const FIN_WAIT2 = 1 << 4;
263        /// The TCP TIME_WAIT state.
264        const TIME_WAIT = 1 << 5;
265        /// The TCP CLOSE state.
266        const CLOSE = 1 << 6;
267        /// The TCP CLOSE_WAIT state.
268        const CLOSE_WAIT = 1 << 7;
269        /// The TCP LAST_ACK state.
270        const LAST_ACK = 1 << 8;
271        /// The TCP LISTEN state.
272        const LISTEN = 1 << 9;
273        /// The TCP CLOSING state.
274        const CLOSING = 1 << 10;
275    }
276}
277
278impl Matcher<TcpSocketState> for TcpStateMatcher {
279    fn matches(&self, actual: &TcpSocketState) -> bool {
280        self.contains(actual.matcher_flag())
281    }
282}
283
284/// Represents the state of a TCP socket's state machine.
285#[derive(Debug, Copy, Clone, PartialEq, Eq)]
286#[allow(missing_docs)]
287pub enum TcpSocketState {
288    Established,
289    SynSent,
290    SynRecv,
291    FinWait1,
292    FinWait2,
293    TimeWait,
294    Close,
295    CloseWait,
296    LastAck,
297    Listen,
298    Closing,
299}
300
301impl TcpSocketState {
302    fn matcher_flag(&self) -> TcpStateMatcher {
303        match self {
304            TcpSocketState::Established => TcpStateMatcher::ESTABLISHED,
305            TcpSocketState::SynSent => TcpStateMatcher::SYN_SENT,
306            TcpSocketState::SynRecv => TcpStateMatcher::SYN_RECV,
307            TcpSocketState::FinWait1 => TcpStateMatcher::FIN_WAIT1,
308            TcpSocketState::FinWait2 => TcpStateMatcher::FIN_WAIT2,
309            TcpSocketState::TimeWait => TcpStateMatcher::TIME_WAIT,
310            TcpSocketState::Close => TcpStateMatcher::CLOSE,
311            TcpSocketState::CloseWait => TcpStateMatcher::CLOSE_WAIT,
312            TcpSocketState::LastAck => TcpStateMatcher::LAST_ACK,
313            TcpSocketState::Listen => TcpStateMatcher::LISTEN,
314            TcpSocketState::Closing => TcpStateMatcher::CLOSING,
315        }
316    }
317}
318
319/// Allows code to match on properties of a TCP socket without Netstack3 Core
320/// having to specifically expose that state.
321pub trait TcpSocketProperties {
322    /// Returns whether the socket's source port is matched by the matcher.
323    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
324
325    /// Returns whether the socket's destination port is matched by the matcher.
326    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
327
328    /// Returns whether the socket's TCP state is matched by the matcher.
329    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool;
330}
331
332impl TcpSocketProperties for Never {
333    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
334        unimplemented!()
335    }
336
337    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
338        unimplemented!()
339    }
340
341    fn state_matches(&self, _matcher: &TcpStateMatcher) -> bool {
342        unimplemented!()
343    }
344}
345
346impl<T> TcpSocketProperties for &T
347where
348    T: TcpSocketProperties,
349{
350    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
351        (*self).src_port_matches(matcher)
352    }
353
354    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
355        (*self).dst_port_matches(matcher)
356    }
357
358    fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
359        (*self).state_matches(matcher)
360    }
361}
362
363/// The top-level matcher for TCP sockets.
364pub enum TcpSocketMatcher {
365    /// Match any TCP socket without further constraints.
366    Empty,
367    /// Match on the source port.
368    SrcPort(PortMatcher),
369    /// Match on the destination port.
370    DstPort(PortMatcher),
371    /// Match on the state of the TCP state machine.
372    State(TcpStateMatcher),
373}
374
375impl<T: TcpSocketProperties> Matcher<T> for TcpSocketMatcher {
376    fn matches(&self, actual: &T) -> bool {
377        match self {
378            TcpSocketMatcher::Empty => true,
379            TcpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
380            TcpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
381            TcpSocketMatcher::State(matcher) => actual.state_matches(matcher),
382        }
383    }
384}
385
386bitflags! {
387    /// A matcher for UDP states.
388    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
389    pub struct UdpStateMatcher: u32 {
390        /// The UDP socket is bound but not connected.
391        const BOUND = 1 << 0;
392        /// The UDP socket is explicitly connected.
393        const CONNECTED = 1 << 1;
394    }
395}
396
397impl Matcher<UdpSocketState> for UdpStateMatcher {
398    fn matches(&self, actual: &UdpSocketState) -> bool {
399        self.contains(actual.matcher_flag())
400    }
401}
402
403/// Represents the state of a UDP socket.
404#[derive(Debug, Copy, Clone, PartialEq, Eq)]
405pub enum UdpSocketState {
406    /// The socket is bound to a local address and (maybe) port.
407    Bound,
408    /// The socket is connected to a remote peer and has a full 4-tuple.
409    Connected,
410}
411
412impl UdpSocketState {
413    fn matcher_flag(&self) -> UdpStateMatcher {
414        match self {
415            UdpSocketState::Bound => UdpStateMatcher::BOUND,
416            UdpSocketState::Connected => UdpStateMatcher::CONNECTED,
417        }
418    }
419}
420
421/// Allows code to match on properties of a UDP socket without Netstack3 Core
422/// having to specifically expose that state.
423pub trait UdpSocketProperties {
424    /// Returns whether the socket's source port is matched by the matcher.
425    fn src_port_matches(&self, matcher: &PortMatcher) -> bool;
426
427    /// Returns whether the socket's destination port is matched by the matcher.
428    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool;
429
430    /// Returns whether the socket's UDP state is matched by the matcher.
431    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool;
432}
433
434impl UdpSocketProperties for Never {
435    fn src_port_matches(&self, _matcher: &PortMatcher) -> bool {
436        unimplemented!()
437    }
438
439    fn dst_port_matches(&self, _matcher: &PortMatcher) -> bool {
440        unimplemented!()
441    }
442
443    fn state_matches(&self, _matcher: &UdpStateMatcher) -> bool {
444        unimplemented!()
445    }
446}
447
448impl<U> UdpSocketProperties for &U
449where
450    U: UdpSocketProperties,
451{
452    fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
453        (*self).src_port_matches(matcher)
454    }
455
456    fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
457        (*self).dst_port_matches(matcher)
458    }
459
460    fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
461        (*self).state_matches(matcher)
462    }
463}
464
465/// The top-level matcher for UDP sockets.
466pub enum UdpSocketMatcher {
467    /// Match any UDP socket without further constraints.
468    Empty,
469    /// Match the source port.
470    SrcPort(PortMatcher),
471    /// Match the destination port.
472    DstPort(PortMatcher),
473    /// Match the UDP state.
474    State(UdpStateMatcher),
475}
476
477impl<T: UdpSocketProperties> Matcher<T> for UdpSocketMatcher {
478    fn matches(&self, actual: &T) -> bool {
479        match self {
480            UdpSocketMatcher::Empty => true,
481            UdpSocketMatcher::SrcPort(matcher) => actual.src_port_matches(matcher),
482            UdpSocketMatcher::DstPort(matcher) => actual.dst_port_matches(matcher),
483            UdpSocketMatcher::State(matcher) => actual.state_matches(matcher),
484        }
485    }
486}
487
488/// Provides optional access to TCP socket properties.
489pub trait MaybeSocketTransportProperties {
490    /// The type that encapsulates TCP socket properties.
491    type TcpProps<'a>: TcpSocketProperties
492    where
493        Self: 'a;
494
495    /// The type that encapsulates UDP socket properties.
496    type UdpProps<'a>: UdpSocketProperties
497    where
498        Self: 'a;
499
500    /// Returns TCP socket properties if the socket is a TCP socket.
501    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>>;
502
503    /// Returns UDP socket properties if the socket is a UDP socket.
504    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>>;
505}
506
507impl MaybeSocketTransportProperties for Never {
508    type TcpProps<'a>
509        = Never
510    where
511        Self: 'a;
512
513    type UdpProps<'a>
514        = Never
515    where
516        Self: 'a;
517
518    fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
519        unimplemented!()
520    }
521
522    fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
523        unimplemented!()
524    }
525}
526
527/// A matcher for the transport protocol of a socket.
528pub enum SocketTransportProtocolMatcher {
529    /// Match against a TCP socket.
530    Tcp(TcpSocketMatcher),
531    /// Match against a UDP socket.
532    Udp(UdpSocketMatcher),
533}
534
535impl<T: MaybeSocketTransportProperties> Matcher<T> for SocketTransportProtocolMatcher {
536    fn matches(&self, actual: &T) -> bool {
537        match self {
538            SocketTransportProtocolMatcher::Tcp(tcp_matcher) => {
539                actual.tcp_socket_properties().map_or(false, |props| tcp_matcher.matches(props))
540            }
541            SocketTransportProtocolMatcher::Udp(udp_matcher) => {
542                actual.udp_socket_properties().map_or(false, |props| udp_matcher.matches(props))
543            }
544        }
545    }
546}
547
548/// A matcher for IP addresses.
549#[derive(Clone, Derivative)]
550#[derivative(Debug)]
551pub enum AddressMatcherType<A: IpAddress> {
552    /// A subnet that must contain the address.
553    #[derivative(Debug = "transparent")]
554    Subnet(SubnetMatcher<A>),
555    /// An inclusive range of IP addresses that must contain the address.
556    Range(RangeInclusive<A>),
557}
558
559impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
560    fn matches(&self, actual: &A) -> bool {
561        match self {
562            Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
563            Self::Range(range) => range.contains(actual),
564        }
565    }
566}
567
568/// A matcher for IP addresses.
569#[derive(Clone, Debug)]
570pub struct AddressMatcher<A: IpAddress> {
571    /// The type of the address matcher.
572    pub matcher: AddressMatcherType<A>,
573    /// Whether to check for an "inverse" or "negative" match (in which case,
574    /// if the matcher criteria do *not* apply, it *is* considered a match, and
575    /// vice versa).
576    pub invert: bool,
577}
578
579impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
580    fn record<I: Inspector>(&self, name: &str, inspector: &mut I) {
581        let AddressMatcher { matcher, invert } = self;
582
583        inspector.record_child(name, |inspector| {
584            inspector.record_bool("invert", *invert);
585            match matcher {
586                AddressMatcherType::Subnet(SubnetMatcher(subnet)) => {
587                    inspector.record_display("subnet", subnet)
588                }
589                AddressMatcherType::Range(range) => {
590                    inspector.record_display("start", range.start());
591                    inspector.record_display("end", range.end());
592                }
593            }
594        })
595    }
596}
597
598impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
599    fn matches(&self, addr: &A) -> bool {
600        let Self { matcher, invert } = self;
601        matcher.matches(addr) ^ *invert
602    }
603}
604
605/// An address matcher that matches any IP version as specified at runtime.
606pub enum AddressMatcherEither {
607    /// The top-level IPv4 address matcher.
608    V4(AddressMatcher<Ipv4Addr>),
609    /// The top-level IPv6 address matcher.
610    V6(AddressMatcher<Ipv6Addr>),
611}
612
613impl Matcher<IpAddr> for AddressMatcherEither {
614    fn matches(&self, addr: &IpAddr) -> bool {
615        match self {
616            AddressMatcherEither::V4(matcher) => match addr {
617                IpAddr::V4(addr) => matcher.matches(addr),
618                IpAddr::V6(_) => false,
619            },
620            AddressMatcherEither::V6(matcher) => match addr {
621                IpAddr::V4(_) => false,
622                IpAddr::V6(addr) => matcher.matches(addr),
623            },
624        }
625    }
626}
627
628/// Allows code to match on properties of a socket without Netstack3 Core
629/// having to specifically expose that state.
630pub trait IpSocketProperties<DeviceClass> {
631    /// Returns whether the provided IP version matches the socket.
632    fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool;
633
634    /// Returns whether the provided address matcher matches the socket's source
635    /// address.
636    fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
637
638    /// Returns whether the provided address matcher matches the socket's
639    /// destination address.
640    fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool;
641
642    /// Returns whether the transport protocol matches the socket's
643    /// transport-layer information.
644    fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool;
645
646    /// Returns whether the provided interface matcher matches the socket's
647    /// bound interface, if present.
648    fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<DeviceClass>) -> bool;
649
650    /// Returns whether the provided cookie matcher matches the socket's cookie.
651    fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool;
652
653    /// Returns whether the provided mark matcher matches the socket's mark 1,
654    /// if present.
655    fn mark1_matches(&self, mark: &MarkMatcher) -> bool;
656
657    /// Returns whether the provided mark matcher matches the socket's mark 2,
658    /// if present.
659    fn mark2_matches(&self, mark: &MarkMatcher) -> bool;
660}
661
662/// The top-level matcher for IP sockets.
663pub enum IpSocketMatcher<DeviceClass> {
664    /// Matches the socket's address family.
665    Family(net_types::ip::IpVersion),
666    /// Matches the socket's source address.
667    SrcAddr(AddressMatcherEither),
668    /// Matches the socket's destination address.
669    DstAddr(AddressMatcherEither),
670    /// Matches the socket's transport protocol.
671    Proto(SocketTransportProtocolMatcher),
672    /// Matches the socket's bound interface.
673    BoundInterface(BoundInterfaceMatcher<DeviceClass>),
674    /// Matches the socket's cookie.
675    Cookie(SocketCookieMatcher),
676    /// Matches the socket's mark 1.
677    Mark1(MarkMatcher),
678    /// Matches the socket's mark 2.
679    Mark2(MarkMatcher),
680}
681
682impl<DeviceClass, S: IpSocketProperties<DeviceClass>> Matcher<S>
683    for &[IpSocketMatcher<DeviceClass>]
684{
685    fn matches(&self, actual: &S) -> bool {
686        self.iter().all(|matcher| matcher.matches(actual))
687    }
688}
689
690impl<DeviceClass, S: IpSocketProperties<DeviceClass>> Matcher<S> for IpSocketMatcher<DeviceClass> {
691    fn matches(&self, actual: &S) -> bool {
692        match self {
693            IpSocketMatcher::Family(family) => actual.family_matches(family),
694            IpSocketMatcher::SrcAddr(addr) => actual.src_addr_matches(addr),
695            IpSocketMatcher::DstAddr(addr) => actual.dst_addr_matches(addr),
696            IpSocketMatcher::Proto(proto) => actual.transport_protocol_matches(proto),
697            IpSocketMatcher::BoundInterface(iface) => actual.bound_interface_matches(iface),
698            IpSocketMatcher::Cookie(cookie) => actual.cookie_matches(cookie),
699            IpSocketMatcher::Mark1(mark) => actual.mark1_matches(mark),
700            IpSocketMatcher::Mark2(mark) => actual.mark2_matches(mark),
701        }
702    }
703}
704
705#[cfg(any(test, feature = "testutils"))]
706pub(crate) mod testutil {
707    use alloc::string::String;
708    use core::num::NonZeroU64;
709
710    use crate::matchers::InterfaceProperties;
711    use crate::testutil::{FakeDeviceClass, FakeStrongDeviceId, FakeWeakDeviceId};
712    use crate::{DeviceIdentifier, StrongDeviceIdentifier};
713
714    /// A fake device ID for testing matchers.
715    #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
716    #[allow(missing_docs)]
717    pub struct FakeMatcherDeviceId {
718        pub id: NonZeroU64,
719        pub name: String,
720        pub class: FakeDeviceClass,
721    }
722
723    impl FakeMatcherDeviceId {
724        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary WLAN interface.
725        ///
726        /// The interface returned will always be identical.
727        pub fn wlan_interface() -> FakeMatcherDeviceId {
728            FakeMatcherDeviceId {
729                id: NonZeroU64::new(1).unwrap(),
730                name: String::from("wlan"),
731                class: FakeDeviceClass::Wlan,
732            }
733        }
734
735        /// Returns a [`FakeMatcherDeviceId`] for an arbitrary Ethernet interface.
736        ///
737        /// The interface returned will always be identical.
738        pub fn ethernet_interface() -> FakeMatcherDeviceId {
739            FakeMatcherDeviceId {
740                id: NonZeroU64::new(2).unwrap(),
741                name: String::from("eth"),
742                class: FakeDeviceClass::Ethernet,
743            }
744        }
745    }
746
747    impl StrongDeviceIdentifier for FakeMatcherDeviceId {
748        type Weak = FakeWeakDeviceId<Self>;
749
750        fn downgrade(&self) -> Self::Weak {
751            FakeWeakDeviceId(self.clone())
752        }
753    }
754
755    impl DeviceIdentifier for FakeMatcherDeviceId {
756        fn is_loopback(&self) -> bool {
757            false
758        }
759    }
760
761    impl FakeStrongDeviceId for FakeMatcherDeviceId {
762        fn is_alive(&self) -> bool {
763            true
764        }
765    }
766
767    impl PartialEq<FakeWeakDeviceId<FakeMatcherDeviceId>> for FakeMatcherDeviceId {
768        fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeMatcherDeviceId>) -> bool {
769            self == other
770        }
771    }
772
773    impl InterfaceProperties<FakeDeviceClass> for FakeMatcherDeviceId {
774        fn id_matches(&self, id: &NonZeroU64) -> bool {
775            &self.id == id
776        }
777
778        fn name_matches(&self, name: &str) -> bool {
779            &self.name == name
780        }
781
782        fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
783            &self.class == class
784        }
785    }
786}
787
788#[cfg(test)]
789mod tests {
790    use ip_test_macro::ip_test;
791    use net_types::Witness;
792    use net_types::ip::{Ip, IpVersion, Ipv4, Ipv6};
793    use test_case::test_case;
794
795    use super::*;
796    use crate::testutil::{FakeDeviceClass, FakeMatcherDeviceId, TestIpExt};
797
798    /// Only matches `true`.
799    #[derive(Debug)]
800    struct TrueMatcher;
801
802    impl Matcher<bool> for TrueMatcher {
803        fn matches(&self, actual: &bool) -> bool {
804            *actual
805        }
806    }
807
808    #[test]
809    fn test_optional_matcher_optional_value() {
810        assert!(TrueMatcher.matches(&true));
811        assert!(!TrueMatcher.matches(&false));
812
813        assert!(TrueMatcher.required_matches(Some(&true)));
814        assert!(!TrueMatcher.required_matches(Some(&false)));
815        assert!(!TrueMatcher.required_matches(None));
816
817        assert!(Some(TrueMatcher).matches(&true));
818        assert!(!Some(TrueMatcher).matches(&false));
819        assert!(None::<TrueMatcher>.matches(&true));
820        assert!(None::<TrueMatcher>.matches(&false));
821
822        assert!(Some(TrueMatcher).required_matches(Some(&true)));
823        assert!(!Some(TrueMatcher).required_matches(Some(&false)));
824        assert!(!Some(TrueMatcher).required_matches(None));
825        assert!(None::<TrueMatcher>.required_matches(Some(&true)));
826        assert!(None::<TrueMatcher>.required_matches(Some(&false)));
827        assert!(None::<TrueMatcher>.required_matches(None));
828    }
829
830    #[test_case(
831        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
832        FakeMatcherDeviceId::wlan_interface() => true
833    )]
834    #[test_case(
835        InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id),
836        FakeMatcherDeviceId::ethernet_interface() => false
837    )]
838    #[test_case(
839        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
840        FakeMatcherDeviceId::wlan_interface() => true
841    )]
842    #[test_case(
843        InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name),
844        FakeMatcherDeviceId::ethernet_interface() => false
845    )]
846    #[test_case(
847        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
848        FakeMatcherDeviceId::wlan_interface() => true
849    )]
850    #[test_case(
851        InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan),
852        FakeMatcherDeviceId::ethernet_interface() => false
853    )]
854    fn interface_matcher(
855        matcher: InterfaceMatcher<FakeDeviceClass>,
856        device: FakeMatcherDeviceId,
857    ) -> bool {
858        matcher.matches(&device)
859    }
860
861    #[test_case(BoundInterfaceMatcher::Unbound, None => true)]
862    #[test_case(
863        BoundInterfaceMatcher::Unbound,
864        Some(FakeMatcherDeviceId::wlan_interface()) => false
865    )]
866    #[test_case(
867        BoundInterfaceMatcher::Bound(
868            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
869        ),
870        None => false
871    )]
872    #[test_case(
873        BoundInterfaceMatcher::Bound(
874            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
875        ),
876        Some(FakeMatcherDeviceId::wlan_interface()) => true
877    )]
878    #[test_case(
879        BoundInterfaceMatcher::Bound(
880            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
881        ),
882        Some(FakeMatcherDeviceId::ethernet_interface()) => false
883    )]
884    #[test_case(
885        BoundInterfaceMatcher::Bound(
886            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
887        ),
888        None => false
889    )]
890    #[test_case(
891        BoundInterfaceMatcher::Bound(
892            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
893        ),
894        Some(FakeMatcherDeviceId::wlan_interface()) => true
895    )]
896    #[test_case(
897        BoundInterfaceMatcher::Bound(
898            InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name)
899        ),
900        Some(FakeMatcherDeviceId::ethernet_interface()) => false
901    )]
902    #[test_case(
903        BoundInterfaceMatcher::Bound(
904            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
905        ),
906        None => false
907    )]
908    #[test_case(
909        BoundInterfaceMatcher::Bound(
910            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
911        ),
912        Some(FakeMatcherDeviceId::wlan_interface()) => true
913    )]
914    #[test_case(
915        BoundInterfaceMatcher::Bound(
916            InterfaceMatcher::DeviceClass(FakeDeviceClass::Wlan)
917        ),
918        Some(FakeMatcherDeviceId::ethernet_interface()) => false
919    )]
920    fn bound_interface_matcher(
921        matcher: BoundInterfaceMatcher<FakeDeviceClass>,
922        device: Option<FakeMatcherDeviceId>,
923    ) -> bool {
924        matcher.matches(&device.as_ref())
925    }
926
927    #[ip_test(I)]
928    fn subnet_matcher<I: Ip + TestIpExt>() {
929        let matcher = SubnetMatcher(I::TEST_ADDRS.subnet);
930        assert!(matcher.matches(&I::TEST_ADDRS.local_ip));
931        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
932    }
933
934    #[test_case(MarkMatcher::Unmarked, Mark(None) => true; "unmarked matches none")]
935    #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false; "unmarked does not match some")]
936    #[test_case(MarkMatcher::Marked {
937        mask: 1,
938        start: 0,
939        end: 0,
940        invert: false,
941    }, Mark(None) => false; "marked does not match none")]
942    #[test_case(MarkMatcher::Marked {
943        mask: 1,
944        start: 0,
945        end: 0,
946        invert: false,
947    }, Mark(Some(0)) => true; "marked 0 mask 1 matches 0")]
948    #[test_case(MarkMatcher::Marked {
949        mask: 1,
950        start: 0,
951        end: 0,
952        invert: false,
953    }, Mark(Some(1)) => false; "marked 0 mask 1 does not match 1")]
954    #[test_case(MarkMatcher::Marked {
955        mask: 1,
956        start: 0,
957        end: 0,
958        invert: false,
959    }, Mark(Some(2)) => true; "marked 0 mask 1 matches 2")]
960    #[test_case(MarkMatcher::Marked {
961        mask: 1,
962        start: 0,
963        end: 0,
964        invert: false,
965    }, Mark(Some(3)) => false; "marked 0 mask 1 does not match 3")]
966    #[test_case(MarkMatcher::Marked {
967        mask: !0,
968        start: 0,
969        end: 10,
970        invert: true,
971    }, Mark(Some(5)) => false; "marked invert no match in range")]
972    #[test_case(MarkMatcher::Marked {
973        mask: !0,
974        start: 0,
975        end: 10,
976        invert: true,
977    }, Mark(Some(11)) => true; "marked invert matches out of range")]
978    fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
979        matcher.matches(&mark)
980    }
981
982    #[test_case(
983        MarkMatchers::new(
984            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
985            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
986        ),
987        Marks::new([]) => true;
988        "all unmarked matches empty"
989    )]
990    #[test_case(
991        MarkMatchers::new(
992            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
993            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
994        ),
995        Marks::new([(MarkDomain::Mark1, 1)]) => false;
996        "all unmarked does not match mark1"
997    )]
998    #[test_case(
999        MarkMatchers::new(
1000            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1001            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1002        ),
1003        Marks::new([(MarkDomain::Mark2, 1)]) => false;
1004        "all unmarked does not match mark2"
1005    )]
1006    #[test_case(
1007        MarkMatchers::new(
1008            [(MarkDomain::Mark1, MarkMatcher::Unmarked),
1009            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1010        ),
1011        Marks::new([
1012            (MarkDomain::Mark1, 1),
1013            (MarkDomain::Mark2, 1),
1014        ]) => false;
1015        "all unmarked does not match mark1 and mark2"
1016    )]
1017    #[test_case(
1018        MarkMatchers::new(
1019            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1020            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1021        ),
1022        Marks::new([(MarkDomain::Mark1, 1)]) => true;
1023        "mark1 marked matches"
1024    )]
1025    #[test_case(
1026        MarkMatchers::new(
1027            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1028            (MarkDomain::Mark2, MarkMatcher::Unmarked)]
1029        ),
1030        Marks::new([(MarkDomain::Mark1, 2)]) => false;
1031        "mark1 marked no match"
1032    )]
1033    #[test_case(
1034        MarkMatchers::new(
1035            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1036            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1037        ),
1038        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 2)]) => true;
1039        "all marked matches"
1040    )]
1041    #[test_case(
1042        MarkMatchers::new(
1043            [(MarkDomain::Mark1, MarkMatcher::Marked { mask: !0, start: 1, end: 1, invert: false }),
1044            (MarkDomain::Mark2, MarkMatcher::Marked { mask: !0, start: 2, end: 2, invert: false })]
1045        ),
1046        Marks::new([(MarkDomain::Mark1, 1), (MarkDomain::Mark2, 3)]) => false;
1047        "all marked no match mark2"
1048    )]
1049    fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
1050        matchers.matches(&marks)
1051    }
1052
1053    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 123 => true)]
1054    #[test_case(SocketCookieMatcher { cookie: 123, invert: false }, 456 => false)]
1055    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 123 => false)]
1056    #[test_case(SocketCookieMatcher { cookie: 123, invert: true }, 456 => true)]
1057    fn socket_cookie_matcher(matcher: SocketCookieMatcher, actual: u64) -> bool {
1058        matcher.matches(&actual)
1059    }
1060
1061    #[test_case(PortMatcher { range: 10..=20, invert: false }, 9 => false)]
1062    #[test_case(PortMatcher { range: 10..=20, invert: false }, 10 => true)]
1063    #[test_case(PortMatcher { range: 10..=20, invert: false }, 15 => true)]
1064    #[test_case(PortMatcher { range: 10..=20, invert: false }, 20 => true)]
1065    #[test_case(PortMatcher { range: 10..=20, invert: false }, 21 => false)]
1066    #[test_case(PortMatcher { range: 10..=20, invert: true }, 9 => true)]
1067    #[test_case(PortMatcher { range: 10..=20, invert: true }, 10 => false)]
1068    #[test_case(PortMatcher { range: 10..=20, invert: true }, 15 => false)]
1069    #[test_case(PortMatcher { range: 10..=20, invert: true }, 20 => false)]
1070    #[test_case(PortMatcher { range: 10..=20, invert: true }, 21 => true)]
1071    fn port_matcher(matcher: PortMatcher, actual: u16) -> bool {
1072        matcher.matches(&actual)
1073    }
1074
1075    struct FakeTcpSocket {
1076        src_port: u16,
1077        dst_port: u16,
1078        state: TcpSocketState,
1079    }
1080
1081    impl MaybeSocketTransportProperties for FakeTcpSocket {
1082        type TcpProps<'a>
1083            = Self
1084        where
1085            Self: 'a;
1086
1087        type UdpProps<'a>
1088            = Never
1089        where
1090            Self: 'a;
1091
1092        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1093            Some(self)
1094        }
1095
1096        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1097            None
1098        }
1099    }
1100
1101    impl TcpSocketProperties for FakeTcpSocket {
1102        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1103            matcher.matches(&self.src_port)
1104        }
1105
1106        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1107            matcher.matches(&self.dst_port)
1108        }
1109
1110        fn state_matches(&self, matcher: &TcpStateMatcher) -> bool {
1111            matcher.matches(&self.state)
1112        }
1113    }
1114
1115    struct FakeUdpSocket {
1116        src_port: u16,
1117        dst_port: u16,
1118        state: UdpSocketState,
1119    }
1120
1121    impl MaybeSocketTransportProperties for FakeUdpSocket {
1122        type TcpProps<'a>
1123            = Never
1124        where
1125            Self: 'a;
1126
1127        type UdpProps<'a>
1128            = Self
1129        where
1130            Self: 'a;
1131
1132        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1133            None
1134        }
1135
1136        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1137            Some(self)
1138        }
1139    }
1140
1141    impl UdpSocketProperties for FakeUdpSocket {
1142        fn src_port_matches(&self, matcher: &PortMatcher) -> bool {
1143            matcher.matches(&self.src_port)
1144        }
1145
1146        fn dst_port_matches(&self, matcher: &PortMatcher) -> bool {
1147            matcher.matches(&self.dst_port)
1148        }
1149
1150        fn state_matches(&self, matcher: &UdpStateMatcher) -> bool {
1151            matcher.matches(&self.state)
1152        }
1153    }
1154
1155    struct FakeIpSocket<I, T>
1156    where
1157        I: TestIpExt,
1158        T: MaybeSocketTransportProperties,
1159    {
1160        src_ip: I::Addr,
1161        dst_ip: I::Addr,
1162        proto: T,
1163        intf: Option<FakeMatcherDeviceId>,
1164        cookie: u64,
1165        mark_1: Mark,
1166        mark_2: Mark,
1167    }
1168
1169    impl<I, T> MaybeSocketTransportProperties for FakeIpSocket<I, T>
1170    where
1171        I: TestIpExt,
1172        T: MaybeSocketTransportProperties,
1173    {
1174        type TcpProps<'a>
1175            = T::TcpProps<'a>
1176        where
1177            Self: 'a;
1178
1179        type UdpProps<'a>
1180            = T::UdpProps<'a>
1181        where
1182            Self: 'a;
1183
1184        fn tcp_socket_properties(&self) -> Option<&Self::TcpProps<'_>> {
1185            self.proto.tcp_socket_properties()
1186        }
1187
1188        fn udp_socket_properties(&self) -> Option<&Self::UdpProps<'_>> {
1189            self.proto.udp_socket_properties()
1190        }
1191    }
1192
1193    impl<I, T> IpSocketProperties<FakeDeviceClass> for FakeIpSocket<I, T>
1194    where
1195        I: TestIpExt,
1196        T: MaybeSocketTransportProperties,
1197    {
1198        fn family_matches(&self, family: &net_types::ip::IpVersion) -> bool {
1199            *family == I::VERSION
1200        }
1201
1202        fn src_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1203            addr.matches(&self.src_ip.into())
1204        }
1205
1206        fn dst_addr_matches(&self, addr: &AddressMatcherEither) -> bool {
1207            addr.matches(&self.dst_ip.into())
1208        }
1209
1210        fn transport_protocol_matches(&self, matcher: &SocketTransportProtocolMatcher) -> bool {
1211            matcher.matches(self)
1212        }
1213
1214        fn bound_interface_matches(&self, iface: &BoundInterfaceMatcher<FakeDeviceClass>) -> bool {
1215            iface.matches(&self.intf.as_ref())
1216        }
1217
1218        fn cookie_matches(&self, cookie: &SocketCookieMatcher) -> bool {
1219            cookie.matches(&self.cookie)
1220        }
1221
1222        fn mark1_matches(&self, mark: &MarkMatcher) -> bool {
1223            mark.matches(&self.mark_1)
1224        }
1225
1226        fn mark2_matches(&self, mark: &MarkMatcher) -> bool {
1227            mark.matches(&self.mark_2)
1228        }
1229    }
1230
1231    #[test_case(
1232        TcpSocketMatcher::Empty,
1233        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1234        "empty matcher"
1235    )]
1236    #[test_case(
1237        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1238        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1239        "src_port match"
1240    )]
1241    #[test_case(
1242        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false }),
1243        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => false;
1244        "src_port no match"
1245    )]
1246    #[test_case(
1247        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1248        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1249        "src_port invert no match"
1250    )]
1251    #[test_case(
1252        TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: true }),
1253        FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established } => true;
1254        "src_port invert match"
1255    )]
1256    #[test_case(
1257        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1258        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1259        "dst_port match"
1260    )]
1261    #[test_case(
1262        TcpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1263        FakeTcpSocket { src_port: 80, dst_port: 12346, state: TcpSocketState::Established } => false;
1264        "dst_port no match"
1265    )]
1266    #[test_case(
1267        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED),
1268        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1269        "state match"
1270    )]
1271    #[test_case(
1272        TcpSocketMatcher::State(TcpStateMatcher::SYN_SENT),
1273        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => false;
1274        "state no match"
1275    )]
1276    #[test_case(
1277        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1278        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established } => true;
1279        "state multi match established"
1280    )]
1281    #[test_case(
1282        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1283        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::SynSent } => true;
1284        "state multi match syn_sent"
1285    )]
1286    #[test_case(
1287        TcpSocketMatcher::State(TcpStateMatcher::ESTABLISHED | TcpStateMatcher::SYN_SENT),
1288        FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::FinWait1 } => false;
1289        "state multi no match"
1290    )]
1291    fn tcp_socket_matcher(matcher: TcpSocketMatcher, socket: FakeTcpSocket) -> bool {
1292        matcher.matches(&socket)
1293    }
1294
1295    #[test_case(
1296        UdpSocketMatcher::Empty,
1297        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1298        "empty matcher"
1299    )]
1300    #[test_case(
1301        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1302        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1303        "src_port match"
1304    )]
1305    #[test_case(
1306        UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false }),
1307        FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound } => false;
1308        "src_port no match"
1309    )]
1310    #[test_case(
1311        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1312        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1313        "dst_port match"
1314    )]
1315    #[test_case(
1316        UdpSocketMatcher::DstPort(PortMatcher { range: 12345..=12345, invert: false }),
1317        FakeUdpSocket { src_port: 53, dst_port: 12346, state: UdpSocketState::Bound } => false;
1318        "dst_port no match"
1319    )]
1320    #[test_case(
1321        UdpSocketMatcher::State(UdpStateMatcher::BOUND),
1322        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1323        "state match bound"
1324    )]
1325    #[test_case(
1326        UdpSocketMatcher::State(UdpStateMatcher::CONNECTED),
1327        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => false;
1328        "state no match connected"
1329    )]
1330    #[test_case(
1331        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1332        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound } => true;
1333        "state multi match bound"
1334    )]
1335    #[test_case(
1336        UdpSocketMatcher::State(UdpStateMatcher::BOUND | UdpStateMatcher::CONNECTED),
1337        FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Connected } => true;
1338        "state multi match connected"
1339    )]
1340    fn udp_socket_matcher(matcher: UdpSocketMatcher, socket: FakeUdpSocket) -> bool {
1341        matcher.matches(&socket)
1342    }
1343
1344    #[ip_test(I)]
1345    #[test_case(
1346        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1347        FakeIpSocket {
1348            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1349            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1350            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1351            cookie: 0,
1352            intf: None,
1353            mark_1: None.into(),
1354            mark_2: None.into(),
1355        } => true;
1356        "tcp empty"
1357    )]
1358    #[test_case(
1359        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Tcp(TcpSocketMatcher::Empty)),
1360        FakeIpSocket {
1361            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1362            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1363            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1364            cookie: 0,
1365            intf: None,
1366            mark_1: None.into(),
1367            mark_2: None.into(),
1368        } => false;
1369        "tcp empty no match udp"
1370    )]
1371    #[test_case(
1372        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1373        FakeIpSocket {
1374            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1375            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1376            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1377            cookie: 0,
1378            intf: None,
1379            mark_1: None.into(),
1380            mark_2: None.into(),
1381        } => false;
1382        "udp empty no match tcp"
1383    )]
1384    #[test_case(
1385        IpSocketMatcher::Proto(SocketTransportProtocolMatcher::Udp(UdpSocketMatcher::Empty)),
1386        FakeIpSocket {
1387            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1388            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1389            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1390            cookie: 0,
1391            intf: None,
1392            mark_1: None.into(),
1393            mark_2: None.into(),
1394        } => true;
1395        "udp empty"
1396    )]
1397    #[test_case(
1398        IpSocketMatcher::Proto(
1399            SocketTransportProtocolMatcher::Tcp(
1400                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1401            )
1402        ),
1403        FakeIpSocket {
1404            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1405            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1406            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1407            cookie: 0,
1408            intf: None,
1409            mark_1: None.into(),
1410            mark_2: None.into(),
1411        } => true;
1412        "tcp src_port match"
1413    )]
1414    #[test_case(
1415        IpSocketMatcher::Proto(
1416            SocketTransportProtocolMatcher::Tcp(
1417                TcpSocketMatcher::SrcPort(PortMatcher { range: 80..=80, invert: false })
1418            )
1419        ),
1420        FakeIpSocket {
1421            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1422            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1423            proto: FakeTcpSocket { src_port: 81, dst_port: 12345, state: TcpSocketState::Established },
1424            cookie: 0,
1425            intf: None,
1426            mark_1: None.into(),
1427            mark_2: None.into(),
1428        } => false;
1429        "tcp src_port no match"
1430    )]
1431    #[test_case(
1432        IpSocketMatcher::Proto(
1433            SocketTransportProtocolMatcher::Udp(
1434                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1435            )
1436        ),
1437        FakeIpSocket {
1438            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1439            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1440            proto: FakeUdpSocket { src_port: 53, dst_port: 12345, state: UdpSocketState::Bound },
1441            cookie: 0,
1442            intf: None,
1443            mark_1: None.into(),
1444            mark_2: None.into(),
1445        } => true;
1446        "udp src_port match"
1447    )]
1448    #[test_case(
1449        IpSocketMatcher::Proto(
1450            SocketTransportProtocolMatcher::Udp(
1451                UdpSocketMatcher::SrcPort(PortMatcher { range: 53..=53, invert: false })
1452            )
1453        ),
1454        FakeIpSocket {
1455            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1456            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1457            proto: FakeUdpSocket { src_port: 54, dst_port: 12345, state: UdpSocketState::Bound },
1458            cookie: 0,
1459            intf: None,
1460            mark_1: None.into(),
1461            mark_2: None.into(),
1462        } => false;
1463        "udp src_port no match"
1464    )]
1465    #[test_case(
1466        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1467        FakeIpSocket {
1468            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1469            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1470            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1471            cookie: 123,
1472            intf: None,
1473            mark_1: None.into(),
1474            mark_2: None.into(),
1475        } => true;
1476        "cookie match"
1477    )]
1478    #[test_case(
1479        IpSocketMatcher::Cookie(SocketCookieMatcher { cookie: 123, invert: false }),
1480        FakeIpSocket {
1481            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1482            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1483            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1484            cookie: 456,
1485            intf: None,
1486            mark_1: None.into(),
1487            mark_2: None.into(),
1488        } => false;
1489        "cookie no match"
1490    )]
1491    #[test_case(
1492        IpSocketMatcher::Mark1(MarkMatcher::Unmarked),
1493        FakeIpSocket {
1494            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1495            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1496            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1497            cookie: 0,
1498            intf: None,
1499            mark_1: None.into(),
1500            mark_2: None.into(),
1501        } => true;
1502        "mark1 unmarked match"
1503    )]
1504    #[test_case(
1505        IpSocketMatcher::Mark1(MarkMatcher::Unmarked),
1506        FakeIpSocket {
1507            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1508            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1509            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1510            cookie: 0,
1511            intf: None,
1512            mark_1: Some(1).into(),
1513            mark_2: None.into(),
1514        } => false;
1515        "mark1 unmarked no match"
1516    )]
1517    #[test_case(
1518        IpSocketMatcher::Mark2(MarkMatcher::Unmarked),
1519        FakeIpSocket {
1520            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1521            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1522            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1523            cookie: 0,
1524            intf: None,
1525            mark_1: None.into(),
1526            mark_2: None.into(),
1527        } => true;
1528        "mark2 unmarked match"
1529    )]
1530    #[test_case(
1531        IpSocketMatcher::Mark2(MarkMatcher::Unmarked),
1532        FakeIpSocket {
1533            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1534            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1535            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1536            cookie: 0,
1537            intf: None,
1538            mark_1: None.into(),
1539            mark_2: Some(1).into(),
1540        } => false;
1541        "mark2 unmarked no match"
1542    )]
1543    #[test_case(
1544        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1545            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1546        )),
1547        FakeIpSocket {
1548            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1549            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1550            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1551            cookie: 0,
1552            intf: Some(FakeMatcherDeviceId::wlan_interface()),
1553            mark_1: None.into(),
1554            mark_2: None.into(),
1555        } => true;
1556        "bound_interface match"
1557    )]
1558    #[test_case(
1559        IpSocketMatcher::BoundInterface(BoundInterfaceMatcher::Bound(
1560            InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id)
1561        )),
1562        FakeIpSocket {
1563            src_ip: <I as TestIpExt>::TEST_ADDRS.local_ip.get(),
1564            dst_ip: <I as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1565            proto: FakeTcpSocket { src_port: 80, dst_port: 12345, state: TcpSocketState::Established },
1566            cookie: 0,
1567            intf: Some(FakeMatcherDeviceId::ethernet_interface()),
1568            mark_1: None.into(),
1569            mark_2: None.into(),
1570        } => false;
1571        "bound_interface no match"
1572    )]
1573    fn ip_socket_matcher<I: TestIpExt, T: MaybeSocketTransportProperties>(
1574        matcher: IpSocketMatcher<FakeDeviceClass>,
1575        socket: FakeIpSocket<I, T>,
1576    ) -> bool {
1577        matcher.matches(&socket)
1578    }
1579
1580    #[ip_test(I)]
1581    fn address_matcher_type<I: TestIpExt>() {
1582        let local_ip = I::TEST_ADDRS.local_ip.get();
1583        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1584
1585        let matcher = AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet));
1586        assert!(matcher.matches(&local_ip));
1587        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1588
1589        let matcher = AddressMatcherType::Range(local_ip..=remote_ip);
1590        assert!(matcher.matches(&local_ip));
1591        assert!(matcher.matches(&remote_ip));
1592        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1593    }
1594
1595    #[ip_test(I)]
1596    fn address_matcher<I: TestIpExt>() {
1597        let local_ip = I::TEST_ADDRS.local_ip.get();
1598        let remote_ip = I::TEST_ADDRS.remote_ip.get();
1599
1600        let matcher = AddressMatcher {
1601            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1602            invert: false,
1603        };
1604        assert!(matcher.matches(&local_ip));
1605        assert!(matcher.matches(&remote_ip));
1606        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1607
1608        let matcher = AddressMatcher {
1609            matcher: AddressMatcherType::Subnet(SubnetMatcher(I::TEST_ADDRS.subnet)),
1610            invert: true,
1611        };
1612        assert!(!matcher.matches(&local_ip));
1613        assert!(!matcher.matches(&remote_ip));
1614        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1615
1616        let matcher = AddressMatcher {
1617            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1618            invert: false,
1619        };
1620        assert!(matcher.matches(&local_ip));
1621        assert!(matcher.matches(&remote_ip));
1622        assert!(!matcher.matches(&I::get_other_remote_ip_address(1)));
1623
1624        let matcher = AddressMatcher {
1625            matcher: AddressMatcherType::Range(local_ip..=remote_ip),
1626            invert: true,
1627        };
1628        assert!(!matcher.matches(&local_ip));
1629        assert!(!matcher.matches(&remote_ip));
1630        assert!(matcher.matches(&I::get_other_remote_ip_address(1)));
1631    }
1632
1633    #[test]
1634    fn agnostic_address_matcher() {
1635        let v4_addr = IpAddr::V4(Ipv4Addr::new([192, 0, 2, 1]));
1636        let v6_addr = IpAddr::V6(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 1]));
1637
1638        let v4_subnet = Subnet::new(Ipv4Addr::new([192, 0, 2, 0]), 24).unwrap();
1639        let v6_subnet = Subnet::new(Ipv6Addr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0]), 32).unwrap();
1640
1641        let v4_matcher = AddressMatcherEither::V4(AddressMatcher {
1642            matcher: AddressMatcherType::Subnet(SubnetMatcher(v4_subnet)),
1643            invert: false,
1644        });
1645        assert!(v4_matcher.matches(&v4_addr));
1646        assert!(!v4_matcher.matches(&v6_addr));
1647
1648        let v6_matcher = AddressMatcherEither::V6(AddressMatcher {
1649            matcher: AddressMatcherType::Subnet(SubnetMatcher(v6_subnet)),
1650            invert: false,
1651        });
1652        assert!(!v6_matcher.matches(&v4_addr));
1653        assert!(v6_matcher.matches(&v6_addr));
1654    }
1655
1656    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => true; "v4 family matcher on v4 socket")]
1657    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => false; "v6 family matcher on v4 socket")]
1658    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1659        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1660        invert: false,
1661    })) => true; "src_addr match")]
1662    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V4(AddressMatcher {
1663        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1664        invert: false,
1665    })) => false; "src_addr no match")]
1666    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1667        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv4::TEST_ADDRS.subnet)),
1668        invert: false,
1669    })) => true; "dst_addr match")]
1670    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V4(AddressMatcher {
1671        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv4Addr::new([0, 0, 0, 0]), 32).unwrap())),
1672        invert: false,
1673    })) => false; "dst_addr no match")]
1674    fn ip_socket_matcher_test_v4(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1675        let socket = FakeIpSocket::<Ipv4, _> {
1676            src_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1677            dst_ip: <Ipv4 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1678            proto: FakeTcpSocket {
1679                src_port: 80,
1680                dst_port: 12345,
1681                state: TcpSocketState::Established,
1682            },
1683            cookie: 0,
1684            intf: None,
1685            mark_1: None.into(),
1686            mark_2: None.into(),
1687        };
1688        matcher.matches(&socket)
1689    }
1690
1691    #[test_case(IpSocketMatcher::Family(IpVersion::V4) => false; "v4 family matcher on v6 socket")]
1692    #[test_case(IpSocketMatcher::Family(IpVersion::V6) => true; "v6 family matcher on v6 socket")]
1693    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1694        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1695        invert: false,
1696    })) => true; "src_addr match v6")]
1697    #[test_case(IpSocketMatcher::SrcAddr(AddressMatcherEither::V6(AddressMatcher {
1698        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1699        invert: false,
1700    })) => false; "src_addr no match v6")]
1701    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1702        matcher: AddressMatcherType::Subnet(SubnetMatcher(Ipv6::TEST_ADDRS.subnet)),
1703        invert: false,
1704    })) => true; "dst_addr match v6")]
1705    #[test_case(IpSocketMatcher::DstAddr(AddressMatcherEither::V6(AddressMatcher {
1706        matcher: AddressMatcherType::Subnet(SubnetMatcher(Subnet::new(Ipv6Addr::new([0; 8]), 128).unwrap())),
1707        invert: false,
1708    })) => false; "dst_addr no match v6")]
1709    fn ip_socket_matcher_test_v6(matcher: IpSocketMatcher<FakeDeviceClass>) -> bool {
1710        let socket = FakeIpSocket::<Ipv6, _> {
1711            src_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.local_ip.get(),
1712            dst_ip: <Ipv6 as TestIpExt>::TEST_ADDRS.remote_ip.get(),
1713            proto: FakeTcpSocket {
1714                src_port: 80,
1715                dst_port: 12345,
1716                state: TcpSocketState::Established,
1717            },
1718            cookie: 0,
1719            intf: None,
1720            mark_1: None.into(),
1721            mark_2: None.into(),
1722        };
1723        matcher.matches(&socket)
1724    }
1725}