1use core::fmt::Debug;
6
7use net_types::ip::{Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::{
10 InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
11 StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
12};
13
14use packet_formats::ip::IpExt;
15
16use crate::matchers::InterfaceProperties;
17use crate::state::State;
18use crate::{FilterIpExt, IpPacket};
19
20pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
26 type DeviceClass: Clone + Debug;
28}
29
30pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
32impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
33
34pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
43 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
44{
45 type NatCtx<'a>: NatContext<
48 I,
49 BT,
50 DeviceId = Self::DeviceId,
51 WeakAddressId = Self::WeakAddressId,
52 >;
53
54 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
56 &mut self,
57 cb: F,
58 ) -> O {
59 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
60 }
61
62 fn with_filter_state_and_nat_ctx<
65 O,
66 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
67 >(
68 &mut self,
69 cb: F,
70 ) -> O;
71}
72
73pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
75 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
76{
77 fn get_local_addr_for_remote(
79 &mut self,
80 device_id: &Self::DeviceId,
81 remote: Option<SpecifiedAddr<I::Addr>>,
82 ) -> Option<Self::AddressId>;
83
84 fn get_address_id(
87 &mut self,
88 device_id: &Self::DeviceId,
89 addr: IpDeviceAddr<I::Addr>,
90 ) -> Option<Self::AddressId>;
91}
92
93pub trait FilterContext<BT: FilterBindingsTypes>:
96 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
97 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
98{
99 fn with_all_filter_state_mut<
101 O,
102 F: FnOnce(
103 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
104 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
105 ) -> O,
106 >(
107 &mut self,
108 cb: F,
109 ) -> O;
110}
111
112#[derive(Copy, Clone, Debug, Eq, PartialEq)]
114pub enum SocketEgressFilterResult {
115 Pass {
117 congestion: bool,
119 },
120
121 Drop {
123 congestion: bool,
125 },
126}
127
128pub trait SocketOpsFilter<D: StrongDeviceIdentifier, T> {
130 fn on_egress<I: FilterIpExt, P: IpPacket<I>>(
132 &self,
133 packet: &P,
134 device: &D,
135 tx_metadata: &T,
136 marks: &Marks,
137 ) -> SocketEgressFilterResult;
138}
139
140pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
142 TxMetadataBindingsTypes
143{
144 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D, Self::TxMetadata>;
146}
147
148#[cfg(feature = "testutils")]
149impl<
150 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
151 Event: Debug + 'static,
152 State: 'static,
153 FrameMeta: 'static,
154 > FilterBindingsTypes
155 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
156{
157 type DeviceClass = ();
158}
159
160#[cfg(test)]
161pub(crate) mod testutil {
162 use alloc::collections::HashMap;
163 use alloc::sync::{Arc, Weak};
164 use alloc::vec::Vec;
165 use core::hash::{Hash, Hasher};
166 use core::ops::Deref;
167 use core::time::Duration;
168
169 use derivative::Derivative;
170 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
171 use netstack3_base::testutil::{
172 FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
173 WithFakeTimerContext,
174 };
175 use netstack3_base::{
176 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
177 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
178 };
179
180 use super::*;
181 use crate::conntrack;
182 use crate::logic::nat::NatConfig;
183 use crate::logic::FilterTimerId;
184 use crate::matchers::testutil::FakeDeviceId;
185 use crate::state::validation::ValidRoutines;
186 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
187
188 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
189
190 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
191
192 #[derive(Debug)]
193 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
194 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
195 );
196
197 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
198 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
199
200 #[derive(Clone, Debug)]
201 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
202 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
203 );
204
205 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
206 fn eq(&self, other: &Self) -> bool {
207 let Self(lhs) = self;
208 let Self(rhs) = other;
209 Weak::ptr_eq(lhs, rhs)
210 }
211 }
212
213 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
214
215 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
216 fn hash<H: Hasher>(&self, state: &mut H) {
217 let Self(this) = self;
218 this.as_ptr().hash(state)
219 }
220 }
221
222 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
223 type Strong = FakeAddressId<I>;
224
225 fn upgrade(&self) -> Option<Self::Strong> {
226 let Self(inner) = self;
227 inner.upgrade().map(FakeAddressId)
228 }
229
230 fn is_assigned(&self) -> bool {
231 let Self(inner) = self;
232 inner.strong_count() != 0
233 }
234 }
235
236 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
237 fn record<Inspector: netstack3_base::Inspector>(
238 &self,
239 _name: &str,
240 _inspector: &mut Inspector,
241 ) {
242 unimplemented!()
243 }
244 }
245
246 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
247 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
248
249 fn deref(&self) -> &Self::Target {
250 let Self(inner) = self;
251 inner.deref()
252 }
253 }
254
255 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
256 type Weak = FakeWeakAddressId<I>;
257
258 fn downgrade(&self) -> Self::Weak {
259 let Self(inner) = self;
260 FakeWeakAddressId(Arc::downgrade(inner))
261 }
262
263 fn addr(&self) -> IpDeviceAddr<I::Addr> {
264 let Self(inner) = self;
265
266 #[derive(GenericOverIp)]
267 #[generic_over_ip(I, Ip)]
268 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
269 I::map_ip(
270 WrapIn(inner.addr()),
271 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
272 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
273 )
274 }
275
276 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
277 let Self(inner) = self;
278 **inner
279 }
280 }
281
282 #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
283 pub enum FakeDeviceClass {
284 Ethernet,
285 Wlan,
286 }
287
288 pub struct FakeCtx<I: TestIpExt> {
289 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
290 nat: FakeNatCtx<I>,
291 }
292
293 #[derive(Derivative)]
294 #[derivative(Default(bound = ""))]
295 pub struct FakeNatCtx<I: TestIpExt> {
296 pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
297 }
298
299 impl<I: TestIpExt> FakeCtx<I> {
300 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
301 Self {
302 state: State {
303 installed_routines: ValidRoutines::default(),
304 uninstalled_routines: Vec::default(),
305 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
306 nat_installed: OneWayBoolean::default(),
307 },
308 nat: FakeNatCtx::default(),
309 }
310 }
311
312 pub fn with_ip_routines(
313 bindings_ctx: &mut FakeBindingsCtx<I>,
314 routines: IpRoutines<I, FakeDeviceClass, ()>,
315 ) -> Self {
316 let (installed_routines, uninstalled_routines) =
317 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
318 .expect("invalid state");
319 Self {
320 state: State {
321 installed_routines,
322 uninstalled_routines,
323 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
324 nat_installed: OneWayBoolean::default(),
325 },
326 nat: FakeNatCtx::default(),
327 }
328 }
329
330 pub fn with_nat_routines_and_device_addrs(
331 bindings_ctx: &mut FakeBindingsCtx<I>,
332 routines: NatRoutines<I, FakeDeviceClass, ()>,
333 device_addrs: impl IntoIterator<
334 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
335 >,
336 ) -> Self {
337 let (installed_routines, uninstalled_routines) =
338 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
339 .expect("invalid state");
340 Self {
341 state: State {
342 installed_routines,
343 uninstalled_routines,
344 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
345 nat_installed: OneWayBoolean::TRUE,
346 },
347 nat: FakeNatCtx {
348 device_addrs: device_addrs
349 .into_iter()
350 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
351 .collect(),
352 },
353 }
354 }
355
356 pub fn conntrack(
357 &mut self,
358 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
359 &self.state.conntrack
360 }
361 }
362
363 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
364 type DeviceId = FakeDeviceId;
365 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
366 }
367
368 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
369 type AddressId = FakeAddressId<I>;
370 type WeakAddressId = FakeWeakAddressId<I>;
371 }
372
373 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
374 type NatCtx<'a> = FakeNatCtx<I>;
375
376 fn with_filter_state_and_nat_ctx<
377 O,
378 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
379 >(
380 &mut self,
381 cb: F,
382 ) -> O {
383 let Self { state, nat } = self;
384 cb(state, nat)
385 }
386 }
387
388 impl<I: TestIpExt> FakeNatCtx<I> {
389 pub fn new(
390 device_addrs: impl IntoIterator<
391 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
392 >,
393 ) -> Self {
394 Self {
395 device_addrs: device_addrs
396 .into_iter()
397 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
398 .collect(),
399 }
400 }
401 }
402
403 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
404 type DeviceId = FakeDeviceId;
405 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
406 }
407
408 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
409 type AddressId = FakeAddressId<I>;
410 type WeakAddressId = FakeWeakAddressId<I>;
411 }
412
413 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
414 fn get_local_addr_for_remote(
415 &mut self,
416 device_id: &Self::DeviceId,
417 _remote: Option<SpecifiedAddr<I::Addr>>,
418 ) -> Option<Self::AddressId> {
419 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
420 Some(FakeAddressId(primary.clone()))
421 }
422
423 fn get_address_id(
424 &mut self,
425 device_id: &Self::DeviceId,
426 addr: IpDeviceAddr<I::Addr>,
427 ) -> Option<Self::AddressId> {
428 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
429 let id = FakeAddressId(id.clone());
430 if id.addr() == addr {
431 Some(id)
432 } else {
433 None
434 }
435 }
436 }
437
438 pub struct FakeBindingsCtx<I: Ip> {
439 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
440 pub rng: FakeCryptoRng,
441 }
442
443 impl<I: Ip> FakeBindingsCtx<I> {
444 pub(crate) fn new() -> Self {
445 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
446 }
447
448 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
449 self.timer_ctx.instant.sleep(time_elapsed)
450 }
451 }
452
453 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
454 type Instant = FakeInstant;
455 type AtomicInstant = FakeAtomicInstant;
456 }
457
458 impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
459 type DeviceClass = FakeDeviceClass;
460 }
461
462 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
463 fn now(&self) -> Self::Instant {
464 self.timer_ctx.now()
465 }
466 }
467
468 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
469 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
470 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
471 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
472 }
473
474 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
475 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
476 self.timer_ctx.new_timer(id)
477 }
478
479 fn schedule_timer_instant(
480 &mut self,
481 time: Self::Instant,
482 timer: &mut Self::Timer,
483 ) -> Option<Self::Instant> {
484 self.timer_ctx.schedule_timer_instant(time, timer)
485 }
486
487 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
488 self.timer_ctx.cancel_timer(timer)
489 }
490
491 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
492 self.timer_ctx.scheduled_instant(timer)
493 }
494
495 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
496 self.timer_ctx.unique_timer_id(timer)
497 }
498 }
499
500 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
501 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
502 &self,
503 f: F,
504 ) -> O {
505 f(&self.timer_ctx)
506 }
507
508 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
509 &mut self,
510 f: F,
511 ) -> O {
512 f(&mut self.timer_ctx)
513 }
514 }
515
516 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
517 type Rng<'a>
518 = FakeCryptoRng
519 where
520 Self: 'a;
521
522 fn rng(&mut self) -> Self::Rng<'_> {
523 self.rng.clone()
524 }
525 }
526}