1use core::fmt::Debug;
6
7use net_types::ip::{Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
12 StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::InterfaceProperties;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
27 type DeviceClass: Clone + Debug;
29}
30
31pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
33impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
34
35pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
44 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
45{
46 type NatCtx<'a>: NatContext<
49 I,
50 BT,
51 DeviceId = Self::DeviceId,
52 WeakAddressId = Self::WeakAddressId,
53 >;
54
55 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
57 &mut self,
58 cb: F,
59 ) -> O {
60 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
61 }
62
63 fn with_filter_state_and_nat_ctx<
66 O,
67 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
68 >(
69 &mut self,
70 cb: F,
71 ) -> O;
72}
73
74pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
76 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
77{
78 fn get_local_addr_for_remote(
80 &mut self,
81 device_id: &Self::DeviceId,
82 remote: Option<SpecifiedAddr<I::Addr>>,
83 ) -> Option<Self::AddressId>;
84
85 fn get_address_id(
88 &mut self,
89 device_id: &Self::DeviceId,
90 addr: IpDeviceAddr<I::Addr>,
91 ) -> Option<Self::AddressId>;
92}
93
94pub trait FilterContext<BT: FilterBindingsTypes>:
97 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
98 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
99{
100 fn with_all_filter_state_mut<
102 O,
103 F: FnOnce(
104 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
105 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
106 ) -> O,
107 >(
108 &mut self,
109 cb: F,
110 ) -> O;
111}
112
113#[derive(Copy, Clone, Debug, Eq, PartialEq)]
115pub enum SocketEgressFilterResult {
116 Pass {
118 congestion: bool,
120 },
121
122 Drop {
124 congestion: bool,
126 },
127}
128
129#[derive(Copy, Clone, Debug, Eq, PartialEq)]
131pub enum SocketIngressFilterResult {
132 Accept,
134
135 Drop,
137}
138
139pub trait SocketOpsFilter<D: StrongDeviceIdentifier, T> {
141 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
143 &self,
144 packet: &P,
145 device: &D,
146 tx_metadata: &T,
147 marks: &Marks,
148 ) -> SocketEgressFilterResult;
149
150 fn on_ingress(
152 &self,
153 packet: &FragmentedByteSlice<'_, &[u8]>,
154 device: &D,
155 cookie: SocketCookie,
156 marks: &Marks,
157 ) -> SocketIngressFilterResult;
158}
159
160pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
162 TxMetadataBindingsTypes
163{
164 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D, Self::TxMetadata>;
166}
167
168#[cfg(feature = "testutils")]
169impl<
170 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
171 Event: Debug + 'static,
172 State: 'static,
173 FrameMeta: 'static,
174 > FilterBindingsTypes
175 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
176{
177 type DeviceClass = ();
178}
179
180#[cfg(test)]
181pub(crate) mod testutil {
182 use alloc::collections::HashMap;
183 use alloc::sync::{Arc, Weak};
184 use alloc::vec::Vec;
185 use core::hash::{Hash, Hasher};
186 use core::ops::Deref;
187 use core::time::Duration;
188
189 use derivative::Derivative;
190 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
191 use netstack3_base::testutil::{
192 FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
193 WithFakeTimerContext,
194 };
195 use netstack3_base::{
196 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
197 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
198 };
199
200 use super::*;
201 use crate::conntrack;
202 use crate::logic::nat::NatConfig;
203 use crate::logic::FilterTimerId;
204 use crate::matchers::testutil::FakeDeviceId;
205 use crate::state::validation::ValidRoutines;
206 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
207
208 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
209
210 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
211
212 #[derive(Debug)]
213 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
214 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
215 );
216
217 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
218 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
219
220 #[derive(Clone, Debug)]
221 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
222 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
223 );
224
225 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
226 fn eq(&self, other: &Self) -> bool {
227 let Self(lhs) = self;
228 let Self(rhs) = other;
229 Weak::ptr_eq(lhs, rhs)
230 }
231 }
232
233 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
234
235 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
236 fn hash<H: Hasher>(&self, state: &mut H) {
237 let Self(this) = self;
238 this.as_ptr().hash(state)
239 }
240 }
241
242 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
243 type Strong = FakeAddressId<I>;
244
245 fn upgrade(&self) -> Option<Self::Strong> {
246 let Self(inner) = self;
247 inner.upgrade().map(FakeAddressId)
248 }
249
250 fn is_assigned(&self) -> bool {
251 let Self(inner) = self;
252 inner.strong_count() != 0
253 }
254 }
255
256 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
257 fn record<Inspector: netstack3_base::Inspector>(
258 &self,
259 _name: &str,
260 _inspector: &mut Inspector,
261 ) {
262 unimplemented!()
263 }
264 }
265
266 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
267 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
268
269 fn deref(&self) -> &Self::Target {
270 let Self(inner) = self;
271 inner.deref()
272 }
273 }
274
275 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
276 type Weak = FakeWeakAddressId<I>;
277
278 fn downgrade(&self) -> Self::Weak {
279 let Self(inner) = self;
280 FakeWeakAddressId(Arc::downgrade(inner))
281 }
282
283 fn addr(&self) -> IpDeviceAddr<I::Addr> {
284 let Self(inner) = self;
285
286 #[derive(GenericOverIp)]
287 #[generic_over_ip(I, Ip)]
288 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
289 I::map_ip(
290 WrapIn(inner.addr()),
291 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
292 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
293 )
294 }
295
296 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
297 let Self(inner) = self;
298 **inner
299 }
300 }
301
302 #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
303 pub enum FakeDeviceClass {
304 Ethernet,
305 Wlan,
306 }
307
308 pub struct FakeCtx<I: TestIpExt> {
309 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
310 nat: FakeNatCtx<I>,
311 }
312
313 #[derive(Derivative)]
314 #[derivative(Default(bound = ""))]
315 pub struct FakeNatCtx<I: TestIpExt> {
316 pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
317 }
318
319 impl<I: TestIpExt> FakeCtx<I> {
320 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
321 Self {
322 state: State {
323 installed_routines: ValidRoutines::default(),
324 uninstalled_routines: Vec::default(),
325 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
326 nat_installed: OneWayBoolean::default(),
327 },
328 nat: FakeNatCtx::default(),
329 }
330 }
331
332 pub fn with_ip_routines(
333 bindings_ctx: &mut FakeBindingsCtx<I>,
334 routines: IpRoutines<I, FakeDeviceClass, ()>,
335 ) -> Self {
336 let (installed_routines, uninstalled_routines) =
337 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
338 .expect("invalid state");
339 Self {
340 state: State {
341 installed_routines,
342 uninstalled_routines,
343 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
344 nat_installed: OneWayBoolean::default(),
345 },
346 nat: FakeNatCtx::default(),
347 }
348 }
349
350 pub fn with_nat_routines_and_device_addrs(
351 bindings_ctx: &mut FakeBindingsCtx<I>,
352 routines: NatRoutines<I, FakeDeviceClass, ()>,
353 device_addrs: impl IntoIterator<
354 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
355 >,
356 ) -> Self {
357 let (installed_routines, uninstalled_routines) =
358 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
359 .expect("invalid state");
360 Self {
361 state: State {
362 installed_routines,
363 uninstalled_routines,
364 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
365 nat_installed: OneWayBoolean::TRUE,
366 },
367 nat: FakeNatCtx {
368 device_addrs: device_addrs
369 .into_iter()
370 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
371 .collect(),
372 },
373 }
374 }
375
376 pub fn conntrack(
377 &mut self,
378 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
379 &self.state.conntrack
380 }
381 }
382
383 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
384 type DeviceId = FakeDeviceId;
385 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
386 }
387
388 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
389 type AddressId = FakeAddressId<I>;
390 type WeakAddressId = FakeWeakAddressId<I>;
391 }
392
393 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
394 type NatCtx<'a> = FakeNatCtx<I>;
395
396 fn with_filter_state_and_nat_ctx<
397 O,
398 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
399 >(
400 &mut self,
401 cb: F,
402 ) -> O {
403 let Self { state, nat } = self;
404 cb(state, nat)
405 }
406 }
407
408 impl<I: TestIpExt> FakeNatCtx<I> {
409 pub fn new(
410 device_addrs: impl IntoIterator<
411 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
412 >,
413 ) -> Self {
414 Self {
415 device_addrs: device_addrs
416 .into_iter()
417 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
418 .collect(),
419 }
420 }
421 }
422
423 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
424 type DeviceId = FakeDeviceId;
425 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
426 }
427
428 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
429 type AddressId = FakeAddressId<I>;
430 type WeakAddressId = FakeWeakAddressId<I>;
431 }
432
433 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
434 fn get_local_addr_for_remote(
435 &mut self,
436 device_id: &Self::DeviceId,
437 _remote: Option<SpecifiedAddr<I::Addr>>,
438 ) -> Option<Self::AddressId> {
439 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
440 Some(FakeAddressId(primary.clone()))
441 }
442
443 fn get_address_id(
444 &mut self,
445 device_id: &Self::DeviceId,
446 addr: IpDeviceAddr<I::Addr>,
447 ) -> Option<Self::AddressId> {
448 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
449 let id = FakeAddressId(id.clone());
450 if id.addr() == addr {
451 Some(id)
452 } else {
453 None
454 }
455 }
456 }
457
458 pub struct FakeBindingsCtx<I: Ip> {
459 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
460 pub rng: FakeCryptoRng,
461 }
462
463 impl<I: Ip> FakeBindingsCtx<I> {
464 pub(crate) fn new() -> Self {
465 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
466 }
467
468 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
469 self.timer_ctx.instant.sleep(time_elapsed)
470 }
471 }
472
473 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
474 type Instant = FakeInstant;
475 type AtomicInstant = FakeAtomicInstant;
476 }
477
478 impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
479 type DeviceClass = FakeDeviceClass;
480 }
481
482 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
483 fn now(&self) -> Self::Instant {
484 self.timer_ctx.now()
485 }
486 }
487
488 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
489 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
490 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
491 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
492 }
493
494 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
495 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
496 self.timer_ctx.new_timer(id)
497 }
498
499 fn schedule_timer_instant(
500 &mut self,
501 time: Self::Instant,
502 timer: &mut Self::Timer,
503 ) -> Option<Self::Instant> {
504 self.timer_ctx.schedule_timer_instant(time, timer)
505 }
506
507 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
508 self.timer_ctx.cancel_timer(timer)
509 }
510
511 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
512 self.timer_ctx.scheduled_instant(timer)
513 }
514
515 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
516 self.timer_ctx.unique_timer_id(timer)
517 }
518 }
519
520 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
521 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
522 &self,
523 f: F,
524 ) -> O {
525 f(&self.timer_ctx)
526 }
527
528 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
529 &mut self,
530 f: F,
531 ) -> O {
532 f(&mut self.timer_ctx)
533 }
534 }
535
536 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
537 type Rng<'a>
538 = FakeCryptoRng
539 where
540 Self: 'a;
541
542 fn rng(&mut self) -> Self::Rng<'_> {
543 self.rng.clone()
544 }
545 }
546}