netstack3_filter/
context.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::fmt::Debug;
6
7use net_types::SpecifiedAddr;
8use net_types::ip::{IpVersion, Ipv4, Ipv6};
9use netstack3_base::socket::SocketCookie;
10use netstack3_base::{
11    InstantBindingsTypes, InterfaceProperties, IpDeviceAddr, IpDeviceAddressIdContext, Marks,
12    MatcherBindingsTypes, RngContext, StrongDeviceIdentifier, TimerBindingsTypes, TimerContext,
13    TxMetadataBindingsTypes,
14};
15use packet::{FragmentedByteSlice, PartialSerializer};
16use packet_formats::ip::IpExt;
17
18use crate::state::State;
19use crate::{FilterIpExt, IpPacket};
20
21/// Trait defining required types for filtering provided by bindings.
22pub trait FilterBindingsTypes:
23    InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static
24{
25}
26
27impl<BT: InstantBindingsTypes + MatcherBindingsTypes + TimerBindingsTypes + 'static>
28    FilterBindingsTypes for BT
29{
30}
31
32/// Trait aggregating functionality required from bindings.
33pub trait FilterBindingsContext: TimerContext + RngContext + FilterBindingsTypes {}
34impl<BC: TimerContext + RngContext + FilterBindingsTypes> FilterBindingsContext for BC {}
35
36/// The IP version-specific execution context for packet filtering.
37///
38/// This trait exists to abstract over access to the filtering state. It is
39/// useful to implement filtering logic in terms of this trait, as opposed to,
40/// for example, [`crate::logic::FilterHandler`] methods taking the state
41/// directly as an argument, because it allows Netstack3 Core to use lock
42/// ordering types to enforce that filtering state is only acquired at or before
43/// a given lock level, while keeping test code free of locking concerns.
44pub trait FilterIpContext<I: FilterIpExt, BT: FilterBindingsTypes>:
45    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
46{
47    /// The execution context that allows the filtering engine to perform
48    /// Network Address Translation (NAT).
49    type NatCtx<'a>: NatContext<I, BT, DeviceId = Self::DeviceId, WeakAddressId = Self::WeakAddressId>;
50
51    /// Calls the function with a reference to filtering state.
52    fn with_filter_state<O, F: FnOnce(&State<I, Self::WeakAddressId, BT>) -> O>(
53        &mut self,
54        cb: F,
55    ) -> O {
56        self.with_filter_state_and_nat_ctx(|state, _ctx| cb(state))
57    }
58
59    /// Calls the function with a reference to filtering state and the NAT
60    /// context.
61    fn with_filter_state_and_nat_ctx<
62        O,
63        F: FnOnce(&State<I, Self::WeakAddressId, BT>, &mut Self::NatCtx<'_>) -> O,
64    >(
65        &mut self,
66        cb: F,
67    ) -> O;
68}
69
70/// The execution context for Network Address Translation (NAT).
71pub trait NatContext<I: IpExt, BT: FilterBindingsTypes>:
72    IpDeviceAddressIdContext<I, DeviceId: InterfaceProperties<BT::DeviceClass>>
73{
74    /// Returns the best local address for communicating with the remote.
75    fn get_local_addr_for_remote(
76        &mut self,
77        device_id: &Self::DeviceId,
78        remote: Option<SpecifiedAddr<I::Addr>>,
79    ) -> Option<Self::AddressId>;
80
81    /// Returns a strongly-held reference to the provided address, if it is assigned
82    /// to the specified device.
83    fn get_address_id(
84        &mut self,
85        device_id: &Self::DeviceId,
86        addr: IpDeviceAddr<I::Addr>,
87    ) -> Option<Self::AddressId>;
88}
89
90/// A context for mutably accessing all filtering state at once, to allow IPv4
91/// and IPv6 filtering state to be modified atomically.
92pub trait FilterContext<BT: FilterBindingsTypes>:
93    IpDeviceAddressIdContext<Ipv4, DeviceId: InterfaceProperties<BT::DeviceClass>>
94    + IpDeviceAddressIdContext<Ipv6, DeviceId: InterfaceProperties<BT::DeviceClass>>
95{
96    /// Calls the function with a mutable reference to all filtering state.
97    fn with_all_filter_state_mut<
98        O,
99        F: FnOnce(
100            &mut State<Ipv4, <Self as IpDeviceAddressIdContext<Ipv4>>::WeakAddressId, BT>,
101            &mut State<Ipv6, <Self as IpDeviceAddressIdContext<Ipv6>>::WeakAddressId, BT>,
102        ) -> O,
103    >(
104        &mut self,
105        cb: F,
106    ) -> O;
107}
108
109/// Result returned from [`SocketOpsFilter::on_egress`].
110#[derive(Copy, Clone, Debug, Eq, PartialEq)]
111pub enum SocketEgressFilterResult {
112    /// Send the packet normally.
113    Pass {
114        /// Indicates that congestion should be signaled to the higher level protocol.
115        congestion: bool,
116    },
117
118    /// Drop the packet.
119    Drop {
120        /// Indicates that congestion should be signaled to the higher level protocol.
121        congestion: bool,
122    },
123}
124
125/// Result returned from [`SocketOpsFilter::on_ingress`].
126#[derive(Copy, Clone, Debug, Eq, PartialEq)]
127pub enum SocketIngressFilterResult {
128    /// Accept the packet.
129    Accept,
130
131    /// Drop the packet.
132    Drop,
133}
134
135/// Trait for a socket operations filter.
136pub trait SocketOpsFilter<D: StrongDeviceIdentifier> {
137    /// Called on every outgoing packet originated from a local socket.
138    fn on_egress<I: FilterIpExt, P: IpPacket<I> + PartialSerializer>(
139        &self,
140        packet: &P,
141        device: &D,
142        cookie: SocketCookie,
143        marks: &Marks,
144    ) -> SocketEgressFilterResult;
145
146    /// Called on every incoming packet handled by a local socket.
147    fn on_ingress(
148        &self,
149        ip_version: IpVersion,
150        packet: FragmentedByteSlice<'_, &[u8]>,
151        device: &D,
152        cookie: SocketCookie,
153        marks: &Marks,
154    ) -> SocketIngressFilterResult;
155}
156
157/// Implemented by bindings to provide socket operations filtering.
158pub trait SocketOpsFilterBindingContext<D: StrongDeviceIdentifier>:
159    TxMetadataBindingsTypes
160{
161    /// Returns the filter that should be called for socket ops.
162    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D>;
163}
164
165#[cfg(any(test, feature = "testutils"))]
166impl<
167    TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
168    Event: Debug + 'static,
169    State: 'static,
170    FrameMeta: 'static,
171    D: StrongDeviceIdentifier,
172> SocketOpsFilterBindingContext<D>
173    for netstack3_base::testutil::FakeBindingsCtx<TimerId, Event, State, FrameMeta>
174{
175    fn socket_ops_filter(&self) -> impl SocketOpsFilter<D> {
176        crate::testutil::NoOpSocketOpsFilter
177    }
178}
179
180#[cfg(test)]
181pub(crate) mod testutil {
182    use alloc::sync::{Arc, Weak};
183    use alloc::vec::Vec;
184    use core::hash::{Hash, Hasher};
185    use core::ops::Deref;
186    use core::time::Duration;
187
188    use derivative::Derivative;
189    use net_types::ip::{AddrSubnet, GenericOverIp, Ip};
190    use netstack3_base::testutil::{
191        FakeAtomicInstant, FakeCryptoRng, FakeDeviceClass, FakeInstant, FakeMatcherDeviceId,
192        FakeTimerCtx, FakeWeakDeviceId, WithFakeTimerContext,
193    };
194    use netstack3_base::{
195        AnyDevice, AssignedAddrIpExt, DeviceIdContext, InspectableValue, InstantContext,
196        IntoCoreTimerCtx, IpAddressId, WeakIpAddressId,
197    };
198    use netstack3_hashmap::HashMap;
199
200    use super::*;
201    use crate::conntrack;
202    use crate::logic::FilterTimerId;
203    use crate::logic::nat::NatConfig;
204    use crate::state::validation::ValidRoutines;
205    use crate::state::{IpRoutines, NatRoutines, OneWayBoolean, Routines};
206
207    pub trait TestIpExt: FilterIpExt + AssignedAddrIpExt {}
208
209    impl<I: FilterIpExt + AssignedAddrIpExt> TestIpExt for I {}
210
211    #[derive(Debug)]
212    pub struct FakePrimaryAddressId<I: AssignedAddrIpExt>(
213        pub Arc<AddrSubnet<I::Addr, I::AssignedWitness>>,
214    );
215
216    #[derive(Clone, Debug, Hash, Eq, PartialEq)]
217    pub struct FakeAddressId<I: AssignedAddrIpExt>(Arc<AddrSubnet<I::Addr, I::AssignedWitness>>);
218
219    #[derive(Clone, Debug)]
220    pub struct FakeWeakAddressId<I: AssignedAddrIpExt>(
221        pub Weak<AddrSubnet<I::Addr, I::AssignedWitness>>,
222    );
223
224    impl<I: AssignedAddrIpExt> PartialEq for FakeWeakAddressId<I> {
225        fn eq(&self, other: &Self) -> bool {
226            let Self(lhs) = self;
227            let Self(rhs) = other;
228            Weak::ptr_eq(lhs, rhs)
229        }
230    }
231
232    impl<I: AssignedAddrIpExt> Eq for FakeWeakAddressId<I> {}
233
234    impl<I: AssignedAddrIpExt> Hash for FakeWeakAddressId<I> {
235        fn hash<H: Hasher>(&self, state: &mut H) {
236            let Self(this) = self;
237            this.as_ptr().hash(state)
238        }
239    }
240
241    impl<I: AssignedAddrIpExt> WeakIpAddressId<I::Addr> for FakeWeakAddressId<I> {
242        type Strong = FakeAddressId<I>;
243
244        fn upgrade(&self) -> Option<Self::Strong> {
245            let Self(inner) = self;
246            inner.upgrade().map(FakeAddressId)
247        }
248
249        fn is_assigned(&self) -> bool {
250            let Self(inner) = self;
251            inner.strong_count() != 0
252        }
253    }
254
255    impl<I: AssignedAddrIpExt> InspectableValue for FakeWeakAddressId<I> {
256        fn record<Inspector: netstack3_base::Inspector>(
257            &self,
258            _name: &str,
259            _inspector: &mut Inspector,
260        ) {
261            unimplemented!()
262        }
263    }
264
265    impl<I: AssignedAddrIpExt> Deref for FakeAddressId<I> {
266        type Target = AddrSubnet<I::Addr, I::AssignedWitness>;
267
268        fn deref(&self) -> &Self::Target {
269            let Self(inner) = self;
270            inner.deref()
271        }
272    }
273
274    impl<I: AssignedAddrIpExt> IpAddressId<I::Addr> for FakeAddressId<I> {
275        type Weak = FakeWeakAddressId<I>;
276
277        fn downgrade(&self) -> Self::Weak {
278            let Self(inner) = self;
279            FakeWeakAddressId(Arc::downgrade(inner))
280        }
281
282        fn addr(&self) -> IpDeviceAddr<I::Addr> {
283            let Self(inner) = self;
284
285            #[derive(GenericOverIp)]
286            #[generic_over_ip(I, Ip)]
287            struct WrapIn<I: AssignedAddrIpExt>(I::AssignedWitness);
288            I::map_ip(
289                WrapIn(inner.addr()),
290                |WrapIn(v4_addr)| IpDeviceAddr::new_from_witness(v4_addr),
291                |WrapIn(v6_addr)| IpDeviceAddr::new_from_ipv6_device_addr(v6_addr),
292            )
293        }
294
295        fn addr_sub(&self) -> AddrSubnet<I::Addr, I::AssignedWitness> {
296            let Self(inner) = self;
297            **inner
298        }
299    }
300
301    pub struct FakeCtx<I: TestIpExt> {
302        state: State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>,
303        nat: FakeNatCtx<I>,
304    }
305
306    #[derive(Derivative)]
307    #[derivative(Default(bound = ""))]
308    pub struct FakeNatCtx<I: TestIpExt> {
309        pub(crate) device_addrs: HashMap<FakeMatcherDeviceId, FakePrimaryAddressId<I>>,
310    }
311
312    impl<I: TestIpExt> FakeCtx<I> {
313        pub fn new(bindings_ctx: &mut FakeBindingsCtx<I>) -> Self {
314            Self {
315                state: State {
316                    installed_routines: ValidRoutines::default(),
317                    uninstalled_routines: Vec::default(),
318                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
319                    nat_installed: OneWayBoolean::default(),
320                },
321                nat: FakeNatCtx::default(),
322            }
323        }
324
325        pub fn with_ip_routines(
326            bindings_ctx: &mut FakeBindingsCtx<I>,
327            routines: IpRoutines<I, FakeDeviceClass, ()>,
328        ) -> Self {
329            let (installed_routines, uninstalled_routines) =
330                ValidRoutines::new(Routines { ip: routines, ..Default::default() })
331                    .expect("invalid state");
332            Self {
333                state: State {
334                    installed_routines,
335                    uninstalled_routines,
336                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
337                    nat_installed: OneWayBoolean::default(),
338                },
339                nat: FakeNatCtx::default(),
340            }
341        }
342
343        pub fn with_nat_routines_and_device_addrs(
344            bindings_ctx: &mut FakeBindingsCtx<I>,
345            routines: NatRoutines<I, FakeDeviceClass, ()>,
346            device_addrs: impl IntoIterator<
347                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
348            >,
349        ) -> Self {
350            let (installed_routines, uninstalled_routines) =
351                ValidRoutines::new(Routines { nat: routines, ..Default::default() })
352                    .expect("invalid state");
353            Self {
354                state: State {
355                    installed_routines,
356                    uninstalled_routines,
357                    conntrack: conntrack::Table::new::<IntoCoreTimerCtx>(bindings_ctx),
358                    nat_installed: OneWayBoolean::TRUE,
359                },
360                nat: FakeNatCtx {
361                    device_addrs: device_addrs
362                        .into_iter()
363                        .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
364                        .collect(),
365                },
366            }
367        }
368
369        pub fn conntrack(
370            &mut self,
371        ) -> &conntrack::Table<I, NatConfig<I, FakeWeakAddressId<I>>, FakeBindingsCtx<I>> {
372            &self.state.conntrack
373        }
374    }
375
376    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeCtx<I> {
377        type DeviceId = FakeMatcherDeviceId;
378        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
379    }
380
381    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeCtx<I> {
382        type AddressId = FakeAddressId<I>;
383        type WeakAddressId = FakeWeakAddressId<I>;
384    }
385
386    impl<I: TestIpExt> FilterIpContext<I, FakeBindingsCtx<I>> for FakeCtx<I> {
387        type NatCtx<'a> = FakeNatCtx<I>;
388
389        fn with_filter_state_and_nat_ctx<
390            O,
391            F: FnOnce(&State<I, FakeWeakAddressId<I>, FakeBindingsCtx<I>>, &mut Self::NatCtx<'_>) -> O,
392        >(
393            &mut self,
394            cb: F,
395        ) -> O {
396            let Self { state, nat } = self;
397            cb(state, nat)
398        }
399    }
400
401    impl<I: TestIpExt> FakeNatCtx<I> {
402        pub fn new(
403            device_addrs: impl IntoIterator<
404                Item = (FakeMatcherDeviceId, AddrSubnet<I::Addr, I::AssignedWitness>),
405            >,
406        ) -> Self {
407            Self {
408                device_addrs: device_addrs
409                    .into_iter()
410                    .map(|(device, addr)| (device, FakePrimaryAddressId(Arc::new(addr))))
411                    .collect(),
412            }
413        }
414    }
415
416    impl<I: TestIpExt> DeviceIdContext<AnyDevice> for FakeNatCtx<I> {
417        type DeviceId = FakeMatcherDeviceId;
418        type WeakDeviceId = FakeWeakDeviceId<FakeMatcherDeviceId>;
419    }
420
421    impl<I: TestIpExt> IpDeviceAddressIdContext<I> for FakeNatCtx<I> {
422        type AddressId = FakeAddressId<I>;
423        type WeakAddressId = FakeWeakAddressId<I>;
424    }
425
426    impl<I: TestIpExt> NatContext<I, FakeBindingsCtx<I>> for FakeNatCtx<I> {
427        fn get_local_addr_for_remote(
428            &mut self,
429            device_id: &Self::DeviceId,
430            _remote: Option<SpecifiedAddr<I::Addr>>,
431        ) -> Option<Self::AddressId> {
432            let FakePrimaryAddressId(primary) = self.device_addrs.get(device_id)?;
433            Some(FakeAddressId(primary.clone()))
434        }
435
436        fn get_address_id(
437            &mut self,
438            device_id: &Self::DeviceId,
439            addr: IpDeviceAddr<I::Addr>,
440        ) -> Option<Self::AddressId> {
441            let FakePrimaryAddressId(id) = self.device_addrs.get(device_id)?;
442            let id = FakeAddressId(id.clone());
443            if id.addr() == addr { Some(id) } else { None }
444        }
445    }
446
447    pub struct FakeBindingsCtx<I: Ip> {
448        pub timer_ctx: FakeTimerCtx<FilterTimerId<I>>,
449        pub rng: FakeCryptoRng,
450    }
451
452    impl<I: Ip> FakeBindingsCtx<I> {
453        pub(crate) fn new() -> Self {
454            Self { timer_ctx: FakeTimerCtx::default(), rng: FakeCryptoRng::default() }
455        }
456
457        pub(crate) fn sleep(&mut self, time_elapsed: Duration) {
458            self.timer_ctx.instant.sleep(time_elapsed)
459        }
460    }
461
462    impl<I: Ip> InstantBindingsTypes for FakeBindingsCtx<I> {
463        type Instant = FakeInstant;
464        type AtomicInstant = FakeAtomicInstant;
465    }
466
467    impl<I: Ip> MatcherBindingsTypes for FakeBindingsCtx<I> {
468        type DeviceClass = FakeDeviceClass;
469    }
470
471    impl<I: Ip> InstantContext for FakeBindingsCtx<I> {
472        fn now(&self) -> Self::Instant {
473            self.timer_ctx.now()
474        }
475    }
476
477    impl<I: Ip> TimerBindingsTypes for FakeBindingsCtx<I> {
478        type Timer = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::Timer;
479        type DispatchId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::DispatchId;
480        type UniqueTimerId = <FakeTimerCtx<FilterTimerId<I>> as TimerBindingsTypes>::UniqueTimerId;
481    }
482
483    impl<I: Ip> TimerContext for FakeBindingsCtx<I> {
484        fn new_timer(&mut self, id: Self::DispatchId) -> Self::Timer {
485            self.timer_ctx.new_timer(id)
486        }
487
488        fn schedule_timer_instant(
489            &mut self,
490            time: Self::Instant,
491            timer: &mut Self::Timer,
492        ) -> Option<Self::Instant> {
493            self.timer_ctx.schedule_timer_instant(time, timer)
494        }
495
496        fn cancel_timer(&mut self, timer: &mut Self::Timer) -> Option<Self::Instant> {
497            self.timer_ctx.cancel_timer(timer)
498        }
499
500        fn scheduled_instant(&self, timer: &mut Self::Timer) -> Option<Self::Instant> {
501            self.timer_ctx.scheduled_instant(timer)
502        }
503
504        fn unique_timer_id(&self, timer: &Self::Timer) -> Self::UniqueTimerId {
505            self.timer_ctx.unique_timer_id(timer)
506        }
507    }
508
509    impl<I: Ip> WithFakeTimerContext<FilterTimerId<I>> for FakeBindingsCtx<I> {
510        fn with_fake_timer_ctx<O, F: FnOnce(&FakeTimerCtx<FilterTimerId<I>>) -> O>(
511            &self,
512            f: F,
513        ) -> O {
514            f(&self.timer_ctx)
515        }
516
517        fn with_fake_timer_ctx_mut<O, F: FnOnce(&mut FakeTimerCtx<FilterTimerId<I>>) -> O>(
518            &mut self,
519            f: F,
520        ) -> O {
521            f(&mut self.timer_ctx)
522        }
523    }
524
525    impl<I: Ip> RngContext for FakeBindingsCtx<I> {
526        type Rng<'a>
527            = FakeCryptoRng
528        where
529            Self: 'a;
530
531        fn rng(&mut self) -> Self::Rng<'_> {
532            self.rng.clone()
533        }
534    }
535}