netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::ip::{Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::{
10    InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
11    StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
12};
13
14use packet_formats::ip::IpExt;
15
16use crate::matchers::InterfaceProperties;
17use crate::state::State;
18use crate::{FilterIpExt, IpPacket};
19
20/// Trait defining required types for filtering provided by bindings.
21///
22/// Allows rules that match on device class to be installed, storing the
23/// [`FilterBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
24/// Core to have Bindings provide the type since it is platform-specific.
25pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
26    /// The device class type for devices installed in the netstack.
27    type DeviceClass: Clone + Debug;
28}
29
30/// Trait aggregating functionality required from bindings.
31pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
32impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
33
34/// The IP version-specific execution context for packet filtering.
35///
36/// This trait exists to abstract over access to the filtering state. It is
37/// useful to implement filtering logic in terms of this trait, as opposed to,
38/// for example, [`crate::logic::FilterHandler`] methods taking the state
39/// directly as an argument, because it allows Netstack3 Core to use lock
40/// ordering types to enforce that filtering state is only acquired at or before
41/// a given lock level, while keeping test code free of locking concerns.
42pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
43    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
44{
45    /// The execution context that allows the filtering engine to perform
46    /// Network Address Translation (NAT).
47    type NatCtx<'a>: NatContext<
48        I,
49        BT,
50        DeviceId = Self::DeviceId,
51        WeakAddressId = Self::WeakAddressId,
52    >;
53
54    /// Calls the function with a reference to filtering state.
55    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
56        &mut self,
57        cb: F,
58    ) -> O {
59        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
60    }
61
62    /// Calls the function with a reference to filtering state and the NAT
63    /// context.
64    fn with_filter_state_and_nat_ctx<
65        O,
66        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
67    >(
68        &mut self,
69        cb: F,
70    ) -> O;
71}
72
73/// The execution context for Network Address Translation (NAT).
74pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
75    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
76{
77    /// Returns the best local address for communicating with the remote.
78    fn get_local_addr_for_remote(
79        &mut self,
80        device_id: &Self::DeviceId,
81        remote: Option<SpecifiedAddr<I::Addr>>,
82    ) -> Option<Self::AddressId>;
83
84    /// Returns a strongly-held reference to the provided address, if it is assigned
85    /// to the specified device.
86    fn get_address_id(
87        &mut self,
88        device_id: &Self::DeviceId,
89        addr: IpDeviceAddr<I::Addr>,
90    ) -> Option<Self::AddressId>;
91}
92
93/// A context for mutably accessing all filtering state at once, to allow IPv4
94/// and IPv6 filtering state to be modified atomically.
95pub trait FilterContext<BT: FilterBindingsTypes>:
96    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
97    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
98{
99    /// Calls the function with a mutable reference to all filtering state.
100    fn with_all_filter_state_mut<
101        O,
102        F: FnOnce(
103            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
104            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
105        ) -> O,
106    >(
107        &mut self,
108        cb: F,
109    ) -> O;
110}
111
112/// Result returned from [`SocketOpsFilter::on_egress`].
113#[derive(Copy, Clone, Debug, Eq, PartialEq)]
114pub enum SocketEgressFilterResult {
115    /// Send the packet normally.
116    Pass {
117        /// Indicates that congestion should be signaled to the higher level protocol.
118        congestion: bool,
119    },
120
121    /// Drop the packet.
122    Drop {
123        /// Indicates that congestion should be signaled to the higher level protocol.
124        congestion: bool,
125    },
126}
127
128/// Trait for a socket operations filter.
129pub trait SocketOpsFilter<D: StrongDeviceIdentifier, T> {
130    /// Called on every outgoing packet originated from a local socket.
131    fn on_egress<I: FilterIpExt, P: IpPacket<I>>(
132        &self,
133        packet: &P,
134        device: &D,
135        tx_metadata: &T,
136        marks: &Marks,
137    ) -> SocketEgressFilterResult;
138}
139
140/// Implemented by bindings to provide socket operations filtering.
141pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
142    TxMetadataBindingsTypes
143{
144    /// Returns the filter that should be called for socket ops.
145    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D, Self::TxMetadata>;
146}
147
148#[cfg(feature = "testutils")]
149impl<
150        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
151        Event: Debug + 'static,
152        State: 'static,
153        FrameMeta: 'static,
154    > FilterBindingsTypes
155    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
156{
157    type DeviceClass = ();
158}
159
160#[cfg(test)]
161pub(crate) mod testutil {
162    use alloc::collections::HashMap;
163    use alloc::sync::{Arc, Weak};
164    use alloc::vec::Vec;
165    use core::hash::{Hash, Hasher};
166    use core::ops::Deref;
167    use core::time::Duration;
168
169    use derivative::Derivative;
170    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
171    use netstack3_base::testutil::{
172        FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
173        WithFakeTimerContext,
174    };
175    use netstack3_base::{
176        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
177        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
178    };
179
180    use super::*;
181    use crate::conntrack;
182    use crate::logic::nat::NatConfig;
183    use crate::logic::FilterTimerId;
184    use crate::matchers::testutil::FakeDeviceId;
185    use crate::state::validation::ValidRoutines;
186    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
187
188    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
189
190    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
191
192    #[derive(Debug)]
193    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
194        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
195    );
196
197    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
198    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
199
200    #[derive(Clone, Debug)]
201    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
202        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
203    );
204
205    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
206        fn eq(&self, other: &Self) -> bool {
207            let Self(lhs) = self;
208            let Self(rhs) = other;
209            Weak::ptr_eq(lhs, rhs)
210        }
211    }
212
213    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
214
215    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
216        fn hash<H: Hasher>(&self, state: &mut H) {
217            let Self(this) = self;
218            this.as_ptr().hash(state)
219        }
220    }
221
222    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
223        type Strong = FakeAddressId<I>;
224
225        fn upgrade(&self) -> Option<Self::Strong> {
226            let Self(inner) = self;
227            inner.upgrade().map(FakeAddressId)
228        }
229
230        fn is_assigned(&self) -> bool {
231            let Self(inner) = self;
232            inner.strong_count() != 0
233        }
234    }
235
236    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
237        fn record<Inspector: netstack3_base::Inspector>(
238            &self,
239            _name: &str,
240            _inspector: &mut Inspector,
241        ) {
242            unimplemented!()
243        }
244    }
245
246    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
247        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
248
249        fn deref(&self) -> &Self::Target {
250            let Self(inner) = self;
251            inner.deref()
252        }
253    }
254
255    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
256        type Weak = FakeWeakAddressId<I>;
257
258        fn downgrade(&self) -> Self::Weak {
259            let Self(inner) = self;
260            FakeWeakAddressId(Arc::downgrade(inner))
261        }
262
263        fn addr(&self) -> IpDeviceAddr<I::Addr> {
264            let Self(inner) = self;
265
266            #[derive(GenericOverIp)]
267            #[generic_over_ip(I, Ip)]
268            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
269            I::map_ip(
270                WrapIn(inner.addr()),
271                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
272                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
273            )
274        }
275
276        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
277            let Self(inner) = self;
278            **inner
279        }
280    }
281
282    #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
283    pub enum FakeDeviceClass {
284        Ethernet,
285        Wlan,
286    }
287
288    pub struct FakeCtx<I: TestIpExt> {
289        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
290        nat: FakeNatCtx<I>,
291    }
292
293    #[derive(Derivative)]
294    #[derivative(Default(bound = ""))]
295    pub struct FakeNatCtx<I: TestIpExt> {
296        pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
297    }
298
299    impl<I: TestIpExt> FakeCtx<I> {
300        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
301            Self {
302                state: State {
303                    installed_routines: ValidRoutines::default(),
304                    uninstalled_routines: Vec::default(),
305                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
306                    nat_installed: OneWayBoolean::default(),
307                },
308                nat: FakeNatCtx::default(),
309            }
310        }
311
312        pub fn with_ip_routines(
313            bindings_ctx: &mut FakeBindingsCtx<I>,
314            routines: IpRoutines<I, FakeDeviceClass, ()>,
315        ) -> Self {
316            let (installed_routines, uninstalled_routines) =
317                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
318                    .expect("invalid state");
319            Self {
320                state: State {
321                    installed_routines,
322                    uninstalled_routines,
323                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
324                    nat_installed: OneWayBoolean::default(),
325                },
326                nat: FakeNatCtx::default(),
327            }
328        }
329
330        pub fn with_nat_routines_and_device_addrs(
331            bindings_ctx: &mut FakeBindingsCtx<I>,
332            routines: NatRoutines<I, FakeDeviceClass, ()>,
333            device_addrs: impl IntoIterator<
334                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
335            >,
336        ) -> Self {
337            let (installed_routines, uninstalled_routines) =
338                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
339                    .expect("invalid state");
340            Self {
341                state: State {
342                    installed_routines,
343                    uninstalled_routines,
344                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
345                    nat_installed: OneWayBoolean::TRUE,
346                },
347                nat: FakeNatCtx {
348                    device_addrs: device_addrs
349                        .into_iter()
350                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
351                        .collect(),
352                },
353            }
354        }
355
356        pub fn conntrack(
357            &mut self,
358        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
359            &self.state.conntrack
360        }
361    }
362
363    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
364        type DeviceId = FakeDeviceId;
365        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
366    }
367
368    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
369        type AddressId = FakeAddressId<I>;
370        type WeakAddressId = FakeWeakAddressId<I>;
371    }
372
373    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
374        type NatCtx<'a> = FakeNatCtx<I>;
375
376        fn with_filter_state_and_nat_ctx<
377            O,
378            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
379        >(
380            &mut self,
381            cb: F,
382        ) -> O {
383            let Self { state, nat } = self;
384            cb(state, nat)
385        }
386    }
387
388    impl<I: TestIpExt> FakeNatCtx<I> {
389        pub fn new(
390            device_addrs: impl IntoIterator<
391                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
392            >,
393        ) -> Self {
394            Self {
395                device_addrs: device_addrs
396                    .into_iter()
397                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
398                    .collect(),
399            }
400        }
401    }
402
403    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
404        type DeviceId = FakeDeviceId;
405        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
406    }
407
408    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
409        type AddressId = FakeAddressId<I>;
410        type WeakAddressId = FakeWeakAddressId<I>;
411    }
412
413    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
414        fn get_local_addr_for_remote(
415            &mut self,
416            device_id: &Self::DeviceId,
417            _remote: Option<SpecifiedAddr<I::Addr>>,
418        ) -> Option<Self::AddressId> {
419            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
420            Some(FakeAddressId(primary.clone()))
421        }
422
423        fn get_address_id(
424            &mut self,
425            device_id: &Self::DeviceId,
426            addr: IpDeviceAddr<I::Addr>,
427        ) -> Option<Self::AddressId> {
428            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
429            let id = FakeAddressId(id.clone());
430            if id.addr() == addr {
431                Some(id)
432            } else {
433                None
434            }
435        }
436    }
437
438    pub struct FakeBindingsCtx<I: Ip> {
439        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
440        pub rng: FakeCryptoRng,
441    }
442
443    impl<I: Ip> FakeBindingsCtx<I> {
444        pub(crate) fn new() -> Self {
445            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
446        }
447
448        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
449            self.timer_ctx.instant.sleep(time_elapsed)
450        }
451    }
452
453    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
454        type Instant = FakeInstant;
455        type AtomicInstant = FakeAtomicInstant;
456    }
457
458    impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
459        type DeviceClass = FakeDeviceClass;
460    }
461
462    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
463        fn now(&self) -> Self::Instant {
464            self.timer_ctx.now()
465        }
466    }
467
468    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
469        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
470        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
471        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
472    }
473
474    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
475        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
476            self.timer_ctx.new_timer(id)
477        }
478
479        fn schedule_timer_instant(
480            &mut self,
481            time: Self::Instant,
482            timer: &mut Self::Timer,
483        ) -> Option<Self::Instant> {
484            self.timer_ctx.schedule_timer_instant(time, timer)
485        }
486
487        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
488            self.timer_ctx.cancel_timer(timer)
489        }
490
491        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
492            self.timer_ctx.scheduled_instant(timer)
493        }
494
495        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
496            self.timer_ctx.unique_timer_id(timer)
497        }
498    }
499
500    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
501        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
502            &self,
503            f: F,
504        ) -> O {
505            f(&self.timer_ctx)
506        }
507
508        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
509            &mut self,
510            f: F,
511        ) -> O {
512            f(&mut self.timer_ctx)
513        }
514    }
515
516    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
517        type Rng<'a>
518            = FakeCryptoRng
519        where
520            Self: 'a;
521
522        fn rng(&mut self) -> Self::Rng<'_> {
523            self.rng.clone()
524        }
525    }
526}