netstack3_filter/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use alloc::sync::Arc;
6use core::convert::Infallible as Never;
7use core::fmt::Debug;
8use netstack3_base::{
9    AddressMatcher, InspectableValue, InterfaceMatcher, InterfaceProperties, Matcher,
10    MatcherBindingsTypes, PortMatcher,
11};
12
13use derivative::Derivative;
14use packet_formats::ip::IpExt;
15
16use crate::FilterBindingsTypes;
17use crate::logic::Interfaces;
18use crate::packets::{FilterIpExt, FilterIpPacket, MaybeTransportPacket, TransportPacketData};
19
20/// A matcher for transport-layer protocol or port numbers.
21#[derive(Debug, Clone)]
22pub struct TransportProtocolMatcher<P> {
23    /// The transport-layer protocol.
24    pub proto: P,
25    /// If set, the matcher for the source port or identifier of the transport
26    /// header.
27    pub src_port: Option<PortMatcher>,
28    /// If set, the matcher for the destination port or identifier of the
29    /// transport header.
30    pub dst_port: Option<PortMatcher>,
31}
32
33impl<P: Debug> InspectableValue for TransportProtocolMatcher<P> {
34    fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
35        inspector.record_debug(name, self);
36    }
37}
38
39impl<P: PartialEq, T: MaybeTransportPacket> Matcher<(Option<P>, T)>
40    for TransportProtocolMatcher<P>
41{
42    fn matches(&self, actual: &(Option<P>, T)) -> bool {
43        let Self { proto, src_port, dst_port } = self;
44        let (packet_proto, packet) = actual;
45
46        let Some(packet_proto) = packet_proto else {
47            return false;
48        };
49
50        proto == packet_proto && {
51            let transport_data = packet.transport_packet_data();
52            src_port.required_matches(
53                transport_data.as_ref().map(TransportPacketData::src_port).as_ref(),
54            ) && dst_port.required_matches(
55                transport_data.as_ref().map(TransportPacketData::dst_port).as_ref(),
56            )
57        }
58    }
59}
60
61/// A trait for external matchers supplied by bindings.
62pub trait BindingsPacketMatcher<D> {
63    /// Returns `true` if the packet matches.
64    fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
65        &self,
66        packet: &P,
67        interfaces: Interfaces<'_, D>,
68    ) -> bool;
69}
70
71impl<D> BindingsPacketMatcher<D> for Never {
72    fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
73        &self,
74        _packet: &P,
75        _interfaces: Interfaces<'_, D>,
76    ) -> bool {
77        match *self {}
78    }
79}
80
81impl<D, M: BindingsPacketMatcher<D>> BindingsPacketMatcher<D> for Arc<M> {
82    fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
83        &self,
84        packet: &P,
85        interfaces: Interfaces<'_, D>,
86    ) -> bool {
87        self.as_ref().matches(packet, interfaces)
88    }
89}
90
91/// Top-level matcher for IP packets.
92#[derive(Derivative)]
93#[derivative(Default(bound = ""), Clone(bound = ""), Debug(bound = ""))]
94pub struct PacketMatcher<I: IpExt, BT: MatcherBindingsTypes> {
95    /// The interface on which the packet entered the stack.
96    ///
97    /// Only available in `INGRESS`, `LOCAL_INGRESS`, and `FORWARDING`.
98    pub in_interface: Option<InterfaceMatcher<BT::DeviceClass>>,
99    /// The interface through which the packet exits the stack.
100    ///
101    /// Only available in `FORWARDING`, `LOCAL_EGRESS`, and `EGRESS`.
102    pub out_interface: Option<InterfaceMatcher<BT::DeviceClass>>,
103    /// Matcher for the source IP address.
104    pub src_address: Option<AddressMatcher<I::Addr>>,
105    /// Matcher for the destination IP address.
106    pub dst_address: Option<AddressMatcher<I::Addr>>,
107    /// Matchers for the transport layer.
108    pub transport_protocol: Option<TransportProtocolMatcher<I::Proto>>,
109    /// A custom matcher to run on the packet.
110    pub external_matcher: Option<BT::BindingsPacketMatcher>,
111}
112
113impl<I, BT> PacketMatcher<I, BT>
114where
115    I: FilterIpExt,
116    BT: FilterBindingsTypes,
117{
118    pub(crate) fn matches<P, D>(&self, packet: &P, interfaces: Interfaces<'_, D>) -> bool
119    where
120        P: FilterIpPacket<I>,
121        D: InterfaceProperties<BT::DeviceClass>,
122        BT: FilterBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher<D>>,
123    {
124        let Self {
125            in_interface,
126            out_interface,
127            src_address,
128            dst_address,
129            transport_protocol,
130            external_matcher,
131        } = self;
132        let Interfaces { ingress: in_if, egress: out_if } = interfaces;
133
134        // If no fields are specified, match all traffic by default.
135        in_interface.required_matches(in_if)
136            && out_interface.required_matches(out_if)
137            && src_address.matches(&packet.src_addr())
138            && dst_address.matches(&packet.dst_addr())
139            && transport_protocol.matches(&(packet.protocol(), packet.maybe_transport_packet()))
140            && external_matcher.as_ref().map_or(true, |m| m.matches(packet, interfaces))
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use ip_test_macro::ip_test;
147    use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
148    use packet_formats::ip::{IpProto, Ipv4Proto};
149    use test_case::test_case;
150
151    use netstack3_base::testutil::{FakeDeviceClass, FakeMatcherDeviceId};
152    use netstack3_base::{AddressMatcherType, SegmentHeader, SubnetMatcher};
153
154    use super::*;
155    use crate::context::testutil::{FakeBindingsCtx, FakeBindingsPacketMatcher};
156    use crate::packets::testutil::internal::{
157        ArbitraryValue, FakeIcmpEchoRequest, FakeIpPacket, FakeNullPacket, FakeTcpSegment,
158        FakeUdpPacket, TestIpExt, TransportPacketExt,
159    };
160
161    #[test_case(InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id))]
162    #[test_case(InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name))]
163    #[test_case(InterfaceMatcher::DeviceClass(FakeMatcherDeviceId::wlan_interface().class))]
164    fn match_on_interface_properties(matcher: InterfaceMatcher<FakeDeviceClass>) {
165        let matcher = PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>> {
166            in_interface: Some(matcher.clone()),
167            out_interface: Some(matcher),
168            ..Default::default()
169        };
170
171        assert_eq!(
172            matcher.matches(
173                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
174                Interfaces {
175                    ingress: Some(&FakeMatcherDeviceId::wlan_interface()),
176                    egress: Some(&FakeMatcherDeviceId::wlan_interface())
177                },
178            ),
179            true
180        );
181        assert_eq!(
182            matcher.matches(
183                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
184                Interfaces {
185                    ingress: Some(&FakeMatcherDeviceId::ethernet_interface()),
186                    egress: Some(&FakeMatcherDeviceId::ethernet_interface())
187                },
188            ),
189            false
190        );
191    }
192
193    #[test_case(InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id))]
194    #[test_case(InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name))]
195    #[test_case(InterfaceMatcher::DeviceClass(FakeMatcherDeviceId::wlan_interface().class))]
196    fn interface_matcher_specified_but_not_available_in_hook_does_not_match(
197        matcher: InterfaceMatcher<FakeDeviceClass>,
198    ) {
199        let matcher = PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>> {
200            in_interface: Some(matcher.clone()),
201            out_interface: Some(matcher),
202            ..Default::default()
203        };
204
205        assert_eq!(
206            matcher.matches(
207                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
208                Interfaces { ingress: None, egress: Some(&FakeMatcherDeviceId::wlan_interface()) },
209            ),
210            false
211        );
212        assert_eq!(
213            matcher.matches(
214                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
215                Interfaces { ingress: Some(&FakeMatcherDeviceId::wlan_interface()), egress: None },
216            ),
217            false
218        );
219        assert_eq!(
220            matcher.matches(
221                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
222                Interfaces {
223                    ingress: Some(&FakeMatcherDeviceId::wlan_interface()),
224                    egress: Some(&FakeMatcherDeviceId::wlan_interface())
225                },
226            ),
227            true
228        );
229    }
230
231    enum AddressMatcherTestCase {
232        Subnet,
233        Range,
234    }
235
236    #[ip_test(I)]
237    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ false)]
238    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ true)]
239    #[test_case(AddressMatcherTestCase::Range, /* invert */ false)]
240    #[test_case(AddressMatcherTestCase::Range, /* invert */ true)]
241    fn match_on_subnet_or_address_range<I: TestIpExt>(
242        test_case: AddressMatcherTestCase,
243        invert: bool,
244    ) {
245        let matcher = AddressMatcher {
246            matcher: match test_case {
247                AddressMatcherTestCase::Subnet => {
248                    AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET))
249                }
250                AddressMatcherTestCase::Range => {
251                    // Generate the inclusive address range that is equivalent to the subnet.
252                    let start = I::SUBNET.network();
253                    let end = I::map_ip(
254                        start,
255                        |start| {
256                            let range_size = 2_u32.pow(32 - u32::from(I::SUBNET.prefix())) - 1;
257                            let end = u32::from_be_bytes(start.ipv4_bytes()) + range_size;
258                            Ipv4Addr::from(end.to_be_bytes())
259                        },
260                        |start| {
261                            let range_size = 2_u128.pow(128 - u32::from(I::SUBNET.prefix())) - 1;
262                            let end = u128::from_be_bytes(start.ipv6_bytes()) + range_size;
263                            Ipv6Addr::from(end.to_be_bytes())
264                        },
265                    );
266                    AddressMatcherType::Range(start..=end)
267                }
268            },
269            invert,
270        };
271
272        for matcher in [
273            PacketMatcher::<I, FakeBindingsCtx<I>> {
274                src_address: Some(matcher.clone()),
275                ..Default::default()
276            },
277            PacketMatcher::<I, FakeBindingsCtx<I>> {
278                dst_address: Some(matcher),
279                ..Default::default()
280            },
281        ] {
282            assert_ne!(
283                matcher.matches::<_, FakeMatcherDeviceId>(
284                    &FakeIpPacket::<I, FakeTcpSegment>::arbitrary_value(),
285                    Interfaces { ingress: None, egress: None },
286                ),
287                invert
288            );
289            assert_eq!(
290                matcher.matches::<_, FakeMatcherDeviceId>(
291                    &FakeIpPacket {
292                        src_ip: I::IP_OUTSIDE_SUBNET,
293                        dst_ip: I::IP_OUTSIDE_SUBNET,
294                        body: FakeTcpSegment::arbitrary_value(),
295                    },
296                    Interfaces { ingress: None, egress: None },
297                ),
298                invert
299            );
300        }
301    }
302
303    enum Protocol {
304        Tcp,
305        Udp,
306        Icmp,
307    }
308
309    impl Protocol {
310        fn ip_proto<I: FilterIpExt>(&self) -> Option<I::Proto> {
311            match self {
312                Self::Tcp => <&FakeTcpSegment as TransportPacketExt<I>>::proto(),
313                Self::Udp => <&FakeUdpPacket as TransportPacketExt<I>>::proto(),
314                Self::Icmp => <&FakeIcmpEchoRequest as TransportPacketExt<I>>::proto(),
315            }
316        }
317    }
318
319    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => true)]
320    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
321    #[test_case(
322        Protocol::Tcp,
323        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
324        => false
325    )]
326    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
327    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => true)]
328    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value()=> false)]
329    #[test_case(
330        Protocol::Udp,
331        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
332        => false
333    )]
334    #[test_case(
335        Protocol::Icmp,
336        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
337        => true
338    )]
339    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
340    #[test_case(
341        Protocol::Icmp,
342        FakeIpPacket::<Ipv6, FakeIcmpEchoRequest>::arbitrary_value()
343        => true
344    )]
345    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => false)]
346    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
347    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
348    fn match_on_transport_protocol<I, P>(protocol: Protocol, packet: P) -> bool
349    where
350        I: TestIpExt,
351        P: FilterIpPacket<I>,
352    {
353        let matcher = PacketMatcher::<I, FakeBindingsCtx<I>> {
354            transport_protocol: Some(TransportProtocolMatcher {
355                proto: protocol.ip_proto::<I>().unwrap(),
356                src_port: None,
357                dst_port: None,
358            }),
359            ..Default::default()
360        };
361
362        matcher
363            .matches::<_, FakeMatcherDeviceId>(&packet, Interfaces { ingress: None, egress: None })
364    }
365
366    #[test_case(
367        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (11111, 80), true;
368        "matching src port"
369    )]
370    #[test_case(
371        Some(PortMatcher { range: 1024..=65535, invert: true }), None, (11111, 80), false;
372        "invert match src port"
373    )]
374    #[test_case(
375        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (53, 80), false;
376        "non-matching src port"
377    )]
378    #[test_case(
379        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 22), true;
380        "match dst port"
381    )]
382    #[test_case(
383        None, Some(PortMatcher { range: 22..=22, invert: true }), (11111, 22), false;
384        "invert match dst port"
385    )]
386    #[test_case(
387        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 80), false;
388        "non-matching dst port"
389    )]
390    fn match_on_port_range(
391        src_port: Option<PortMatcher>,
392        dst_port: Option<PortMatcher>,
393        transport_header: (u16, u16),
394        expect_match: bool,
395    ) {
396        // TCP
397        let matcher = PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>> {
398            transport_protocol: Some(TransportProtocolMatcher {
399                proto: Ipv4Proto::Proto(IpProto::Tcp),
400                src_port: src_port.clone(),
401                dst_port: dst_port.clone(),
402            }),
403            ..Default::default()
404        };
405        let (src, dst) = transport_header;
406        assert_eq!(
407            matcher.matches::<_, FakeMatcherDeviceId>(
408                &FakeIpPacket::<Ipv4, _> {
409                    body: FakeTcpSegment {
410                        src_port: src,
411                        dst_port: dst,
412                        segment: SegmentHeader::arbitrary_value(),
413                        payload_len: 8888,
414                    },
415                    ..ArbitraryValue::arbitrary_value()
416                },
417                Interfaces { ingress: None, egress: None },
418            ),
419            expect_match
420        );
421
422        // UDP
423        let matcher = PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>> {
424            transport_protocol: Some(TransportProtocolMatcher {
425                proto: Ipv4Proto::Proto(IpProto::Udp),
426                src_port,
427                dst_port,
428            }),
429            ..Default::default()
430        };
431        let (src, dst) = transport_header;
432        assert_eq!(
433            matcher.matches::<_, FakeMatcherDeviceId>(
434                &FakeIpPacket::<Ipv4, _> {
435                    body: FakeUdpPacket { src_port: src, dst_port: dst },
436                    ..ArbitraryValue::arbitrary_value()
437                },
438                Interfaces { ingress: None, egress: None },
439            ),
440            expect_match
441        );
442    }
443
444    #[ip_test(I)]
445    fn packet_must_match_all_provided_matchers<I: TestIpExt>() {
446        let matcher = PacketMatcher::<I, FakeBindingsCtx<I>> {
447            src_address: Some(AddressMatcher {
448                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
449                invert: false,
450            }),
451            dst_address: Some(AddressMatcher {
452                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
453                invert: false,
454            }),
455            ..Default::default()
456        };
457
458        assert_eq!(
459            matcher.matches::<_, FakeMatcherDeviceId>(
460                &FakeIpPacket::<_, FakeTcpSegment> {
461                    src_ip: I::IP_OUTSIDE_SUBNET,
462                    ..ArbitraryValue::arbitrary_value()
463                },
464                Interfaces { ingress: None, egress: None },
465            ),
466            false
467        );
468        assert_eq!(
469            matcher.matches::<_, FakeMatcherDeviceId>(
470                &FakeIpPacket::<_, FakeTcpSegment> {
471                    dst_ip: I::IP_OUTSIDE_SUBNET,
472                    ..ArbitraryValue::arbitrary_value()
473                },
474                Interfaces { ingress: None, egress: None },
475            ),
476            false
477        );
478        assert_eq!(
479            matcher.matches::<_, FakeMatcherDeviceId>(
480                &FakeIpPacket::<_, FakeTcpSegment>::arbitrary_value(),
481                Interfaces { ingress: None, egress: None },
482            ),
483            true
484        );
485    }
486
487    #[test]
488    fn match_by_default_if_no_specified_matchers() {
489        assert_eq!(
490            PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>>::default()
491                .matches::<_, FakeMatcherDeviceId>(
492                    &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
493                    Interfaces { ingress: None, egress: None },
494                ),
495            true
496        );
497    }
498
499    #[test_case(true; "yes")]
500    #[test_case(false; "no")]
501    fn match_external_matcher(result: bool) {
502        let external_matcher = FakeBindingsPacketMatcher::new(result);
503        assert_eq!(
504            PacketMatcher::<Ipv4, FakeBindingsCtx<Ipv4>> {
505                external_matcher: Some(external_matcher.clone()),
506                ..Default::default()
507            }
508            .matches::<_, FakeMatcherDeviceId>(
509                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
510                Interfaces {
511                    ingress: Some(&FakeMatcherDeviceId::ethernet_interface()),
512                    egress: None
513                },
514            ),
515            result
516        );
517        assert_eq!(external_matcher.num_calls(), 1);
518    }
519}