1use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11 InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12 MatcherBindingsTypes, RngContext, StrongDeviceIdentifier, TimerBindingsTypes, TimerContext,
13 TxMetadataBindingsTypes,
14};
15use packet::{FragmentedByteSlice, PartialSerializer};
16use packet_formats::ip::IpExt;
17
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21pub trait FilterBindingsTypes:
23 InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
24{
25}
26
27impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
28 FilterBindingsTypes for BT
29{
30}
31
32pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
34impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
35
36pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
45 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
46{
47 type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
50
51 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
53 &mut self,
54 cb: F,
55 ) -> O {
56 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
57 }
58
59 fn with_filter_state_and_nat_ctx<
62 O,
63 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
64 >(
65 &mut self,
66 cb: F,
67 ) -> O;
68}
69
70pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
72 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
73{
74 fn get_local_addr_for_remote(
76 &mut self,
77 device_id: &Self::DeviceId,
78 remote: Option<SpecifiedAddr<I::Addr>>,
79 ) -> Option<Self::AddressId>;
80
81 fn get_address_id(
84 &mut self,
85 device_id: &Self::DeviceId,
86 addr: IpDeviceAddr<I::Addr>,
87 ) -> Option<Self::AddressId>;
88}
89
90pub trait FilterContext<BT: FilterBindingsTypes>:
93 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
94 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
95{
96 fn with_all_filter_state_mut<
98 O,
99 F: FnOnce(
100 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
101 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
102 ) -> O,
103 >(
104 &mut self,
105 cb: F,
106 ) -> O;
107}
108
109#[derive(Copy, Clone, Debug, Eq, PartialEq)]
111pub enum SocketEgressFilterResult {
112 Pass {
114 congestion: bool,
116 },
117
118 Drop {
120 congestion: bool,
122 },
123}
124
125#[derive(Copy, Clone, Debug, Eq, PartialEq)]
127pub enum SocketIngressFilterResult {
128 Accept,
130
131 Drop,
133}
134
135pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
137 fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
139 &self,
140 packet: &P,
141 device: &D,
142 cookie: SocketCookie,
143 marks: &Marks,
144 ) -> SocketEgressFilterResult;
145
146 fn on_ingress(
148 &self,
149 ip_version: IpVersion,
150 packet: FragmentedByteSlice<'_, &[u8]>,
151 device: &D,
152 cookie: SocketCookie,
153 marks: &Marks,
154 ) -> SocketIngressFilterResult;
155}
156
157pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
159 TxMetadataBindingsTypes
160{
161 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
163}
164
165#[cfg(any(test, feature = "testutils"))]
166impl<
167 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
168 Event: Debug + 'static,
169 State: 'static,
170 FrameMeta: 'static,
171 D: StrongDeviceIdentifier,
172> SocketOpsFilterBindingContext<D>
173 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
174{
175 fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
176 crate::testutil::NoOpSocketOpsFilter
177 }
178}
179
180#[cfg(test)]
181pub(crate) mod testutil {
182 use alloc::sync::{Arc, Weak};
183 use alloc::vec::Vec;
184 use core::hash::{Hash, Hasher};
185 use core::ops::Deref;
186 use core::time::Duration;
187
188 use derivative::Derivative;
189 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
190 use netstack3_base::testutil::{
191 FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
192 FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
193 };
194 use netstack3_base::{
195 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
196 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
197 };
198 use netstack3_hashmap::HashMap;
199
200 use super::*;
201 use crate::conntrack;
202 use crate::logic::FilterTimerId;
203 use crate::logic::nat::NatConfig;
204 use crate::state::validation::ValidRoutines;
205 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
206
207 pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
208
209 impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
210
211 #[derive(Debug)]
212 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
213 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
214 );
215
216 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
217 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
218
219 #[derive(Clone, Debug)]
220 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
221 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
222 );
223
224 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
225 fn eq(&self, other: &Self) -> bool {
226 let Self(lhs) = self;
227 let Self(rhs) = other;
228 Weak::ptr_eq(lhs, rhs)
229 }
230 }
231
232 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
233
234 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
235 fn hash<H: Hasher>(&self, state: &mut H) {
236 let Self(this) = self;
237 this.as_ptr().hash(state)
238 }
239 }
240
241 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
242 type Strong = FakeAddressId<I>;
243
244 fn upgrade(&self) -> Option<Self::Strong> {
245 let Self(inner) = self;
246 inner.upgrade().map(FakeAddressId)
247 }
248
249 fn is_assigned(&self) -> bool {
250 let Self(inner) = self;
251 inner.strong_count() != 0
252 }
253 }
254
255 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
256 fn record<Inspector: netstack3_base::Inspector>(
257 &self,
258 _name: &str,
259 _inspector: &mut Inspector,
260 ) {
261 unimplemented!()
262 }
263 }
264
265 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
266 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
267
268 fn deref(&self) -> &Self::Target {
269 let Self(inner) = self;
270 inner.deref()
271 }
272 }
273
274 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
275 type Weak = FakeWeakAddressId<I>;
276
277 fn downgrade(&self) -> Self::Weak {
278 let Self(inner) = self;
279 FakeWeakAddressId(Arc::downgrade(inner))
280 }
281
282 fn addr(&self) -> IpDeviceAddr<I::Addr> {
283 let Self(inner) = self;
284
285 #[derive(GenericOverIp)]
286 #[generic_over_ip(I, Ip)]
287 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
288 I::map_ip(
289 WrapIn(inner.addr()),
290 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
291 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
292 )
293 }
294
295 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
296 let Self(inner) = self;
297 **inner
298 }
299 }
300
301 pub struct FakeCtx<I: TestIpExt> {
302 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
303 nat: FakeNatCtx<I>,
304 }
305
306 #[derive(Derivative)]
307 #[derivative(Default(bound = ""))]
308 pub struct FakeNatCtx<I: TestIpExt> {
309 pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
310 }
311
312 impl<I: TestIpExt> FakeCtx<I> {
313 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
314 Self {
315 state: State {
316 installed_routines: ValidRoutines::default(),
317 uninstalled_routines: Vec::default(),
318 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
319 nat_installed: OneWayBoolean::default(),
320 },
321 nat: FakeNatCtx::default(),
322 }
323 }
324
325 pub fn with_ip_routines(
326 bindings_ctx: &mut FakeBindingsCtx<I>,
327 routines: IpRoutines<I, FakeDeviceClass, ()>,
328 ) -> Self {
329 let (installed_routines, uninstalled_routines) =
330 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
331 .expect("invalid state");
332 Self {
333 state: State {
334 installed_routines,
335 uninstalled_routines,
336 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
337 nat_installed: OneWayBoolean::default(),
338 },
339 nat: FakeNatCtx::default(),
340 }
341 }
342
343 pub fn with_nat_routines_and_device_addrs(
344 bindings_ctx: &mut FakeBindingsCtx<I>,
345 routines: NatRoutines<I, FakeDeviceClass, ()>,
346 device_addrs: impl IntoIterator<
347 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
348 >,
349 ) -> Self {
350 let (installed_routines, uninstalled_routines) =
351 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
352 .expect("invalid state");
353 Self {
354 state: State {
355 installed_routines,
356 uninstalled_routines,
357 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
358 nat_installed: OneWayBoolean::TRUE,
359 },
360 nat: FakeNatCtx {
361 device_addrs: device_addrs
362 .into_iter()
363 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
364 .collect(),
365 },
366 }
367 }
368
369 pub fn conntrack(
370 &mut self,
371 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
372 &self.state.conntrack
373 }
374 }
375
376 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
377 type DeviceId = FakeMatcherDeviceId;
378 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
379 }
380
381 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
382 type AddressId = FakeAddressId<I>;
383 type WeakAddressId = FakeWeakAddressId<I>;
384 }
385
386 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
387 type NatCtx<'a> = FakeNatCtx<I>;
388
389 fn with_filter_state_and_nat_ctx<
390 O,
391 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
392 >(
393 &mut self,
394 cb: F,
395 ) -> O {
396 let Self { state, nat } = self;
397 cb(state, nat)
398 }
399 }
400
401 impl<I: TestIpExt> FakeNatCtx<I> {
402 pub fn new(
403 device_addrs: impl IntoIterator<
404 Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
405 >,
406 ) -> Self {
407 Self {
408 device_addrs: device_addrs
409 .into_iter()
410 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
411 .collect(),
412 }
413 }
414 }
415
416 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
417 type DeviceId = FakeMatcherDeviceId;
418 type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
419 }
420
421 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
422 type AddressId = FakeAddressId<I>;
423 type WeakAddressId = FakeWeakAddressId<I>;
424 }
425
426 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
427 fn get_local_addr_for_remote(
428 &mut self,
429 device_id: &Self::DeviceId,
430 _remote: Option<SpecifiedAddr<I::Addr>>,
431 ) -> Option<Self::AddressId> {
432 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
433 Some(FakeAddressId(primary.clone()))
434 }
435
436 fn get_address_id(
437 &mut self,
438 device_id: &Self::DeviceId,
439 addr: IpDeviceAddr<I::Addr>,
440 ) -> Option<Self::AddressId> {
441 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
442 let id = FakeAddressId(id.clone());
443 if id.addr() == addr { Some(id) } else { None }
444 }
445 }
446
447 pub struct FakeBindingsCtx<I: Ip> {
448 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
449 pub rng: FakeCryptoRng,
450 }
451
452 impl<I: Ip> FakeBindingsCtx<I> {
453 pub(crate) fn new() -> Self {
454 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
455 }
456
457 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
458 self.timer_ctx.instant.sleep(time_elapsed)
459 }
460 }
461
462 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
463 type Instant = FakeInstant;
464 type AtomicInstant = FakeAtomicInstant;
465 }
466
467 impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
468 type DeviceClass = FakeDeviceClass;
469 }
470
471 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
472 fn now(&self) -> Self::Instant {
473 self.timer_ctx.now()
474 }
475 }
476
477 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
478 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
479 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
480 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
481 }
482
483 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
484 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
485 self.timer_ctx.new_timer(id)
486 }
487
488 fn schedule_timer_instant(
489 &mut self,
490 time: Self::Instant,
491 timer: &mut Self::Timer,
492 ) -> Option<Self::Instant> {
493 self.timer_ctx.schedule_timer_instant(time, timer)
494 }
495
496 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
497 self.timer_ctx.cancel_timer(timer)
498 }
499
500 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
501 self.timer_ctx.scheduled_instant(timer)
502 }
503
504 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
505 self.timer_ctx.unique_timer_id(timer)
506 }
507 }
508
509 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
510 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
511 &self,
512 f: F,
513 ) -> O {
514 f(&self.timer_ctx)
515 }
516
517 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
518 &mut self,
519 f: F,
520 ) -> O {
521 f(&mut self.timer_ctx)
522 }
523 }
524
525 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
526 type Rng<'a>
527 = FakeCryptoRng
528 where
529 Self: 'a;
530
531 fn rng(&mut self) -> Self::Rng<'_> {
532 self.rng.clone()
533 }
534 }
535}