netstack3_filter/
matchers.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6use netstack3_base::{
7    AddressMatcher, InspectableValue, InterfaceMatcher, InterfaceProperties, Matcher, PortMatcher,
8};
9
10use derivative::Derivative;
11use packet_formats::ip::IpExt;
12
13use crate::logic::Interfaces;
14use crate::packets::{FilterIpExt, IpPacket, MaybeTransportPacket, TransportPacketData};
15
16/// A matcher for transport-layer protocol or port numbers.
17#[derive(Debug, Clone)]
18pub struct TransportProtocolMatcher<P> {
19    /// The transport-layer protocol.
20    pub proto: P,
21    /// If set, the matcher for the source port or identifier of the transport
22    /// header.
23    pub src_port: Option<PortMatcher>,
24    /// If set, the matcher for the destination port or identifier of the
25    /// transport header.
26    pub dst_port: Option<PortMatcher>,
27}
28
29impl<P: Debug> InspectableValue for TransportProtocolMatcher<P> {
30    fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
31        inspector.record_debug(name, self);
32    }
33}
34
35impl<P: PartialEq, T: MaybeTransportPacket> Matcher<(Option<P>, T)>
36    for TransportProtocolMatcher<P>
37{
38    fn matches(&self, actual: &(Option<P>, T)) -> bool {
39        let Self { proto, src_port, dst_port } = self;
40        let (packet_proto, packet) = actual;
41
42        let Some(packet_proto) = packet_proto else {
43            return false;
44        };
45
46        proto == packet_proto && {
47            let transport_data = packet.transport_packet_data();
48            src_port.required_matches(
49                transport_data.as_ref().map(TransportPacketData::src_port).as_ref(),
50            ) && dst_port.required_matches(
51                transport_data.as_ref().map(TransportPacketData::dst_port).as_ref(),
52            )
53        }
54    }
55}
56
57/// Top-level matcher for IP packets.
58#[derive(Derivative, Debug, Clone)]
59#[derivative(Default(bound = ""))]
60pub struct PacketMatcher<I: IpExt, DeviceClass> {
61    /// The interface on which the packet entered the stack.
62    ///
63    /// Only available in `INGRESS`, `LOCAL_INGRESS`, and `FORWARDING`.
64    pub in_interface: Option<InterfaceMatcher<DeviceClass>>,
65    /// The interface through which the packet exits the stack.
66    ///
67    /// Only available in `FORWARDING`, `LOCAL_EGRESS`, and `EGRESS`.
68    pub out_interface: Option<InterfaceMatcher<DeviceClass>>,
69    /// Matcher for the source IP address.
70    pub src_address: Option<AddressMatcher<I::Addr>>,
71    /// Matcher for the destination IP address.
72    pub dst_address: Option<AddressMatcher<I::Addr>>,
73    /// Matchers for the transport layer.
74    pub transport_protocol: Option<TransportProtocolMatcher<I::Proto>>,
75}
76
77impl<I: FilterIpExt, DeviceClass> PacketMatcher<I, DeviceClass> {
78    pub(crate) fn matches<P: IpPacket<I>, D: InterfaceProperties<DeviceClass>>(
79        &self,
80        packet: &P,
81        interfaces: &Interfaces<'_, D>,
82    ) -> bool {
83        let Self { in_interface, out_interface, src_address, dst_address, transport_protocol } =
84            self;
85        let Interfaces { ingress: in_if, egress: out_if } = interfaces;
86
87        // If no fields are specified, match all traffic by default.
88        in_interface.required_matches(*in_if)
89            && out_interface.required_matches(*out_if)
90            && src_address.matches(&packet.src_addr())
91            && dst_address.matches(&packet.dst_addr())
92            && transport_protocol.matches(&(packet.protocol(), packet.maybe_transport_packet()))
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use ip_test_macro::ip_test;
99    use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
100    use packet_formats::ip::{IpProto, Ipv4Proto};
101    use test_case::test_case;
102
103    use netstack3_base::testutil::{FakeDeviceClass, FakeMatcherDeviceId};
104    use netstack3_base::{AddressMatcherType, SegmentHeader, SubnetMatcher};
105
106    use super::*;
107    use crate::packets::testutil::internal::{
108        ArbitraryValue, FakeIcmpEchoRequest, FakeIpPacket, FakeNullPacket, FakeTcpSegment,
109        FakeUdpPacket, TestIpExt, TransportPacketExt,
110    };
111
112    #[test_case(InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id))]
113    #[test_case(InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name))]
114    #[test_case(InterfaceMatcher::DeviceClass(FakeMatcherDeviceId::wlan_interface().class))]
115    fn match_on_interface_properties(matcher: InterfaceMatcher<FakeDeviceClass>) {
116        let matcher = PacketMatcher {
117            in_interface: Some(matcher.clone()),
118            out_interface: Some(matcher),
119            ..Default::default()
120        };
121
122        assert_eq!(
123            matcher.matches(
124                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
125                &Interfaces {
126                    ingress: Some(&FakeMatcherDeviceId::wlan_interface()),
127                    egress: Some(&FakeMatcherDeviceId::wlan_interface())
128                },
129            ),
130            true
131        );
132        assert_eq!(
133            matcher.matches(
134                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
135                &Interfaces {
136                    ingress: Some(&FakeMatcherDeviceId::ethernet_interface()),
137                    egress: Some(&FakeMatcherDeviceId::ethernet_interface())
138                },
139            ),
140            false
141        );
142    }
143
144    #[test_case(InterfaceMatcher::Id(FakeMatcherDeviceId::wlan_interface().id))]
145    #[test_case(InterfaceMatcher::Name(FakeMatcherDeviceId::wlan_interface().name))]
146    #[test_case(InterfaceMatcher::DeviceClass(FakeMatcherDeviceId::wlan_interface().class))]
147    fn interface_matcher_specified_but_not_available_in_hook_does_not_match(
148        matcher: InterfaceMatcher<FakeDeviceClass>,
149    ) {
150        let matcher = PacketMatcher {
151            in_interface: Some(matcher.clone()),
152            out_interface: Some(matcher),
153            ..Default::default()
154        };
155
156        assert_eq!(
157            matcher.matches(
158                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
159                &Interfaces { ingress: None, egress: Some(&FakeMatcherDeviceId::wlan_interface()) },
160            ),
161            false
162        );
163        assert_eq!(
164            matcher.matches(
165                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
166                &Interfaces { ingress: Some(&FakeMatcherDeviceId::wlan_interface()), egress: None },
167            ),
168            false
169        );
170        assert_eq!(
171            matcher.matches(
172                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
173                &Interfaces {
174                    ingress: Some(&FakeMatcherDeviceId::wlan_interface()),
175                    egress: Some(&FakeMatcherDeviceId::wlan_interface())
176                },
177            ),
178            true
179        );
180    }
181
182    enum AddressMatcherTestCase {
183        Subnet,
184        Range,
185    }
186
187    #[ip_test(I)]
188    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ false)]
189    #[test_case(AddressMatcherTestCase::Subnet, /* invert */ true)]
190    #[test_case(AddressMatcherTestCase::Range, /* invert */ false)]
191    #[test_case(AddressMatcherTestCase::Range, /* invert */ true)]
192    fn match_on_subnet_or_address_range<I: TestIpExt>(
193        test_case: AddressMatcherTestCase,
194        invert: bool,
195    ) {
196        let matcher = AddressMatcher {
197            matcher: match test_case {
198                AddressMatcherTestCase::Subnet => {
199                    AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET))
200                }
201                AddressMatcherTestCase::Range => {
202                    // Generate the inclusive address range that is equivalent to the subnet.
203                    let start = I::SUBNET.network();
204                    let end = I::map_ip(
205                        start,
206                        |start| {
207                            let range_size = 2_u32.pow(32 - u32::from(I::SUBNET.prefix())) - 1;
208                            let end = u32::from_be_bytes(start.ipv4_bytes()) + range_size;
209                            Ipv4Addr::from(end.to_be_bytes())
210                        },
211                        |start| {
212                            let range_size = 2_u128.pow(128 - u32::from(I::SUBNET.prefix())) - 1;
213                            let end = u128::from_be_bytes(start.ipv6_bytes()) + range_size;
214                            Ipv6Addr::from(end.to_be_bytes())
215                        },
216                    );
217                    AddressMatcherType::Range(start..=end)
218                }
219            },
220            invert,
221        };
222
223        for matcher in [
224            PacketMatcher { src_address: Some(matcher.clone()), ..Default::default() },
225            PacketMatcher { dst_address: Some(matcher), ..Default::default() },
226        ] {
227            assert_ne!(
228                matcher.matches::<_, FakeMatcherDeviceId>(
229                    &FakeIpPacket::<I, FakeTcpSegment>::arbitrary_value(),
230                    &Interfaces { ingress: None, egress: None },
231                ),
232                invert
233            );
234            assert_eq!(
235                matcher.matches::<_, FakeMatcherDeviceId>(
236                    &FakeIpPacket {
237                        src_ip: I::IP_OUTSIDE_SUBNET,
238                        dst_ip: I::IP_OUTSIDE_SUBNET,
239                        body: FakeTcpSegment::arbitrary_value(),
240                    },
241                    &Interfaces { ingress: None, egress: None },
242                ),
243                invert
244            );
245        }
246    }
247
248    enum Protocol {
249        Tcp,
250        Udp,
251        Icmp,
252    }
253
254    impl Protocol {
255        fn ip_proto<I: FilterIpExt>(&self) -> Option<I::Proto> {
256            match self {
257                Self::Tcp => <&FakeTcpSegment as TransportPacketExt<I>>::proto(),
258                Self::Udp => <&FakeUdpPacket as TransportPacketExt<I>>::proto(),
259                Self::Icmp => <&FakeIcmpEchoRequest as TransportPacketExt<I>>::proto(),
260            }
261        }
262    }
263
264    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => true)]
265    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
266    #[test_case(
267        Protocol::Tcp,
268        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
269        => false
270    )]
271    #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
272    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => true)]
273    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value()=> false)]
274    #[test_case(
275        Protocol::Udp,
276        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
277        => false
278    )]
279    #[test_case(
280        Protocol::Icmp,
281        FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
282        => true
283    )]
284    #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
285    #[test_case(
286        Protocol::Icmp,
287        FakeIpPacket::<Ipv6, FakeIcmpEchoRequest>::arbitrary_value()
288        => true
289    )]
290    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => false)]
291    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
292    #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
293    fn match_on_transport_protocol<I: TestIpExt, P: IpPacket<I>>(
294        protocol: Protocol,
295        packet: P,
296    ) -> bool {
297        let matcher = PacketMatcher {
298            transport_protocol: Some(TransportProtocolMatcher {
299                proto: protocol.ip_proto::<I>().unwrap(),
300                src_port: None,
301                dst_port: None,
302            }),
303            ..Default::default()
304        };
305
306        matcher
307            .matches::<_, FakeMatcherDeviceId>(&packet, &Interfaces { ingress: None, egress: None })
308    }
309
310    #[test_case(
311        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (11111, 80), true;
312        "matching src port"
313    )]
314    #[test_case(
315        Some(PortMatcher { range: 1024..=65535, invert: true }), None, (11111, 80), false;
316        "invert match src port"
317    )]
318    #[test_case(
319        Some(PortMatcher { range: 1024..=65535, invert: false }), None, (53, 80), false;
320        "non-matching src port"
321    )]
322    #[test_case(
323        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 22), true;
324        "match dst port"
325    )]
326    #[test_case(
327        None, Some(PortMatcher { range: 22..=22, invert: true }), (11111, 22), false;
328        "invert match dst port"
329    )]
330    #[test_case(
331        None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 80), false;
332        "non-matching dst port"
333    )]
334    fn match_on_port_range(
335        src_port: Option<PortMatcher>,
336        dst_port: Option<PortMatcher>,
337        transport_header: (u16, u16),
338        expect_match: bool,
339    ) {
340        // TCP
341        let matcher = PacketMatcher {
342            transport_protocol: Some(TransportProtocolMatcher {
343                proto: Ipv4Proto::Proto(IpProto::Tcp),
344                src_port: src_port.clone(),
345                dst_port: dst_port.clone(),
346            }),
347            ..Default::default()
348        };
349        let (src, dst) = transport_header;
350        assert_eq!(
351            matcher.matches::<_, FakeMatcherDeviceId>(
352                &FakeIpPacket::<Ipv4, _> {
353                    body: FakeTcpSegment {
354                        src_port: src,
355                        dst_port: dst,
356                        segment: SegmentHeader::arbitrary_value(),
357                        payload_len: 8888,
358                    },
359                    ..ArbitraryValue::arbitrary_value()
360                },
361                &Interfaces { ingress: None, egress: None },
362            ),
363            expect_match
364        );
365
366        // UDP
367        let matcher = PacketMatcher {
368            transport_protocol: Some(TransportProtocolMatcher {
369                proto: Ipv4Proto::Proto(IpProto::Udp),
370                src_port,
371                dst_port,
372            }),
373            ..Default::default()
374        };
375        let (src, dst) = transport_header;
376        assert_eq!(
377            matcher.matches::<_, FakeMatcherDeviceId>(
378                &FakeIpPacket::<Ipv4, _> {
379                    body: FakeUdpPacket { src_port: src, dst_port: dst },
380                    ..ArbitraryValue::arbitrary_value()
381                },
382                &Interfaces { ingress: None, egress: None },
383            ),
384            expect_match
385        );
386    }
387
388    #[ip_test(I)]
389    fn packet_must_match_all_provided_matchers<I: TestIpExt>() {
390        let matcher = PacketMatcher::<I, FakeDeviceClass> {
391            src_address: Some(AddressMatcher {
392                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
393                invert: false,
394            }),
395            dst_address: Some(AddressMatcher {
396                matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
397                invert: false,
398            }),
399            ..Default::default()
400        };
401
402        assert_eq!(
403            matcher.matches::<_, FakeMatcherDeviceId>(
404                &FakeIpPacket::<_, FakeTcpSegment> {
405                    src_ip: I::IP_OUTSIDE_SUBNET,
406                    ..ArbitraryValue::arbitrary_value()
407                },
408                &Interfaces { ingress: None, egress: None },
409            ),
410            false
411        );
412        assert_eq!(
413            matcher.matches::<_, FakeMatcherDeviceId>(
414                &FakeIpPacket::<_, FakeTcpSegment> {
415                    dst_ip: I::IP_OUTSIDE_SUBNET,
416                    ..ArbitraryValue::arbitrary_value()
417                },
418                &Interfaces { ingress: None, egress: None },
419            ),
420            false
421        );
422        assert_eq!(
423            matcher.matches::<_, FakeMatcherDeviceId>(
424                &FakeIpPacket::<_, FakeTcpSegment>::arbitrary_value(),
425                &Interfaces { ingress: None, egress: None },
426            ),
427            true
428        );
429    }
430
431    #[test]
432    fn match_by_default_if_no_specified_matchers() {
433        assert_eq!(
434            PacketMatcher::<_, FakeDeviceClass>::default().matches::<_, FakeMatcherDeviceId>(
435                &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
436                &Interfaces { ingress: None, egress: None },
437            ),
438            true
439        );
440    }
441}