netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12    MatcherBindingsTypes, RngContext, StrongDeviceIdentifier, TimerBindingsTypes, TimerContext,
13    TxMetadataBindingsTypes,
14};
15use packet::{FragmentedByteSlice, PartialSerializer};
16use packet_formats::ip::IpExt;
17
18use crate::matchers::BindingsPacketMatcher;
19use crate::state::State;
20use crate::{FilterIpExt, IpPacket};
21
22/// Trait defining required types for filtering provided by bindings.
23pub trait FilterBindingsTypes:
24    InstantBindingsTypes
25    + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
26    + TimerBindingsTypes
27    + 'static
28{
29}
30
31impl<
32    BT: InstantBindingsTypes
33        + MatcherBindingsTypes<BindingsPacketMatcher: BindingsPacketMatcher>
34        + TimerBindingsTypes
35        + 'static,
36> FilterBindingsTypes for BT
37{
38}
39
40/// Trait aggregating functionality required from bindings.
41pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
42impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
43
44/// The IP version-specific execution context for packet filtering.
45///
46/// This trait exists to abstract over access to the filtering state. It is
47/// useful to implement filtering logic in terms of this trait, as opposed to,
48/// for example, [`crate::logic::FilterHandler`] methods taking the state
49/// directly as an argument, because it allows Netstack3 Core to use lock
50/// ordering types to enforce that filtering state is only acquired at or before
51/// a given lock level, while keeping test code free of locking concerns.
52pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
53    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
54{
55    /// The execution context that allows the filtering engine to perform
56    /// Network Address Translation (NAT).
57    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
58
59    /// Calls the function with a reference to filtering state.
60    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
61        &mut self,
62        cb: F,
63    ) -> O {
64        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
65    }
66
67    /// Calls the function with a reference to filtering state and the NAT
68    /// context.
69    fn with_filter_state_and_nat_ctx<
70        O,
71        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
72    >(
73        &mut self,
74        cb: F,
75    ) -> O;
76}
77
78/// The execution context for Network Address Translation (NAT).
79pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
80    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
81{
82    /// Returns the best local address for communicating with the remote.
83    fn get_local_addr_for_remote(
84        &mut self,
85        device_id: &Self::DeviceId,
86        remote: Option<SpecifiedAddr<I::Addr>>,
87    ) -> Option<Self::AddressId>;
88
89    /// Returns a strongly-held reference to the provided address, if it is assigned
90    /// to the specified device.
91    fn get_address_id(
92        &mut self,
93        device_id: &Self::DeviceId,
94        addr: IpDeviceAddr<I::Addr>,
95    ) -> Option<Self::AddressId>;
96}
97
98/// A context for mutably accessing all filtering state at once, to allow IPv4
99/// and IPv6 filtering state to be modified atomically.
100pub trait FilterContext<BT: FilterBindingsTypes>:
101    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
102    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
103{
104    /// Calls the function with a mutable reference to all filtering state.
105    fn with_all_filter_state_mut<
106        O,
107        F: FnOnce(
108            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
109            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
110        ) -> O,
111    >(
112        &mut self,
113        cb: F,
114    ) -> O;
115}
116
117/// Result returned from [`SocketOpsFilter::on_egress`].
118#[derive(Copy, Clone, Debug, Eq, PartialEq)]
119pub enum SocketEgressFilterResult {
120    /// Send the packet normally.
121    Pass {
122        /// Indicates that congestion should be signaled to the higher level protocol.
123        congestion: bool,
124    },
125
126    /// Drop the packet.
127    Drop {
128        /// Indicates that congestion should be signaled to the higher level protocol.
129        congestion: bool,
130    },
131}
132
133/// Result returned from [`SocketOpsFilter::on_ingress`].
134#[derive(Copy, Clone, Debug, Eq, PartialEq)]
135pub enum SocketIngressFilterResult {
136    /// Accept the packet.
137    Accept,
138
139    /// Drop the packet.
140    Drop,
141}
142
143/// Trait for a socket operations filter.
144pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
145    /// Called on every outgoing packet originated from a local socket.
146    fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
147        &self,
148        packet: &P,
149        device: &D,
150        cookie: SocketCookie,
151        marks: &Marks,
152    ) -> SocketEgressFilterResult;
153
154    /// Called on every incoming packet handled by a local socket.
155    fn on_ingress(
156        &self,
157        ip_version: IpVersion,
158        packet: FragmentedByteSlice<'_, &[u8]>,
159        device: &D,
160        cookie: SocketCookie,
161        marks: &Marks,
162    ) -> SocketIngressFilterResult;
163}
164
165/// Implemented by bindings to provide socket operations filtering.
166pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
167    TxMetadataBindingsTypes
168{
169    /// Returns the filter that should be called for socket ops.
170    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
171}
172
173#[cfg(any(test, feature = "testutils"))]
174impl<
175    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
176    Event: Debug + 'static,
177    State: 'static,
178    FrameMeta: 'static,
179    D: StrongDeviceIdentifier,
180> SocketOpsFilterBindingContext<D>
181    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
182{
183    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
184        crate::testutil::NoOpSocketOpsFilter
185    }
186}
187
188#[cfg(test)]
189pub(crate) mod testutil {
190    use alloc::sync::{Arc, Weak};
191    use alloc::vec::Vec;
192    use core::hash::{Hash, Hasher};
193    use core::ops::Deref;
194    use core::sync::atomic::AtomicUsize;
195    use core::time::Duration;
196
197    use derivative::Derivative;
198    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
199    use netstack3_base::testutil::{
200        FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
201        FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
202    };
203    use netstack3_base::{
204        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, Inspector, InstantContext,
205        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
206    };
207    use netstack3_hashmap::HashMap;
208
209    use super::*;
210    use crate::conntrack;
211    use crate::logic::FilterTimerId;
212    use crate::logic::nat::NatConfig;
213    use crate::state::validation::ValidRoutines;
214    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
215
216    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
217
218    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
219
220    #[derive(Debug)]
221    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
222        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
223    );
224
225    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
226    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
227
228    #[derive(Clone, Debug)]
229    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
230        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
231    );
232
233    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
234        fn eq(&self, other: &Self) -> bool {
235            let Self(lhs) = self;
236            let Self(rhs) = other;
237            Weak::ptr_eq(lhs, rhs)
238        }
239    }
240
241    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
242
243    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
244        fn hash<H: Hasher>(&self, state: &mut H) {
245            let Self(this) = self;
246            this.as_ptr().hash(state)
247        }
248    }
249
250    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
251        type Strong = FakeAddressId<I>;
252
253        fn upgrade(&self) -> Option<Self::Strong> {
254            let Self(inner) = self;
255            inner.upgrade().map(FakeAddressId)
256        }
257
258        fn is_assigned(&self) -> bool {
259            let Self(inner) = self;
260            inner.strong_count() != 0
261        }
262    }
263
264    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
265        fn record<Inspector: netstack3_base::Inspector>(
266            &self,
267            _name: &str,
268            _inspector: &mut Inspector,
269        ) {
270            unimplemented!()
271        }
272    }
273
274    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
275        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
276
277        fn deref(&self) -> &Self::Target {
278            let Self(inner) = self;
279            inner.deref()
280        }
281    }
282
283    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
284        type Weak = FakeWeakAddressId<I>;
285
286        fn downgrade(&self) -> Self::Weak {
287            let Self(inner) = self;
288            FakeWeakAddressId(Arc::downgrade(inner))
289        }
290
291        fn addr(&self) -> IpDeviceAddr<I::Addr> {
292            let Self(inner) = self;
293
294            #[derive(GenericOverIp)]
295            #[generic_over_ip(I, Ip)]
296            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
297            I::map_ip(
298                WrapIn(inner.addr()),
299                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
300                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
301            )
302        }
303
304        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
305            let Self(inner) = self;
306            **inner
307        }
308    }
309
310    pub struct FakeCtx<I: TestIpExt> {
311        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
312        nat: FakeNatCtx<I>,
313    }
314
315    #[derive(Derivative)]
316    #[derivative(Default(bound = ""))]
317    pub struct FakeNatCtx<I: TestIpExt> {
318        pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
319    }
320
321    impl<I: TestIpExt> FakeCtx<I> {
322        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
323            Self {
324                state: State {
325                    installed_routines: ValidRoutines::default(),
326                    uninstalled_routines: Vec::default(),
327                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
328                    nat_installed: OneWayBoolean::default(),
329                },
330                nat: FakeNatCtx::default(),
331            }
332        }
333
334        pub fn with_ip_routines(
335            bindings_ctx: &mut FakeBindingsCtx<I>,
336            routines: IpRoutines<I, FakeBindingsCtx<I>, ()>,
337        ) -> Self {
338            let (installed_routines, uninstalled_routines) =
339                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
340                    .expect("invalid state");
341            Self {
342                state: State {
343                    installed_routines,
344                    uninstalled_routines,
345                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
346                    nat_installed: OneWayBoolean::default(),
347                },
348                nat: FakeNatCtx::default(),
349            }
350        }
351
352        pub fn with_nat_routines_and_device_addrs(
353            bindings_ctx: &mut FakeBindingsCtx<I>,
354            routines: NatRoutines<I, FakeBindingsCtx<I>, ()>,
355            device_addrs: impl IntoIterator<
356                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
357            >,
358        ) -> Self {
359            let (installed_routines, uninstalled_routines) =
360                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
361                    .expect("invalid state");
362            Self {
363                state: State {
364                    installed_routines,
365                    uninstalled_routines,
366                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
367                    nat_installed: OneWayBoolean::TRUE,
368                },
369                nat: FakeNatCtx {
370                    device_addrs: device_addrs
371                        .into_iter()
372                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
373                        .collect(),
374                },
375            }
376        }
377
378        pub fn conntrack(
379            &mut self,
380        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
381            &self.state.conntrack
382        }
383    }
384
385    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
386        type DeviceId = FakeMatcherDeviceId;
387        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
388    }
389
390    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
391        type AddressId = FakeAddressId<I>;
392        type WeakAddressId = FakeWeakAddressId<I>;
393    }
394
395    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
396        type NatCtx<'a> = FakeNatCtx<I>;
397
398        fn with_filter_state_and_nat_ctx<
399            O,
400            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
401        >(
402            &mut self,
403            cb: F,
404        ) -> O {
405            let Self { state, nat } = self;
406            cb(state, nat)
407        }
408    }
409
410    impl<I: TestIpExt> FakeNatCtx<I> {
411        pub fn new(
412            device_addrs: impl IntoIterator<
413                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
414            >,
415        ) -> Self {
416            Self {
417                device_addrs: device_addrs
418                    .into_iter()
419                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
420                    .collect(),
421            }
422        }
423    }
424
425    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
426        type DeviceId = FakeMatcherDeviceId;
427        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
428    }
429
430    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
431        type AddressId = FakeAddressId<I>;
432        type WeakAddressId = FakeWeakAddressId<I>;
433    }
434
435    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
436        fn get_local_addr_for_remote(
437            &mut self,
438            device_id: &Self::DeviceId,
439            _remote: Option<SpecifiedAddr<I::Addr>>,
440        ) -> Option<Self::AddressId> {
441            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
442            Some(FakeAddressId(primary.clone()))
443        }
444
445        fn get_address_id(
446            &mut self,
447            device_id: &Self::DeviceId,
448            addr: IpDeviceAddr<I::Addr>,
449        ) -> Option<Self::AddressId> {
450            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
451            let id = FakeAddressId(id.clone());
452            if id.addr() == addr { Some(id) } else { None }
453        }
454    }
455
456    pub struct FakeBindingsCtx<I: Ip> {
457        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
458        pub rng: FakeCryptoRng,
459    }
460
461    impl<I: Ip> FakeBindingsCtx<I> {
462        pub(crate) fn new() -> Self {
463            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
464        }
465
466        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
467            self.timer_ctx.instant.sleep(time_elapsed)
468        }
469    }
470
471    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
472        type Instant = FakeInstant;
473        type AtomicInstant = FakeAtomicInstant;
474    }
475
476    #[derive(Debug)]
477    pub struct FakeBindingsPacketMatcher {
478        num_calls: AtomicUsize,
479        result: bool,
480    }
481
482    impl FakeBindingsPacketMatcher {
483        pub fn new(result: bool) -> Arc<Self> {
484            Arc::new(Self { num_calls: AtomicUsize::new(0), result })
485        }
486
487        pub fn num_calls(&self) -> usize {
488            self.num_calls.load(core::sync::atomic::Ordering::SeqCst)
489        }
490    }
491
492    impl BindingsPacketMatcher for FakeBindingsPacketMatcher {
493        fn matches<I: FilterIpExt, P: IpPacket<I>>(&self, _packet: &P) -> bool {
494            let _: usize = self.num_calls.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
495            self.result
496        }
497    }
498
499    impl InspectableValue for FakeBindingsPacketMatcher {
500        fn record<I: Inspector>(&self, _name: &str, _inspector: &mut I) {
501            unimplemented!()
502        }
503    }
504
505    impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
506        type DeviceClass = FakeDeviceClass;
507        type BindingsPacketMatcher = Arc<FakeBindingsPacketMatcher>;
508    }
509
510    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
511        fn now(&self) -> Self::Instant {
512            self.timer_ctx.now()
513        }
514    }
515
516    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
517        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
518        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
519        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
520    }
521
522    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
523        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
524            self.timer_ctx.new_timer(id)
525        }
526
527        fn schedule_timer_instant(
528            &mut self,
529            time: Self::Instant,
530            timer: &mut Self::Timer,
531        ) -> Option<Self::Instant> {
532            self.timer_ctx.schedule_timer_instant(time, timer)
533        }
534
535        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
536            self.timer_ctx.cancel_timer(timer)
537        }
538
539        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
540            self.timer_ctx.scheduled_instant(timer)
541        }
542
543        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
544            self.timer_ctx.unique_timer_id(timer)
545        }
546    }
547
548    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
549        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
550            &self,
551            f: F,
552        ) -> O {
553            f(&self.timer_ctx)
554        }
555
556        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
557            &mut self,
558            f: F,
559        ) -> O {
560            f(&mut self.timer_ctx)
561        }
562    }
563
564    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
565        type Rng<'a>
566            = FakeCryptoRng
567        where
568            Self: 'a;
569
570        fn rng(&mut self) -> Self::Rng<'_> {
571            self.rng.clone()
572        }
573    }
574}