netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::ip::{Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::{
10    InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, RngContext, TimerBindingsTypes,
11    TimerContext,
12};
13use packet_formats::ip::IpExt;
14
15use crate::matchers::InterfaceProperties;
16use crate::state::State;
17
18/// Trait defining required types for filtering provided by bindings.
19///
20/// Allows rules that match on device class to be installed, storing the
21/// [`FilterBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
22/// Core to have Bindings provide the type since it is platform-specific.
23pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
24    /// The device class type for devices installed in the netstack.
25    type DeviceClass: Clone + Debug;
26}
27
28/// Trait aggregating functionality required from bindings.
29pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
30impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
31
32/// The IP version-specific execution context for packet filtering.
33///
34/// This trait exists to abstract over access to the filtering state. It is
35/// useful to implement filtering logic in terms of this trait, as opposed to,
36/// for example, [`crate::logic::FilterHandler`] methods taking the state
37/// directly as an argument, because it allows Netstack3 Core to use lock
38/// ordering types to enforce that filtering state is only acquired at or before
39/// a given lock level, while keeping test code free of locking concerns.
40pub trait FilterIpContext<I: IpExt, BT: FilterBindingsTypes>:
41    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
42{
43    /// The execution context that allows the filtering engine to perform
44    /// Network Address Translation (NAT).
45    type NatCtx<'a>: NatContext<
46        I,
47        BT,
48        DeviceId = Self::DeviceId,
49        WeakAddressId = Self::WeakAddressId,
50    >;
51
52    /// Calls the function with a reference to filtering state.
53    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
54        &mut self,
55        cb: F,
56    ) -> O {
57        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
58    }
59
60    /// Calls the function with a reference to filtering state and the NAT
61    /// context.
62    fn with_filter_state_and_nat_ctx<
63        O,
64        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
65    >(
66        &mut self,
67        cb: F,
68    ) -> O;
69}
70
71/// The execution context for Network Address Translation (NAT).
72pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
73    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
74{
75    /// Returns the best local address for communicating with the remote.
76    fn get_local_addr_for_remote(
77        &mut self,
78        device_id: &Self::DeviceId,
79        remote: Option<SpecifiedAddr<I::Addr>>,
80    ) -> Option<Self::AddressId>;
81
82    /// Returns a strongly-held reference to the provided address, if it is assigned
83    /// to the specified device.
84    fn get_address_id(
85        &mut self,
86        device_id: &Self::DeviceId,
87        addr: IpDeviceAddr<I::Addr>,
88    ) -> Option<Self::AddressId>;
89}
90
91/// A context for mutably accessing all filtering state at once, to allow IPv4
92/// and IPv6 filtering state to be modified atomically.
93pub trait FilterContext<BT: FilterBindingsTypes>:
94    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
95    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
96{
97    /// Calls the function with a mutable reference to all filtering state.
98    fn with_all_filter_state_mut<
99        O,
100        F: FnOnce(
101            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
102            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
103        ) -> O,
104    >(
105        &mut self,
106        cb: F,
107    ) -> O;
108}
109
110#[cfg(feature = "testutils")]
111impl<
112        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
113        Event: Debug + 'static,
114        State: 'static,
115        FrameMeta: 'static,
116    > FilterBindingsTypes
117    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
118{
119    type DeviceClass = ();
120}
121
122#[cfg(test)]
123pub(crate) mod testutil {
124    use alloc::collections::HashMap;
125    use alloc::sync::{Arc, Weak};
126    use alloc::vec::Vec;
127    use core::hash::{Hash, Hasher};
128    use core::ops::Deref;
129    use core::time::Duration;
130
131    use derivative::Derivative;
132    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
133    use netstack3_base::testutil::{
134        FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
135        WithFakeTimerContext,
136    };
137    use netstack3_base::{
138        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
139        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
140    };
141
142    use super::*;
143    use crate::conntrack;
144    use crate::logic::nat::NatConfig;
145    use crate::logic::FilterTimerId;
146    use crate::matchers::testutil::FakeDeviceId;
147    use crate::state::validation::ValidRoutines;
148    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
149
150    pub trait TestIpExt: IpExt + AssignedAddrIpExt {}
151
152    impl<I: IpExt + AssignedAddrIpExt> TestIpExt for I {}
153
154    #[derive(Debug)]
155    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
156        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
157    );
158
159    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
160    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
161
162    #[derive(Clone, Debug)]
163    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
164        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
165    );
166
167    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
168        fn eq(&self, other: &Self) -> bool {
169            let Self(lhs) = self;
170            let Self(rhs) = other;
171            Weak::ptr_eq(lhs, rhs)
172        }
173    }
174
175    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
176
177    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
178        fn hash<H: Hasher>(&self, state: &mut H) {
179            let Self(this) = self;
180            this.as_ptr().hash(state)
181        }
182    }
183
184    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
185        type Strong = FakeAddressId<I>;
186
187        fn upgrade(&self) -> Option<Self::Strong> {
188            let Self(inner) = self;
189            inner.upgrade().map(FakeAddressId)
190        }
191
192        fn is_assigned(&self) -> bool {
193            let Self(inner) = self;
194            inner.strong_count() != 0
195        }
196    }
197
198    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
199        fn record<Inspector: netstack3_base::Inspector>(
200            &self,
201            _name: &str,
202            _inspector: &mut Inspector,
203        ) {
204            unimplemented!()
205        }
206    }
207
208    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
209        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
210
211        fn deref(&self) -> &Self::Target {
212            let Self(inner) = self;
213            inner.deref()
214        }
215    }
216
217    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
218        type Weak = FakeWeakAddressId<I>;
219
220        fn downgrade(&self) -> Self::Weak {
221            let Self(inner) = self;
222            FakeWeakAddressId(Arc::downgrade(inner))
223        }
224
225        fn addr(&self) -> IpDeviceAddr<I::Addr> {
226            let Self(inner) = self;
227
228            #[derive(GenericOverIp)]
229            #[generic_over_ip(I, Ip)]
230            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
231            I::map_ip(
232                WrapIn(inner.addr()),
233                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
234                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
235            )
236        }
237
238        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
239            let Self(inner) = self;
240            **inner
241        }
242    }
243
244    #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
245    pub enum FakeDeviceClass {
246        Ethernet,
247        Wlan,
248    }
249
250    pub struct FakeCtx<I: TestIpExt> {
251        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
252        nat: FakeNatCtx<I>,
253    }
254
255    #[derive(Derivative)]
256    #[derivative(Default(bound = ""))]
257    pub struct FakeNatCtx<I: TestIpExt> {
258        pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
259    }
260
261    impl<I: TestIpExt> FakeCtx<I> {
262        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
263            Self {
264                state: State {
265                    installed_routines: ValidRoutines::default(),
266                    uninstalled_routines: Vec::default(),
267                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
268                    nat_installed: OneWayBoolean::default(),
269                },
270                nat: FakeNatCtx::default(),
271            }
272        }
273
274        pub fn with_ip_routines(
275            bindings_ctx: &mut FakeBindingsCtx<I>,
276            routines: IpRoutines<I, FakeDeviceClass, ()>,
277        ) -> Self {
278            let (installed_routines, uninstalled_routines) =
279                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
280                    .expect("invalid state");
281            Self {
282                state: State {
283                    installed_routines,
284                    uninstalled_routines,
285                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
286                    nat_installed: OneWayBoolean::default(),
287                },
288                nat: FakeNatCtx::default(),
289            }
290        }
291
292        pub fn with_nat_routines_and_device_addrs(
293            bindings_ctx: &mut FakeBindingsCtx<I>,
294            routines: NatRoutines<I, FakeDeviceClass, ()>,
295            device_addrs: impl IntoIterator<
296                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
297            >,
298        ) -> Self {
299            let (installed_routines, uninstalled_routines) =
300                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
301                    .expect("invalid state");
302            Self {
303                state: State {
304                    installed_routines,
305                    uninstalled_routines,
306                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
307                    nat_installed: OneWayBoolean::TRUE,
308                },
309                nat: FakeNatCtx {
310                    device_addrs: device_addrs
311                        .into_iter()
312                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
313                        .collect(),
314                },
315            }
316        }
317
318        pub fn conntrack(
319            &mut self,
320        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
321            &self.state.conntrack
322        }
323    }
324
325    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
326        type DeviceId = FakeDeviceId;
327        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
328    }
329
330    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
331        type AddressId = FakeAddressId<I>;
332        type WeakAddressId = FakeWeakAddressId<I>;
333    }
334
335    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
336        type NatCtx<'a> = FakeNatCtx<I>;
337
338        fn with_filter_state_and_nat_ctx<
339            O,
340            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
341        >(
342            &mut self,
343            cb: F,
344        ) -> O {
345            let Self { state, nat } = self;
346            cb(state, nat)
347        }
348    }
349
350    impl<I: TestIpExt> FakeNatCtx<I> {
351        pub fn new(
352            device_addrs: impl IntoIterator<
353                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
354            >,
355        ) -> Self {
356            Self {
357                device_addrs: device_addrs
358                    .into_iter()
359                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
360                    .collect(),
361            }
362        }
363    }
364
365    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
366        type DeviceId = FakeDeviceId;
367        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
368    }
369
370    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
371        type AddressId = FakeAddressId<I>;
372        type WeakAddressId = FakeWeakAddressId<I>;
373    }
374
375    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
376        fn get_local_addr_for_remote(
377            &mut self,
378            device_id: &Self::DeviceId,
379            _remote: Option<SpecifiedAddr<I::Addr>>,
380        ) -> Option<Self::AddressId> {
381            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
382            Some(FakeAddressId(primary.clone()))
383        }
384
385        fn get_address_id(
386            &mut self,
387            device_id: &Self::DeviceId,
388            addr: IpDeviceAddr<I::Addr>,
389        ) -> Option<Self::AddressId> {
390            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
391            let id = FakeAddressId(id.clone());
392            if id.addr() == addr {
393                Some(id)
394            } else {
395                None
396            }
397        }
398    }
399
400    pub struct FakeBindingsCtx<I: Ip> {
401        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
402        pub rng: FakeCryptoRng,
403    }
404
405    impl<I: Ip> FakeBindingsCtx<I> {
406        pub(crate) fn new() -> Self {
407            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
408        }
409
410        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
411            self.timer_ctx.instant.sleep(time_elapsed)
412        }
413    }
414
415    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
416        type Instant = FakeInstant;
417        type AtomicInstant = FakeAtomicInstant;
418    }
419
420    impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
421        type DeviceClass = FakeDeviceClass;
422    }
423
424    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
425        fn now(&self) -> Self::Instant {
426            self.timer_ctx.now()
427        }
428    }
429
430    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
431        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
432        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
433        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
434    }
435
436    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
437        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
438            self.timer_ctx.new_timer(id)
439        }
440
441        fn schedule_timer_instant(
442            &mut self,
443            time: Self::Instant,
444            timer: &mut Self::Timer,
445        ) -> Option<Self::Instant> {
446            self.timer_ctx.schedule_timer_instant(time, timer)
447        }
448
449        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
450            self.timer_ctx.cancel_timer(timer)
451        }
452
453        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
454            self.timer_ctx.scheduled_instant(timer)
455        }
456
457        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
458            self.timer_ctx.unique_timer_id(timer)
459        }
460    }
461
462    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
463        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
464            &self,
465            f: F,
466        ) -> O {
467            f(&self.timer_ctx)
468        }
469
470        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
471            &mut self,
472            f: F,
473        ) -> O {
474            f(&mut self.timer_ctx)
475        }
476    }
477
478    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
479        type Rng<'a>
480            = FakeCryptoRng
481        where
482            Self: 'a;
483
484        fn rng(&mut self) -> Self::Rng<'_> {
485            self.rng.clone()
486        }
487    }
488}