netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12    MatcherBindingsTypes, RngContext, TimerBindingsTypes, TimerContext, TxMetadataBindingsTypes,
13};
14use packet::FragmentedByteSlice;
15use packet_formats::ip::IpExt;
16
17use crate::FilterIpExt;
18use crate::matchers::BindingsPacketMatcher;
19use crate::packets::FilterIpPacket;
20use crate::state::State;
21
22/// Trait defining required types for filtering provided by bindings.
23pub trait FilterBindingsTypes:
24    InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
25{
26}
27
28impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
29    FilterBindingsTypes for BT
30{
31}
32
33/// Trait aggregating functionality required from bindings.
34pub trait FilterBindingsContext<D>:
35    TimerContext + RngContext + FilterBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher<D>>
36{
37}
38impl<D, BC> FilterBindingsContext<D> for BC
39where
40    BC: TimerContext + RngContext + FilterBindingsTypes,
41    BC::BindingsPacketMatcher: BindingsPacketMatcher<D>,
42{
43}
44
45/// The IP version-specific execution context for packet filtering.
46///
47/// This trait exists to abstract over access to the filtering state. It is
48/// useful to implement filtering logic in terms of this trait, as opposed to,
49/// for example, [`crate::logic::FilterHandler`] methods taking the state
50/// directly as an argument, because it allows Netstack3 Core to use lock
51/// ordering types to enforce that filtering state is only acquired at or before
52/// a given lock level, while keeping test code free of locking concerns.
53pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
54    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
55{
56    /// The execution context that allows the filtering engine to perform
57    /// Network Address Translation (NAT).
58    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
59
60    /// Calls the function with a reference to filtering state.
61    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
62        &mut self,
63        cb: F,
64    ) -> O {
65        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
66    }
67
68    /// Calls the function with a reference to filtering state and the NAT
69    /// context.
70    fn with_filter_state_and_nat_ctx<
71        O,
72        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
73    >(
74        &mut self,
75        cb: F,
76    ) -> O;
77}
78
79/// The execution context for Network Address Translation (NAT).
80pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
81    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
82{
83    /// Returns the best local address for communicating with the remote.
84    fn get_local_addr_for_remote(
85        &mut self,
86        device_id: &Self::DeviceId,
87        remote: Option<SpecifiedAddr<I::Addr>>,
88    ) -> Option<Self::AddressId>;
89
90    /// Returns a strongly-held reference to the provided address, if it is assigned
91    /// to the specified device.
92    fn get_address_id(
93        &mut self,
94        device_id: &Self::DeviceId,
95        addr: IpDeviceAddr<I::Addr>,
96    ) -> Option<Self::AddressId>;
97}
98
99/// A context for mutably accessing all filtering state at once, to allow IPv4
100/// and IPv6 filtering state to be modified atomically.
101pub trait FilterContext<BT: FilterBindingsTypes>:
102    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
103    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
104{
105    /// Calls the function with a mutable reference to all filtering state.
106    fn with_all_filter_state_mut<
107        O,
108        F: FnOnce(
109            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
110            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
111        ) -> O,
112    >(
113        &mut self,
114        cb: F,
115    ) -> O;
116}
117
118/// Result returned from [`SocketOpsFilter::on_egress`].
119#[derive(Copy, Clone, Debug, Eq, PartialEq)]
120pub enum SocketEgressFilterResult {
121    /// Send the packet normally.
122    Pass {
123        /// Indicates that congestion should be signaled to the higher level protocol.
124        congestion: bool,
125    },
126
127    /// Drop the packet.
128    Drop {
129        /// Indicates that congestion should be signaled to the higher level protocol.
130        congestion: bool,
131    },
132}
133
134/// Result returned from [`SocketOpsFilter::on_ingress`].
135#[derive(Copy, Clone, Debug, Eq, PartialEq)]
136pub enum SocketIngressFilterResult {
137    /// Accept the packet.
138    Accept,
139
140    /// Drop the packet.
141    Drop,
142}
143
144/// Trait for a socket operations filter.
145pub trait SocketOpsFilter<D> {
146    /// Called on every outgoing packet originated from a local socket.
147    fn on_egress<I: FilterIpExt, P: FilterIpPacket<I>>(
148        &self,
149        packet: &P,
150        device: &D,
151        cookie: SocketCookie,
152        marks: &Marks,
153    ) -> SocketEgressFilterResult;
154
155    /// Called on every incoming packet handled by a local socket.
156    fn on_ingress(
157        &self,
158        ip_version: IpVersion,
159        packet: FragmentedByteSlice<'_, &[u8]>,
160        device: &D,
161        cookie: SocketCookie,
162        marks: &Marks,
163    ) -> SocketIngressFilterResult;
164}
165
166/// Implemented by bindings to provide socket operations filtering.
167pub trait SocketOpsFilterBindingContext<D>: TxMetadataBindingsTypes {
168    /// Returns the filter that should be called for socket ops.
169    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
170}
171
172#[cfg(any(test, feature = "testutils"))]
173impl<
174    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
175    Event: Debug + 'static,
176    State: 'static,
177    FrameMeta: 'static,
178    D,
179> SocketOpsFilterBindingContext<D>
180    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
181{
182    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
183        crate::testutil::NoOpSocketOpsFilter
184    }
185}
186
187#[cfg(test)]
188pub(crate) mod testutil {
189    use alloc::sync::{Arc, Weak};
190    use alloc::vec::Vec;
191    use core::hash::{Hash, Hasher};
192    use core::ops::Deref;
193    use core::sync::atomic::AtomicUsize;
194    use core::time::Duration;
195
196    use derivative::Derivative;
197    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
198    use netstack3_base::testutil::{
199        FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
200        FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
201    };
202    use netstack3_base::{
203        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
204        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
205    };
206    use netstack3_hashmap::HashMap;
207
208    use super::*;
209    use crate::logic::FilterTimerId;
210    use crate::logic::nat::NatConfig;
211    use crate::state::validation::ValidRoutines;
212    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
213    use crate::{Interfaces, conntrack};
214
215    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
216
217    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
218
219    #[derive(Debug)]
220    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
221        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
222    );
223
224    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
225    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
226
227    #[derive(Clone, Debug)]
228    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
229        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
230    );
231
232    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
233        fn eq(&self, other: &Self) -> bool {
234            let Self(lhs) = self;
235            let Self(rhs) = other;
236            Weak::ptr_eq(lhs, rhs)
237        }
238    }
239
240    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
241
242    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
243        fn hash<H: Hasher>(&self, state: &mut H) {
244            let Self(this) = self;
245            this.as_ptr().hash(state)
246        }
247    }
248
249    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
250        type Strong = FakeAddressId<I>;
251
252        fn upgrade(&self) -> Option<Self::Strong> {
253            let Self(inner) = self;
254            inner.upgrade().map(FakeAddressId)
255        }
256
257        fn is_assigned(&self) -> bool {
258            let Self(inner) = self;
259            inner.strong_count() != 0
260        }
261    }
262
263    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
264        fn record<Inspector: netstack3_base::Inspector>(
265            &self,
266            _name: &str,
267            _inspector: &mut Inspector,
268        ) {
269            unimplemented!()
270        }
271    }
272
273    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
274        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
275
276        fn deref(&self) -> &Self::Target {
277            let Self(inner) = self;
278            inner.deref()
279        }
280    }
281
282    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
283        type Weak = FakeWeakAddressId<I>;
284
285        fn downgrade(&self) -> Self::Weak {
286            let Self(inner) = self;
287            FakeWeakAddressId(Arc::downgrade(inner))
288        }
289
290        fn addr(&self) -> IpDeviceAddr<I::Addr> {
291            let Self(inner) = self;
292
293            #[derive(GenericOverIp)]
294            #[generic_over_ip(I, Ip)]
295            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
296            I::map_ip(
297                WrapIn(inner.addr()),
298                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
299                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
300            )
301        }
302
303        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
304            let Self(inner) = self;
305            **inner
306        }
307    }
308
309    pub struct FakeCtx<I: TestIpExt> {
310        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
311        nat: FakeNatCtx<I>,
312    }
313
314    #[derive(Derivative)]
315    #[derivative(Default(bound = ""))]
316    pub struct FakeNatCtx<I: TestIpExt> {
317        pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
318    }
319
320    impl<I: TestIpExt> FakeCtx<I> {
321        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
322            Self {
323                state: State {
324                    installed_routines: ValidRoutines::default(),
325                    uninstalled_routines: Vec::default(),
326                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
327                    nat_installed: OneWayBoolean::default(),
328                },
329                nat: FakeNatCtx::default(),
330            }
331        }
332
333        pub fn with_ip_routines(
334            bindings_ctx: &mut FakeBindingsCtx<I>,
335            routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
336        ) -> Self {
337            let (installed_routines, uninstalled_routines) =
338                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
339                    .expect("invalid state");
340            Self {
341                state: State {
342                    installed_routines,
343                    uninstalled_routines,
344                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
345                    nat_installed: OneWayBoolean::default(),
346                },
347                nat: FakeNatCtx::default(),
348            }
349        }
350
351        pub fn with_nat_routines_and_device_addrs(
352            bindings_ctx: &mut FakeBindingsCtx<I>,
353            routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
354            device_addrs: impl IntoIterator<
355                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
356            >,
357        ) -> Self {
358            let (installed_routines, uninstalled_routines) =
359                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
360                    .expect("invalid state");
361            Self {
362                state: State {
363                    installed_routines,
364                    uninstalled_routines,
365                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
366                    nat_installed: OneWayBoolean::TRUE,
367                },
368                nat: FakeNatCtx {
369                    device_addrs: device_addrs
370                        .into_iter()
371                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
372                        .collect(),
373                },
374            }
375        }
376
377        pub fn conntrack(
378            &mut self,
379        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
380            &self.state.conntrack
381        }
382    }
383
384    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
385        type DeviceId = FakeMatcherDeviceId;
386        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
387    }
388
389    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
390        type AddressId = FakeAddressId<I>;
391        type WeakAddressId = FakeWeakAddressId<I>;
392    }
393
394    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
395        type NatCtx<'a> = FakeNatCtx<I>;
396
397        fn with_filter_state_and_nat_ctx<
398            O,
399            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
400        >(
401            &mut self,
402            cb: F,
403        ) -> O {
404            let Self { state, nat } = self;
405            cb(state, nat)
406        }
407    }
408
409    impl<I: TestIpExt> FakeNatCtx<I> {
410        pub fn new(
411            device_addrs: impl IntoIterator<
412                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
413            >,
414        ) -> Self {
415            Self {
416                device_addrs: device_addrs
417                    .into_iter()
418                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
419                    .collect(),
420            }
421        }
422    }
423
424    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
425        type DeviceId = FakeMatcherDeviceId;
426        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
427    }
428
429    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
430        type AddressId = FakeAddressId<I>;
431        type WeakAddressId = FakeWeakAddressId<I>;
432    }
433
434    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
435        fn get_local_addr_for_remote(
436            &mut self,
437            device_id: &Self::DeviceId,
438            _remote: Option<SpecifiedAddr<I::Addr>>,
439        ) -> Option<Self::AddressId> {
440            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
441            Some(FakeAddressId(primary.clone()))
442        }
443
444        fn get_address_id(
445            &mut self,
446            device_id: &Self::DeviceId,
447            addr: IpDeviceAddr<I::Addr>,
448        ) -> Option<Self::AddressId> {
449            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
450            let id = FakeAddressId(id.clone());
451            if id.addr() == addr { Some(id) } else { None }
452        }
453    }
454
455    pub struct FakeBindingsCtx<I: Ip> {
456        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
457        pub rng: FakeCryptoRng,
458    }
459
460    impl<I: Ip> FakeBindingsCtx<I> {
461        pub(crate) fn new() -> Self {
462            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
463        }
464
465        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
466            self.timer_ctx.instant.sleep(time_elapsed)
467        }
468    }
469
470    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
471        type Instant = FakeInstant;
472        type AtomicInstant = FakeAtomicInstant;
473    }
474
475    #[derive(Debug)]
476    pub struct FakeBindingsPacketMatcher {
477        num_calls: AtomicUsize,
478        result: bool,
479    }
480
481    impl FakeBindingsPacketMatcher {
482        pub fn new(result: bool) -> Arc<Self> {
483            Arc::new(Self { num_calls: AtomicUsize::new(0), result })
484        }
485
486        pub fn num_calls(&self) -> usize {
487            self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
488        }
489    }
490
491    impl BindingsPacketMatcher<FakeMatcherDeviceId> for FakeBindingsPacketMatcher {
492        fn matches<I: FilterIpExt, P: FilterIpPacket<I>>(
493            &self,
494            _packet: &P,
495            _interfaces: Interfaces<'_, FakeMatcherDeviceId>,
496        ) -> bool {
497            let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
498            self.result
499        }
500    }
501
502    impl InspectableValue for FakeBindingsPacketMatcher {
503        fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
504            unimplemented!()
505        }
506    }
507
508    impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
509        type DeviceClass = FakeDeviceClass;
510        type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
511    }
512
513    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
514        fn now(&self) -> Self::Instant {
515            self.timer_ctx.now()
516        }
517    }
518
519    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
520        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
521        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
522        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
523    }
524
525    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
526        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
527            self.timer_ctx.new_timer(id)
528        }
529
530        fn schedule_timer_instant(
531            &mut self,
532            time: Self::Instant,
533            timer: &mut Self::Timer,
534        ) -> Option<Self::Instant> {
535            self.timer_ctx.schedule_timer_instant(time, timer)
536        }
537
538        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
539            self.timer_ctx.cancel_timer(timer)
540        }
541
542        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
543            self.timer_ctx.scheduled_instant(timer)
544        }
545
546        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
547            self.timer_ctx.unique_timer_id(timer)
548        }
549    }
550
551    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
552        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
553            &self,
554            f: F,
555        ) -> O {
556            f(&self.timer_ctx)
557        }
558
559        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
560            &mut self,
561            f: F,
562        ) -> O {
563            f(&mut self.timer_ctx)
564        }
565    }
566
567    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
568        type Rng<'a>
569            = FakeCryptoRng
570        where
571            Self: 'a;
572
573        fn rng(&mut self) -> Self::Rng<'_> {
574            self.rng.clone()
575        }
576    }
577}