1use core::fmt::Debug;
6
7use net_types::ip::{IpVersion, Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
12 StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::InterfaceProperties;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
27 type DeviceClass: Clone + Debug;
29}
30
31pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
33impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
34
35pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
44 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
45{
46 type NatCtx<'a>: NatContext<
49 I,
50 BT,
51 DeviceId = Self::DeviceId,
52 WeakAddressId = Self::WeakAddressId,
53 >;
54
55 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
57 &mut self,
58 cb: F,
59 ) -> O {
60 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
61 }
62
63 fn with_filter_state_and_nat_ctx<
66 O,
67 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
68 >(
69 &mut self,
70 cb: F,
71 ) -> O;
72}
73
74pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
76 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
77{
78 fn get_local_addr_for_remote(
80 &mut self,
81 device_id: &Self::DeviceId,
82 remote: Option<SpecifiedAddr<I::Addr>>,
83 ) -> Option<Self::AddressId>;
84
85 fn get_address_id(
88 &mut self,
89 device_id: &Self::DeviceId,
90 addr: IpDeviceAddr<I::Addr>,
91 ) -> Option<Self::AddressId>;
92}
93
94pub trait FilterContext<BT: FilterBindingsTypes>:
97 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
98 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
99{
100 fn with_all_filter_state_mut<
102 O,
103 F: FnOnce(
104 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
105 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
106 ) -> O,
107 >(
108 &mut self,
109 cb: F,
110 ) -> O;
111}
112
113#[derive(Copy, Clone, Debug, Eq, PartialEq)]
115pub enum SocketEgressFilterResult {
116 Pass {
118 congestion: bool,
120 },
121
122 Drop {
124 congestion: bool,
126 },
127}
128
129#[derive(Copy, Clone, Debug, Eq, PartialEq)]
131pub enum SocketIngressFilterResult {
132 Accept,
134
135 Drop,
137}
138
139pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
141 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
143 &self,
144 packet: &P,
145 device: &D,
146 cookie: SocketCookie,
147 marks: &Marks,
148 ) -> SocketEgressFilterResult;
149
150 fn on_ingress(
152 &self,
153 ip_version: IpVersion,
154 packet: FragmentedByteSlice<'_, &[u8]>,
155 device: &D,
156 cookie: SocketCookie,
157 marks: &Marks,
158 ) -> SocketIngressFilterResult;
159}
160
161pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
163 TxMetadataBindingsTypes
164{
165 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
167}
168
169#[cfg(any(test, feature = "testutils"))]
170impl<
171 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
172 Event: Debug + 'static,
173 State: 'static,
174 FrameMeta: 'static,
175 > FilterBindingsTypes
176 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
177{
178 type DeviceClass = ();
179}
180
181#[cfg(any(test, feature = "testutils"))]
182impl<
183 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
184 Event: Debug + 'static,
185 State: 'static,
186 FrameMeta: 'static,
187 D: StrongDeviceIdentifier,
188 > SocketOpsFilterBindingContext<D>
189 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
190{
191 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
192 crate::testutil::NoOpSocketOpsFilter
193 }
194}
195
196#[cfg(test)]
197pub(crate) mod testutil {
198 use alloc::sync::{Arc, Weak};
199 use alloc::vec::Vec;
200 use core::hash::{Hash, Hasher};
201 use core::ops::Deref;
202 use core::time::Duration;
203
204 use derivative::Derivative;
205 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
206 use netstack3_base::testutil::{
207 FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
208 WithFakeTimerContext,
209 };
210 use netstack3_base::{
211 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
212 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
213 };
214 use netstack3_hashmap::HashMap;
215
216 use super::*;
217 use crate::conntrack;
218 use crate::logic::nat::NatConfig;
219 use crate::logic::FilterTimerId;
220 use crate::matchers::testutil::FakeDeviceId;
221 use crate::state::validation::ValidRoutines;
222 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
223
224 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
225
226 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
227
228 #[derive(Debug)]
229 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
230 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
231 );
232
233 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
234 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
235
236 #[derive(Clone, Debug)]
237 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
238 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
239 );
240
241 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
242 fn eq(&self, other: &Self) -> bool {
243 let Self(lhs) = self;
244 let Self(rhs) = other;
245 Weak::ptr_eq(lhs, rhs)
246 }
247 }
248
249 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
250
251 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
252 fn hash<H: Hasher>(&self, state: &mut H) {
253 let Self(this) = self;
254 this.as_ptr().hash(state)
255 }
256 }
257
258 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
259 type Strong = FakeAddressId<I>;
260
261 fn upgrade(&self) -> Option<Self::Strong> {
262 let Self(inner) = self;
263 inner.upgrade().map(FakeAddressId)
264 }
265
266 fn is_assigned(&self) -> bool {
267 let Self(inner) = self;
268 inner.strong_count() != 0
269 }
270 }
271
272 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
273 fn record<Inspector: netstack3_base::Inspector>(
274 &self,
275 _name: &str,
276 _inspector: &mut Inspector,
277 ) {
278 unimplemented!()
279 }
280 }
281
282 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
283 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
284
285 fn deref(&self) -> &Self::Target {
286 let Self(inner) = self;
287 inner.deref()
288 }
289 }
290
291 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
292 type Weak = FakeWeakAddressId<I>;
293
294 fn downgrade(&self) -> Self::Weak {
295 let Self(inner) = self;
296 FakeWeakAddressId(Arc::downgrade(inner))
297 }
298
299 fn addr(&self) -> IpDeviceAddr<I::Addr> {
300 let Self(inner) = self;
301
302 #[derive(GenericOverIp)]
303 #[generic_over_ip(I, Ip)]
304 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
305 I::map_ip(
306 WrapIn(inner.addr()),
307 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
308 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
309 )
310 }
311
312 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
313 let Self(inner) = self;
314 **inner
315 }
316 }
317
318 #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
319 pub enum FakeDeviceClass {
320 Ethernet,
321 Wlan,
322 }
323
324 pub struct FakeCtx<I: TestIpExt> {
325 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
326 nat: FakeNatCtx<I>,
327 }
328
329 #[derive(Derivative)]
330 #[derivative(Default(bound = ""))]
331 pub struct FakeNatCtx<I: TestIpExt> {
332 pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
333 }
334
335 impl<I: TestIpExt> FakeCtx<I> {
336 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
337 Self {
338 state: State {
339 installed_routines: ValidRoutines::default(),
340 uninstalled_routines: Vec::default(),
341 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
342 nat_installed: OneWayBoolean::default(),
343 },
344 nat: FakeNatCtx::default(),
345 }
346 }
347
348 pub fn with_ip_routines(
349 bindings_ctx: &mut FakeBindingsCtx<I>,
350 routines: IpRoutines<I, FakeDeviceClass, ()>,
351 ) -> Self {
352 let (installed_routines, uninstalled_routines) =
353 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
354 .expect("invalid state");
355 Self {
356 state: State {
357 installed_routines,
358 uninstalled_routines,
359 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
360 nat_installed: OneWayBoolean::default(),
361 },
362 nat: FakeNatCtx::default(),
363 }
364 }
365
366 pub fn with_nat_routines_and_device_addrs(
367 bindings_ctx: &mut FakeBindingsCtx<I>,
368 routines: NatRoutines<I, FakeDeviceClass, ()>,
369 device_addrs: impl IntoIterator<
370 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
371 >,
372 ) -> Self {
373 let (installed_routines, uninstalled_routines) =
374 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
375 .expect("invalid state");
376 Self {
377 state: State {
378 installed_routines,
379 uninstalled_routines,
380 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
381 nat_installed: OneWayBoolean::TRUE,
382 },
383 nat: FakeNatCtx {
384 device_addrs: device_addrs
385 .into_iter()
386 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
387 .collect(),
388 },
389 }
390 }
391
392 pub fn conntrack(
393 &mut self,
394 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
395 &self.state.conntrack
396 }
397 }
398
399 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
400 type DeviceId = FakeDeviceId;
401 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
402 }
403
404 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
405 type AddressId = FakeAddressId<I>;
406 type WeakAddressId = FakeWeakAddressId<I>;
407 }
408
409 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
410 type NatCtx<'a> = FakeNatCtx<I>;
411
412 fn with_filter_state_and_nat_ctx<
413 O,
414 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
415 >(
416 &mut self,
417 cb: F,
418 ) -> O {
419 let Self { state, nat } = self;
420 cb(state, nat)
421 }
422 }
423
424 impl<I: TestIpExt> FakeNatCtx<I> {
425 pub fn new(
426 device_addrs: impl IntoIterator<
427 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
428 >,
429 ) -> Self {
430 Self {
431 device_addrs: device_addrs
432 .into_iter()
433 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
434 .collect(),
435 }
436 }
437 }
438
439 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
440 type DeviceId = FakeDeviceId;
441 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
442 }
443
444 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
445 type AddressId = FakeAddressId<I>;
446 type WeakAddressId = FakeWeakAddressId<I>;
447 }
448
449 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
450 fn get_local_addr_for_remote(
451 &mut self,
452 device_id: &Self::DeviceId,
453 _remote: Option<SpecifiedAddr<I::Addr>>,
454 ) -> Option<Self::AddressId> {
455 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
456 Some(FakeAddressId(primary.clone()))
457 }
458
459 fn get_address_id(
460 &mut self,
461 device_id: &Self::DeviceId,
462 addr: IpDeviceAddr<I::Addr>,
463 ) -> Option<Self::AddressId> {
464 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
465 let id = FakeAddressId(id.clone());
466 if id.addr() == addr {
467 Some(id)
468 } else {
469 None
470 }
471 }
472 }
473
474 pub struct FakeBindingsCtx<I: Ip> {
475 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
476 pub rng: FakeCryptoRng,
477 }
478
479 impl<I: Ip> FakeBindingsCtx<I> {
480 pub(crate) fn new() -> Self {
481 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
482 }
483
484 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
485 self.timer_ctx.instant.sleep(time_elapsed)
486 }
487 }
488
489 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
490 type Instant = FakeInstant;
491 type AtomicInstant = FakeAtomicInstant;
492 }
493
494 impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
495 type DeviceClass = FakeDeviceClass;
496 }
497
498 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
499 fn now(&self) -> Self::Instant {
500 self.timer_ctx.now()
501 }
502 }
503
504 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
505 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
506 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
507 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
508 }
509
510 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
511 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
512 self.timer_ctx.new_timer(id)
513 }
514
515 fn schedule_timer_instant(
516 &mut self,
517 time: Self::Instant,
518 timer: &mut Self::Timer,
519 ) -> Option<Self::Instant> {
520 self.timer_ctx.schedule_timer_instant(time, timer)
521 }
522
523 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
524 self.timer_ctx.cancel_timer(timer)
525 }
526
527 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
528 self.timer_ctx.scheduled_instant(timer)
529 }
530
531 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
532 self.timer_ctx.unique_timer_id(timer)
533 }
534 }
535
536 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
537 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
538 &self,
539 f: F,
540 ) -> O {
541 f(&self.timer_ctx)
542 }
543
544 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
545 &mut self,
546 f: F,
547 ) -> O {
548 f(&mut self.timer_ctx)
549 }
550 }
551
552 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
553 type Rng<'a>
554 = FakeCryptoRng
555 where
556 Self: 'a;
557
558 fn rng(&mut self) -> Self::Rng<'_> {
559 self.rng.clone()
560 }
561 }
562}