1use alloc::vec::Vec;
8use core::fmt::Debug;
9use core::ops::Deref as _;
10
11use net_types::ip::Ip;
12use netstack3_base::{
13 DeviceNameMatcher, DeviceWithName, Mark, MarkDomain, MarkStorage, Marks, Matcher, SubnetMatcher,
14};
15
16use crate::internal::routing::PacketOrigin;
17use crate::RoutingTableId;
18
19pub struct RulesTable<I: Ip, D> {
21 rules: Vec<Rule<I, D>>,
23}
24
25impl<I: Ip, D> RulesTable<I, D> {
26 pub(crate) fn new(main_table_id: RoutingTableId<I, D>) -> Self {
27 Self {
30 rules: alloc::vec![Rule {
31 matcher: RuleMatcher::match_all_packets(),
32 action: RuleAction::Lookup(main_table_id)
33 }],
34 }
35 }
36
37 pub(crate) fn iter(&self) -> impl Iterator<Item = &'_ Rule<I, D>> {
38 self.rules.iter()
39 }
40
41 #[cfg(any(test, feature = "testutils"))]
43 pub fn rules_mut(&mut self) -> &mut Vec<Rule<I, D>> {
44 &mut self.rules
45 }
46
47 pub fn replace(&mut self, new_rules: Vec<Rule<I, D>>) {
49 self.rules = new_rules;
50 }
51}
52
53pub struct Rule<I: Ip, D> {
55 pub matcher: RuleMatcher<I>,
57 pub action: RuleAction<RoutingTableId<I, D>>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum RuleAction<Lookup> {
64 Unreachable,
66 Lookup(Lookup),
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum BoundDeviceMatcher {
73 DeviceName(DeviceNameMatcher),
75 Unbound,
77}
78
79impl<'a, D: DeviceWithName> Matcher<Option<&'a D>> for BoundDeviceMatcher {
80 fn matches(&self, actual: &Option<&'a D>) -> bool {
81 match self {
82 BoundDeviceMatcher::DeviceName(name_matcher) => {
83 name_matcher.required_matches(actual.as_deref())
84 }
85 BoundDeviceMatcher::Unbound => actual.is_none(),
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum TrafficOriginMatcher {
97 Local {
100 bound_device_matcher: Option<BoundDeviceMatcher>,
102 },
103 NonLocal,
105}
106
107impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for SubnetMatcher<I::Addr> {
108 fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
109 match actual {
110 PacketOrigin::Local { bound_address, bound_device: _ } => {
111 self.required_matches(bound_address.as_deref())
112 }
113 PacketOrigin::NonLocal { source_address, incoming_device: _ } => {
114 self.matches(source_address.deref())
115 }
116 }
117 }
118}
119
120impl<'a, I: Ip, D: DeviceWithName> Matcher<PacketOrigin<I, &'a D>> for TrafficOriginMatcher {
121 fn matches(&self, actual: &PacketOrigin<I, &'a D>) -> bool {
122 match (self, actual) {
123 (
124 TrafficOriginMatcher::Local { bound_device_matcher },
125 PacketOrigin::Local { bound_address: _, bound_device },
126 ) => bound_device_matcher.matches(bound_device),
127 (
128 TrafficOriginMatcher::NonLocal,
129 PacketOrigin::NonLocal { source_address: _, incoming_device: _ },
130 ) => true,
131 (TrafficOriginMatcher::Local { .. }, PacketOrigin::NonLocal { .. })
132 | (TrafficOriginMatcher::NonLocal, PacketOrigin::Local { .. }) => false,
133 }
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub enum MarkMatcher {
140 Unmarked,
142 Marked {
144 mask: u32,
146 start: u32,
148 end: u32,
150 },
151}
152
153impl Matcher<Mark> for MarkMatcher {
154 fn matches(&self, Mark(actual): &Mark) -> bool {
155 match self {
156 MarkMatcher::Unmarked => actual.is_none(),
157 MarkMatcher::Marked { mask, start, end } => {
158 actual.is_some_and(|actual| (*start..=*end).contains(&(actual & *mask)))
159 }
160 }
161 }
162}
163
164#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
166pub struct MarkMatchers(MarkStorage<Option<MarkMatcher>>);
167
168impl MarkMatchers {
169 pub fn new(matchers: impl IntoIterator<Item = (MarkDomain, MarkMatcher)>) -> Self {
177 MarkMatchers(MarkStorage::new(matchers))
178 }
179
180 pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &Option<MarkMatcher>)> {
182 let Self(storage) = self;
183 storage.iter()
184 }
185}
186
187impl Matcher<Marks> for MarkMatchers {
188 fn matches(&self, actual: &Marks) -> bool {
189 let Self(matchers) = self;
190 matchers.zip_with(actual).all(|(_domain, matcher, actual)| matcher.matches(actual))
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
198pub struct RuleMatcher<I: Ip> {
199 pub source_address_matcher: Option<SubnetMatcher<I::Addr>>,
205 pub traffic_origin_matcher: Option<TrafficOriginMatcher>,
208 pub mark_matchers: MarkMatchers,
210}
211
212impl<I: Ip> RuleMatcher<I> {
213 pub fn match_all_packets() -> Self {
215 RuleMatcher {
216 source_address_matcher: None,
217 traffic_origin_matcher: None,
218 mark_matchers: MarkMatchers::default(),
219 }
220 }
221}
222
223pub struct RuleInput<'a, I: Ip, D> {
225 pub(crate) packet_origin: PacketOrigin<I, &'a D>,
226 pub(crate) marks: &'a Marks,
227}
228
229impl<'a, I: Ip, D: DeviceWithName> Matcher<RuleInput<'a, I, D>> for RuleMatcher<I> {
230 fn matches(&self, actual: &RuleInput<'a, I, D>) -> bool {
231 let Self { source_address_matcher, traffic_origin_matcher, mark_matchers } = self;
232 let RuleInput { packet_origin, marks } = actual;
233 source_address_matcher.matches(packet_origin)
234 && traffic_origin_matcher.matches(packet_origin)
235 && mark_matchers.matches(marks)
236 }
237}
238
239#[cfg(test)]
240mod test {
241 use ip_test_macro::ip_test;
242 use net_types::ip::Subnet;
243 use net_types::SpecifiedAddr;
244 use netstack3_base::testutil::{FakeDeviceId, MultipleDevicesId, TestIpExt};
245 use test_case::test_case;
246
247 use super::*;
248
249 #[ip_test(I)]
250 #[test_case(None, None => true)]
251 #[test_case(None, Some(MultipleDevicesId::A) => true)]
252 #[test_case(
253 Some(BoundDeviceMatcher::Unbound),
254 None => true)]
255 #[test_case(
256 Some(BoundDeviceMatcher::Unbound),
257 Some(MultipleDevicesId::A) => false)]
258 #[test_case(
259 Some(BoundDeviceMatcher::DeviceName(DeviceNameMatcher("A".into()))),
260 None => false)]
261 #[test_case(
262 Some(BoundDeviceMatcher::DeviceName(DeviceNameMatcher("A".into()))),
263 Some(MultipleDevicesId::A) => true)]
264 #[test_case(
265 Some(BoundDeviceMatcher::DeviceName(DeviceNameMatcher("A".into()))),
266 Some(MultipleDevicesId::B) => false)]
267 fn rule_matcher_matches_bound_device<I: TestIpExt>(
268 bound_device_matcher: Option<BoundDeviceMatcher>,
269 bound_device: Option<MultipleDevicesId>,
270 ) -> bool {
271 let matcher = RuleMatcher::<I> {
272 traffic_origin_matcher: Some(TrafficOriginMatcher::Local { bound_device_matcher }),
273 ..RuleMatcher::match_all_packets()
274 };
275 let input = RuleInput {
276 packet_origin: PacketOrigin::Local {
277 bound_address: None,
278 bound_device: bound_device.as_ref(),
279 },
280 marks: &Default::default(),
281 };
282 matcher.matches(&input)
283 }
284
285 #[ip_test(I)]
286 #[test_case(None, None => true)]
287 #[test_case(None, Some(I::LOOPBACK_ADDRESS) => true)]
288 #[test_case(
289 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
290 None => false)]
291 #[test_case(
292 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
293 Some(<I as TestIpExt>::TEST_ADDRS.local_ip) => true)]
294 #[test_case(
295 Some(<I as TestIpExt>::TEST_ADDRS.subnet),
296 Some(<I as TestIpExt>::get_other_remote_ip_address(1)) => false)]
297 fn rule_matcher_matches_local_addr<I: TestIpExt>(
298 source_address_subnet: Option<Subnet<I::Addr>>,
299 bound_address: Option<SpecifiedAddr<I::Addr>>,
300 ) -> bool {
301 let matcher = RuleMatcher::<I> {
302 source_address_matcher: source_address_subnet.map(SubnetMatcher),
303 ..RuleMatcher::match_all_packets()
304 };
305 let marks = Default::default();
306 let input = RuleInput::<'_, _, FakeDeviceId> {
307 packet_origin: PacketOrigin::Local { bound_address, bound_device: None },
308 marks: &marks,
309 };
310 matcher.matches(&input)
311 }
312
313 #[ip_test(I)]
314 #[test_case(None, PacketOrigin::Local {
315 bound_address: None,
316 bound_device: None
317 } => true)]
318 #[test_case(None, PacketOrigin::NonLocal {
319 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
320 incoming_device: &FakeDeviceId
321 } => true)]
322 #[test_case(Some(TrafficOriginMatcher::Local {
323 bound_device_matcher: None
324 }), PacketOrigin::Local {
325 bound_address: None,
326 bound_device: None
327 } => true)]
328 #[test_case(Some(TrafficOriginMatcher::NonLocal),
329 PacketOrigin::NonLocal {
330 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
331 incoming_device: &FakeDeviceId
332 } => true)]
333 #[test_case(Some(TrafficOriginMatcher::Local { bound_device_matcher: None }),
334 PacketOrigin::NonLocal {
335 source_address: <I as TestIpExt>::TEST_ADDRS.remote_ip,
336 incoming_device: &FakeDeviceId
337 } => false)]
338 #[test_case(Some(TrafficOriginMatcher::NonLocal),
339 PacketOrigin::Local {
340 bound_address: None,
341 bound_device: None
342 } => false)]
343 fn rule_matcher_matches_locally_generated<I: TestIpExt>(
344 traffic_origin_matcher: Option<TrafficOriginMatcher>,
345 packet_origin: PacketOrigin<I, &'static FakeDeviceId>,
346 ) -> bool {
347 let matcher =
348 RuleMatcher::<I> { traffic_origin_matcher, ..RuleMatcher::match_all_packets() };
349 let marks = Default::default();
350 let input = RuleInput::<'_, _, FakeDeviceId> { packet_origin, marks: &marks };
351 matcher.matches(&input)
352 }
353
354 #[ip_test(I)]
355 #[test_case::test_matrix(
356 [
357 None,
358 Some(<I as TestIpExt>::TEST_ADDRS.local_ip),
359 Some(<I as TestIpExt>::get_other_remote_ip_address(1))
360 ],
361 [
362 None,
363 Some(&MultipleDevicesId::A),
364 Some(&MultipleDevicesId::B),
365 Some(&MultipleDevicesId::C),
366 ],
367 [true, false]
368 )]
369 fn rule_matcher_matches_multiple_conditions<I: TestIpExt>(
370 ip: Option<SpecifiedAddr<I::Addr>>,
371 device: Option<&'static MultipleDevicesId>,
372 locally_generated: bool,
373 ) {
374 let matcher = RuleMatcher::<I> {
375 source_address_matcher: Some(SubnetMatcher(I::TEST_ADDRS.subnet)),
376 traffic_origin_matcher: Some(TrafficOriginMatcher::Local {
377 bound_device_matcher: Some(BoundDeviceMatcher::DeviceName(DeviceNameMatcher(
378 "A".into(),
379 ))),
380 }),
381 ..RuleMatcher::match_all_packets()
382 };
383
384 let packet_origin = if locally_generated {
385 PacketOrigin::Local { bound_address: ip, bound_device: device }
386 } else {
387 let (Some(source_address), Some(incoming_device)) = (ip, device) else {
388 return;
389 };
390 PacketOrigin::NonLocal { source_address, incoming_device }
391 };
392
393 let input = RuleInput { packet_origin, marks: &Default::default() };
394
395 if ip == Some(I::TEST_ADDRS.local_ip)
396 && (device == Some(&MultipleDevicesId::A))
397 && locally_generated
398 {
399 assert!(matcher.matches(&input))
400 } else {
401 assert!(!matcher.matches(&input))
402 }
403 }
404
405 #[test_case(MarkMatcher::Unmarked, Mark(None) => true)]
406 #[test_case(MarkMatcher::Unmarked, Mark(Some(0)) => false)]
407 #[test_case(MarkMatcher::Marked {
408 mask: 1,
409 start: 0,
410 end: 0,
411 }, Mark(None) => false)]
412 #[test_case(MarkMatcher::Marked {
413 mask: 1,
414 start: 0,
415 end: 0,
416 }, Mark(Some(0)) => true)]
417 #[test_case(MarkMatcher::Marked {
418 mask: 1,
419 start: 0,
420 end: 0,
421 }, Mark(Some(1)) => false)]
422 #[test_case(MarkMatcher::Marked {
423 mask: 1,
424 start: 0,
425 end: 0,
426 }, Mark(Some(2)) => true)]
427 #[test_case(MarkMatcher::Marked {
428 mask: 1,
429 start: 0,
430 end: 0,
431 }, Mark(Some(3)) => false)]
432 fn mark_matcher(matcher: MarkMatcher, mark: Mark) -> bool {
433 matcher.matches(&mark)
434 }
435
436 #[test_case(
437 MarkMatchers::new(
438 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
439 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
440 ),
441 Marks::new([]) => true
442 )]
443 #[test_case(
444 MarkMatchers::new(
445 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
446 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
447 ),
448 Marks::new([(MarkDomain::Mark1, 1)]) => false
449 )]
450 #[test_case(
451 MarkMatchers::new(
452 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
453 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
454 ),
455 Marks::new([(MarkDomain::Mark2, 1)]) => false
456 )]
457 #[test_case(
458 MarkMatchers::new(
459 [(MarkDomain::Mark1, MarkMatcher::Unmarked),
460 (MarkDomain::Mark2, MarkMatcher::Unmarked)]
461 ),
462 Marks::new([
463 (MarkDomain::Mark1, 1),
464 (MarkDomain::Mark2, 1),
465 ]) => false
466 )]
467 fn mark_matchers(matchers: MarkMatchers, marks: Marks) -> bool {
468 matchers.matches(&marks)
469 }
470}