netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::ip::{IpVersion, Ipv4, Ipv6};
8use net_types::SpecifiedAddr;
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, IpDeviceAddr, IpDeviceAddressIdContext, Marks, RngContext,
12    StrongDeviceIdentifier, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::{FragmentedByteSlice, PartialSerializer};
15use packet_formats::ip::IpExt;
16
17use crate::matchers::InterfaceProperties;
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21/// Trait defining required types for filtering provided by bindings.
22///
23/// Allows rules that match on device class to be installed, storing the
24/// [`FilterBindingsTypes::DeviceClass`] type at rest, while allowing Netstack3
25/// Core to have Bindings provide the type since it is platform-specific.
26pub trait FilterBindingsTypes: InstantBindingsTypes + TimerBindingsTypes + 'static {
27    /// The device class type for devices installed in the netstack.
28    type DeviceClass: Clone + Debug;
29}
30
31/// Trait aggregating functionality required from bindings.
32pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
33impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
34
35/// The IP version-specific execution context for packet filtering.
36///
37/// This trait exists to abstract over access to the filtering state. It is
38/// useful to implement filtering logic in terms of this trait, as opposed to,
39/// for example, [`crate::logic::FilterHandler`] methods taking the state
40/// directly as an argument, because it allows Netstack3 Core to use lock
41/// ordering types to enforce that filtering state is only acquired at or before
42/// a given lock level, while keeping test code free of locking concerns.
43pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
44    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
45{
46    /// The execution context that allows the filtering engine to perform
47    /// Network Address Translation (NAT).
48    type NatCtx<'a>: NatContext<
49        I,
50        BT,
51        DeviceId = Self::DeviceId,
52        WeakAddressId = Self::WeakAddressId,
53    >;
54
55    /// Calls the function with a reference to filtering state.
56    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
57        &mut self,
58        cb: F,
59    ) -> O {
60        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
61    }
62
63    /// Calls the function with a reference to filtering state and the NAT
64    /// context.
65    fn with_filter_state_and_nat_ctx<
66        O,
67        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
68    >(
69        &mut self,
70        cb: F,
71    ) -> O;
72}
73
74/// The execution context for Network Address Translation (NAT).
75pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
76    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
77{
78    /// Returns the best local address for communicating with the remote.
79    fn get_local_addr_for_remote(
80        &mut self,
81        device_id: &Self::DeviceId,
82        remote: Option<SpecifiedAddr<I::Addr>>,
83    ) -> Option<Self::AddressId>;
84
85    /// Returns a strongly-held reference to the provided address, if it is assigned
86    /// to the specified device.
87    fn get_address_id(
88        &mut self,
89        device_id: &Self::DeviceId,
90        addr: IpDeviceAddr<I::Addr>,
91    ) -> Option<Self::AddressId>;
92}
93
94/// A context for mutably accessing all filtering state at once, to allow IPv4
95/// and IPv6 filtering state to be modified atomically.
96pub trait FilterContext<BT: FilterBindingsTypes>:
97    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
98    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
99{
100    /// Calls the function with a mutable reference to all filtering state.
101    fn with_all_filter_state_mut<
102        O,
103        F: FnOnce(
104            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
105            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
106        ) -> O,
107    >(
108        &mut self,
109        cb: F,
110    ) -> O;
111}
112
113/// Result returned from [`SocketOpsFilter::on_egress`].
114#[derive(Copy, Clone, Debug, Eq, PartialEq)]
115pub enum SocketEgressFilterResult {
116    /// Send the packet normally.
117    Pass {
118        /// Indicates that congestion should be signaled to the higher level protocol.
119        congestion: bool,
120    },
121
122    /// Drop the packet.
123    Drop {
124        /// Indicates that congestion should be signaled to the higher level protocol.
125        congestion: bool,
126    },
127}
128
129/// Result returned from [`SocketOpsFilter::on_ingress`].
130#[derive(Copy, Clone, Debug, Eq, PartialEq)]
131pub enum SocketIngressFilterResult {
132    /// Accept the packet.
133    Accept,
134
135    /// Drop the packet.
136    Drop,
137}
138
139/// Trait for a socket operations filter.
140pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
141    /// Called on every outgoing packet originated from a local socket.
142    fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
143        &self,
144        packet: &P,
145        device: &D,
146        cookie: SocketCookie,
147        marks: &Marks,
148    ) -> SocketEgressFilterResult;
149
150    /// Called on every incoming packet handled by a local socket.
151    fn on_ingress(
152        &self,
153        ip_version: IpVersion,
154        packet: FragmentedByteSlice<'_, &[u8]>,
155        device: &D,
156        cookie: SocketCookie,
157        marks: &Marks,
158    ) -> SocketIngressFilterResult;
159}
160
161/// Implemented by bindings to provide socket operations filtering.
162pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
163    TxMetadataBindingsTypes
164{
165    /// Returns the filter that should be called for socket ops.
166    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
167}
168
169#[cfg(any(test, feature = "testutils"))]
170impl<
171        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
172        Event: Debug + 'static,
173        State: 'static,
174        FrameMeta: 'static,
175    > FilterBindingsTypes
176    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
177{
178    type DeviceClass = ();
179}
180
181#[cfg(any(test, feature = "testutils"))]
182impl<
183        TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
184        Event: Debug + 'static,
185        State: 'static,
186        FrameMeta: 'static,
187        D: StrongDeviceIdentifier,
188    > SocketOpsFilterBindingContext<D>
189    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
190{
191    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
192        crate::testutil::NoOpSocketOpsFilter
193    }
194}
195
196#[cfg(test)]
197pub(crate) mod testutil {
198    use alloc::sync::{Arc, Weak};
199    use alloc::vec::Vec;
200    use core::hash::{Hash, Hasher};
201    use core::ops::Deref;
202    use core::time::Duration;
203
204    use derivative::Derivative;
205    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
206    use netstack3_base::testutil::{
207        FakeAtomicInstant, FakeCryptoRng, FakeInstant, FakeTimerCtx, FakeWeakDeviceId,
208        WithFakeTimerContext,
209    };
210    use netstack3_base::{
211        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
212        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
213    };
214    use netstack3_hashmap::HashMap;
215
216    use super::*;
217    use crate::conntrack;
218    use crate::logic::nat::NatConfig;
219    use crate::logic::FilterTimerId;
220    use crate::matchers::testutil::FakeDeviceId;
221    use crate::state::validation::ValidRoutines;
222    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
223
224    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
225
226    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
227
228    #[derive(Debug)]
229    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
230        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
231    );
232
233    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
234    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
235
236    #[derive(Clone, Debug)]
237    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
238        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
239    );
240
241    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
242        fn eq(&self, other: &Self) -> bool {
243            let Self(lhs) = self;
244            let Self(rhs) = other;
245            Weak::ptr_eq(lhs, rhs)
246        }
247    }
248
249    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
250
251    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
252        fn hash<H: Hasher>(&self, state: &mut H) {
253            let Self(this) = self;
254            this.as_ptr().hash(state)
255        }
256    }
257
258    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
259        type Strong = FakeAddressId<I>;
260
261        fn upgrade(&self) -> Option<Self::Strong> {
262            let Self(inner) = self;
263            inner.upgrade().map(FakeAddressId)
264        }
265
266        fn is_assigned(&self) -> bool {
267            let Self(inner) = self;
268            inner.strong_count() != 0
269        }
270    }
271
272    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
273        fn record<Inspector: netstack3_base::Inspector>(
274            &self,
275            _name: &str,
276            _inspector: &mut Inspector,
277        ) {
278            unimplemented!()
279        }
280    }
281
282    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
283        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
284
285        fn deref(&self) -> &Self::Target {
286            let Self(inner) = self;
287            inner.deref()
288        }
289    }
290
291    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
292        type Weak = FakeWeakAddressId<I>;
293
294        fn downgrade(&self) -> Self::Weak {
295            let Self(inner) = self;
296            FakeWeakAddressId(Arc::downgrade(inner))
297        }
298
299        fn addr(&self) -> IpDeviceAddr<I::Addr> {
300            let Self(inner) = self;
301
302            #[derive(GenericOverIp)]
303            #[generic_over_ip(I, Ip)]
304            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
305            I::map_ip(
306                WrapIn(inner.addr()),
307                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
308                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
309            )
310        }
311
312        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
313            let Self(inner) = self;
314            **inner
315        }
316    }
317
318    #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
319    pub enum FakeDeviceClass {
320        Ethernet,
321        Wlan,
322    }
323
324    pub struct FakeCtx<I: TestIpExt> {
325        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
326        nat: FakeNatCtx<I>,
327    }
328
329    #[derive(Derivative)]
330    #[derivative(Default(bound = ""))]
331    pub struct FakeNatCtx<I: TestIpExt> {
332        pub(crate) device_addrs: HashMap<FakeDeviceId, FakePrimaryAddressId<I>>,
333    }
334
335    impl<I: TestIpExt> FakeCtx<I> {
336        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
337            Self {
338                state: State {
339                    installed_routines: ValidRoutines::default(),
340                    uninstalled_routines: Vec::default(),
341                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
342                    nat_installed: OneWayBoolean::default(),
343                },
344                nat: FakeNatCtx::default(),
345            }
346        }
347
348        pub fn with_ip_routines(
349            bindings_ctx: &mut FakeBindingsCtx<I>,
350            routines: IpRoutines<I, FakeDeviceClass, ()>,
351        ) -> Self {
352            let (installed_routines, uninstalled_routines) =
353                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
354                    .expect("invalid state");
355            Self {
356                state: State {
357                    installed_routines,
358                    uninstalled_routines,
359                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
360                    nat_installed: OneWayBoolean::default(),
361                },
362                nat: FakeNatCtx::default(),
363            }
364        }
365
366        pub fn with_nat_routines_and_device_addrs(
367            bindings_ctx: &mut FakeBindingsCtx<I>,
368            routines: NatRoutines<I, FakeDeviceClass, ()>,
369            device_addrs: impl IntoIterator<
370                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
371            >,
372        ) -> Self {
373            let (installed_routines, uninstalled_routines) =
374                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
375                    .expect("invalid state");
376            Self {
377                state: State {
378                    installed_routines,
379                    uninstalled_routines,
380                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
381                    nat_installed: OneWayBoolean::TRUE,
382                },
383                nat: FakeNatCtx {
384                    device_addrs: device_addrs
385                        .into_iter()
386                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
387                        .collect(),
388                },
389            }
390        }
391
392        pub fn conntrack(
393            &mut self,
394        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
395            &self.state.conntrack
396        }
397    }
398
399    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
400        type DeviceId = FakeDeviceId;
401        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
402    }
403
404    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
405        type AddressId = FakeAddressId<I>;
406        type WeakAddressId = FakeWeakAddressId<I>;
407    }
408
409    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
410        type NatCtx<'a> = FakeNatCtx<I>;
411
412        fn with_filter_state_and_nat_ctx<
413            O,
414            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
415        >(
416            &mut self,
417            cb: F,
418        ) -> O {
419            let Self { state, nat } = self;
420            cb(state, nat)
421        }
422    }
423
424    impl<I: TestIpExt> FakeNatCtx<I> {
425        pub fn new(
426            device_addrs: impl IntoIterator<
427                Item = (FakeDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
428            >,
429        ) -> Self {
430            Self {
431                device_addrs: device_addrs
432                    .into_iter()
433                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
434                    .collect(),
435            }
436        }
437    }
438
439    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
440        type DeviceId = FakeDeviceId;
441        type WeakDeviceId = FakeWeakDeviceId<FakeDeviceId>;
442    }
443
444    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
445        type AddressId = FakeAddressId<I>;
446        type WeakAddressId = FakeWeakAddressId<I>;
447    }
448
449    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
450        fn get_local_addr_for_remote(
451            &mut self,
452            device_id: &Self::DeviceId,
453            _remote: Option<SpecifiedAddr<I::Addr>>,
454        ) -> Option<Self::AddressId> {
455            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
456            Some(FakeAddressId(primary.clone()))
457        }
458
459        fn get_address_id(
460            &mut self,
461            device_id: &Self::DeviceId,
462            addr: IpDeviceAddr<I::Addr>,
463        ) -> Option<Self::AddressId> {
464            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
465            let id = FakeAddressId(id.clone());
466            if id.addr() == addr {
467                Some(id)
468            } else {
469                None
470            }
471        }
472    }
473
474    pub struct FakeBindingsCtx<I: Ip> {
475        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
476        pub rng: FakeCryptoRng,
477    }
478
479    impl<I: Ip> FakeBindingsCtx<I> {
480        pub(crate) fn new() -> Self {
481            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
482        }
483
484        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
485            self.timer_ctx.instant.sleep(time_elapsed)
486        }
487    }
488
489    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
490        type Instant = FakeInstant;
491        type AtomicInstant = FakeAtomicInstant;
492    }
493
494    impl<I: Ip> FilterBindingsTypes for FakeBindingsCtx<I> {
495        type DeviceClass = FakeDeviceClass;
496    }
497
498    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
499        fn now(&self) -> Self::Instant {
500            self.timer_ctx.now()
501        }
502    }
503
504    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
505        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
506        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
507        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
508    }
509
510    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
511        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
512            self.timer_ctx.new_timer(id)
513        }
514
515        fn schedule_timer_instant(
516            &mut self,
517            time: Self::Instant,
518            timer: &mut Self::Timer,
519        ) -> Option<Self::Instant> {
520            self.timer_ctx.schedule_timer_instant(time, timer)
521        }
522
523        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
524            self.timer_ctx.cancel_timer(timer)
525        }
526
527        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
528            self.timer_ctx.scheduled_instant(timer)
529        }
530
531        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
532            self.timer_ctx.unique_timer_id(timer)
533        }
534    }
535
536    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
537        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
538            &self,
539            f: F,
540        ) -> O {
541            f(&self.timer_ctx)
542        }
543
544        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
545            &mut self,
546            f: F,
547        ) -> O {
548            f(&mut self.timer_ctx)
549        }
550    }
551
552    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
553        type Rng<'a>
554            = FakeCryptoRng
555        where
556            Self: 'a;
557
558        fn rng(&mut self) -> Self::Rng<'_> {
559            self.rng.clone()
560        }
561    }
562}