1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12 MatcherBindingsTypes, RngContext, StrongDeviceIdentifier, TimerBindingsTypes, TimerContext,
13 TxMetadataBindingsTypes,
14};
15use packet::{FragmentedByteSlice, PartialSerializer};
16use packet_formats::ip::IpExt;
17
18use crate::matchers::BindingsPacketMatcher;
19use crate::state::State;
20use crate::{FilterIpExt, IpPacket};
21
22pub trait FilterBindingsTypes:
24 InstantBindingsTypes
25 + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
26 + TimerBindingsTypes
27 + 'static
28{
29}
30
31impl<
32 BT: InstantBindingsTypes
33 + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
34 + TimerBindingsTypes
35 + 'static,
36> FilterBindingsTypes for BT
37{
38}
39
40pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
42impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
43
44pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
53 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
54{
55 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
58
59 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
61 &mut self,
62 cb: F,
63 ) -> O {
64 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
65 }
66
67 fn with_filter_state_and_nat_ctx<
70 O,
71 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
72 >(
73 &mut self,
74 cb: F,
75 ) -> O;
76}
77
78pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
80 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
81{
82 fn get_local_addr_for_remote(
84 &mut self,
85 device_id: &Self::DeviceId,
86 remote: Option<SpecifiedAddr<I::Addr>>,
87 ) -> Option<Self::AddressId>;
88
89 fn get_address_id(
92 &mut self,
93 device_id: &Self::DeviceId,
94 addr: IpDeviceAddr<I::Addr>,
95 ) -> Option<Self::AddressId>;
96}
97
98pub trait FilterContext<BT: FilterBindingsTypes>:
101 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
102 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
103{
104 fn with_all_filter_state_mut<
106 O,
107 F: FnOnce(
108 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
109 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
110 ) -> O,
111 >(
112 &mut self,
113 cb: F,
114 ) -> O;
115}
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum SocketEgressFilterResult {
120 Pass {
122 congestion: bool,
124 },
125
126 Drop {
128 congestion: bool,
130 },
131}
132
133#[derive(Copy, Clone, Debug, Eq, PartialEq)]
135pub enum SocketIngressFilterResult {
136 Accept,
138
139 Drop,
141}
142
143pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
145 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
147 &self,
148 packet: &P,
149 device: &D,
150 cookie: SocketCookie,
151 marks: &Marks,
152 ) -> SocketEgressFilterResult;
153
154 fn on_ingress(
156 &self,
157 ip_version: IpVersion,
158 packet: FragmentedByteSlice<'_, &[u8]>,
159 device: &D,
160 cookie: SocketCookie,
161 marks: &Marks,
162 ) -> SocketIngressFilterResult;
163}
164
165pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
167 TxMetadataBindingsTypes
168{
169 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
171}
172
173#[cfg(any(test, feature = "testutils"))]
174impl<
175 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
176 Event: Debug + 'static,
177 State: 'static,
178 FrameMeta: 'static,
179 D: StrongDeviceIdentifier,
180> SocketOpsFilterBindingContext<D>
181 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
182{
183 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
184 crate::testutil::NoOpSocketOpsFilter
185 }
186}
187
188#[cfg(test)]
189pub(crate) mod testutil {
190 use alloc::sync::{Arc, Weak};
191 use alloc::vec::Vec;
192 use core::hash::{Hash, Hasher};
193 use core::ops::Deref;
194 use core::sync::atomic::AtomicUsize;
195 use core::time::Duration;
196
197 use derivative::Derivative;
198 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
199 use netstack3_base::testutil::{
200 FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
201 FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
202 };
203 use netstack3_base::{
204 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
205 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
206 };
207 use netstack3_hashmap::HashMap;
208
209 use super::*;
210 use crate::conntrack;
211 use crate::logic::FilterTimerId;
212 use crate::logic::nat::NatConfig;
213 use crate::state::validation::ValidRoutines;
214 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
215
216 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
217
218 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
219
220 #[derive(Debug)]
221 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
222 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
223 );
224
225 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
226 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
227
228 #[derive(Clone, Debug)]
229 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
230 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
231 );
232
233 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
234 fn eq(&self, other: &Self) -> bool {
235 let Self(lhs) = self;
236 let Self(rhs) = other;
237 Weak::ptr_eq(lhs, rhs)
238 }
239 }
240
241 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
242
243 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
244 fn hash<H: Hasher>(&self, state: &mut H) {
245 let Self(this) = self;
246 this.as_ptr().hash(state)
247 }
248 }
249
250 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
251 type Strong = FakeAddressId<I>;
252
253 fn upgrade(&self) -> Option<Self::Strong> {
254 let Self(inner) = self;
255 inner.upgrade().map(FakeAddressId)
256 }
257
258 fn is_assigned(&self) -> bool {
259 let Self(inner) = self;
260 inner.strong_count() != 0
261 }
262 }
263
264 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
265 fn record<Inspector: netstack3_base::Inspector>(
266 &self,
267 _name: &str,
268 _inspector: &mut Inspector,
269 ) {
270 unimplemented!()
271 }
272 }
273
274 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
275 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
276
277 fn deref(&self) -> &Self::Target {
278 let Self(inner) = self;
279 inner.deref()
280 }
281 }
282
283 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
284 type Weak = FakeWeakAddressId<I>;
285
286 fn downgrade(&self) -> Self::Weak {
287 let Self(inner) = self;
288 FakeWeakAddressId(Arc::downgrade(inner))
289 }
290
291 fn addr(&self) -> IpDeviceAddr<I::Addr> {
292 let Self(inner) = self;
293
294 #[derive(GenericOverIp)]
295 #[generic_over_ip(I, Ip)]
296 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
297 I::map_ip(
298 WrapIn(inner.addr()),
299 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
300 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
301 )
302 }
303
304 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
305 let Self(inner) = self;
306 **inner
307 }
308 }
309
310 pub struct FakeCtx<I: TestIpExt> {
311 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
312 nat: FakeNatCtx<I>,
313 }
314
315 #[derive(Derivative)]
316 #[derivative(Default(bound = ""))]
317 pub struct FakeNatCtx<I: TestIpExt> {
318 pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
319 }
320
321 impl<I: TestIpExt> FakeCtx<I> {
322 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
323 Self {
324 state: State {
325 installed_routines: ValidRoutines::default(),
326 uninstalled_routines: Vec::default(),
327 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
328 nat_installed: OneWayBoolean::default(),
329 },
330 nat: FakeNatCtx::default(),
331 }
332 }
333
334 pub fn with_ip_routines(
335 bindings_ctx: &mut FakeBindingsCtx<I>,
336 routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
337 ) -> Self {
338 let (installed_routines, uninstalled_routines) =
339 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
340 .expect("invalid state");
341 Self {
342 state: State {
343 installed_routines,
344 uninstalled_routines,
345 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
346 nat_installed: OneWayBoolean::default(),
347 },
348 nat: FakeNatCtx::default(),
349 }
350 }
351
352 pub fn with_nat_routines_and_device_addrs(
353 bindings_ctx: &mut FakeBindingsCtx<I>,
354 routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
355 device_addrs: impl IntoIterator<
356 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
357 >,
358 ) -> Self {
359 let (installed_routines, uninstalled_routines) =
360 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
361 .expect("invalid state");
362 Self {
363 state: State {
364 installed_routines,
365 uninstalled_routines,
366 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
367 nat_installed: OneWayBoolean::TRUE,
368 },
369 nat: FakeNatCtx {
370 device_addrs: device_addrs
371 .into_iter()
372 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
373 .collect(),
374 },
375 }
376 }
377
378 pub fn conntrack(
379 &mut self,
380 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
381 &self.state.conntrack
382 }
383 }
384
385 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
386 type DeviceId = FakeMatcherDeviceId;
387 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
388 }
389
390 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
391 type AddressId = FakeAddressId<I>;
392 type WeakAddressId = FakeWeakAddressId<I>;
393 }
394
395 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
396 type NatCtx<'a> = FakeNatCtx<I>;
397
398 fn with_filter_state_and_nat_ctx<
399 O,
400 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
401 >(
402 &mut self,
403 cb: F,
404 ) -> O {
405 let Self { state, nat } = self;
406 cb(state, nat)
407 }
408 }
409
410 impl<I: TestIpExt> FakeNatCtx<I> {
411 pub fn new(
412 device_addrs: impl IntoIterator<
413 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
414 >,
415 ) -> Self {
416 Self {
417 device_addrs: device_addrs
418 .into_iter()
419 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
420 .collect(),
421 }
422 }
423 }
424
425 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
426 type DeviceId = FakeMatcherDeviceId;
427 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
428 }
429
430 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
431 type AddressId = FakeAddressId<I>;
432 type WeakAddressId = FakeWeakAddressId<I>;
433 }
434
435 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
436 fn get_local_addr_for_remote(
437 &mut self,
438 device_id: &Self::DeviceId,
439 _remote: Option<SpecifiedAddr<I::Addr>>,
440 ) -> Option<Self::AddressId> {
441 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
442 Some(FakeAddressId(primary.clone()))
443 }
444
445 fn get_address_id(
446 &mut self,
447 device_id: &Self::DeviceId,
448 addr: IpDeviceAddr<I::Addr>,
449 ) -> Option<Self::AddressId> {
450 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
451 let id = FakeAddressId(id.clone());
452 if id.addr() == addr { Some(id) } else { None }
453 }
454 }
455
456 pub struct FakeBindingsCtx<I: Ip> {
457 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
458 pub rng: FakeCryptoRng,
459 }
460
461 impl<I: Ip> FakeBindingsCtx<I> {
462 pub(crate) fn new() -> Self {
463 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
464 }
465
466 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
467 self.timer_ctx.instant.sleep(time_elapsed)
468 }
469 }
470
471 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
472 type Instant = FakeInstant;
473 type AtomicInstant = FakeAtomicInstant;
474 }
475
476 #[derive(Debug)]
477 pub struct FakeBindingsPacketMatcher {
478 num_calls: AtomicUsize,
479 result: bool,
480 }
481
482 impl FakeBindingsPacketMatcher {
483 pub fn new(result: bool) -> Arc<Self> {
484 Arc::new(Self { num_calls: AtomicUsize::new(0), result })
485 }
486
487 pub fn num_calls(&self) -> usize {
488 self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
489 }
490 }
491
492 impl BindingsPacketMatcher for FakeBindingsPacketMatcher {
493 fn matches<I: FilterIpExt, P: IpPacket<I>>(&self, _packet: &P) -> bool {
494 let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
495 self.result
496 }
497 }
498
499 impl InspectableValue for FakeBindingsPacketMatcher {
500 fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
501 unimplemented!()
502 }
503 }
504
505 impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
506 type DeviceClass = FakeDeviceClass;
507 type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
508 }
509
510 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
511 fn now(&self) -> Self::Instant {
512 self.timer_ctx.now()
513 }
514 }
515
516 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
517 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
518 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
519 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
520 }
521
522 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
523 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
524 self.timer_ctx.new_timer(id)
525 }
526
527 fn schedule_timer_instant(
528 &mut self,
529 time: Self::Instant,
530 timer: &mut Self::Timer,
531 ) -> Option<Self::Instant> {
532 self.timer_ctx.schedule_timer_instant(time, timer)
533 }
534
535 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
536 self.timer_ctx.cancel_timer(timer)
537 }
538
539 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
540 self.timer_ctx.scheduled_instant(timer)
541 }
542
543 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
544 self.timer_ctx.unique_timer_id(timer)
545 }
546 }
547
548 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
549 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
550 &self,
551 f: F,
552 ) -> O {
553 f(&self.timer_ctx)
554 }
555
556 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
557 &mut self,
558 f: F,
559 ) -> O {
560 f(&mut self.timer_ctx)
561 }
562 }
563
564 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
565 type Rng<'a>
566 = FakeCryptoRng
567 where
568 Self: 'a;
569
570 fn rng(&mut self) -> Self::Rng<'_> {
571 self.rng.clone()
572 }
573 }
574}