netstack3_filter/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6use core::num::NonZeroU64;
7use core::ops::RangeInclusive;
8use netstack3_base::{DeviceNameMatcher, DeviceWithName, InspectableValue, Matcher, SubnetMatcher};
9
10use derivative::Derivative;
11use net_types::ip::IpAddress;
12use packet_formats::ip::IpExt;
13
14use crate::logic::Interfaces;
15use crate::packets::{IpPacket, MaybeTransportPacket, TransportPacketData};
16
17/// A matcher for network interfaces.
18#[derive(Clone, Derivative)]
19#[derivative(Debug)]
20pub enum InterfaceMatcher<DeviceClass> {
21    /// The ID of the interface as assigned by the netstack.
22    Id(NonZeroU64),
23    /// Match based on name.
24    #[derivative(Debug = "transparent")]
25    Name(DeviceNameMatcher),
26    /// The device class of the interface.
27    DeviceClass(DeviceClass),
28}
29
30impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
31    fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
32        inspector.record_debug(name, self);
33    }
34}
35
36/// Allows filtering code to match on properties of an interface (ID, name, and
37/// device class) without Netstack3 Core (or Bindings, in the case of the device
38/// class) having to specifically expose that state.
39pub trait InterfaceProperties<DeviceClass>: DeviceWithName {
40    /// Returns whether the provided ID matches the interface.
41    fn id_matches(&self, id: &NonZeroU64) -> bool;
42
43    /// Returns whether the provided device class matches the interface.
44    fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
45}
46
47impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
48    for InterfaceMatcher<DeviceClass>
49{
50    fn matches(&self, actual: &I) -> bool {
51        match self {
52            InterfaceMatcher::Id(id) => actual.id_matches(id),
53            InterfaceMatcher::Name(name_matcher) => name_matcher.matches(actual),
54            InterfaceMatcher::DeviceClass(device_class) => {
55                actual.device_class_matches(device_class)
56            }
57        }
58    }
59}
60
61/// A matcher for IP addresses.
62#[derive(Clone, Derivative)]
63#[derivative(Debug)]
64pub enum AddressMatcherType<A: IpAddress> {
65    /// A subnet that must contain the address.
66    #[derivative(Debug = "transparent")]
67    Subnet(SubnetMatcher<A>),
68    /// An inclusive range of IP addresses that must contain the address.
69    Range(RangeInclusive<A>),
70}
71
72impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
73    fn matches(&self, actual: &A) -> bool {
74        match self {
75            Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
76            Self::Range(range) => range.contains(actual),
77        }
78    }
79}
80
81/// A matcher for IP addresses.
82#[derive(Clone, Debug)]
83pub struct AddressMatcher<A: IpAddress> {
84    /// The type of the address matcher.
85    pub matcher: AddressMatcherType<A>,
86    /// Whether to check for an "inverse" or "negative" match (in which case,
87    /// if the matcher criteria do *not* apply, it *is* considered a match, and
88    /// vice versa).
89    pub invert: bool,
90}
91
92impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
93    fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
94        inspector.record_debug(name, self);
95    }
96}
97
98impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
99    fn matches(&self, addr: &A) -> bool {
100        let Self { matcher, invert } = self;
101        matcher.matches(addr) ^ *invert
102    }
103}
104
105/// A matcher for transport-layer port numbers.
106#[derive(Clone, Debug)]
107pub struct PortMatcher {
108    /// The range of port numbers in which the tested port number must fall.
109    pub range: RangeInclusive<u16>,
110    /// Whether to check for an "inverse" or "negative" match (in which case,
111    /// if the matcher criteria do *not* apply, it *is* considered a match, and
112    /// vice versa).
113    pub invert: bool,
114}
115
116impl Matcher<u16> for PortMatcher {
117    fn matches(&self, actual: &u16) -> bool {
118        let Self { range, invert } = self;
119        range.contains(actual) ^ *invert
120    }
121}
122
123/// A matcher for transport-layer protocol or port numbers.
124#[derive(Debug, Clone)]
125pub struct TransportProtocolMatcher<P> {
126    /// The transport-layer protocol.
127    pub proto: P,
128    /// If set, the matcher for the source port or identifier of the transport
129    /// header.
130    pub src_port: Option<PortMatcher>,
131    /// If set, the matcher for the destination port or identifier of the
132    /// transport header.
133    pub dst_port: Option<PortMatcher>,
134}
135
136impl<P: Debug> InspectableValue for TransportProtocolMatcher<P> {
137    fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
138        inspector.record_debug(name, self);
139    }
140}
141
142impl<P: PartialEq, T: MaybeTransportPacket> Matcher<(P, T)> for TransportProtocolMatcher<P> {
143    fn matches(&self, actual: &(P, T)) -> bool {
144        let Self { proto, src_port, dst_port } = self;
145        let (packet_proto, packet) = actual;
146
147        proto == packet_proto
148            && src_port.required_matches(
149                packet.transport_packet_data().as_ref().map(TransportPacketData::src_port).as_ref(),
150            )
151            && dst_port.required_matches(
152                packet.transport_packet_data().as_ref().map(TransportPacketData::dst_port).as_ref(),
153            )
154    }
155}
156
157/// Top-level matcher for IP packets.
158#[derive(Derivative, Debug, Clone)]
159#[derivative(Default(bound = ""))]
160pub struct PacketMatcher<I: IpExt, DeviceClass> {
161    /// The interface on which the packet entered the stack.
162    ///
163    /// Only available in `INGRESS`, `LOCAL_INGRESS`, and `FORWARDING`.
164    pub in_interface: Option<InterfaceMatcher<DeviceClass>>,
165    /// The interface through which the packet exits the stack.
166    ///
167    /// Only available in `FORWARDING`, `LOCAL_EGRESS`, and `EGRESS`.
168    pub out_interface: Option<InterfaceMatcher<DeviceClass>>,
169    /// Matcher for the source IP address.
170    pub src_address: Option<AddressMatcher<I::Addr>>,
171    /// Matcher for the destination IP address.
172    pub dst_address: Option<AddressMatcher<I::Addr>>,
173    /// Matchers for the transport layer.
174    pub transport_protocol: Option<TransportProtocolMatcher<I::Proto>>,
175}
176
177impl<I: IpExt, DeviceClass> PacketMatcher<I, DeviceClass> {
178    pub(crate) fn matches<P: IpPacket<I>, D: InterfaceProperties<DeviceClass>>(
179        &self,
180        packet: &P,
181        interfaces: &Interfaces<'_, D>,
182    ) -> bool {
183        let Self { in_interface, out_interface, src_address, dst_address, transport_protocol } =
184            self;
185        let Interfaces { ingress: in_if, egress: out_if } = interfaces;
186
187        // If no fields are specified, match all traffic by default.
188        in_interface.required_matches(*in_if)
189            && out_interface.required_matches(*out_if)
190            && src_address.matches(&packet.src_addr())
191            && dst_address.matches(&packet.dst_addr())
192            && transport_protocol.matches(&(packet.protocol(), packet.maybe_transport_packet()))
193    }
194}
195
196#[cfg(test)]
197pub(crate) mod testutil {
198    use alloc::string::String;
199    use core::num::NonZeroU64;
200
201    use netstack3_base::testutil::{FakeStrongDeviceId, FakeWeakDeviceId};
202    use netstack3_base::{DeviceIdentifier, DeviceWithName, StrongDeviceIdentifier};
203
204    use super::*;
205    use crate::context::testutil::FakeDeviceClass;
206
207    #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
208    pub struct FakeDeviceId {
209        pub id: NonZeroU64,
210        pub name: String,
211        pub class: FakeDeviceClass,
212    }
213
214    impl StrongDeviceIdentifier for FakeDeviceId {
215        type Weak = FakeWeakDeviceId<Self>;
216
217        fn downgrade(&self) -> Self::Weak {
218            FakeWeakDeviceId(self.clone())
219        }
220    }
221
222    impl DeviceIdentifier for FakeDeviceId {
223        fn is_loopback(&self) -> bool {
224            false
225        }
226    }
227
228    impl FakeStrongDeviceId for FakeDeviceId {
229        fn is_alive(&self) -> bool {
230            true
231        }
232    }
233
234    impl PartialEq<FakeWeakDeviceId<FakeDeviceId>> for FakeDeviceId {
235        fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeDeviceId>) -> bool {
236            self == other
237        }
238    }
239
240    impl DeviceWithName for FakeDeviceId {
241        fn name_matches(&self, name: &str) -> bool {
242            &self.name == name
243        }
244    }
245
246    impl InterfaceProperties<FakeDeviceClass> for FakeDeviceId {
247        fn id_matches(&self, id: &NonZeroU64) -> bool {
248            &self.id == id
249        }
250
251        fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
252            &self.class == class
253        }
254    }
255
256    pub fn wlan_interface() -> FakeDeviceId {
257        FakeDeviceId {
258            id: NonZeroU64::new(1).unwrap(),
259            name: String::from("wlan"),
260            class: FakeDeviceClass::Wlan,
261        }
262    }
263
264    pub fn ethernet_interface() -> FakeDeviceId {
265        FakeDeviceId {
266            id: NonZeroU64::new(2).unwrap(),
267            name: String::from("eth"),
268            class: FakeDeviceClass::Ethernet,
269        }
270    }
271}
272
273/// Test utilities implementations for base crate test types.
274#[cfg(feature = "testutils")]
275mod base_testutil {
276    use super::*;
277
278    impl InterfaceProperties<()> for netstack3_base::testutil::FakeDeviceId {
279        fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
280            unimplemented!()
281        }
282
283        fn device_class_matches(&self, _: &()) -> bool {
284            unimplemented!()
285        }
286    }
287
288    impl InterfaceProperties<()> for netstack3_base::testutil::FakeReferencyDeviceId {
289        fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
290            unimplemented!()
291        }
292
293        fn device_class_matches(&self, _: &()) -> bool {
294            unimplemented!()
295        }
296    }
297
298    impl InterfaceProperties<()> for netstack3_base::testutil::MultipleDevicesId {
299        fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
300            unimplemented!()
301        }
302
303        fn device_class_matches(&self, _: &()) -> bool {
304            unimplemented!()
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use ip_test_macro::ip_test;
312    use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
313    use packet_formats::ip::{IpProto, Ipv4Proto};
314    use test_case::test_case;
315
316    use netstack3_base::SegmentHeader;
317
318    use super::testutil::*;
319    use super::*;
320    use crate::context::testutil::FakeDeviceClass;
321    use crate::packets::testutil::internal::{
322        ArbitraryValue, FakeIcmpEchoRequest, FakeIpPacket, FakeTcpSegment, FakeUdpPacket,
323        TestIpExt, TransportPacketExt,
324    };
325
326    #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
327    #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
328    #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
329    fn match_on_interface_properties(matcher: InterfaceMatcher<FakeDeviceClass>) {
330        let matcher = PacketMatcher {
331            in_interface: Some(matcher.clone()),
332            out_interface: Some(matcher),
333            ..Default::default()
334        };
335
336        assert_eq!(
337            matcher.matches(
338                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
339                &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
340            ),
341            true
342        );
343        assert_eq!(
344            matcher.matches(
345                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
346                &Interfaces {
347                    ingress: Some(&ethernet_interface()),
348                    egress: Some(&ethernet_interface())
349                },
350            ),
351            false
352        );
353    }
354
355    #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
356    #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
357    #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
358    fn interface_matcher_specified_but_not_available_in_hook_does_not_match(
359        matcher: InterfaceMatcher<FakeDeviceClass>,
360    ) {
361        let matcher = PacketMatcher {
362            in_interface: Some(matcher.clone()),
363            out_interface: Some(matcher),
364            ..Default::default()
365        };
366
367        assert_eq!(
368            matcher.matches(
369                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
370                &Interfaces { ingress: None, egress: Some(&wlan_interface()) },
371            ),
372            false
373        );
374        assert_eq!(
375            matcher.matches(
376                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
377                &Interfaces { ingress: Some(&wlan_interface()), egress: None },
378            ),
379            false
380        );
381        assert_eq!(
382            matcher.matches(
383                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
384                &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
385            ),
386            true
387        );
388    }
389
390    enum AddressMatcherTestCase {
391        Subnet,
392        Range,
393    }
394
395    #[ip_test(I)]
396    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ false)]
397    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ true)]
398    #[test_case(AddressMatcherTestCase::Range, /* invert */ false)]
399    #[test_case(AddressMatcherTestCase::Range, /* invert */ true)]
400    fn match_on_subnet_or_address_range<I: TestIpExt>(
401        test_case: AddressMatcherTestCase,
402        invert: bool,
403    ) {
404        let matcher = AddressMatcher {
405            matcher: match test_case {
406                AddressMatcherTestCase::Subnet => {
407                    AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET))
408                }
409                AddressMatcherTestCase::Range => {
410                    // Generate the inclusive address range that is equivalent to the subnet.
411                    let start = I::SUBNET.network();
412                    let end = I::map_ip(
413                        start,
414                        |start| {
415                            let range_size = 2_u32.pow(32 - u32::from(I::SUBNET.prefix())) - 1;
416                            let end = u32::from_be_bytes(start.ipv4_bytes()) + range_size;
417                            Ipv4Addr::from(end.to_be_bytes())
418                        },
419                        |start| {
420                            let range_size = 2_u128.pow(128 - u32::from(I::SUBNET.prefix())) - 1;
421                            let end = u128::from_be_bytes(start.ipv6_bytes()) + range_size;
422                            Ipv6Addr::from(end.to_be_bytes())
423                        },
424                    );
425                    AddressMatcherType::Range(start..=end)
426                }
427            },
428            invert,
429        };
430
431        for matcher in [
432            PacketMatcher { src_address: Some(matcher.clone()), ..Default::default() },
433            PacketMatcher { dst_address: Some(matcher), ..Default::default() },
434        ] {
435            assert_ne!(
436                matcher.matches::<_, FakeDeviceId>(
437                    &FakeIpPacket::<I, FakeTcpSegment>::arbitrary_value(),
438                    &Interfaces { ingress: None, egress: None },
439                ),
440                invert
441            );
442            assert_eq!(
443                matcher.matches::<_, FakeDeviceId>(
444                    &FakeIpPacket {
445                        src_ip: I::IP_OUTSIDE_SUBNET,
446                        dst_ip: I::IP_OUTSIDE_SUBNET,
447                        body: FakeTcpSegment::arbitrary_value(),
448                    },
449                    &Interfaces { ingress: None, egress: None },
450                ),
451                invert
452            );
453        }
454    }
455
456    enum Protocol {
457        Tcp,
458        Udp,
459        Icmp,
460    }
461
462    impl Protocol {
463        fn ip_proto<I: IpExt>(&self) -> I::Proto {
464            match self {
465                Self::Tcp => <&FakeTcpSegment as TransportPacketExt<I>>::proto(),
466                Self::Udp => <&FakeUdpPacket as TransportPacketExt<I>>::proto(),
467                Self::Icmp => <&FakeIcmpEchoRequest as TransportPacketExt<I>>::proto(),
468            }
469        }
470    }
471
472    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => true)]
473    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
474    #[test_case(
475        Protocol::Tcp,
476        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
477        => false
478    )]
479    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => true)]
480    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value()=> false)]
481    #[test_case(
482        Protocol::Udp,
483        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
484        => false
485    )]
486    #[test_case(
487        Protocol::Icmp,
488        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
489        => true
490    )]
491    #[test_case(
492        Protocol::Icmp,
493        FakeIpPacket::<Ipv6, FakeIcmpEchoRequest>::arbitrary_value()
494        => true
495    )]
496    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => false)]
497    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
498    fn match_on_transport_protocol<I: TestIpExt, P: IpPacket<I>>(
499        protocol: Protocol,
500        packet: P,
501    ) -> bool {
502        let matcher = PacketMatcher {
503            transport_protocol: Some(TransportProtocolMatcher {
504                proto: protocol.ip_proto::<I>(),
505                src_port: None,
506                dst_port: None,
507            }),
508            ..Default::default()
509        };
510
511        matcher.matches::<_, FakeDeviceId>(&packet, &Interfaces { ingress: None, egress: None })
512    }
513
514    #[test_case(
515        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (11111, 80), true;
516        "matching src port"
517    )]
518    #[test_case(
519        Some(PortMatcher { range: 1024..=65535, invert: true }), None, (11111, 80), false;
520        "invert match src port"
521    )]
522    #[test_case(
523        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (53, 80), false;
524        "non-matching src port"
525    )]
526    #[test_case(
527        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 22), true;
528        "match dst port"
529    )]
530    #[test_case(
531        None, Some(PortMatcher { range: 22..=22, invert: true }), (11111, 22), false;
532        "invert match dst port"
533    )]
534    #[test_case(
535        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 80), false;
536        "non-matching dst port"
537    )]
538    fn match_on_port_range(
539        src_port: Option<PortMatcher>,
540        dst_port: Option<PortMatcher>,
541        transport_header: (u16, u16),
542        expect_match: bool,
543    ) {
544        // TCP
545        let matcher = PacketMatcher {
546            transport_protocol: Some(TransportProtocolMatcher {
547                proto: Ipv4Proto::Proto(IpProto::Tcp),
548                src_port: src_port.clone(),
549                dst_port: dst_port.clone(),
550            }),
551            ..Default::default()
552        };
553        let (src, dst) = transport_header;
554        assert_eq!(
555            matcher.matches::<_, FakeDeviceId>(
556                &FakeIpPacket::<Ipv4, _> {
557                    body: FakeTcpSegment {
558                        src_port: src,
559                        dst_port: dst,
560                        segment: SegmentHeader::arbitrary_value(),
561                        payload_len: 8888,
562                    },
563                    ..ArbitraryValue::arbitrary_value()
564                },
565                &Interfaces { ingress: None, egress: None },
566            ),
567            expect_match
568        );
569
570        // UDP
571        let matcher = PacketMatcher {
572            transport_protocol: Some(TransportProtocolMatcher {
573                proto: Ipv4Proto::Proto(IpProto::Udp),
574                src_port,
575                dst_port,
576            }),
577            ..Default::default()
578        };
579        let (src, dst) = transport_header;
580        assert_eq!(
581            matcher.matches::<_, FakeDeviceId>(
582                &FakeIpPacket::<Ipv4, _> {
583                    body: FakeUdpPacket { src_port: src, dst_port: dst },
584                    ..ArbitraryValue::arbitrary_value()
585                },
586                &Interfaces { ingress: None, egress: None },
587            ),
588            expect_match
589        );
590    }
591
592    #[ip_test(I)]
593    fn packet_must_match_all_provided_matchers<I: TestIpExt>() {
594        let matcher = PacketMatcher::<I, FakeDeviceClass> {
595            src_address: Some(AddressMatcher {
596                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
597                invert: false,
598            }),
599            dst_address: Some(AddressMatcher {
600                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
601                invert: false,
602            }),
603            ..Default::default()
604        };
605
606        assert_eq!(
607            matcher.matches::<_, FakeDeviceId>(
608                &FakeIpPacket::<_, FakeTcpSegment> {
609                    src_ip: I::IP_OUTSIDE_SUBNET,
610                    ..ArbitraryValue::arbitrary_value()
611                },
612                &Interfaces { ingress: None, egress: None },
613            ),
614            false
615        );
616        assert_eq!(
617            matcher.matches::<_, FakeDeviceId>(
618                &FakeIpPacket::<_, FakeTcpSegment> {
619                    dst_ip: I::IP_OUTSIDE_SUBNET,
620                    ..ArbitraryValue::arbitrary_value()
621                },
622                &Interfaces { ingress: None, egress: None },
623            ),
624            false
625        );
626        assert_eq!(
627            matcher.matches::<_, FakeDeviceId>(
628                &FakeIpPacket::<_, FakeTcpSegment>::arbitrary_value(),
629                &Interfaces { ingress: None, egress: None },
630            ),
631            true
632        );
633    }
634
635    #[test]
636    fn match_by_default_if_no_specified_matchers() {
637        assert_eq!(
638            PacketMatcher::<_, FakeDeviceClass>::default().matches::<_, FakeDeviceId>(
639                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
640                &Interfaces { ingress: None, egress: None },
641            ),
642            true
643        );
644    }
645}