1use core::fmt::Debug;
6use core::num::NonZeroU64;
7use core::ops::RangeInclusive;
8use netstack3_base::{DeviceNameMatcher, DeviceWithName, InspectableValue, Matcher, SubnetMatcher};
9
10use derivative::Derivative;
11use net_types::ip::IpAddress;
12use packet_formats::ip::IpExt;
13
14use crate::logic::Interfaces;
15use crate::packets::{FilterIpExt, IpPacket, MaybeTransportPacket, TransportPacketData};
16
17#[derive(Clone, Derivative)]
19#[derivative(Debug)]
20pub enum InterfaceMatcher<DeviceClass> {
21 Id(NonZeroU64),
23 #[derivative(Debug = "transparent")]
25 Name(DeviceNameMatcher),
26 DeviceClass(DeviceClass),
28}
29
30impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
31 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
32 inspector.record_debug(name, self);
33 }
34}
35
36pub trait InterfaceProperties<DeviceClass>: DeviceWithName {
40 fn id_matches(&self, id: &NonZeroU64) -> bool;
42
43 fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
45}
46
47impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
48 for InterfaceMatcher<DeviceClass>
49{
50 fn matches(&self, actual: &I) -> bool {
51 match self {
52 InterfaceMatcher::Id(id) => actual.id_matches(id),
53 InterfaceMatcher::Name(name_matcher) => name_matcher.matches(actual),
54 InterfaceMatcher::DeviceClass(device_class) => {
55 actual.device_class_matches(device_class)
56 }
57 }
58 }
59}
60
61#[derive(Clone, Derivative)]
63#[derivative(Debug)]
64pub enum AddressMatcherType<A: IpAddress> {
65 #[derivative(Debug = "transparent")]
67 Subnet(SubnetMatcher<A>),
68 Range(RangeInclusive<A>),
70}
71
72impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
73 fn matches(&self, actual: &A) -> bool {
74 match self {
75 Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
76 Self::Range(range) => range.contains(actual),
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct AddressMatcher<A: IpAddress> {
84 pub matcher: AddressMatcherType<A>,
86 pub invert: bool,
90}
91
92impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
93 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
94 inspector.record_debug(name, self);
95 }
96}
97
98impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
99 fn matches(&self, addr: &A) -> bool {
100 let Self { matcher, invert } = self;
101 matcher.matches(addr) ^ *invert
102 }
103}
104
105#[derive(Clone, Debug)]
107pub struct PortMatcher {
108 pub range: RangeInclusive<u16>,
110 pub invert: bool,
114}
115
116impl Matcher<u16> for PortMatcher {
117 fn matches(&self, actual: &u16) -> bool {
118 let Self { range, invert } = self;
119 range.contains(actual) ^ *invert
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct TransportProtocolMatcher<P> {
126 pub proto: P,
128 pub src_port: Option<PortMatcher>,
131 pub dst_port: Option<PortMatcher>,
134}
135
136impl<P: Debug> InspectableValue for TransportProtocolMatcher<P> {
137 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
138 inspector.record_debug(name, self);
139 }
140}
141
142impl<P: PartialEq, T: MaybeTransportPacket> Matcher<(Option<P>, T)>
143 for TransportProtocolMatcher<P>
144{
145 fn matches(&self, actual: &(Option<P>, T)) -> bool {
146 let Self { proto, src_port, dst_port } = self;
147 let (packet_proto, packet) = actual;
148
149 let Some(packet_proto) = packet_proto else {
150 return false;
151 };
152
153 proto == packet_proto
154 && src_port.required_matches(
155 packet.transport_packet_data().as_ref().map(TransportPacketData::src_port).as_ref(),
156 )
157 && dst_port.required_matches(
158 packet.transport_packet_data().as_ref().map(TransportPacketData::dst_port).as_ref(),
159 )
160 }
161}
162
163#[derive(Derivative, Debug, Clone)]
165#[derivative(Default(bound = ""))]
166pub struct PacketMatcher<I: IpExt, DeviceClass> {
167 pub in_interface: Option<InterfaceMatcher<DeviceClass>>,
171 pub out_interface: Option<InterfaceMatcher<DeviceClass>>,
175 pub src_address: Option<AddressMatcher<I::Addr>>,
177 pub dst_address: Option<AddressMatcher<I::Addr>>,
179 pub transport_protocol: Option<TransportProtocolMatcher<I::Proto>>,
181}
182
183impl<I: FilterIpExt, DeviceClass> PacketMatcher<I, DeviceClass> {
184 pub(crate) fn matches<P: IpPacket<I>, D: InterfaceProperties<DeviceClass>>(
185 &self,
186 packet: &P,
187 interfaces: &Interfaces<'_, D>,
188 ) -> bool {
189 let Self { in_interface, out_interface, src_address, dst_address, transport_protocol } =
190 self;
191 let Interfaces { ingress: in_if, egress: out_if } = interfaces;
192
193 in_interface.required_matches(*in_if)
195 && out_interface.required_matches(*out_if)
196 && src_address.matches(&packet.src_addr())
197 && dst_address.matches(&packet.dst_addr())
198 && transport_protocol.matches(&(packet.protocol(), packet.maybe_transport_packet()))
199 }
200}
201
202#[cfg(test)]
203pub(crate) mod testutil {
204 use alloc::string::String;
205 use core::num::NonZeroU64;
206
207 use netstack3_base::testutil::{FakeStrongDeviceId, FakeWeakDeviceId};
208 use netstack3_base::{DeviceIdentifier, DeviceWithName, StrongDeviceIdentifier};
209
210 use super::*;
211 use crate::context::testutil::FakeDeviceClass;
212
213 #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
214 pub struct FakeDeviceId {
215 pub id: NonZeroU64,
216 pub name: String,
217 pub class: FakeDeviceClass,
218 }
219
220 impl StrongDeviceIdentifier for FakeDeviceId {
221 type Weak = FakeWeakDeviceId<Self>;
222
223 fn downgrade(&self) -> Self::Weak {
224 FakeWeakDeviceId(self.clone())
225 }
226 }
227
228 impl DeviceIdentifier for FakeDeviceId {
229 fn is_loopback(&self) -> bool {
230 false
231 }
232 }
233
234 impl FakeStrongDeviceId for FakeDeviceId {
235 fn is_alive(&self) -> bool {
236 true
237 }
238 }
239
240 impl PartialEq<FakeWeakDeviceId<FakeDeviceId>> for FakeDeviceId {
241 fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeDeviceId>) -> bool {
242 self == other
243 }
244 }
245
246 impl DeviceWithName for FakeDeviceId {
247 fn name_matches(&self, name: &str) -> bool {
248 &self.name == name
249 }
250 }
251
252 impl InterfaceProperties<FakeDeviceClass> for FakeDeviceId {
253 fn id_matches(&self, id: &NonZeroU64) -> bool {
254 &self.id == id
255 }
256
257 fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
258 &self.class == class
259 }
260 }
261
262 pub fn wlan_interface() -> FakeDeviceId {
263 FakeDeviceId {
264 id: NonZeroU64::new(1).unwrap(),
265 name: String::from("wlan"),
266 class: FakeDeviceClass::Wlan,
267 }
268 }
269
270 pub fn ethernet_interface() -> FakeDeviceId {
271 FakeDeviceId {
272 id: NonZeroU64::new(2).unwrap(),
273 name: String::from("eth"),
274 class: FakeDeviceClass::Ethernet,
275 }
276 }
277}
278
279#[cfg(feature = "testutils")]
281mod base_testutil {
282 use super::*;
283
284 impl InterfaceProperties<()> for netstack3_base::testutil::FakeDeviceId {
285 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
286 unimplemented!()
287 }
288
289 fn device_class_matches(&self, _: &()) -> bool {
290 unimplemented!()
291 }
292 }
293
294 impl InterfaceProperties<()> for netstack3_base::testutil::FakeReferencyDeviceId {
295 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
296 unimplemented!()
297 }
298
299 fn device_class_matches(&self, _: &()) -> bool {
300 unimplemented!()
301 }
302 }
303
304 impl InterfaceProperties<()> for netstack3_base::testutil::MultipleDevicesId {
305 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
306 unimplemented!()
307 }
308
309 fn device_class_matches(&self, _: &()) -> bool {
310 unimplemented!()
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use ip_test_macro::ip_test;
318 use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
319 use packet_formats::ip::{IpProto, Ipv4Proto};
320 use test_case::test_case;
321
322 use netstack3_base::SegmentHeader;
323
324 use super::testutil::*;
325 use super::*;
326 use crate::context::testutil::FakeDeviceClass;
327 use crate::packets::testutil::internal::{
328 ArbitraryValue, FakeIcmpEchoRequest, FakeIpPacket, FakeNullPacket, FakeTcpSegment,
329 FakeUdpPacket, TestIpExt, TransportPacketExt,
330 };
331
332 #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
333 #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
334 #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
335 fn match_on_interface_properties(matcher: InterfaceMatcher<FakeDeviceClass>) {
336 let matcher = PacketMatcher {
337 in_interface: Some(matcher.clone()),
338 out_interface: Some(matcher),
339 ..Default::default()
340 };
341
342 assert_eq!(
343 matcher.matches(
344 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
345 &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
346 ),
347 true
348 );
349 assert_eq!(
350 matcher.matches(
351 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
352 &Interfaces {
353 ingress: Some(ðernet_interface()),
354 egress: Some(ðernet_interface())
355 },
356 ),
357 false
358 );
359 }
360
361 #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
362 #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
363 #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
364 fn interface_matcher_specified_but_not_available_in_hook_does_not_match(
365 matcher: InterfaceMatcher<FakeDeviceClass>,
366 ) {
367 let matcher = PacketMatcher {
368 in_interface: Some(matcher.clone()),
369 out_interface: Some(matcher),
370 ..Default::default()
371 };
372
373 assert_eq!(
374 matcher.matches(
375 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
376 &Interfaces { ingress: None, egress: Some(&wlan_interface()) },
377 ),
378 false
379 );
380 assert_eq!(
381 matcher.matches(
382 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
383 &Interfaces { ingress: Some(&wlan_interface()), egress: None },
384 ),
385 false
386 );
387 assert_eq!(
388 matcher.matches(
389 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
390 &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
391 ),
392 true
393 );
394 }
395
396 enum AddressMatcherTestCase {
397 Subnet,
398 Range,
399 }
400
401 #[ip_test(I)]
402 #[test_case(AddressMatcherTestCase::Subnet, false)]
403 #[test_case(AddressMatcherTestCase::Subnet, true)]
404 #[test_case(AddressMatcherTestCase::Range, false)]
405 #[test_case(AddressMatcherTestCase::Range, true)]
406 fn match_on_subnet_or_address_range<I: TestIpExt>(
407 test_case: AddressMatcherTestCase,
408 invert: bool,
409 ) {
410 let matcher = AddressMatcher {
411 matcher: match test_case {
412 AddressMatcherTestCase::Subnet => {
413 AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET))
414 }
415 AddressMatcherTestCase::Range => {
416 let start = I::SUBNET.network();
418 let end = I::map_ip(
419 start,
420 |start| {
421 let range_size = 2_u32.pow(32 - u32::from(I::SUBNET.prefix())) - 1;
422 let end = u32::from_be_bytes(start.ipv4_bytes()) + range_size;
423 Ipv4Addr::from(end.to_be_bytes())
424 },
425 |start| {
426 let range_size = 2_u128.pow(128 - u32::from(I::SUBNET.prefix())) - 1;
427 let end = u128::from_be_bytes(start.ipv6_bytes()) + range_size;
428 Ipv6Addr::from(end.to_be_bytes())
429 },
430 );
431 AddressMatcherType::Range(start..=end)
432 }
433 },
434 invert,
435 };
436
437 for matcher in [
438 PacketMatcher { src_address: Some(matcher.clone()), ..Default::default() },
439 PacketMatcher { dst_address: Some(matcher), ..Default::default() },
440 ] {
441 assert_ne!(
442 matcher.matches::<_, FakeDeviceId>(
443 &FakeIpPacket::<I, FakeTcpSegment>::arbitrary_value(),
444 &Interfaces { ingress: None, egress: None },
445 ),
446 invert
447 );
448 assert_eq!(
449 matcher.matches::<_, FakeDeviceId>(
450 &FakeIpPacket {
451 src_ip: I::IP_OUTSIDE_SUBNET,
452 dst_ip: I::IP_OUTSIDE_SUBNET,
453 body: FakeTcpSegment::arbitrary_value(),
454 },
455 &Interfaces { ingress: None, egress: None },
456 ),
457 invert
458 );
459 }
460 }
461
462 enum Protocol {
463 Tcp,
464 Udp,
465 Icmp,
466 }
467
468 impl Protocol {
469 fn ip_proto<I: FilterIpExt>(&self) -> Option<I::Proto> {
470 match self {
471 Self::Tcp => <&FakeTcpSegment as TransportPacketExt<I>>::proto(),
472 Self::Udp => <&FakeUdpPacket as TransportPacketExt<I>>::proto(),
473 Self::Icmp => <&FakeIcmpEchoRequest as TransportPacketExt<I>>::proto(),
474 }
475 }
476 }
477
478 #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => true)]
479 #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
480 #[test_case(
481 Protocol::Tcp,
482 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
483 => false
484 )]
485 #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
486 #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => true)]
487 #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value()=> false)]
488 #[test_case(
489 Protocol::Udp,
490 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
491 => false
492 )]
493 #[test_case(
494 Protocol::Icmp,
495 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
496 => true
497 )]
498 #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
499 #[test_case(
500 Protocol::Icmp,
501 FakeIpPacket::<Ipv6, FakeIcmpEchoRequest>::arbitrary_value()
502 => true
503 )]
504 #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => false)]
505 #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
506 #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeNullPacket>::arbitrary_value() => false)]
507 fn match_on_transport_protocol<I: TestIpExt, P: IpPacket<I>>(
508 protocol: Protocol,
509 packet: P,
510 ) -> bool {
511 let matcher = PacketMatcher {
512 transport_protocol: Some(TransportProtocolMatcher {
513 proto: protocol.ip_proto::<I>().unwrap(),
514 src_port: None,
515 dst_port: None,
516 }),
517 ..Default::default()
518 };
519
520 matcher.matches::<_, FakeDeviceId>(&packet, &Interfaces { ingress: None, egress: None })
521 }
522
523 #[test_case(
524 Some(PortMatcher { range: 1024..=65535, invert: false }), None, (11111, 80), true;
525 "matching src port"
526 )]
527 #[test_case(
528 Some(PortMatcher { range: 1024..=65535, invert: true }), None, (11111, 80), false;
529 "invert match src port"
530 )]
531 #[test_case(
532 Some(PortMatcher { range: 1024..=65535, invert: false }), None, (53, 80), false;
533 "non-matching src port"
534 )]
535 #[test_case(
536 None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 22), true;
537 "match dst port"
538 )]
539 #[test_case(
540 None, Some(PortMatcher { range: 22..=22, invert: true }), (11111, 22), false;
541 "invert match dst port"
542 )]
543 #[test_case(
544 None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 80), false;
545 "non-matching dst port"
546 )]
547 fn match_on_port_range(
548 src_port: Option<PortMatcher>,
549 dst_port: Option<PortMatcher>,
550 transport_header: (u16, u16),
551 expect_match: bool,
552 ) {
553 let matcher = PacketMatcher {
555 transport_protocol: Some(TransportProtocolMatcher {
556 proto: Ipv4Proto::Proto(IpProto::Tcp),
557 src_port: src_port.clone(),
558 dst_port: dst_port.clone(),
559 }),
560 ..Default::default()
561 };
562 let (src, dst) = transport_header;
563 assert_eq!(
564 matcher.matches::<_, FakeDeviceId>(
565 &FakeIpPacket::<Ipv4, _> {
566 body: FakeTcpSegment {
567 src_port: src,
568 dst_port: dst,
569 segment: SegmentHeader::arbitrary_value(),
570 payload_len: 8888,
571 },
572 ..ArbitraryValue::arbitrary_value()
573 },
574 &Interfaces { ingress: None, egress: None },
575 ),
576 expect_match
577 );
578
579 let matcher = PacketMatcher {
581 transport_protocol: Some(TransportProtocolMatcher {
582 proto: Ipv4Proto::Proto(IpProto::Udp),
583 src_port,
584 dst_port,
585 }),
586 ..Default::default()
587 };
588 let (src, dst) = transport_header;
589 assert_eq!(
590 matcher.matches::<_, FakeDeviceId>(
591 &FakeIpPacket::<Ipv4, _> {
592 body: FakeUdpPacket { src_port: src, dst_port: dst },
593 ..ArbitraryValue::arbitrary_value()
594 },
595 &Interfaces { ingress: None, egress: None },
596 ),
597 expect_match
598 );
599 }
600
601 #[ip_test(I)]
602 fn packet_must_match_all_provided_matchers<I: TestIpExt>() {
603 let matcher = PacketMatcher::<I, FakeDeviceClass> {
604 src_address: Some(AddressMatcher {
605 matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
606 invert: false,
607 }),
608 dst_address: Some(AddressMatcher {
609 matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
610 invert: false,
611 }),
612 ..Default::default()
613 };
614
615 assert_eq!(
616 matcher.matches::<_, FakeDeviceId>(
617 &FakeIpPacket::<_, FakeTcpSegment> {
618 src_ip: I::IP_OUTSIDE_SUBNET,
619 ..ArbitraryValue::arbitrary_value()
620 },
621 &Interfaces { ingress: None, egress: None },
622 ),
623 false
624 );
625 assert_eq!(
626 matcher.matches::<_, FakeDeviceId>(
627 &FakeIpPacket::<_, FakeTcpSegment> {
628 dst_ip: I::IP_OUTSIDE_SUBNET,
629 ..ArbitraryValue::arbitrary_value()
630 },
631 &Interfaces { ingress: None, egress: None },
632 ),
633 false
634 );
635 assert_eq!(
636 matcher.matches::<_, FakeDeviceId>(
637 &FakeIpPacket::<_, FakeTcpSegment>::arbitrary_value(),
638 &Interfaces { ingress: None, egress: None },
639 ),
640 true
641 );
642 }
643
644 #[test]
645 fn match_by_default_if_no_specified_matchers() {
646 assert_eq!(
647 PacketMatcher::<_, FakeDeviceClass>::default().matches::<_, FakeDeviceId>(
648 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
649 &Interfaces { ingress: None, egress: None },
650 ),
651 true
652 );
653 }
654}