1use core::fmt::Debug;
6use core::num::NonZeroU64;
7use core::ops::RangeInclusive;
8use netstack3_base::{DeviceNameMatcher, DeviceWithName, InspectableValue, Matcher, SubnetMatcher};
9
10use derivative::Derivative;
11use net_types::ip::IpAddress;
12use packet_formats::ip::IpExt;
13
14use crate::logic::Interfaces;
15use crate::packets::{IpPacket, MaybeTransportPacket, TransportPacketData};
16
17#[derive(Clone, Derivative)]
19#[derivative(Debug)]
20pub enum InterfaceMatcher<DeviceClass> {
21 Id(NonZeroU64),
23 #[derivative(Debug = "transparent")]
25 Name(DeviceNameMatcher),
26 DeviceClass(DeviceClass),
28}
29
30impl<DeviceClass: Debug> InspectableValue for InterfaceMatcher<DeviceClass> {
31 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
32 inspector.record_debug(name, self);
33 }
34}
35
36pub trait InterfaceProperties<DeviceClass>: DeviceWithName {
40 fn id_matches(&self, id: &NonZeroU64) -> bool;
42
43 fn device_class_matches(&self, device_class: &DeviceClass) -> bool;
45}
46
47impl<DeviceClass, I: InterfaceProperties<DeviceClass>> Matcher<I>
48 for InterfaceMatcher<DeviceClass>
49{
50 fn matches(&self, actual: &I) -> bool {
51 match self {
52 InterfaceMatcher::Id(id) => actual.id_matches(id),
53 InterfaceMatcher::Name(name_matcher) => name_matcher.matches(actual),
54 InterfaceMatcher::DeviceClass(device_class) => {
55 actual.device_class_matches(device_class)
56 }
57 }
58 }
59}
60
61#[derive(Clone, Derivative)]
63#[derivative(Debug)]
64pub enum AddressMatcherType<A: IpAddress> {
65 #[derivative(Debug = "transparent")]
67 Subnet(SubnetMatcher<A>),
68 Range(RangeInclusive<A>),
70}
71
72impl<A: IpAddress> Matcher<A> for AddressMatcherType<A> {
73 fn matches(&self, actual: &A) -> bool {
74 match self {
75 Self::Subnet(subnet_matcher) => subnet_matcher.matches(actual),
76 Self::Range(range) => range.contains(actual),
77 }
78 }
79}
80
81#[derive(Clone, Debug)]
83pub struct AddressMatcher<A: IpAddress> {
84 pub matcher: AddressMatcherType<A>,
86 pub invert: bool,
90}
91
92impl<A: IpAddress> InspectableValue for AddressMatcher<A> {
93 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
94 inspector.record_debug(name, self);
95 }
96}
97
98impl<A: IpAddress> Matcher<A> for AddressMatcher<A> {
99 fn matches(&self, addr: &A) -> bool {
100 let Self { matcher, invert } = self;
101 matcher.matches(addr) ^ *invert
102 }
103}
104
105#[derive(Clone, Debug)]
107pub struct PortMatcher {
108 pub range: RangeInclusive<u16>,
110 pub invert: bool,
114}
115
116impl Matcher<u16> for PortMatcher {
117 fn matches(&self, actual: &u16) -> bool {
118 let Self { range, invert } = self;
119 range.contains(actual) ^ *invert
120 }
121}
122
123#[derive(Debug, Clone)]
125pub struct TransportProtocolMatcher<P> {
126 pub proto: P,
128 pub src_port: Option<PortMatcher>,
131 pub dst_port: Option<PortMatcher>,
134}
135
136impl<P: Debug> InspectableValue for TransportProtocolMatcher<P> {
137 fn record<I: netstack3_base::Inspector>(&self, name: &str, inspector: &mut I) {
138 inspector.record_debug(name, self);
139 }
140}
141
142impl<P: PartialEq, T: MaybeTransportPacket> Matcher<(P, T)> for TransportProtocolMatcher<P> {
143 fn matches(&self, actual: &(P, T)) -> bool {
144 let Self { proto, src_port, dst_port } = self;
145 let (packet_proto, packet) = actual;
146
147 proto == packet_proto
148 && src_port.required_matches(
149 packet.transport_packet_data().as_ref().map(TransportPacketData::src_port).as_ref(),
150 )
151 && dst_port.required_matches(
152 packet.transport_packet_data().as_ref().map(TransportPacketData::dst_port).as_ref(),
153 )
154 }
155}
156
157#[derive(Derivative, Debug, Clone)]
159#[derivative(Default(bound = ""))]
160pub struct PacketMatcher<I: IpExt, DeviceClass> {
161 pub in_interface: Option<InterfaceMatcher<DeviceClass>>,
165 pub out_interface: Option<InterfaceMatcher<DeviceClass>>,
169 pub src_address: Option<AddressMatcher<I::Addr>>,
171 pub dst_address: Option<AddressMatcher<I::Addr>>,
173 pub transport_protocol: Option<TransportProtocolMatcher<I::Proto>>,
175}
176
177impl<I: IpExt, DeviceClass> PacketMatcher<I, DeviceClass> {
178 pub(crate) fn matches<P: IpPacket<I>, D: InterfaceProperties<DeviceClass>>(
179 &self,
180 packet: &P,
181 interfaces: &Interfaces<'_, D>,
182 ) -> bool {
183 let Self { in_interface, out_interface, src_address, dst_address, transport_protocol } =
184 self;
185 let Interfaces { ingress: in_if, egress: out_if } = interfaces;
186
187 in_interface.required_matches(*in_if)
189 && out_interface.required_matches(*out_if)
190 && src_address.matches(&packet.src_addr())
191 && dst_address.matches(&packet.dst_addr())
192 && transport_protocol.matches(&(packet.protocol(), packet.maybe_transport_packet()))
193 }
194}
195
196#[cfg(test)]
197pub(crate) mod testutil {
198 use alloc::string::String;
199 use core::num::NonZeroU64;
200
201 use netstack3_base::testutil::{FakeStrongDeviceId, FakeWeakDeviceId};
202 use netstack3_base::{DeviceIdentifier, DeviceWithName, StrongDeviceIdentifier};
203
204 use super::*;
205 use crate::context::testutil::FakeDeviceClass;
206
207 #[derive(Clone, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
208 pub struct FakeDeviceId {
209 pub id: NonZeroU64,
210 pub name: String,
211 pub class: FakeDeviceClass,
212 }
213
214 impl StrongDeviceIdentifier for FakeDeviceId {
215 type Weak = FakeWeakDeviceId<Self>;
216
217 fn downgrade(&self) -> Self::Weak {
218 FakeWeakDeviceId(self.clone())
219 }
220 }
221
222 impl DeviceIdentifier for FakeDeviceId {
223 fn is_loopback(&self) -> bool {
224 false
225 }
226 }
227
228 impl FakeStrongDeviceId for FakeDeviceId {
229 fn is_alive(&self) -> bool {
230 true
231 }
232 }
233
234 impl PartialEq<FakeWeakDeviceId<FakeDeviceId>> for FakeDeviceId {
235 fn eq(&self, FakeWeakDeviceId(other): &FakeWeakDeviceId<FakeDeviceId>) -> bool {
236 self == other
237 }
238 }
239
240 impl DeviceWithName for FakeDeviceId {
241 fn name_matches(&self, name: &str) -> bool {
242 &self.name == name
243 }
244 }
245
246 impl InterfaceProperties<FakeDeviceClass> for FakeDeviceId {
247 fn id_matches(&self, id: &NonZeroU64) -> bool {
248 &self.id == id
249 }
250
251 fn device_class_matches(&self, class: &FakeDeviceClass) -> bool {
252 &self.class == class
253 }
254 }
255
256 pub fn wlan_interface() -> FakeDeviceId {
257 FakeDeviceId {
258 id: NonZeroU64::new(1).unwrap(),
259 name: String::from("wlan"),
260 class: FakeDeviceClass::Wlan,
261 }
262 }
263
264 pub fn ethernet_interface() -> FakeDeviceId {
265 FakeDeviceId {
266 id: NonZeroU64::new(2).unwrap(),
267 name: String::from("eth"),
268 class: FakeDeviceClass::Ethernet,
269 }
270 }
271}
272
273#[cfg(feature = "testutils")]
275mod base_testutil {
276 use super::*;
277
278 impl InterfaceProperties<()> for netstack3_base::testutil::FakeDeviceId {
279 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
280 unimplemented!()
281 }
282
283 fn device_class_matches(&self, _: &()) -> bool {
284 unimplemented!()
285 }
286 }
287
288 impl InterfaceProperties<()> for netstack3_base::testutil::FakeReferencyDeviceId {
289 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
290 unimplemented!()
291 }
292
293 fn device_class_matches(&self, _: &()) -> bool {
294 unimplemented!()
295 }
296 }
297
298 impl InterfaceProperties<()> for netstack3_base::testutil::MultipleDevicesId {
299 fn id_matches(&self, _: &core::num::NonZeroU64) -> bool {
300 unimplemented!()
301 }
302
303 fn device_class_matches(&self, _: &()) -> bool {
304 unimplemented!()
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use ip_test_macro::ip_test;
312 use net_types::ip::{Ipv4, Ipv4Addr, Ipv6, Ipv6Addr};
313 use packet_formats::ip::{IpProto, Ipv4Proto};
314 use test_case::test_case;
315
316 use netstack3_base::SegmentHeader;
317
318 use super::testutil::*;
319 use super::*;
320 use crate::context::testutil::FakeDeviceClass;
321 use crate::packets::testutil::internal::{
322 ArbitraryValue, FakeIcmpEchoRequest, FakeIpPacket, FakeTcpSegment, FakeUdpPacket,
323 TestIpExt, TransportPacketExt,
324 };
325
326 #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
327 #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
328 #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
329 fn match_on_interface_properties(matcher: InterfaceMatcher<FakeDeviceClass>) {
330 let matcher = PacketMatcher {
331 in_interface: Some(matcher.clone()),
332 out_interface: Some(matcher),
333 ..Default::default()
334 };
335
336 assert_eq!(
337 matcher.matches(
338 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
339 &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
340 ),
341 true
342 );
343 assert_eq!(
344 matcher.matches(
345 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
346 &Interfaces {
347 ingress: Some(ðernet_interface()),
348 egress: Some(ðernet_interface())
349 },
350 ),
351 false
352 );
353 }
354
355 #[test_case(InterfaceMatcher::Id(wlan_interface().id))]
356 #[test_case(InterfaceMatcher::Name(DeviceNameMatcher(wlan_interface().name.clone())))]
357 #[test_case(InterfaceMatcher::DeviceClass(wlan_interface().class))]
358 fn interface_matcher_specified_but_not_available_in_hook_does_not_match(
359 matcher: InterfaceMatcher<FakeDeviceClass>,
360 ) {
361 let matcher = PacketMatcher {
362 in_interface: Some(matcher.clone()),
363 out_interface: Some(matcher),
364 ..Default::default()
365 };
366
367 assert_eq!(
368 matcher.matches(
369 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
370 &Interfaces { ingress: None, egress: Some(&wlan_interface()) },
371 ),
372 false
373 );
374 assert_eq!(
375 matcher.matches(
376 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
377 &Interfaces { ingress: Some(&wlan_interface()), egress: None },
378 ),
379 false
380 );
381 assert_eq!(
382 matcher.matches(
383 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
384 &Interfaces { ingress: Some(&wlan_interface()), egress: Some(&wlan_interface()) },
385 ),
386 true
387 );
388 }
389
390 enum AddressMatcherTestCase {
391 Subnet,
392 Range,
393 }
394
395 #[ip_test(I)]
396 #[test_case(AddressMatcherTestCase::Subnet, false)]
397 #[test_case(AddressMatcherTestCase::Subnet, true)]
398 #[test_case(AddressMatcherTestCase::Range, false)]
399 #[test_case(AddressMatcherTestCase::Range, true)]
400 fn match_on_subnet_or_address_range<I: TestIpExt>(
401 test_case: AddressMatcherTestCase,
402 invert: bool,
403 ) {
404 let matcher = AddressMatcher {
405 matcher: match test_case {
406 AddressMatcherTestCase::Subnet => {
407 AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET))
408 }
409 AddressMatcherTestCase::Range => {
410 let start = I::SUBNET.network();
412 let end = I::map_ip(
413 start,
414 |start| {
415 let range_size = 2_u32.pow(32 - u32::from(I::SUBNET.prefix())) - 1;
416 let end = u32::from_be_bytes(start.ipv4_bytes()) + range_size;
417 Ipv4Addr::from(end.to_be_bytes())
418 },
419 |start| {
420 let range_size = 2_u128.pow(128 - u32::from(I::SUBNET.prefix())) - 1;
421 let end = u128::from_be_bytes(start.ipv6_bytes()) + range_size;
422 Ipv6Addr::from(end.to_be_bytes())
423 },
424 );
425 AddressMatcherType::Range(start..=end)
426 }
427 },
428 invert,
429 };
430
431 for matcher in [
432 PacketMatcher { src_address: Some(matcher.clone()), ..Default::default() },
433 PacketMatcher { dst_address: Some(matcher), ..Default::default() },
434 ] {
435 assert_ne!(
436 matcher.matches::<_, FakeDeviceId>(
437 &FakeIpPacket::<I, FakeTcpSegment>::arbitrary_value(),
438 &Interfaces { ingress: None, egress: None },
439 ),
440 invert
441 );
442 assert_eq!(
443 matcher.matches::<_, FakeDeviceId>(
444 &FakeIpPacket {
445 src_ip: I::IP_OUTSIDE_SUBNET,
446 dst_ip: I::IP_OUTSIDE_SUBNET,
447 body: FakeTcpSegment::arbitrary_value(),
448 },
449 &Interfaces { ingress: None, egress: None },
450 ),
451 invert
452 );
453 }
454 }
455
456 enum Protocol {
457 Tcp,
458 Udp,
459 Icmp,
460 }
461
462 impl Protocol {
463 fn ip_proto<I: IpExt>(&self) -> I::Proto {
464 match self {
465 Self::Tcp => <&FakeTcpSegment as TransportPacketExt<I>>::proto(),
466 Self::Udp => <&FakeUdpPacket as TransportPacketExt<I>>::proto(),
467 Self::Icmp => <&FakeIcmpEchoRequest as TransportPacketExt<I>>::proto(),
468 }
469 }
470 }
471
472 #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => true)]
473 #[test_case(Protocol::Tcp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
474 #[test_case(
475 Protocol::Tcp,
476 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
477 => false
478 )]
479 #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => true)]
480 #[test_case(Protocol::Udp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value()=> false)]
481 #[test_case(
482 Protocol::Udp,
483 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
484 => false
485 )]
486 #[test_case(
487 Protocol::Icmp,
488 FakeIpPacket::<Ipv4, FakeIcmpEchoRequest>::arbitrary_value()
489 => true
490 )]
491 #[test_case(
492 Protocol::Icmp,
493 FakeIpPacket::<Ipv6, FakeIcmpEchoRequest>::arbitrary_value()
494 => true
495 )]
496 #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value() => false)]
497 #[test_case(Protocol::Icmp, FakeIpPacket::<Ipv4, FakeUdpPacket>::arbitrary_value() => false)]
498 fn match_on_transport_protocol<I: TestIpExt, P: IpPacket<I>>(
499 protocol: Protocol,
500 packet: P,
501 ) -> bool {
502 let matcher = PacketMatcher {
503 transport_protocol: Some(TransportProtocolMatcher {
504 proto: protocol.ip_proto::<I>(),
505 src_port: None,
506 dst_port: None,
507 }),
508 ..Default::default()
509 };
510
511 matcher.matches::<_, FakeDeviceId>(&packet, &Interfaces { ingress: None, egress: None })
512 }
513
514 #[test_case(
515 Some(PortMatcher { range: 1024..=65535, invert: false }), None, (11111, 80), true;
516 "matching src port"
517 )]
518 #[test_case(
519 Some(PortMatcher { range: 1024..=65535, invert: true }), None, (11111, 80), false;
520 "invert match src port"
521 )]
522 #[test_case(
523 Some(PortMatcher { range: 1024..=65535, invert: false }), None, (53, 80), false;
524 "non-matching src port"
525 )]
526 #[test_case(
527 None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 22), true;
528 "match dst port"
529 )]
530 #[test_case(
531 None, Some(PortMatcher { range: 22..=22, invert: true }), (11111, 22), false;
532 "invert match dst port"
533 )]
534 #[test_case(
535 None, Some(PortMatcher { range: 22..=22, invert: false }), (11111, 80), false;
536 "non-matching dst port"
537 )]
538 fn match_on_port_range(
539 src_port: Option<PortMatcher>,
540 dst_port: Option<PortMatcher>,
541 transport_header: (u16, u16),
542 expect_match: bool,
543 ) {
544 let matcher = PacketMatcher {
546 transport_protocol: Some(TransportProtocolMatcher {
547 proto: Ipv4Proto::Proto(IpProto::Tcp),
548 src_port: src_port.clone(),
549 dst_port: dst_port.clone(),
550 }),
551 ..Default::default()
552 };
553 let (src, dst) = transport_header;
554 assert_eq!(
555 matcher.matches::<_, FakeDeviceId>(
556 &FakeIpPacket::<Ipv4, _> {
557 body: FakeTcpSegment {
558 src_port: src,
559 dst_port: dst,
560 segment: SegmentHeader::arbitrary_value(),
561 payload_len: 8888,
562 },
563 ..ArbitraryValue::arbitrary_value()
564 },
565 &Interfaces { ingress: None, egress: None },
566 ),
567 expect_match
568 );
569
570 let matcher = PacketMatcher {
572 transport_protocol: Some(TransportProtocolMatcher {
573 proto: Ipv4Proto::Proto(IpProto::Udp),
574 src_port,
575 dst_port,
576 }),
577 ..Default::default()
578 };
579 let (src, dst) = transport_header;
580 assert_eq!(
581 matcher.matches::<_, FakeDeviceId>(
582 &FakeIpPacket::<Ipv4, _> {
583 body: FakeUdpPacket { src_port: src, dst_port: dst },
584 ..ArbitraryValue::arbitrary_value()
585 },
586 &Interfaces { ingress: None, egress: None },
587 ),
588 expect_match
589 );
590 }
591
592 #[ip_test(I)]
593 fn packet_must_match_all_provided_matchers<I: TestIpExt>() {
594 let matcher = PacketMatcher::<I, FakeDeviceClass> {
595 src_address: Some(AddressMatcher {
596 matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
597 invert: false,
598 }),
599 dst_address: Some(AddressMatcher {
600 matcher: AddressMatcherType::Subnet(SubnetMatcher(I::SUBNET)),
601 invert: false,
602 }),
603 ..Default::default()
604 };
605
606 assert_eq!(
607 matcher.matches::<_, FakeDeviceId>(
608 &FakeIpPacket::<_, FakeTcpSegment> {
609 src_ip: I::IP_OUTSIDE_SUBNET,
610 ..ArbitraryValue::arbitrary_value()
611 },
612 &Interfaces { ingress: None, egress: None },
613 ),
614 false
615 );
616 assert_eq!(
617 matcher.matches::<_, FakeDeviceId>(
618 &FakeIpPacket::<_, FakeTcpSegment> {
619 dst_ip: I::IP_OUTSIDE_SUBNET,
620 ..ArbitraryValue::arbitrary_value()
621 },
622 &Interfaces { ingress: None, egress: None },
623 ),
624 false
625 );
626 assert_eq!(
627 matcher.matches::<_, FakeDeviceId>(
628 &FakeIpPacket::<_, FakeTcpSegment>::arbitrary_value(),
629 &Interfaces { ingress: None, egress: None },
630 ),
631 true
632 );
633 }
634
635 #[test]
636 fn match_by_default_if_no_specified_matchers() {
637 assert_eq!(
638 PacketMatcher::<_, FakeDeviceClass>::default().matches::<_, FakeDeviceId>(
639 &FakeIpPacket::<Ipv4, FakeTcpSegment>::arbitrary_value(),
640 &Interfaces { ingress: None, egress: None },
641 ),
642 true
643 );
644 }
645}