1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12 MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::FragmentedByteSlice;
15use packet_formats::ip::IpExt;
16
17use crate::FilterIpExt;
18use crate::matchers::BindingsPacketMatcher;
19use crate::packets::FilterIpPacket;
20use crate::state::State;
21
22pub trait FilterBindingsTypes:
24 InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
25{
26}
27
28impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
29 FilterBindingsTypes for BT
30{
31}
32
33pub trait FilterBindingsContext<D>:
35 TimerContext + RngContext + FilterBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher<D>>
36{
37}
38impl<D, BC> FilterBindingsContext<D> for BC
39where
40 BC: TimerContext + RngContext + FilterBindingsTypes,
41 BC::BindingsPacketMatcher: BindingsPacketMatcher<D>,
42{
43}
44
45pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
54 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
55{
56 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
59
60 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
62 &mut self,
63 cb: F,
64 ) -> O {
65 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
66 }
67
68 fn with_filter_state_and_nat_ctx<
71 O,
72 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
73 >(
74 &mut self,
75 cb: F,
76 ) -> O;
77}
78
79pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
81 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
82{
83 fn get_local_addr_for_remote(
85 &mut self,
86 device_id: &Self::DeviceId,
87 remote: Option<SpecifiedAddr<I::Addr>>,
88 ) -> Option<Self::AddressId>;
89
90 fn get_address_id(
93 &mut self,
94 device_id: &Self::DeviceId,
95 addr: IpDeviceAddr<I::Addr>,
96 ) -> Option<Self::AddressId>;
97}
98
99pub trait FilterContext<BT: FilterBindingsTypes>:
102 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
103 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
104{
105 fn with_all_filter_state_mut<
107 O,
108 F: FnOnce(
109 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
110 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
111 ) -> O,
112 >(
113 &mut self,
114 cb: F,
115 ) -> O;
116}
117
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
120pub enum SocketEgressFilterResult {
121 Pass {
123 congestion: bool,
125 },
126
127 Drop {
129 congestion: bool,
131 },
132}
133
134#[derive(Copy, Clone, Debug, Eq, PartialEq)]
136pub enum SocketIngressFilterResult {
137 Accept,
139
140 Drop,
142}
143
144pub trait SocketOpsFilter<D> {
146 fn on_egress<I: FilterIpExt, P: FilterIpPacket<I>>(
148 &self,
149 packet: &P,
150 device: &D,
151 cookie: SocketCookie,
152 marks: &Marks,
153 ) -> SocketEgressFilterResult;
154
155 fn on_ingress(
157 &self,
158 ip_version: IpVersion,
159 packet: FragmentedByteSlice<'_, &[u8]>,
160 device: &D,
161 cookie: SocketCookie,
162 marks: &Marks,
163 ) -> SocketIngressFilterResult;
164}
165
166pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
168 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
170}
171
172#[cfg(any(test, feature = "testutils"))]
173impl<
174 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
175 Event: Debug + 'static,
176 State: 'static,
177 FrameMeta: 'static,
178 D,
179> SocketOpsFilterBindingContext<D>
180 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
181{
182 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
183 crate::testutil::NoOpSocketOpsFilter
184 }
185}
186
187#[cfg(test)]
188pub(crate) mod testutil {
189 use alloc::sync::{Arc, Weak};
190 use alloc::vec::Vec;
191 use core::hash::{Hash, Hasher};
192 use core::ops::Deref;
193 use core::sync::atomic::AtomicUsize;
194 use core::time::Duration;
195
196 use derivative::Derivative;
197 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
198 use netstack3_base::testutil::{
199 FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
200 FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
201 };
202 use netstack3_base::{
203 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
204 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
205 };
206 use netstack3_hashmap::HashMap;
207
208 use super::*;
209 use crate::logic::FilterTimerId;
210 use crate::logic::nat::NatConfig;
211 use crate::state::validation::ValidRoutines;
212 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
213 use crate::{Interfaces, conntrack};
214
215 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
216
217 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
218
219 #[derive(Debug)]
220 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
221 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
222 );
223
224 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
225 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
226
227 #[derive(Clone, Debug)]
228 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
229 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
230 );
231
232 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
233 fn eq(&self, other: &Self) -> bool {
234 let Self(lhs) = self;
235 let Self(rhs) = other;
236 Weak::ptr_eq(lhs, rhs)
237 }
238 }
239
240 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
241
242 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
243 fn hash<H: Hasher>(&self, state: &mut H) {
244 let Self(this) = self;
245 this.as_ptr().hash(state)
246 }
247 }
248
249 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
250 type Strong = FakeAddressId<I>;
251
252 fn upgrade(&self) -> Option<Self::Strong> {
253 let Self(inner) = self;
254 inner.upgrade().map(FakeAddressId)
255 }
256
257 fn is_assigned(&self) -> bool {
258 let Self(inner) = self;
259 inner.strong_count() != 0
260 }
261 }
262
263 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
264 fn record<Inspector: netstack3_base::Inspector>(
265 &self,
266 _name: &str,
267 _inspector: &mut Inspector,
268 ) {
269 unimplemented!()
270 }
271 }
272
273 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
274 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
275
276 fn deref(&self) -> &Self::Target {
277 let Self(inner) = self;
278 inner.deref()
279 }
280 }
281
282 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
283 type Weak = FakeWeakAddressId<I>;
284
285 fn downgrade(&self) -> Self::Weak {
286 let Self(inner) = self;
287 FakeWeakAddressId(Arc::downgrade(inner))
288 }
289
290 fn addr(&self) -> IpDeviceAddr<I::Addr> {
291 let Self(inner) = self;
292
293 #[derive(GenericOverIp)]
294 #[generic_over_ip(I, Ip)]
295 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
296 I::map_ip(
297 WrapIn(inner.addr()),
298 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
299 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
300 )
301 }
302
303 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
304 let Self(inner) = self;
305 **inner
306 }
307 }
308
309 pub struct FakeCtx<I: TestIpExt> {
310 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
311 nat: FakeNatCtx<I>,
312 }
313
314 #[derive(Derivative)]
315 #[derivative(Default(bound = ""))]
316 pub struct FakeNatCtx<I: TestIpExt> {
317 pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
318 }
319
320 impl<I: TestIpExt> FakeCtx<I> {
321 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
322 Self {
323 state: State {
324 installed_routines: ValidRoutines::default(),
325 uninstalled_routines: Vec::default(),
326 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
327 nat_installed: OneWayBoolean::default(),
328 },
329 nat: FakeNatCtx::default(),
330 }
331 }
332
333 pub fn with_ip_routines(
334 bindings_ctx: &mut FakeBindingsCtx<I>,
335 routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
336 ) -> Self {
337 let (installed_routines, uninstalled_routines) =
338 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
339 .expect("invalid state");
340 Self {
341 state: State {
342 installed_routines,
343 uninstalled_routines,
344 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
345 nat_installed: OneWayBoolean::default(),
346 },
347 nat: FakeNatCtx::default(),
348 }
349 }
350
351 pub fn with_nat_routines_and_device_addrs(
352 bindings_ctx: &mut FakeBindingsCtx<I>,
353 routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
354 device_addrs: impl IntoIterator<
355 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
356 >,
357 ) -> Self {
358 let (installed_routines, uninstalled_routines) =
359 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
360 .expect("invalid state");
361 Self {
362 state: State {
363 installed_routines,
364 uninstalled_routines,
365 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
366 nat_installed: OneWayBoolean::TRUE,
367 },
368 nat: FakeNatCtx {
369 device_addrs: device_addrs
370 .into_iter()
371 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
372 .collect(),
373 },
374 }
375 }
376
377 pub fn conntrack(
378 &mut self,
379 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
380 &self.state.conntrack
381 }
382 }
383
384 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
385 type DeviceId = FakeMatcherDeviceId;
386 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
387 }
388
389 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
390 type AddressId = FakeAddressId<I>;
391 type WeakAddressId = FakeWeakAddressId<I>;
392 }
393
394 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
395 type NatCtx<'a> = FakeNatCtx<I>;
396
397 fn with_filter_state_and_nat_ctx<
398 O,
399 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
400 >(
401 &mut self,
402 cb: F,
403 ) -> O {
404 let Self { state, nat } = self;
405 cb(state, nat)
406 }
407 }
408
409 impl<I: TestIpExt> FakeNatCtx<I> {
410 pub fn new(
411 device_addrs: impl IntoIterator<
412 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
413 >,
414 ) -> Self {
415 Self {
416 device_addrs: device_addrs
417 .into_iter()
418 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
419 .collect(),
420 }
421 }
422 }
423
424 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
425 type DeviceId = FakeMatcherDeviceId;
426 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
427 }
428
429 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
430 type AddressId = FakeAddressId<I>;
431 type WeakAddressId = FakeWeakAddressId<I>;
432 }
433
434 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
435 fn get_local_addr_for_remote(
436 &mut self,
437 device_id: &Self::DeviceId,
438 _remote: Option<SpecifiedAddr<I::Addr>>,
439 ) -> Option<Self::AddressId> {
440 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
441 Some(FakeAddressId(primary.clone()))
442 }
443
444 fn get_address_id(
445 &mut self,
446 device_id: &Self::DeviceId,
447 addr: IpDeviceAddr<I::Addr>,
448 ) -> Option<Self::AddressId> {
449 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
450 let id = FakeAddressId(id.clone());
451 if id.addr() == addr { Some(id) } else { None }
452 }
453 }
454
455 pub struct FakeBindingsCtx<I: Ip> {
456 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
457 pub rng: FakeCryptoRng,
458 }
459
460 impl<I: Ip> FakeBindingsCtx<I> {
461 pub(crate) fn new() -> Self {
462 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
463 }
464
465 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
466 self.timer_ctx.instant.sleep(time_elapsed)
467 }
468 }
469
470 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
471 type Instant = FakeInstant;
472 type AtomicInstant = FakeAtomicInstant;
473 }
474
475 #[derive(Debug)]
476 pub struct FakeBindingsPacketMatcher {
477 num_calls: AtomicUsize,
478 result: bool,
479 }
480
481 impl FakeBindingsPacketMatcher {
482 pub fn new(result: bool) -> Arc<Self> {
483 Arc::new(Self { num_calls: AtomicUsize::new(0), result })
484 }
485
486 pub fn num_calls(&self) -> usize {
487 self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
488 }
489 }
490
491 impl BindingsPacketMatcher<FakeMatcherDeviceId> for FakeBindingsPacketMatcher {
492 fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
493 &self,
494 _packet: &P,
495 _interfaces: Interfaces<'_, FakeMatcherDeviceId>,
496 ) -> bool {
497 let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
498 self.result
499 }
500 }
501
502 impl InspectableValue for FakeBindingsPacketMatcher {
503 fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
504 unimplemented!()
505 }
506 }
507
508 impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
509 type DeviceClass = FakeDeviceClass;
510 type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
511 }
512
513 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
514 fn now(&self) -> Self::Instant {
515 self.timer_ctx.now()
516 }
517 }
518
519 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
520 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
521 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
522 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
523 }
524
525 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
526 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
527 self.timer_ctx.new_timer(id)
528 }
529
530 fn schedule_timer_instant(
531 &mut self,
532 time: Self::Instant,
533 timer: &mut Self::Timer,
534 ) -> Option<Self::Instant> {
535 self.timer_ctx.schedule_timer_instant(time, timer)
536 }
537
538 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
539 self.timer_ctx.cancel_timer(timer)
540 }
541
542 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
543 self.timer_ctx.scheduled_instant(timer)
544 }
545
546 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
547 self.timer_ctx.unique_timer_id(timer)
548 }
549 }
550
551 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
552 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
553 &self,
554 f: F,
555 ) -> O {
556 f(&self.timer_ctx)
557 }
558
559 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
560 &mut self,
561 f: F,
562 ) -> O {
563 f(&mut self.timer_ctx)
564 }
565 }
566
567 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
568 type Rng<'a>
569 = FakeCryptoRng
570 where
571 Self: 'a;
572
573 fn rng(&mut self) -> Self::Rng<'_> {
574 self.rng.clone()
575 }
576 }
577}