netstack3_device/
socket.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Link-layer sockets (analogous to Linux's AF_PACKET sockets).
6
7use alloc::collections::{HashMap, HashSet};
8use core::fmt::Debug;
9use core::hash::Hash;
10use core::num::NonZeroU16;
11
12use derivative::Derivative;
13use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
14use net_types::ethernet::Mac;
15use net_types::ip::IpVersion;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18    AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19    Inspector, InspectorDeviceExt, InspectorExt, ReferenceNotifiers, ReferenceNotifiersExt as _,
20    RemoveResourceResultWithContext, ResourceCounterContext, SendFrameContext,
21    SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use packet::{BufferMut, ParsablePacket as _, Serializer};
24use packet_formats::error::ParseError;
25use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
26
27use crate::internal::base::DeviceLayerTypes;
28use crate::internal::id::WeakDeviceId;
29
30/// A selector for frames based on link-layer protocol number.
31#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
32pub enum Protocol {
33    /// Select all frames, regardless of protocol number.
34    All,
35    /// Select frames with the given protocol number.
36    Specific(NonZeroU16),
37}
38
39/// Selector for devices to send and receive packets on.
40#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
41#[derivative(Default(bound = ""))]
42pub enum TargetDevice<D> {
43    /// Act on any device in the system.
44    #[derivative(Default)]
45    AnyDevice,
46    /// Act on a specific device.
47    SpecificDevice(D),
48}
49
50/// Information about the bound state of a socket.
51#[derive(Debug)]
52#[cfg_attr(test, derive(PartialEq))]
53pub struct SocketInfo<D> {
54    /// The protocol the socket is bound to, or `None` if no protocol is set.
55    pub protocol: Option<Protocol>,
56    /// The device selector for which the socket is set.
57    pub device: TargetDevice<D>,
58}
59
60/// Provides associated types for device sockets provided by the bindings
61/// context.
62pub trait DeviceSocketTypes {
63    /// State for the socket held by core and exposed to bindings.
64    type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
65}
66
67/// The execution context for device sockets provided by bindings.
68pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>: DeviceSocketTypes {
69    /// Called for each received frame that matches the provided socket.
70    ///
71    /// `frame` and `raw_frame` are parsed and raw views into the same data.
72    fn receive_frame(
73        &self,
74        socket: &Self::SocketState<DeviceId::Weak>,
75        device: &DeviceId,
76        frame: Frame<&[u8]>,
77        raw_frame: &[u8],
78    );
79}
80
81/// Strong owner of socket state.
82///
83/// This type strongly owns the socket state.
84#[derive(Debug)]
85pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
86    PrimaryRc<SocketState<D, BT>>,
87);
88
89impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
90    /// Creates a new socket ID with `external_state`.
91    fn new(external_state: BT::SocketState<D>) -> Self {
92        Self(PrimaryRc::new(SocketState {
93            external_state,
94            counters: Default::default(),
95            target: Default::default(),
96        }))
97    }
98
99    /// Clones the primary's underlying reference and returns as a strong id.
100    fn clone_strong(&self) -> DeviceSocketId<D, BT> {
101        let PrimaryDeviceSocketId(rc) = self;
102        DeviceSocketId(PrimaryRc::clone_strong(rc))
103    }
104}
105
106/// Reference to live socket state.
107///
108/// The existence of a `StrongId` attests to the liveness of the state of the
109/// backing socket.
110#[derive(Derivative)]
111#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
112pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
113    StrongRc<SocketState<D, BT>>,
114);
115
116impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
117    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
118        let Self(rc) = self;
119        f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
120    }
121}
122
123impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
124    for DeviceSocketId<D, BT>
125{
126    type Lock = Mutex<Target<D>>;
127    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
128        let Self(rc) = self;
129        OrderedLockRef::new(&rc.target)
130    }
131}
132
133/// A weak reference to socket state.
134///
135/// The existence of a [`WeakSocketDeviceId`] does not attest to the liveness of
136/// the backing socket.
137#[derive(Derivative)]
138#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
139pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
140    WeakRc<SocketState<D, BT>>,
141);
142
143impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
144    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145        let Self(rc) = self;
146        f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
147    }
148}
149
150/// Holds shared state for sockets.
151#[derive(Derivative)]
152#[derivative(Default(bound = ""))]
153pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
154    /// Holds strong (but not owning) references to sockets that aren't
155    /// targeting a particular device.
156    any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
157
158    /// Table of all sockets in the system, regardless of target.
159    ///
160    /// Holds the primary (owning) reference for all sockets.
161    // This needs to be after `any_device_sockets` so that when an instance of
162    // this type is dropped, any strong IDs get dropped before their
163    // corresponding primary IDs.
164    all_sockets: RwLock<AllSockets<D, BT>>,
165}
166
167/// The set of sockets associated with a device.
168#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
171    HashSet<DeviceSocketId<D, BT>>,
172);
173
174/// A collection of all device sockets in the system.
175#[derive(Derivative)]
176#[derivative(Default(bound = ""))]
177pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
178    HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
179);
180
181/// State held by a device socket.
182#[derive(Debug)]
183pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
184    /// State provided by bindings that is held in core.
185    pub external_state: BT::SocketState<D>,
186    /// The socket's target device and protocol.
187    // TODO(https://fxbug.dev/42077026): Consider splitting up the state here to
188    // improve performance.
189    target: Mutex<Target<D>>,
190    /// Statistics about the socket's usage.
191    counters: DeviceSocketCounters,
192}
193
194/// A device socket's binding information.
195#[derive(Debug, Derivative)]
196#[derivative(Default(bound = ""))]
197pub struct Target<D> {
198    protocol: Option<Protocol>,
199    device: TargetDevice<D>,
200}
201
202/// Per-device state for packet sockets.
203///
204/// Holds sockets that are bound to a particular device. An instance of this
205/// should be held in the state for each device in the system.
206#[derive(Derivative)]
207#[derivative(Default(bound = ""))]
208#[cfg_attr(
209    test,
210    derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
211)]
212pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
213    HashSet<DeviceSocketId<D, BT>>,
214);
215
216/// Convenience alias for use in device state storage.
217pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
218
219/// Convenience alias for use in shared storage.
220///
221/// The type parameter is expected to implement [`DeviceSocketTypes`].
222pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
223
224/// Core context for accessing socket state.
225pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
226    /// The core context available in callbacks to methods on this context.
227    type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<
228        BT,
229        DeviceId = Self::DeviceId,
230        WeakDeviceId = Self::WeakDeviceId,
231    >;
232
233    /// Executes the provided callback with access to the collection of all
234    /// sockets.
235    fn with_all_device_sockets<
236        F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
237        R,
238    >(
239        &mut self,
240        cb: F,
241    ) -> R;
242
243    /// Executes the provided callback with mutable access to the collection of
244    /// all sockets.
245    fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
246        &mut self,
247        cb: F,
248    ) -> R;
249
250    /// Executes the provided callback with immutable access to socket state.
251    fn with_any_device_sockets<
252        F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
253        R,
254    >(
255        &mut self,
256        cb: F,
257    ) -> R;
258
259    /// Executes the provided callback with mutable access to socket state.
260    fn with_any_device_sockets_mut<
261        F: FnOnce(
262            &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
263            &mut Self::SocketTablesCoreCtx<'_>,
264        ) -> R,
265        R,
266    >(
267        &mut self,
268        cb: F,
269    ) -> R;
270}
271
272/// Core context for accessing the state of an individual socket.
273pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
274    /// Provides read-only access to the state of a socket.
275    fn with_socket_state<
276        F: FnOnce(&BT::SocketState<Self::WeakDeviceId>, &Target<Self::WeakDeviceId>) -> R,
277        R,
278    >(
279        &mut self,
280        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
281        cb: F,
282    ) -> R;
283
284    /// Provides mutable access to the state of a socket.
285    fn with_socket_state_mut<
286        F: FnOnce(&BT::SocketState<Self::WeakDeviceId>, &mut Target<Self::WeakDeviceId>) -> R,
287        R,
288    >(
289        &mut self,
290        socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291        cb: F,
292    ) -> R;
293}
294
295/// Core context for accessing the socket state for a device.
296pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
297    /// Core context available in callbacks to methods on this context.
298    type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
299        + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
300
301    /// Executes the provided callback with immutable access to device-specific
302    /// socket state.
303    fn with_device_sockets<
304        F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
305        R,
306    >(
307        &mut self,
308        device: &Self::DeviceId,
309        cb: F,
310    ) -> R;
311
312    /// Executes the provided callback with mutable access to device-specific
313    /// socket state.
314    fn with_device_sockets_mut<
315        F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
316        R,
317    >(
318        &mut self,
319        device: &Self::DeviceId,
320        cb: F,
321    ) -> R;
322}
323
324enum MaybeUpdate<T> {
325    NoChange,
326    NewValue(T),
327}
328
329fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
330    core_ctx: &mut CC,
331    socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
332    new_device: TargetDevice<&CC::DeviceId>,
333    protocol_update: MaybeUpdate<Protocol>,
334) {
335    core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
336        // Even if we're never moving the socket from/to the any-device
337        // state, we acquire the lock to make the move between devices
338        // atomic from the perspective of frame delivery. Otherwise there
339        // would be a brief period during which arriving frames wouldn't be
340        // delivered to the socket from either device.
341        let old_device = core_ctx.with_socket_state_mut(
342            socket,
343            |_: &BT::SocketState<CC::WeakDeviceId>, Target { protocol, device }| {
344                match protocol_update {
345                    MaybeUpdate::NewValue(p) => *protocol = Some(p),
346                    MaybeUpdate::NoChange => (),
347                };
348                let old_device = match &device {
349                    TargetDevice::SpecificDevice(device) => device.upgrade(),
350                    TargetDevice::AnyDevice => {
351                        assert!(any_device_sockets.remove(socket));
352                        None
353                    }
354                };
355                *device = match &new_device {
356                    TargetDevice::AnyDevice => TargetDevice::AnyDevice,
357                    TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
358                };
359                old_device
360            },
361        );
362
363        // This modification occurs without holding the socket's individual
364        // lock. That's safe because all modifications to the socket's
365        // device are done within a `with_sockets_mut` call, which
366        // synchronizes them.
367
368        if let Some(device) = old_device {
369            // Remove the reference to the socket from the old device if
370            // there is one, and it hasn't been removed.
371            core_ctx.with_device_sockets_mut(
372                &device,
373                |DeviceSockets(device_sockets), _core_ctx| {
374                    assert!(device_sockets.remove(socket), "socket not found in device state");
375                },
376            );
377        }
378
379        // Add the reference to the new device, if there is one.
380        match &new_device {
381            TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
382                new_device,
383                |DeviceSockets(device_sockets), _core_ctx| {
384                    assert!(device_sockets.insert(socket.clone()));
385                },
386            ),
387            TargetDevice::AnyDevice => {
388                assert!(any_device_sockets.insert(socket.clone()))
389            }
390        }
391    })
392}
393
394/// The device socket API.
395pub struct DeviceSocketApi<C>(C);
396
397impl<C> DeviceSocketApi<C> {
398    /// Creates a new `DeviceSocketApi` for `ctx`.
399    pub fn new(ctx: C) -> Self {
400        Self(ctx)
401    }
402}
403
404/// A local alias for [`DeviceSocketId`] for use in [`DeviceSocketApi`].
405///
406/// TODO(https://github.com/rust-lang/rust/issues/8995): Make this an inherent
407/// associated type.
408type ApiSocketId<C> = DeviceSocketId<
409    <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
410    <C as ContextPair>::BindingsContext,
411>;
412
413impl<C> DeviceSocketApi<C>
414where
415    C: ContextPair,
416    C::CoreContext: DeviceSocketContext<C::BindingsContext>
417        + SocketStateAccessor<C::BindingsContext>
418        + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
419    C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
420        + ReferenceNotifiers
421        + 'static,
422{
423    fn core_ctx(&mut self) -> &mut C::CoreContext {
424        let Self(pair) = self;
425        pair.core_ctx()
426    }
427
428    fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
429        let Self(pair) = self;
430        pair.contexts()
431    }
432
433    /// Creates an packet socket with no protocol set configured for all devices.
434    pub fn create(
435        &mut self,
436        external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
437            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
438        >,
439    ) -> ApiSocketId<C> {
440        let core_ctx = self.core_ctx();
441
442        let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
443            let primary = PrimaryDeviceSocketId::new(external_state);
444            let strong = primary.clone_strong();
445            assert!(sockets.insert(strong.clone(), primary).is_none());
446            strong
447        });
448        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
449            // On creation, sockets do not target any device or protocol.
450            // Inserting them into the `any_device_sockets` table lets us treat
451            // newly-created sockets uniformly with sockets whose target device
452            // or protocol was set. The difference is unobservable at runtime
453            // since newly-created sockets won't match any frames being
454            // delivered.
455            assert!(any_device_sockets.insert(strong.clone()));
456        });
457        strong
458    }
459
460    /// Sets the device for which a packet socket will receive packets.
461    pub fn set_device(
462        &mut self,
463        socket: &ApiSocketId<C>,
464        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
465    ) {
466        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
467    }
468
469    /// Sets the device and protocol for which a socket will receive packets.
470    pub fn set_device_and_protocol(
471        &mut self,
472        socket: &ApiSocketId<C>,
473        device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
474        protocol: Protocol,
475    ) {
476        update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
477    }
478
479    /// Gets the bound info for a socket.
480    pub fn get_info(
481        &mut self,
482        id: &ApiSocketId<C>,
483    ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
484        self.core_ctx().with_socket_state(id, |_external_state, Target { device, protocol }| {
485            SocketInfo { device: device.clone(), protocol: *protocol }
486        })
487    }
488
489    /// Removes a bound socket.
490    pub fn remove(
491        &mut self,
492        id: ApiSocketId<C>,
493    ) -> RemoveResourceResultWithContext<
494        <C::BindingsContext as DeviceSocketTypes>::SocketState<
495            <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
496        >,
497        C::BindingsContext,
498    > {
499        let core_ctx = self.core_ctx();
500        core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
501            let old_device = core_ctx.with_socket_state_mut(&id, |_external_state, target| {
502                let Target { device, protocol: _ } = target;
503                match &device {
504                    TargetDevice::SpecificDevice(device) => device.upgrade(),
505                    TargetDevice::AnyDevice => {
506                        assert!(any_device_sockets.remove(&id));
507                        None
508                    }
509                }
510            });
511            if let Some(device) = old_device {
512                core_ctx.with_device_sockets_mut(
513                    &device,
514                    |DeviceSockets(device_sockets), _core_ctx| {
515                        assert!(device_sockets.remove(&id), "device doesn't have socket");
516                    },
517                )
518            }
519        });
520
521        core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
522            let primary = sockets
523                .remove(&id)
524                .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
525            // Make sure to drop the strong ID before trying to unwrap the primary
526            // ID.
527            drop(id);
528
529            let PrimaryDeviceSocketId(primary) = primary;
530            C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
531                primary,
532                |SocketState { external_state, counters: _, target: _ }| external_state,
533            )
534        })
535    }
536
537    /// Sends a frame for the specified socket.
538    pub fn send_frame<S, D>(
539        &mut self,
540        id: &ApiSocketId<C>,
541        metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
542        body: S,
543    ) -> Result<(), SendFrameErrorReason>
544    where
545        S: Serializer,
546        S::Buffer: BufferMut,
547        D: DeviceSocketSendTypes,
548        C::CoreContext: DeviceIdContext<D>
549            + SendFrameContext<
550                C::BindingsContext,
551                DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
552            >,
553        C::BindingsContext: DeviceLayerTypes,
554    {
555        let (core_ctx, bindings_ctx) = self.contexts();
556        let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
557        match &result {
558            Ok(()) => {
559                core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
560            }
561            Err(SendFrameErrorReason::QueueFull) => core_ctx
562                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
563            Err(SendFrameErrorReason::Alloc) => core_ctx
564                .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
565            Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
566                .increment_both(id, |counters: &DeviceSocketCounters| {
567                    &counters.tx_err_size_constraint
568                }),
569        }
570        result
571    }
572
573    /// Provides inspect data for raw IP sockets.
574    pub fn inspect<N>(&mut self, inspector: &mut N)
575    where
576        N: Inspector
577            + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
578    {
579        self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
580            sockets.keys().for_each(|socket| {
581                inspector.record_debug_child(socket, |node| {
582                    core_ctx.with_socket_state(
583                        socket,
584                        |_external_state, Target { protocol, device }| {
585                            node.record_debug("Protocol", protocol);
586                            match device {
587                                TargetDevice::AnyDevice => node.record_str("Device", "Any"),
588                                TargetDevice::SpecificDevice(d) => {
589                                    N::record_device(node, "Device", d)
590                                }
591                            }
592                        },
593                    );
594                    node.record_child("Counters", |node| {
595                        node.delegate_inspectable(socket.counters())
596                    })
597                })
598            })
599        })
600    }
601}
602
603/// A provider of the types required to send on a device socket.
604pub trait DeviceSocketSendTypes: Device {
605    /// The metadata required to send a frame on the device.
606    type Metadata;
607}
608
609/// Metadata required to send a frame on a device socket.
610#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612    /// The device ID to send via.
613    pub device_id: DeviceId,
614    /// The metadata required to send that's specific to the device type.
615    pub metadata: D::Metadata,
616    // TODO(https://fxbug.dev/391946195): Include send buffer ownership metadata
617    // here.
618}
619
620/// Parameters needed to apply system-framing of an Ethernet frame.
621#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623    /// The destination MAC address to send to.
624    pub dest_addr: Mac,
625    /// The upperlayer protocol of the data contained in this Ethernet frame.
626    pub protocol: EtherType,
627}
628
629/// Public identifier for a socket.
630///
631/// Strongly owns the state of the socket. So long as the `SocketId` for a
632/// socket is not dropped, the socket is guaranteed to exist.
633pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636    /// Provides immutable access to [`DeviceSocketTypes::SocketState`] for the
637    /// socket.
638    pub fn socket_state(&self) -> &BT::SocketState<D> {
639        let Self(strong) = self;
640        let SocketState { external_state, counters: _, target: _ } = &**strong;
641        external_state
642    }
643
644    /// Obtain a [`WeakDeviceSocketId`] from this [`DeviceSocketId`].
645    pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646        let Self(inner) = self;
647        WeakDeviceSocketId(StrongRc::downgrade(inner))
648    }
649
650    /// Provides access to the socket's counters.
651    pub fn counters(&self) -> &DeviceSocketCounters {
652        let Self(strong) = self;
653        let SocketState { external_state: _, counters, target: _ } = &**strong;
654        counters
655    }
656}
657
658/// Allows the rest of the stack to dispatch packets to listening sockets.
659///
660/// This is implemented on top of [`DeviceSocketContext`] and abstracts packet
661/// socket delivery from the rest of the system.
662pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663    /// Dispatch a received frame to sockets.
664    fn handle_frame(
665        &mut self,
666        bindings_ctx: &mut BC,
667        device: &Self::DeviceId,
668        frame: Frame<&[u8]>,
669        whole_frame: &[u8],
670    );
671}
672
673/// A frame received on a device.
674#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676    /// An ethernet frame received on a device.
677    Ethernet {
678        /// Where the frame was destined.
679        destination: FrameDestination,
680        /// The parsed ethernet frame.
681        frame: EthernetFrame<B>,
682    },
683    /// An IP frame received on a device.
684    ///
685    /// Note that this is not an IP packet within an Ethernet Frame. This is an
686    /// IP packet received directly from the device (e.g. a pure IP device).
687    Ip(IpFrame<B>),
688}
689
690/// A frame sent on a device.
691#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693    /// An ethernet frame sent on a device.
694    Ethernet(EthernetFrame<B>),
695    /// An IP frame sent on a device.
696    ///
697    /// Note that this is not an IP packet within an Ethernet Frame. This is an
698    /// IP Packet send directly on the device (e.g. a pure IP device).
699    Ip(IpFrame<B>),
700}
701
702/// A frame couldn't be parsed as a [`SentFrame`].
703#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707    /// Tries to parse the given frame as an Ethernet frame.
708    pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709        packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710            .map_err(|_: ParseError| ParseSentFrameError)
711            .map(|frame| SentFrame::Ethernet(frame.into()))
712    }
713}
714
715/// Data from an Ethernet frame.
716#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718    /// The source address of the frame.
719    pub src_mac: Mac,
720    /// The destination address of the frame.
721    pub dst_mac: Mac,
722    /// The EtherType of the frame, or `None` if there was none.
723    pub ethertype: Option<EtherType>,
724    /// The body of the frame.
725    pub body: B,
726}
727
728/// Data from an IP frame.
729#[derive(Clone, Copy, Debug, Eq, PartialEq)]
730pub struct IpFrame<B> {
731    /// The IP version of the frame.
732    pub ip_version: IpVersion,
733    /// The body of the frame.
734    pub body: B,
735}
736
737impl<B> IpFrame<B> {
738    fn ethertype(&self) -> EtherType {
739        let IpFrame { ip_version, body: _ } = self;
740        EtherType::from_ip_version(*ip_version)
741    }
742}
743
744/// A frame sent or received on a device
745#[derive(Clone, Copy, Debug, Eq, PartialEq)]
746pub enum Frame<B> {
747    /// A sent frame.
748    Sent(SentFrame<B>),
749    /// A received frame.
750    Received(ReceivedFrame<B>),
751}
752
753impl<B> From<SentFrame<B>> for Frame<B> {
754    fn from(value: SentFrame<B>) -> Self {
755        Self::Sent(value)
756    }
757}
758
759impl<B> From<ReceivedFrame<B>> for Frame<B> {
760    fn from(value: ReceivedFrame<B>) -> Self {
761        Self::Received(value)
762    }
763}
764
765impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
766    fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
767        Self {
768            src_mac: frame.src_mac(),
769            dst_mac: frame.dst_mac(),
770            ethertype: frame.ethertype(),
771            body: frame.into_body(),
772        }
773    }
774}
775
776impl<'a> ReceivedFrame<&'a [u8]> {
777    pub(crate) fn from_ethernet(
778        frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
779        destination: FrameDestination,
780    ) -> Self {
781        Self::Ethernet { destination, frame: frame.into() }
782    }
783}
784
785impl<B> Frame<B> {
786    /// Returns ether type for the packet if it's known.
787    pub fn protocol(&self) -> Option<u16> {
788        let ethertype = match self {
789            Self::Sent(SentFrame::Ethernet(frame))
790            | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
791            Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
792                Some(frame.ethertype())
793            }
794        };
795        ethertype.map(Into::into)
796    }
797
798    /// Convenience method for consuming the `Frame` and producing the body.
799    pub fn into_body(self) -> B {
800        match self {
801            Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
802            | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
803            Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
804                frame.body
805            }
806        }
807    }
808}
809
810impl<
811        D: Device,
812        BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
813        CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
814    > DeviceSocketHandler<D, BC> for CC
815where
816    <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
817{
818    fn handle_frame(
819        &mut self,
820        bindings_ctx: &mut BC,
821        device: &Self::DeviceId,
822        frame: Frame<&[u8]>,
823        whole_frame: &[u8],
824    ) {
825        let device = device.clone().into();
826
827        // TODO(https://fxbug.dev/42076496): Invert the order of acquisition
828        // for the lock on the sockets held in the device and the any-device
829        // sockets lock.
830        self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
831            // Iterate through the device's sockets while also holding the
832            // any-device sockets lock. This prevents double delivery to the
833            // same socket. If the two tables were locked independently,
834            // we could end up with a race, with the following thread
835            // interleaving (thread A is executing this code for device D,
836            // thread B is updating the device to D for the same socket X):
837            //   A) lock the any device sockets table
838            //   A) deliver to socket X in the table
839            //   A) unlock the any device sockets table
840            //   B) lock the any device sockets table, then D's sockets
841            //   B) remove X from the any table and add to D's
842            //   B) unlock D's sockets and any device sockets
843            //   A) lock D's sockets
844            //   A) deliver to socket X in D's table (!)
845            core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
846                for socket in any_device_sockets.iter().chain(device_sockets) {
847                    let delivered = core_ctx.with_socket_state(
848                        socket,
849                        |external_state, Target { protocol, device: _ }| {
850                            let should_deliver = match protocol {
851                                None => false,
852                                Some(p) => match p {
853                                    // Sent frames are only delivered to sockets
854                                    // matching all protocols for Linux
855                                    // compatibility. See https://github.com/google/gvisor/blob/68eae979409452209e4faaeac12aee4191b3d6f0/test/syscalls/linux/packet_socket.cc#L331-L392.
856                                    Protocol::Specific(p) => match frame {
857                                        Frame::Received(_) => Some(p.get()) == frame.protocol(),
858                                        Frame::Sent(_) => false,
859                                    },
860                                    Protocol::All => true,
861                                },
862                            };
863                            if should_deliver {
864                                bindings_ctx.receive_frame(
865                                    external_state,
866                                    &device,
867                                    frame,
868                                    whole_frame,
869                                )
870                            }
871                            should_deliver
872                        },
873                    );
874                    if delivered {
875                        core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
876                            &counters.rx_frames
877                        });
878                    }
879                }
880            })
881        })
882    }
883}
884
885/// Usage statistics about Device Sockets.
886///
887/// Tracked stack-wide and per-socket.
888#[derive(Debug, Default)]
889pub struct DeviceSocketCounters {
890    /// Count of incoming frames that were delivered to the socket.
891    ///
892    /// Note that a single frame may be delivered to multiple device sockets.
893    /// Thus this counter, when tracking the stack-wide aggregate, may exceed
894    /// the total number of frames received by the stack.
895    rx_frames: Counter,
896    /// Count of outgoing frames that were sent by the socket.
897    tx_frames: Counter,
898    /// Count of failed tx frames due to [`SendFrameErrorReason::QueueFull`].
899    tx_err_queue_full: Counter,
900    /// Count of failed tx frames due to [`SendFrameErrorReason::Alloc`].
901    tx_err_alloc: Counter,
902    /// Count of failed tx frames due to [`SendFrameErrorReason::SizeConstraintsViolation`].
903    tx_err_size_constraint: Counter,
904}
905
906impl Inspectable for DeviceSocketCounters {
907    fn record<I: Inspector>(&self, inspector: &mut I) {
908        let Self { rx_frames, tx_frames, tx_err_queue_full, tx_err_alloc, tx_err_size_constraint } =
909            self;
910        inspector.record_child("Rx", |inspector| {
911            inspector.record_counter("DeliveredFrames", rx_frames);
912        });
913        inspector.record_child("Tx", |inspector| {
914            inspector.record_counter("SentFrames", tx_frames);
915            inspector.record_counter("QueueFullError", tx_err_queue_full);
916            inspector.record_counter("AllocError", tx_err_alloc);
917            inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
918        });
919    }
920}
921
922impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
923    for Sockets<D, BT>
924{
925    type Lock = RwLock<AnyDeviceSockets<D, BT>>;
926    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
927        OrderedLockRef::new(&self.any_device_sockets)
928    }
929}
930
931impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
932    for Sockets<D, BT>
933{
934    type Lock = RwLock<AllSockets<D, BT>>;
935    fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
936        OrderedLockRef::new(&self.all_sockets)
937    }
938}
939
940#[cfg(any(test, feature = "testutils"))]
941mod testutil {
942    use alloc::vec::Vec;
943    use core::num::NonZeroU64;
944    use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
945    use netstack3_base::StrongDeviceIdentifier;
946
947    use super::*;
948    use crate::internal::base::{
949        DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
950    };
951
952    #[derive(Clone, Debug, PartialEq)]
953    pub struct ReceivedFrame<D> {
954        pub device: D,
955        pub frame: Frame<Vec<u8>>,
956        pub raw: Vec<u8>,
957    }
958
959    #[derive(Debug, Derivative)]
960    #[derivative(Default(bound = ""))]
961    pub struct ExternalSocketState<D>(pub Mutex<Vec<ReceivedFrame<D>>>);
962
963    impl<TimerId, Event: Debug, State> DeviceSocketTypes
964        for FakeBindingsCtx<TimerId, Event, State, ()>
965    {
966        type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
967    }
968
969    impl Frame<&[u8]> {
970        pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
971            match self {
972                Self::Sent(SentFrame::Ethernet(frame)) => {
973                    Frame::Sent(SentFrame::Ethernet(frame.cloned()))
974                }
975                Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
976                    Frame::Received(super::ReceivedFrame::Ethernet {
977                        destination,
978                        frame: frame.cloned(),
979                    })
980                }
981                Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
982                Self::Received(super::ReceivedFrame::Ip(frame)) => {
983                    Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
984                }
985            }
986        }
987    }
988
989    impl EthernetFrame<&[u8]> {
990        fn cloned(self) -> EthernetFrame<Vec<u8>> {
991            let Self { src_mac, dst_mac, ethertype, body } = self;
992            EthernetFrame { src_mac, dst_mac, ethertype, body: Vec::from(body) }
993        }
994    }
995
996    impl IpFrame<&[u8]> {
997        fn cloned(self) -> IpFrame<Vec<u8>> {
998            let Self { ip_version, body } = self;
999            IpFrame { ip_version, body: Vec::from(body) }
1000        }
1001    }
1002
1003    impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1004        for FakeBindingsCtx<TimerId, Event, State, ()>
1005    {
1006        fn receive_frame(
1007            &self,
1008            state: &ExternalSocketState<D::Weak>,
1009            device: &D,
1010            frame: Frame<&[u8]>,
1011            raw_frame: &[u8],
1012        ) {
1013            let ExternalSocketState(queue) = state;
1014            queue.lock().push(ReceivedFrame {
1015                device: device.downgrade(),
1016                frame: frame.cloned(),
1017                raw: raw_frame.into(),
1018            })
1019        }
1020    }
1021
1022    impl<
1023            TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1024            Event: Debug + 'static,
1025            State: 'static,
1026        > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1027    {
1028        type EthernetDeviceState = ();
1029        type LoopbackDeviceState = ();
1030        type PureIpDeviceState = ();
1031        type BlackholeDeviceState = ();
1032        type DeviceIdentifier = MonotonicIdentifier;
1033    }
1034
1035    impl DeviceClassMatcher<()> for () {
1036        fn device_class_matches(&self, (): &()) -> bool {
1037            unimplemented!()
1038        }
1039    }
1040
1041    impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1042        fn id_matches(&self, _id: &NonZeroU64) -> bool {
1043            unimplemented!()
1044        }
1045
1046        fn name_matches(&self, _name: &str) -> bool {
1047            unimplemented!()
1048        }
1049    }
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054    use alloc::collections::HashMap;
1055    use alloc::vec;
1056    use alloc::vec::Vec;
1057    use core::marker::PhantomData;
1058
1059    use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1060    use netstack3_base::testutil::{
1061        FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1062    };
1063    use netstack3_base::{CounterContext, CtxPair, SendFrameError, SendableFrameMeta};
1064    use packet::ParsablePacket;
1065    use test_case::test_case;
1066
1067    use super::*;
1068
1069    type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1070    type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1071    type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1072
1073    /// A trait providing a shortcut to instantiate a [`DeviceSocketApi`] from a
1074    /// context.
1075    trait DeviceSocketApiExt: ContextPair + Sized {
1076        fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1077            DeviceSocketApi::new(self)
1078        }
1079    }
1080
1081    impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1082
1083    #[derive(Derivative)]
1084    #[derivative(Default(bound = ""))]
1085    struct FakeSockets<D: FakeStrongDeviceId> {
1086        any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1087        device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1088        all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1089        /// The stack-wide counters for device sockets.
1090        counters: DeviceSocketCounters,
1091        sent_frames: Vec<Vec<u8>>,
1092    }
1093
1094    /// Tuple of references
1095    pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1096        &'m mut AnyDevice,
1097        &'m mut AllSockets,
1098        &'m mut Devices,
1099        PhantomData<Device>,
1100        &'m DeviceSocketCounters,
1101    );
1102
1103    /// Helper trait to allow treating a `&mut self` as a
1104    /// [`FakeSocketsMutRefs`].
1105    pub trait AsFakeSocketsMutRefs {
1106        type AnyDevice: 'static;
1107        type AllSockets: 'static;
1108        type Devices: 'static;
1109        type Device: 'static;
1110        fn as_sockets_ref(
1111            &mut self,
1112        ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1113    }
1114
1115    impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1116        type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1117        type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1118        type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1119        type Device = D;
1120
1121        fn as_sockets_ref(
1122            &mut self,
1123        ) -> FakeSocketsMutRefs<
1124            '_,
1125            AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1126            AllSockets<D::Weak, FakeBindingsCtx>,
1127            HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1128            D,
1129        > {
1130            let FakeSockets {
1131                any_device_sockets,
1132                device_sockets,
1133                all_sockets,
1134                counters,
1135                sent_frames: _,
1136            } = &mut self.state;
1137            FakeSocketsMutRefs(
1138                any_device_sockets,
1139                all_sockets,
1140                device_sockets,
1141                PhantomData,
1142                counters,
1143            )
1144        }
1145    }
1146
1147    impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1148        AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1149    {
1150        type AnyDevice = AnyDevice;
1151        type AllSockets = AllSockets;
1152        type Devices = Devices;
1153        type Device = Device;
1154
1155        fn as_sockets_ref(
1156            &mut self,
1157        ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1158            let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1159            FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1160        }
1161    }
1162
1163    impl<D: Clone> TargetDevice<&D> {
1164        fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1165            match self {
1166                TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1167                TargetDevice::SpecificDevice(d) => {
1168                    TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1169                }
1170            }
1171        }
1172    }
1173
1174    impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1175        fn new(devices: impl IntoIterator<Item = D>) -> Self {
1176            let device_sockets =
1177                devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1178            Self {
1179                any_device_sockets: AnyDeviceSockets::default(),
1180                device_sockets,
1181                all_sockets: Default::default(),
1182                counters: Default::default(),
1183                sent_frames: Default::default(),
1184            }
1185        }
1186    }
1187
1188    impl<
1189            'm,
1190            DeviceId: FakeStrongDeviceId,
1191            As: AsFakeSocketsMutRefs
1192                + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1193        > SocketStateAccessor<FakeBindingsCtx> for As
1194    {
1195        fn with_socket_state<
1196            F: FnOnce(&ExternalSocketState<Self::WeakDeviceId>, &Target<Self::WeakDeviceId>) -> R,
1197            R,
1198        >(
1199            &mut self,
1200            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1201            cb: F,
1202        ) -> R {
1203            let DeviceSocketId(rc) = socket;
1204            // NB: Circumvent lock ordering for tests.
1205            let target = rc.target.lock();
1206            cb(&rc.external_state, &target)
1207        }
1208
1209        fn with_socket_state_mut<
1210            F: FnOnce(&ExternalSocketState<Self::WeakDeviceId>, &mut Target<Self::WeakDeviceId>) -> R,
1211            R,
1212        >(
1213            &mut self,
1214            socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1215            cb: F,
1216        ) -> R {
1217            let DeviceSocketId(rc) = socket;
1218            // NB: Circumvent lock ordering for tests.
1219            let mut target = rc.target.lock();
1220            cb(&rc.external_state, &mut target)
1221        }
1222    }
1223
1224    impl<
1225            'm,
1226            DeviceId: FakeStrongDeviceId,
1227            As: AsFakeSocketsMutRefs<
1228                    Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1229                > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1230        > DeviceSocketAccessor<FakeBindingsCtx> for As
1231    {
1232        type DeviceSocketCoreCtx<'a> =
1233            FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1234        fn with_device_sockets<
1235            F: FnOnce(
1236                &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1237                &mut Self::DeviceSocketCoreCtx<'_>,
1238            ) -> R,
1239            R,
1240        >(
1241            &mut self,
1242            device: &Self::DeviceId,
1243            cb: F,
1244        ) -> R {
1245            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1246                self.as_sockets_ref();
1247            let mut devices = device_sockets.keys().cloned().collect();
1248            let device = device_sockets.get(device).unwrap();
1249            cb(
1250                device,
1251                &mut FakeSocketsMutRefs(
1252                    any_device,
1253                    all_sockets,
1254                    &mut devices,
1255                    PhantomData,
1256                    counters,
1257                ),
1258            )
1259        }
1260        fn with_device_sockets_mut<
1261            F: FnOnce(
1262                &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1263                &mut Self::DeviceSocketCoreCtx<'_>,
1264            ) -> R,
1265            R,
1266        >(
1267            &mut self,
1268            device: &Self::DeviceId,
1269            cb: F,
1270        ) -> R {
1271            let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1272                self.as_sockets_ref();
1273            let mut devices = device_sockets.keys().cloned().collect();
1274            let device = device_sockets.get_mut(device).unwrap();
1275            cb(
1276                device,
1277                &mut FakeSocketsMutRefs(
1278                    any_device,
1279                    all_sockets,
1280                    &mut devices,
1281                    PhantomData,
1282                    counters,
1283                ),
1284            )
1285        }
1286    }
1287
1288    impl<
1289            'm,
1290            DeviceId: FakeStrongDeviceId,
1291            As: AsFakeSocketsMutRefs<
1292                    AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1293                    AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1294                    Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1295                > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1296        > DeviceSocketContext<FakeBindingsCtx> for As
1297    {
1298        type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1299            'a,
1300            (),
1301            (),
1302            HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1303            DeviceId,
1304        >;
1305
1306        fn with_any_device_sockets<
1307            F: FnOnce(
1308                &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1309                &mut Self::SocketTablesCoreCtx<'_>,
1310            ) -> R,
1311            R,
1312        >(
1313            &mut self,
1314            cb: F,
1315        ) -> R {
1316            let FakeSocketsMutRefs(
1317                any_device_sockets,
1318                _all_sockets,
1319                device_sockets,
1320                PhantomData,
1321                counters,
1322            ) = self.as_sockets_ref();
1323            cb(
1324                any_device_sockets,
1325                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1326            )
1327        }
1328        fn with_any_device_sockets_mut<
1329            F: FnOnce(
1330                &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1331                &mut Self::SocketTablesCoreCtx<'_>,
1332            ) -> R,
1333            R,
1334        >(
1335            &mut self,
1336            cb: F,
1337        ) -> R {
1338            let FakeSocketsMutRefs(
1339                any_device_sockets,
1340                _all_sockets,
1341                device_sockets,
1342                PhantomData,
1343                counters,
1344            ) = self.as_sockets_ref();
1345            cb(
1346                any_device_sockets,
1347                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1348            )
1349        }
1350
1351        fn with_all_device_sockets<
1352            F: FnOnce(
1353                &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1354                &mut Self::SocketTablesCoreCtx<'_>,
1355            ) -> R,
1356            R,
1357        >(
1358            &mut self,
1359            cb: F,
1360        ) -> R {
1361            let FakeSocketsMutRefs(
1362                _any_device_sockets,
1363                all_sockets,
1364                device_sockets,
1365                PhantomData,
1366                counters,
1367            ) = self.as_sockets_ref();
1368            cb(
1369                all_sockets,
1370                &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1371            )
1372        }
1373
1374        fn with_all_device_sockets_mut<
1375            F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1376            R,
1377        >(
1378            &mut self,
1379            cb: F,
1380        ) -> R {
1381            let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1382            cb(all_sockets)
1383        }
1384    }
1385
1386    impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1387        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1388    {
1389        type DeviceId = D;
1390        type WeakDeviceId = FakeWeakDeviceId<D>;
1391    }
1392
1393    impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1394        fn counters(&self) -> &DeviceSocketCounters {
1395            &self.state.counters
1396        }
1397    }
1398
1399    impl<D: FakeStrongDeviceId>
1400        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1401        for FakeCoreCtx<D>
1402    {
1403        fn per_resource_counters<'a>(
1404            &'a self,
1405            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1406        ) -> &'a DeviceSocketCounters {
1407            socket.counters()
1408        }
1409    }
1410
1411    impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1412        fn counters(&self) -> &DeviceSocketCounters {
1413            let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1414            counters
1415        }
1416    }
1417
1418    impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1419        ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1420        for FakeSocketsMutRefs<'m, X, Y, Z, D>
1421    {
1422        fn per_resource_counters<'a>(
1423            &'a self,
1424            socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1425        ) -> &'a DeviceSocketCounters {
1426            socket.counters()
1427        }
1428    }
1429
1430    const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1431
1432    #[test]
1433    fn create_remove() {
1434        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1435            MultipleDevicesId::all(),
1436        )));
1437        let mut api = ctx.device_socket_api();
1438
1439        let bound = api.create(Default::default());
1440        assert_eq!(
1441            api.get_info(&bound),
1442            SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1443        );
1444
1445        let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1446    }
1447
1448    #[test_case(TargetDevice::AnyDevice)]
1449    #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1450    fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1451        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1452            MultipleDevicesId::all(),
1453        )));
1454        let mut api = ctx.device_socket_api();
1455
1456        let bound = api.create(Default::default());
1457        api.set_device(&bound, device.clone());
1458        assert_eq!(
1459            api.get_info(&bound),
1460            SocketInfo { device: device.with_weak_id(), protocol: None }
1461        );
1462
1463        let device_sockets = &api.core_ctx().state.device_sockets;
1464        if let TargetDevice::SpecificDevice(d) = device {
1465            let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1466            assert_eq!(socket_ids, &HashSet::from([bound]));
1467        }
1468    }
1469
1470    #[test]
1471    fn update_device() {
1472        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1473            MultipleDevicesId::all(),
1474        )));
1475        let mut api = ctx.device_socket_api();
1476        let bound = api.create(Default::default());
1477
1478        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1479
1480        // Now update the device and make sure the socket only appears in the
1481        // one device's list.
1482        api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1483        assert_eq!(
1484            api.get_info(&bound),
1485            SocketInfo {
1486                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1487                protocol: None
1488            }
1489        );
1490
1491        let device_sockets = &api.core_ctx().state.device_sockets;
1492        let device_socket_lists = device_sockets
1493            .iter()
1494            .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1495            .collect::<HashMap<_, _>>();
1496
1497        assert_eq!(
1498            device_socket_lists,
1499            HashMap::from([
1500                (&MultipleDevicesId::A, vec![]),
1501                (&MultipleDevicesId::B, vec![&bound]),
1502                (&MultipleDevicesId::C, vec![])
1503            ])
1504        );
1505    }
1506
1507    #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1508    #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1509    #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1510    #[test_case(
1511        Protocol::Specific(SOME_PROTOCOL),
1512        TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1513    )]
1514    fn create_set_device_and_protocol_remove_multiple(
1515        protocol: Protocol,
1516        device: TargetDevice<&MultipleDevicesId>,
1517    ) {
1518        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1519            MultipleDevicesId::all(),
1520        )));
1521        let mut api = ctx.device_socket_api();
1522
1523        let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1524        for socket in &mut sockets {
1525            api.set_device_and_protocol(socket, device.clone(), protocol);
1526            assert_eq!(
1527                api.get_info(socket),
1528                SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1529            );
1530        }
1531
1532        for socket in sockets {
1533            let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1534        }
1535    }
1536
1537    #[test]
1538    fn change_device_after_removal() {
1539        let device_to_remove = FakeReferencyDeviceId::default();
1540        let device_to_maintain = FakeReferencyDeviceId::default();
1541        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1542            device_to_remove.clone(),
1543            device_to_maintain.clone(),
1544        ])));
1545        let mut api = ctx.device_socket_api();
1546
1547        let bound = api.create(Default::default());
1548        // Set the device for the socket before removing the device state
1549        // entirely.
1550        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1551
1552        // Now remove the device; this should cause future attempts to upgrade
1553        // the device ID to fail.
1554        device_to_remove.mark_removed();
1555
1556        // Changing the device should gracefully handle the fact that the
1557        // earlier-bound device is now gone.
1558        api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1559        assert_eq!(
1560            api.get_info(&bound),
1561            SocketInfo {
1562                device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1563                protocol: None,
1564            }
1565        );
1566
1567        let device_sockets = &api.core_ctx().state.device_sockets;
1568        let DeviceSockets(weak_sockets) =
1569            device_sockets.get(&device_to_maintain).expect("device state exists");
1570        assert_eq!(weak_sockets, &HashSet::from([bound]));
1571    }
1572
1573    struct TestData;
1574    impl TestData {
1575        const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1576        const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1577        /// Arbitrary protocol number.
1578        const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1579        const BODY: &'static [u8] = b"some pig";
1580        const BUFFER: &'static [u8] = &[
1581            6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1582            b'i', b'g',
1583        ];
1584
1585        /// Creates an EthernetFrame with the values specified above.
1586        fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1587            let mut buffer_view = Self::BUFFER;
1588            packet_formats::ethernet::EthernetFrame::parse(
1589                &mut buffer_view,
1590                EthernetFrameLengthCheck::NoCheck,
1591            )
1592            .unwrap()
1593        }
1594    }
1595
1596    const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1597
1598    fn make_bound<D: FakeStrongDeviceId>(
1599        ctx: &mut FakeCtx<D>,
1600        device: TargetDevice<D>,
1601        protocol: Option<Protocol>,
1602        state: ExternalSocketState<D::Weak>,
1603    ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1604        let mut api = ctx.device_socket_api();
1605        let id = api.create(state);
1606        let device = match &device {
1607            TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1608            TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1609        };
1610        match protocol {
1611            Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1612            None => api.set_device(&id, device),
1613        };
1614        id
1615    }
1616
1617    /// Deliver one frame to the provided contexts and return the IDs of the
1618    /// sockets it was delivered to.
1619    fn deliver_one_frame(
1620        delivered_frame: Frame<&[u8]>,
1621        FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1622    ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1623        DeviceSocketHandler::handle_frame(
1624            core_ctx,
1625            bindings_ctx,
1626            &MultipleDevicesId::A,
1627            delivered_frame.clone(),
1628            TestData::BUFFER,
1629        );
1630
1631        let FakeSockets {
1632            all_sockets: AllSockets(all_sockets),
1633            any_device_sockets: _,
1634            device_sockets: _,
1635            counters: _,
1636            sent_frames: _,
1637        } = &core_ctx.state;
1638
1639        all_sockets
1640            .iter()
1641            .filter_map(|(id, _primary)| {
1642                let DeviceSocketId(rc) = &id;
1643                let ExternalSocketState(frames) = &rc.external_state;
1644                let frames = frames.lock();
1645                (!frames.is_empty()).then(|| {
1646                    assert_eq!(
1647                        &*frames,
1648                        &[ReceivedFrame {
1649                            device: FakeWeakDeviceId(MultipleDevicesId::A),
1650                            frame: delivered_frame.cloned(),
1651                            raw: TestData::BUFFER.into(),
1652                        }]
1653                    );
1654                    id.clone()
1655                })
1656            })
1657            .collect()
1658    }
1659
1660    #[test]
1661    fn receive_frame_deliver_to_multiple() {
1662        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1663            MultipleDevicesId::all(),
1664        )));
1665
1666        use Protocol::*;
1667        use TargetDevice::*;
1668        let never_bound = {
1669            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1670            ctx.device_socket_api().create(state)
1671        };
1672
1673        let mut make_bound = |device, protocol| {
1674            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1675            make_bound(&mut ctx, device, protocol, state)
1676        };
1677        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1678        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1679        let bound_a_right_protocol =
1680            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1681        let bound_a_wrong_protocol =
1682            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1683        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1684        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1685        let bound_b_right_protocol =
1686            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1687        let bound_b_wrong_protocol =
1688            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1689        let bound_any_no_protocol = make_bound(AnyDevice, None);
1690        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1691        let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1692        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1693
1694        let mut sockets_with_received_frames = deliver_one_frame(
1695            super::ReceivedFrame::from_ethernet(
1696                TestData::frame(),
1697                FrameDestination::Individual { local: true },
1698            )
1699            .into(),
1700            &mut ctx,
1701        );
1702
1703        let sockets_not_expecting_frames = [
1704            never_bound,
1705            bound_a_no_protocol,
1706            bound_a_wrong_protocol,
1707            bound_b_no_protocol,
1708            bound_b_all_protocols,
1709            bound_b_right_protocol,
1710            bound_b_wrong_protocol,
1711            bound_any_no_protocol,
1712            bound_any_wrong_protocol,
1713        ];
1714        let sockets_expecting_frames = [
1715            bound_a_all_protocols,
1716            bound_a_right_protocol,
1717            bound_any_all_protocols,
1718            bound_any_right_protocol,
1719        ];
1720
1721        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1722            assert!(
1723                sockets_with_received_frames.remove(&socket),
1724                "socket {n} didn't receive the frame"
1725            );
1726        }
1727        assert!(sockets_with_received_frames.is_empty());
1728
1729        // Verify Counters were set appropriately for each socket.
1730        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1731            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1732        }
1733        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1734            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1735        }
1736    }
1737
1738    #[test]
1739    fn sent_frame_deliver_to_multiple() {
1740        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1741            MultipleDevicesId::all(),
1742        )));
1743
1744        use Protocol::*;
1745        use TargetDevice::*;
1746        let never_bound = {
1747            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1748            ctx.device_socket_api().create(state)
1749        };
1750
1751        let mut make_bound = |device, protocol| {
1752            let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1753            make_bound(&mut ctx, device, protocol, state)
1754        };
1755        let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1756        let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1757        let bound_a_same_protocol =
1758            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1759        let bound_a_wrong_protocol =
1760            make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1761        let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1762        let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1763        let bound_b_same_protocol =
1764            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1765        let bound_b_wrong_protocol =
1766            make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1767        let bound_any_no_protocol = make_bound(AnyDevice, None);
1768        let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1769        let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1770        let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1771
1772        let mut sockets_with_received_frames =
1773            deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1774
1775        let sockets_not_expecting_frames = [
1776            never_bound,
1777            bound_a_no_protocol,
1778            bound_a_same_protocol,
1779            bound_a_wrong_protocol,
1780            bound_b_no_protocol,
1781            bound_b_all_protocols,
1782            bound_b_same_protocol,
1783            bound_b_wrong_protocol,
1784            bound_any_no_protocol,
1785            bound_any_same_protocol,
1786            bound_any_wrong_protocol,
1787        ];
1788        // Only any-protocol sockets receive sent frames.
1789        let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1790
1791        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1792            assert!(
1793                sockets_with_received_frames.remove(&socket),
1794                "socket {n} didn't receive the frame"
1795            );
1796        }
1797        assert!(sockets_with_received_frames.is_empty());
1798
1799        // Verify Counters were set appropriately for each socket.
1800        for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1801            assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1802        }
1803        for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1804            assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1805        }
1806    }
1807
1808    #[test]
1809    fn deliver_multiple_frames() {
1810        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1811            MultipleDevicesId::all(),
1812        )));
1813        let socket = make_bound(
1814            &mut ctx,
1815            TargetDevice::AnyDevice,
1816            Some(Protocol::All),
1817            ExternalSocketState::default(),
1818        );
1819        let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1820
1821        const RECEIVE_COUNT: usize = 10;
1822        for _ in 0..RECEIVE_COUNT {
1823            DeviceSocketHandler::handle_frame(
1824                &mut core_ctx,
1825                &mut bindings_ctx,
1826                &MultipleDevicesId::A,
1827                super::ReceivedFrame::from_ethernet(
1828                    TestData::frame(),
1829                    FrameDestination::Individual { local: true },
1830                )
1831                .into(),
1832                TestData::BUFFER,
1833            );
1834        }
1835
1836        let FakeSockets {
1837            all_sockets: AllSockets(mut all_sockets),
1838            any_device_sockets: _,
1839            device_sockets: _,
1840            counters: _,
1841            sent_frames: _,
1842        } = core_ctx.into_state();
1843        let primary = all_sockets.remove(&socket).unwrap();
1844        let PrimaryDeviceSocketId(primary) = primary;
1845        assert!(all_sockets.is_empty());
1846        drop(socket);
1847        let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1848            PrimaryRc::unwrap(primary);
1849        assert_eq!(
1850            received.into_inner(),
1851            vec![
1852                ReceivedFrame {
1853                    device: FakeWeakDeviceId(MultipleDevicesId::A),
1854                    frame: Frame::Received(super::ReceivedFrame::Ethernet {
1855                        destination: FrameDestination::Individual { local: true },
1856                        frame: EthernetFrame {
1857                            src_mac: TestData::SRC_MAC,
1858                            dst_mac: TestData::DST_MAC,
1859                            ethertype: Some(TestData::PROTO.get().into()),
1860                            body: Vec::from(TestData::BODY),
1861                        }
1862                    }),
1863                    raw: TestData::BUFFER.into()
1864                };
1865                RECEIVE_COUNT
1866            ]
1867        );
1868        assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1869    }
1870
1871    pub struct FakeSendMetadata;
1872    impl DeviceSocketSendTypes for AnyDevice {
1873        type Metadata = FakeSendMetadata;
1874    }
1875    impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1876        for DeviceSocketMetadata<AnyDevice, D>
1877    {
1878        fn send_meta<S>(
1879            self,
1880            core_ctx: &mut FakeCoreCtx<D>,
1881            _bindings_ctx: &mut BC,
1882            frame: S,
1883        ) -> Result<(), SendFrameError<S>>
1884        where
1885            S: packet::Serializer,
1886            S::Buffer: packet::BufferMut,
1887        {
1888            let frame = match frame.serialize_vec_outer() {
1889                Err(e) => {
1890                    let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1891                    unreachable!()
1892                }
1893                Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1894            };
1895            core_ctx.state.sent_frames.push(frame);
1896            Ok(())
1897        }
1898    }
1899
1900    #[test]
1901    fn send_multiple_frames() {
1902        let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1903            MultipleDevicesId::all(),
1904        )));
1905
1906        const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1907        let socket = make_bound(
1908            &mut ctx,
1909            TargetDevice::SpecificDevice(DEVICE),
1910            Some(Protocol::All),
1911            ExternalSocketState::default(),
1912        );
1913        let mut api = ctx.device_socket_api();
1914
1915        const SEND_COUNT: usize = 10;
1916        const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
1917        for _ in 0..SEND_COUNT {
1918            let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
1919            api.send_frame(
1920                &socket,
1921                DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
1922                buf,
1923            )
1924            .expect("send failed");
1925        }
1926
1927        assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
1928
1929        assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
1930    }
1931}