1use core::fmt::Debug;
6
7use net_types::ip::{Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::{
10 InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, RngContext, TimerBindingsTypes,
11 TimerContext,
12};
13use packet_formats::ip::IpExt;
14
15use crate::matchers::InterfaceProperties;
16use crate::state::State;
17
18pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
24 type DeviceClass: Clone + Debug;
26}
27
28pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
30impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
31
32pub trait FilterIpContext<I: IpExt, BT: FilterBindingsTypes>:
41 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
42{
43 type NatCtx<'a>: NatContext<
46 I,
47 BT,
48 DeviceId = Self::DeviceId,
49 WeakAddressId = Self::WeakAddressId,
50 >;
51
52 fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
54 &mut self,
55 cb: F,
56 ) -> O {
57 self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
58 }
59
60 fn with_filter_state_and_nat_ctx<
63 O,
64 F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
65 >(
66 &mut self,
67 cb: F,
68 ) -> O;
69}
70
71pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
73 IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
74{
75 fn get_local_addr_for_remote(
77 &mut self,
78 device_id: &Self::DeviceId,
79 remote: Option<SpecifiedAddr<I::Addr>>,
80 ) -> Option<Self::AddressId>;
81
82 fn get_address_id(
85 &mut self,
86 device_id: &Self::DeviceId,
87 addr: IpDeviceAddr<I::Addr>,
88 ) -> Option<Self::AddressId>;
89}
90
91pub trait FilterContext<BT: FilterBindingsTypes>:
94 IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
95 + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
96{
97 fn with_all_filter_state_mut<
99 O,
100 F: FnOnce(
101 &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
102 &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
103 ) -> O,
104 >(
105 &mut self,
106 cb: F,
107 ) -> O;
108}
109
110#[cfg(feature = "testutils")]
111impl<
112 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
113 Event: Debug + 'static,
114 State: 'static,
115 FrameMeta: 'static,
116 > FilterBindingsTypes
117 for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
118{
119 type DeviceClass = ();
120}
121
122#[cfg(test)]
123pub(crate) mod testutil {
124 use alloc::collections::HashMap;
125 use alloc::sync::{Arc, Weak};
126 use alloc::vec::Vec;
127 use core::hash::{Hash, Hasher};
128 use core::ops::Deref;
129 use core::time::Duration;
130
131 use derivative::Derivative;
132 use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
133 use netstack3_base::testutil::{
134 FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
135 WithFakeTimerContext,
136 };
137 use netstack3_base::{
138 AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
139 IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
140 };
141
142 use super::*;
143 use crate::conntrack;
144 use crate::logic::nat::NatConfig;
145 use crate::logic::FilterTimerId;
146 use crate::matchers::testutil::FakeDeviceId;
147 use crate::state::validation::ValidRoutines;
148 use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
149
150 pub trait TestIpExt: IpExt + AssignedAddrIpExt {}
151
152 impl<I: IpExt + AssignedAddrIpExt> TestIpExt for I {}
153
154 #[derive(Debug)]
155 pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
156 pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
157 );
158
159 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
160 pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
161
162 #[derive(Clone, Debug)]
163 pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
164 pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
165 );
166
167 impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
168 fn eq(&self, other: &Self) -> bool {
169 let Self(lhs) = self;
170 let Self(rhs) = other;
171 Weak::ptr_eq(lhs, rhs)
172 }
173 }
174
175 impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
176
177 impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
178 fn hash<H: Hasher>(&self, state: &mut H) {
179 let Self(this) = self;
180 this.as_ptr().hash(state)
181 }
182 }
183
184 impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
185 type Strong = FakeAddressId<I>;
186
187 fn upgrade(&self) -> Option<Self::Strong> {
188 let Self(inner) = self;
189 inner.upgrade().map(FakeAddressId)
190 }
191
192 fn is_assigned(&self) -> bool {
193 let Self(inner) = self;
194 inner.strong_count() != 0
195 }
196 }
197
198 impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
199 fn record<Inspector: netstack3_base::Inspector>(
200 &self,
201 _name: &str,
202 _inspector: &mut Inspector,
203 ) {
204 unimplemented!()
205 }
206 }
207
208 impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
209 type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
210
211 fn deref(&self) -> &Self::Target {
212 let Self(inner) = self;
213 inner.deref()
214 }
215 }
216
217 impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
218 type Weak = FakeWeakAddressId<I>;
219
220 fn downgrade(&self) -> Self::Weak {
221 let Self(inner) = self;
222 FakeWeakAddressId(Arc::downgrade(inner))
223 }
224
225 fn addr(&self) -> IpDeviceAddr<I::Addr> {
226 let Self(inner) = self;
227
228 #[derive(GenericOverIp)]
229 #[generic_over_ip(I, Ip)]
230 struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
231 I::map_ip(
232 WrapIn(inner.addr()),
233 |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
234 |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
235 )
236 }
237
238 fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
239 let Self(inner) = self;
240 **inner
241 }
242 }
243
244 #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
245 pub enum FakeDeviceClass {
246 Ethernet,
247 Wlan,
248 }
249
250 pub struct FakeCtx<I: TestIpExt> {
251 state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
252 nat: FakeNatCtx<I>,
253 }
254
255 #[derive(Derivative)]
256 #[derivative(Default(bound = ""))]
257 pub struct FakeNatCtx<I: TestIpExt> {
258 pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
259 }
260
261 impl<I: TestIpExt> FakeCtx<I> {
262 pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
263 Self {
264 state: State {
265 installed_routines: ValidRoutines::default(),
266 uninstalled_routines: Vec::default(),
267 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
268 nat_installed: OneWayBoolean::default(),
269 },
270 nat: FakeNatCtx::default(),
271 }
272 }
273
274 pub fn with_ip_routines(
275 bindings_ctx: &mut FakeBindingsCtx<I>,
276 routines: IpRoutines<I, FakeDeviceClass, ()>,
277 ) -> Self {
278 let (installed_routines, uninstalled_routines) =
279 ValidRoutines::new(Routines { ip: routines, ..Default::default() })
280 .expect("invalid state");
281 Self {
282 state: State {
283 installed_routines,
284 uninstalled_routines,
285 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
286 nat_installed: OneWayBoolean::default(),
287 },
288 nat: FakeNatCtx::default(),
289 }
290 }
291
292 pub fn with_nat_routines_and_device_addrs(
293 bindings_ctx: &mut FakeBindingsCtx<I>,
294 routines: NatRoutines<I, FakeDeviceClass, ()>,
295 device_addrs: impl IntoIterator<
296 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
297 >,
298 ) -> Self {
299 let (installed_routines, uninstalled_routines) =
300 ValidRoutines::new(Routines { nat: routines, ..Default::default() })
301 .expect("invalid state");
302 Self {
303 state: State {
304 installed_routines,
305 uninstalled_routines,
306 conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
307 nat_installed: OneWayBoolean::TRUE,
308 },
309 nat: FakeNatCtx {
310 device_addrs: device_addrs
311 .into_iter()
312 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
313 .collect(),
314 },
315 }
316 }
317
318 pub fn conntrack(
319 &mut self,
320 ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
321 &self.state.conntrack
322 }
323 }
324
325 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
326 type DeviceId = FakeDeviceId;
327 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
328 }
329
330 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
331 type AddressId = FakeAddressId<I>;
332 type WeakAddressId = FakeWeakAddressId<I>;
333 }
334
335 impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
336 type NatCtx<'a> = FakeNatCtx<I>;
337
338 fn with_filter_state_and_nat_ctx<
339 O,
340 F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
341 >(
342 &mut self,
343 cb: F,
344 ) -> O {
345 let Self { state, nat } = self;
346 cb(state, nat)
347 }
348 }
349
350 impl<I: TestIpExt> FakeNatCtx<I> {
351 pub fn new(
352 device_addrs: impl IntoIterator<
353 Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
354 >,
355 ) -> Self {
356 Self {
357 device_addrs: device_addrs
358 .into_iter()
359 .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
360 .collect(),
361 }
362 }
363 }
364
365 impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
366 type DeviceId = FakeDeviceId;
367 type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
368 }
369
370 impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
371 type AddressId = FakeAddressId<I>;
372 type WeakAddressId = FakeWeakAddressId<I>;
373 }
374
375 impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
376 fn get_local_addr_for_remote(
377 &mut self,
378 device_id: &Self::DeviceId,
379 _remote: Option<SpecifiedAddr<I::Addr>>,
380 ) -> Option<Self::AddressId> {
381 let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
382 Some(FakeAddressId(primary.clone()))
383 }
384
385 fn get_address_id(
386 &mut self,
387 device_id: &Self::DeviceId,
388 addr: IpDeviceAddr<I::Addr>,
389 ) -> Option<Self::AddressId> {
390 let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
391 let id = FakeAddressId(id.clone());
392 if id.addr() == addr {
393 Some(id)
394 } else {
395 None
396 }
397 }
398 }
399
400 pub struct FakeBindingsCtx<I: Ip> {
401 pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
402 pub rng: FakeCryptoRng,
403 }
404
405 impl<I: Ip> FakeBindingsCtx<I> {
406 pub(crate) fn new() -> Self {
407 Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
408 }
409
410 pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
411 self.timer_ctx.instant.sleep(time_elapsed)
412 }
413 }
414
415 impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
416 type Instant = FakeInstant;
417 type AtomicInstant = FakeAtomicInstant;
418 }
419
420 impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
421 type DeviceClass = FakeDeviceClass;
422 }
423
424 impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
425 fn now(&self) -> Self::Instant {
426 self.timer_ctx.now()
427 }
428 }
429
430 impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
431 type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
432 type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
433 type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
434 }
435
436 impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
437 fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
438 self.timer_ctx.new_timer(id)
439 }
440
441 fn schedule_timer_instant(
442 &mut self,
443 time: Self::Instant,
444 timer: &mut Self::Timer,
445 ) -> Option<Self::Instant> {
446 self.timer_ctx.schedule_timer_instant(time, timer)
447 }
448
449 fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
450 self.timer_ctx.cancel_timer(timer)
451 }
452
453 fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
454 self.timer_ctx.scheduled_instant(timer)
455 }
456
457 fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
458 self.timer_ctx.unique_timer_id(timer)
459 }
460 }
461
462 impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
463 fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
464 &self,
465 f: F,
466 ) -> O {
467 f(&self.timer_ctx)
468 }
469
470 fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
471 &mut self,
472 f: F,
473 ) -> O {
474 f(&mut self.timer_ctx)
475 }
476 }
477
478 impl<I: Ip> RngContext for FakeBindingsCtx<I> {
479 type Rng<'a>
480 = FakeCryptoRng
481 where
482 Self: 'a;
483
484 fn rng(&mut self) -> Self::Rng<'_> {
485 self.rng.clone()
486 }
487 }
488}