1use alloc::collections::{HashMap, HashSet};
8use core::fmt::Debug;
9use core::hash::Hash;
10use core::num::NonZeroU16;
11
12use derivative::Derivative;
13use lock_order::lock::{OrderedLockAccess, OrderedLockRef};
14use net_types::ethernet::Mac;
15use net_types::ip::IpVersion;
16use netstack3_base::sync::{Mutex, PrimaryRc, RwLock, StrongRc, WeakRc};
17use netstack3_base::{
18 AnyDevice, ContextPair, Counter, Device, DeviceIdContext, FrameDestination, Inspectable,
19 Inspector, InspectorDeviceExt, InspectorExt, ReferenceNotifiers, ReferenceNotifiersExt as _,
20 RemoveResourceResultWithContext, ResourceCounterContext, SendFrameContext,
21 SendFrameErrorReason, StrongDeviceIdentifier, WeakDeviceIdentifier as _,
22};
23use packet::{BufferMut, ParsablePacket as _, Serializer};
24use packet_formats::error::ParseError;
25use packet_formats::ethernet::{EtherType, EthernetFrameLengthCheck};
26
27use crate::internal::base::DeviceLayerTypes;
28use crate::internal::id::WeakDeviceId;
29
30#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
32pub enum Protocol {
33 All,
35 Specific(NonZeroU16),
37}
38
39#[derive(Clone, Debug, Derivative, Eq, Hash, PartialEq)]
41#[derivative(Default(bound = ""))]
42pub enum TargetDevice<D> {
43 #[derivative(Default)]
45 AnyDevice,
46 SpecificDevice(D),
48}
49
50#[derive(Debug)]
52#[cfg_attr(test, derive(PartialEq))]
53pub struct SocketInfo<D> {
54 pub protocol: Option<Protocol>,
56 pub device: TargetDevice<D>,
58}
59
60pub trait DeviceSocketTypes {
63 type SocketState<D: Send + Sync + Debug>: Send + Sync + Debug;
65}
66
67pub trait DeviceSocketBindingsContext<DeviceId: StrongDeviceIdentifier>: DeviceSocketTypes {
69 fn receive_frame(
73 &self,
74 socket: &Self::SocketState<DeviceId::Weak>,
75 device: &DeviceId,
76 frame: Frame<&[u8]>,
77 raw_frame: &[u8],
78 );
79}
80
81#[derive(Debug)]
85pub struct PrimaryDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
86 PrimaryRc<SocketState<D, BT>>,
87);
88
89impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> PrimaryDeviceSocketId<D, BT> {
90 fn new(external_state: BT::SocketState<D>) -> Self {
92 Self(PrimaryRc::new(SocketState {
93 external_state,
94 counters: Default::default(),
95 target: Default::default(),
96 }))
97 }
98
99 fn clone_strong(&self) -> DeviceSocketId<D, BT> {
101 let PrimaryDeviceSocketId(rc) = self;
102 DeviceSocketId(PrimaryRc::clone_strong(rc))
103 }
104}
105
106#[derive(Derivative)]
111#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
112pub struct DeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
113 StrongRc<SocketState<D, BT>>,
114);
115
116impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for DeviceSocketId<D, BT> {
117 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
118 let Self(rc) = self;
119 f.debug_tuple("DeviceSocketId").field(&StrongRc::debug_id(rc)).finish()
120 }
121}
122
123impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<Target<D>>
124 for DeviceSocketId<D, BT>
125{
126 type Lock = Mutex<Target<D>>;
127 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
128 let Self(rc) = self;
129 OrderedLockRef::new(&rc.target)
130 }
131}
132
133#[derive(Derivative)]
138#[derivative(Clone(bound = ""), Hash(bound = ""), Eq(bound = ""), PartialEq(bound = ""))]
139pub struct WeakDeviceSocketId<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
140 WeakRc<SocketState<D, BT>>,
141);
142
143impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> Debug for WeakDeviceSocketId<D, BT> {
144 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145 let Self(rc) = self;
146 f.debug_tuple("WeakDeviceSocketId").field(&WeakRc::debug_id(rc)).finish()
147 }
148}
149
150#[derive(Derivative)]
152#[derivative(Default(bound = ""))]
153pub struct Sockets<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
154 any_device_sockets: RwLock<AnyDeviceSockets<D, BT>>,
157
158 all_sockets: RwLock<AllSockets<D, BT>>,
165}
166
167#[derive(Derivative)]
169#[derivative(Default(bound = ""))]
170pub struct AnyDeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
171 HashSet<DeviceSocketId<D, BT>>,
172);
173
174#[derive(Derivative)]
176#[derivative(Default(bound = ""))]
177pub struct AllSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
178 HashMap<DeviceSocketId<D, BT>, PrimaryDeviceSocketId<D, BT>>,
179);
180
181#[derive(Debug)]
183pub struct SocketState<D: Send + Sync + Debug, BT: DeviceSocketTypes> {
184 pub external_state: BT::SocketState<D>,
186 target: Mutex<Target<D>>,
190 counters: DeviceSocketCounters,
192}
193
194#[derive(Debug, Derivative)]
196#[derivative(Default(bound = ""))]
197pub struct Target<D> {
198 protocol: Option<Protocol>,
199 device: TargetDevice<D>,
200}
201
202#[derive(Derivative)]
207#[derivative(Default(bound = ""))]
208#[cfg_attr(
209 test,
210 derivative(Debug, PartialEq(bound = "BT::SocketState<D>: Hash + Eq, D: Hash + Eq"))
211)]
212pub struct DeviceSockets<D: Send + Sync + Debug, BT: DeviceSocketTypes>(
213 HashSet<DeviceSocketId<D, BT>>,
214);
215
216pub type HeldDeviceSockets<BT> = DeviceSockets<WeakDeviceId<BT>, BT>;
218
219pub type HeldSockets<BT> = Sockets<WeakDeviceId<BT>, BT>;
223
224pub trait DeviceSocketContext<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
226 type SocketTablesCoreCtx<'a>: DeviceSocketAccessor<
228 BT,
229 DeviceId = Self::DeviceId,
230 WeakDeviceId = Self::WeakDeviceId,
231 >;
232
233 fn with_all_device_sockets<
236 F: FnOnce(&AllSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
237 R,
238 >(
239 &mut self,
240 cb: F,
241 ) -> R;
242
243 fn with_all_device_sockets_mut<F: FnOnce(&mut AllSockets<Self::WeakDeviceId, BT>) -> R, R>(
246 &mut self,
247 cb: F,
248 ) -> R;
249
250 fn with_any_device_sockets<
252 F: FnOnce(&AnyDeviceSockets<Self::WeakDeviceId, BT>, &mut Self::SocketTablesCoreCtx<'_>) -> R,
253 R,
254 >(
255 &mut self,
256 cb: F,
257 ) -> R;
258
259 fn with_any_device_sockets_mut<
261 F: FnOnce(
262 &mut AnyDeviceSockets<Self::WeakDeviceId, BT>,
263 &mut Self::SocketTablesCoreCtx<'_>,
264 ) -> R,
265 R,
266 >(
267 &mut self,
268 cb: F,
269 ) -> R;
270}
271
272pub trait SocketStateAccessor<BT: DeviceSocketTypes>: DeviceIdContext<AnyDevice> {
274 fn with_socket_state<
276 F: FnOnce(&BT::SocketState<Self::WeakDeviceId>, &Target<Self::WeakDeviceId>) -> R,
277 R,
278 >(
279 &mut self,
280 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
281 cb: F,
282 ) -> R;
283
284 fn with_socket_state_mut<
286 F: FnOnce(&BT::SocketState<Self::WeakDeviceId>, &mut Target<Self::WeakDeviceId>) -> R,
287 R,
288 >(
289 &mut self,
290 socket: &DeviceSocketId<Self::WeakDeviceId, BT>,
291 cb: F,
292 ) -> R;
293}
294
295pub trait DeviceSocketAccessor<BT: DeviceSocketTypes>: SocketStateAccessor<BT> {
297 type DeviceSocketCoreCtx<'a>: SocketStateAccessor<BT, DeviceId = Self::DeviceId, WeakDeviceId = Self::WeakDeviceId>
299 + ResourceCounterContext<DeviceSocketId<Self::WeakDeviceId, BT>, DeviceSocketCounters>;
300
301 fn with_device_sockets<
304 F: FnOnce(&DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
305 R,
306 >(
307 &mut self,
308 device: &Self::DeviceId,
309 cb: F,
310 ) -> R;
311
312 fn with_device_sockets_mut<
315 F: FnOnce(&mut DeviceSockets<Self::WeakDeviceId, BT>, &mut Self::DeviceSocketCoreCtx<'_>) -> R,
316 R,
317 >(
318 &mut self,
319 device: &Self::DeviceId,
320 cb: F,
321 ) -> R;
322}
323
324enum MaybeUpdate<T> {
325 NoChange,
326 NewValue(T),
327}
328
329fn update_device_and_protocol<CC: DeviceSocketContext<BT>, BT: DeviceSocketTypes>(
330 core_ctx: &mut CC,
331 socket: &DeviceSocketId<CC::WeakDeviceId, BT>,
332 new_device: TargetDevice<&CC::DeviceId>,
333 protocol_update: MaybeUpdate<Protocol>,
334) {
335 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
336 let old_device = core_ctx.with_socket_state_mut(
342 socket,
343 |_: &BT::SocketState<CC::WeakDeviceId>, Target { protocol, device }| {
344 match protocol_update {
345 MaybeUpdate::NewValue(p) => *protocol = Some(p),
346 MaybeUpdate::NoChange => (),
347 };
348 let old_device = match &device {
349 TargetDevice::SpecificDevice(device) => device.upgrade(),
350 TargetDevice::AnyDevice => {
351 assert!(any_device_sockets.remove(socket));
352 None
353 }
354 };
355 *device = match &new_device {
356 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
357 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d.downgrade()),
358 };
359 old_device
360 },
361 );
362
363 if let Some(device) = old_device {
369 core_ctx.with_device_sockets_mut(
372 &device,
373 |DeviceSockets(device_sockets), _core_ctx| {
374 assert!(device_sockets.remove(socket), "socket not found in device state");
375 },
376 );
377 }
378
379 match &new_device {
381 TargetDevice::SpecificDevice(new_device) => core_ctx.with_device_sockets_mut(
382 new_device,
383 |DeviceSockets(device_sockets), _core_ctx| {
384 assert!(device_sockets.insert(socket.clone()));
385 },
386 ),
387 TargetDevice::AnyDevice => {
388 assert!(any_device_sockets.insert(socket.clone()))
389 }
390 }
391 })
392}
393
394pub struct DeviceSocketApi<C>(C);
396
397impl<C> DeviceSocketApi<C> {
398 pub fn new(ctx: C) -> Self {
400 Self(ctx)
401 }
402}
403
404type ApiSocketId<C> = DeviceSocketId<
409 <<C as ContextPair>::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
410 <C as ContextPair>::BindingsContext,
411>;
412
413impl<C> DeviceSocketApi<C>
414where
415 C: ContextPair,
416 C::CoreContext: DeviceSocketContext<C::BindingsContext>
417 + SocketStateAccessor<C::BindingsContext>
418 + ResourceCounterContext<ApiSocketId<C>, DeviceSocketCounters>,
419 C::BindingsContext: DeviceSocketBindingsContext<<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>
420 + ReferenceNotifiers
421 + 'static,
422{
423 fn core_ctx(&mut self) -> &mut C::CoreContext {
424 let Self(pair) = self;
425 pair.core_ctx()
426 }
427
428 fn contexts(&mut self) -> (&mut C::CoreContext, &mut C::BindingsContext) {
429 let Self(pair) = self;
430 pair.contexts()
431 }
432
433 pub fn create(
435 &mut self,
436 external_state: <C::BindingsContext as DeviceSocketTypes>::SocketState<
437 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
438 >,
439 ) -> ApiSocketId<C> {
440 let core_ctx = self.core_ctx();
441
442 let strong = core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
443 let primary = PrimaryDeviceSocketId::new(external_state);
444 let strong = primary.clone_strong();
445 assert!(sockets.insert(strong.clone(), primary).is_none());
446 strong
447 });
448 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), _core_ctx| {
449 assert!(any_device_sockets.insert(strong.clone()));
456 });
457 strong
458 }
459
460 pub fn set_device(
462 &mut self,
463 socket: &ApiSocketId<C>,
464 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
465 ) {
466 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NoChange)
467 }
468
469 pub fn set_device_and_protocol(
471 &mut self,
472 socket: &ApiSocketId<C>,
473 device: TargetDevice<&<C::CoreContext as DeviceIdContext<AnyDevice>>::DeviceId>,
474 protocol: Protocol,
475 ) {
476 update_device_and_protocol(self.core_ctx(), socket, device, MaybeUpdate::NewValue(protocol))
477 }
478
479 pub fn get_info(
481 &mut self,
482 id: &ApiSocketId<C>,
483 ) -> SocketInfo<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId> {
484 self.core_ctx().with_socket_state(id, |_external_state, Target { device, protocol }| {
485 SocketInfo { device: device.clone(), protocol: *protocol }
486 })
487 }
488
489 pub fn remove(
491 &mut self,
492 id: ApiSocketId<C>,
493 ) -> RemoveResourceResultWithContext<
494 <C::BindingsContext as DeviceSocketTypes>::SocketState<
495 <C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId,
496 >,
497 C::BindingsContext,
498 > {
499 let core_ctx = self.core_ctx();
500 core_ctx.with_any_device_sockets_mut(|AnyDeviceSockets(any_device_sockets), core_ctx| {
501 let old_device = core_ctx.with_socket_state_mut(&id, |_external_state, target| {
502 let Target { device, protocol: _ } = target;
503 match &device {
504 TargetDevice::SpecificDevice(device) => device.upgrade(),
505 TargetDevice::AnyDevice => {
506 assert!(any_device_sockets.remove(&id));
507 None
508 }
509 }
510 });
511 if let Some(device) = old_device {
512 core_ctx.with_device_sockets_mut(
513 &device,
514 |DeviceSockets(device_sockets), _core_ctx| {
515 assert!(device_sockets.remove(&id), "device doesn't have socket");
516 },
517 )
518 }
519 });
520
521 core_ctx.with_all_device_sockets_mut(|AllSockets(sockets)| {
522 let primary = sockets
523 .remove(&id)
524 .unwrap_or_else(|| panic!("{id:?} not present in all socket map"));
525 drop(id);
528
529 let PrimaryDeviceSocketId(primary) = primary;
530 C::BindingsContext::unwrap_or_notify_with_new_reference_notifier(
531 primary,
532 |SocketState { external_state, counters: _, target: _ }| external_state,
533 )
534 })
535 }
536
537 pub fn send_frame<S, D>(
539 &mut self,
540 id: &ApiSocketId<C>,
541 metadata: DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
542 body: S,
543 ) -> Result<(), SendFrameErrorReason>
544 where
545 S: Serializer,
546 S::Buffer: BufferMut,
547 D: DeviceSocketSendTypes,
548 C::CoreContext: DeviceIdContext<D>
549 + SendFrameContext<
550 C::BindingsContext,
551 DeviceSocketMetadata<D, <C::CoreContext as DeviceIdContext<D>>::DeviceId>,
552 >,
553 C::BindingsContext: DeviceLayerTypes,
554 {
555 let (core_ctx, bindings_ctx) = self.contexts();
556 let result = core_ctx.send_frame(bindings_ctx, metadata, body).map_err(|e| e.into_err());
557 match &result {
558 Ok(()) => {
559 core_ctx.increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_frames)
560 }
561 Err(SendFrameErrorReason::QueueFull) => core_ctx
562 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_queue_full),
563 Err(SendFrameErrorReason::Alloc) => core_ctx
564 .increment_both(id, |counters: &DeviceSocketCounters| &counters.tx_err_alloc),
565 Err(SendFrameErrorReason::SizeConstraintsViolation) => core_ctx
566 .increment_both(id, |counters: &DeviceSocketCounters| {
567 &counters.tx_err_size_constraint
568 }),
569 }
570 result
571 }
572
573 pub fn inspect<N>(&mut self, inspector: &mut N)
575 where
576 N: Inspector
577 + InspectorDeviceExt<<C::CoreContext as DeviceIdContext<AnyDevice>>::WeakDeviceId>,
578 {
579 self.core_ctx().with_all_device_sockets(|AllSockets(sockets), core_ctx| {
580 sockets.keys().for_each(|socket| {
581 inspector.record_debug_child(socket, |node| {
582 core_ctx.with_socket_state(
583 socket,
584 |_external_state, Target { protocol, device }| {
585 node.record_debug("Protocol", protocol);
586 match device {
587 TargetDevice::AnyDevice => node.record_str("Device", "Any"),
588 TargetDevice::SpecificDevice(d) => {
589 N::record_device(node, "Device", d)
590 }
591 }
592 },
593 );
594 node.record_child("Counters", |node| {
595 node.delegate_inspectable(socket.counters())
596 })
597 })
598 })
599 })
600 }
601}
602
603pub trait DeviceSocketSendTypes: Device {
605 type Metadata;
607}
608
609#[derive(Debug, PartialEq)]
611pub struct DeviceSocketMetadata<D: DeviceSocketSendTypes, DeviceId> {
612 pub device_id: DeviceId,
614 pub metadata: D::Metadata,
616 }
619
620#[derive(Debug, PartialEq)]
622pub struct EthernetHeaderParams {
623 pub dest_addr: Mac,
625 pub protocol: EtherType,
627}
628
629pub type SocketId<BC> = DeviceSocketId<WeakDeviceId<BC>, BC>;
634
635impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> DeviceSocketId<D, BT> {
636 pub fn socket_state(&self) -> &BT::SocketState<D> {
639 let Self(strong) = self;
640 let SocketState { external_state, counters: _, target: _ } = &**strong;
641 external_state
642 }
643
644 pub fn downgrade(&self) -> WeakDeviceSocketId<D, BT> {
646 let Self(inner) = self;
647 WeakDeviceSocketId(StrongRc::downgrade(inner))
648 }
649
650 pub fn counters(&self) -> &DeviceSocketCounters {
652 let Self(strong) = self;
653 let SocketState { external_state: _, counters, target: _ } = &**strong;
654 counters
655 }
656}
657
658pub trait DeviceSocketHandler<D: Device, BC>: DeviceIdContext<D> {
663 fn handle_frame(
665 &mut self,
666 bindings_ctx: &mut BC,
667 device: &Self::DeviceId,
668 frame: Frame<&[u8]>,
669 whole_frame: &[u8],
670 );
671}
672
673#[derive(Clone, Copy, Debug, Eq, PartialEq)]
675pub enum ReceivedFrame<B> {
676 Ethernet {
678 destination: FrameDestination,
680 frame: EthernetFrame<B>,
682 },
683 Ip(IpFrame<B>),
688}
689
690#[derive(Clone, Copy, Debug, Eq, PartialEq)]
692pub enum SentFrame<B> {
693 Ethernet(EthernetFrame<B>),
695 Ip(IpFrame<B>),
700}
701
702#[derive(Debug)]
704pub struct ParseSentFrameError;
705
706impl SentFrame<&[u8]> {
707 pub fn try_parse_as_ethernet(mut buf: &[u8]) -> Result<SentFrame<&[u8]>, ParseSentFrameError> {
709 packet_formats::ethernet::EthernetFrame::parse(&mut buf, EthernetFrameLengthCheck::NoCheck)
710 .map_err(|_: ParseError| ParseSentFrameError)
711 .map(|frame| SentFrame::Ethernet(frame.into()))
712 }
713}
714
715#[derive(Clone, Copy, Debug, Eq, PartialEq)]
717pub struct EthernetFrame<B> {
718 pub src_mac: Mac,
720 pub dst_mac: Mac,
722 pub ethertype: Option<EtherType>,
724 pub body: B,
726}
727
728#[derive(Clone, Copy, Debug, Eq, PartialEq)]
730pub struct IpFrame<B> {
731 pub ip_version: IpVersion,
733 pub body: B,
735}
736
737impl<B> IpFrame<B> {
738 fn ethertype(&self) -> EtherType {
739 let IpFrame { ip_version, body: _ } = self;
740 EtherType::from_ip_version(*ip_version)
741 }
742}
743
744#[derive(Clone, Copy, Debug, Eq, PartialEq)]
746pub enum Frame<B> {
747 Sent(SentFrame<B>),
749 Received(ReceivedFrame<B>),
751}
752
753impl<B> From<SentFrame<B>> for Frame<B> {
754 fn from(value: SentFrame<B>) -> Self {
755 Self::Sent(value)
756 }
757}
758
759impl<B> From<ReceivedFrame<B>> for Frame<B> {
760 fn from(value: ReceivedFrame<B>) -> Self {
761 Self::Received(value)
762 }
763}
764
765impl<'a> From<packet_formats::ethernet::EthernetFrame<&'a [u8]>> for EthernetFrame<&'a [u8]> {
766 fn from(frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>) -> Self {
767 Self {
768 src_mac: frame.src_mac(),
769 dst_mac: frame.dst_mac(),
770 ethertype: frame.ethertype(),
771 body: frame.into_body(),
772 }
773 }
774}
775
776impl<'a> ReceivedFrame<&'a [u8]> {
777 pub(crate) fn from_ethernet(
778 frame: packet_formats::ethernet::EthernetFrame<&'a [u8]>,
779 destination: FrameDestination,
780 ) -> Self {
781 Self::Ethernet { destination, frame: frame.into() }
782 }
783}
784
785impl<B> Frame<B> {
786 pub fn protocol(&self) -> Option<u16> {
788 let ethertype = match self {
789 Self::Sent(SentFrame::Ethernet(frame))
790 | Self::Received(ReceivedFrame::Ethernet { destination: _, frame }) => frame.ethertype,
791 Self::Sent(SentFrame::Ip(frame)) | Self::Received(ReceivedFrame::Ip(frame)) => {
792 Some(frame.ethertype())
793 }
794 };
795 ethertype.map(Into::into)
796 }
797
798 pub fn into_body(self) -> B {
800 match self {
801 Self::Received(ReceivedFrame::Ethernet { destination: _, frame })
802 | Self::Sent(SentFrame::Ethernet(frame)) => frame.body,
803 Self::Received(ReceivedFrame::Ip(frame)) | Self::Sent(SentFrame::Ip(frame)) => {
804 frame.body
805 }
806 }
807 }
808}
809
810impl<
811 D: Device,
812 BC: DeviceSocketBindingsContext<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
813 CC: DeviceSocketContext<BC> + DeviceIdContext<D>,
814 > DeviceSocketHandler<D, BC> for CC
815where
816 <CC as DeviceIdContext<D>>::DeviceId: Into<<CC as DeviceIdContext<AnyDevice>>::DeviceId>,
817{
818 fn handle_frame(
819 &mut self,
820 bindings_ctx: &mut BC,
821 device: &Self::DeviceId,
822 frame: Frame<&[u8]>,
823 whole_frame: &[u8],
824 ) {
825 let device = device.clone().into();
826
827 self.with_any_device_sockets(|AnyDeviceSockets(any_device_sockets), core_ctx| {
831 core_ctx.with_device_sockets(&device, |DeviceSockets(device_sockets), core_ctx| {
846 for socket in any_device_sockets.iter().chain(device_sockets) {
847 let delivered = core_ctx.with_socket_state(
848 socket,
849 |external_state, Target { protocol, device: _ }| {
850 let should_deliver = match protocol {
851 None => false,
852 Some(p) => match p {
853 Protocol::Specific(p) => match frame {
857 Frame::Received(_) => Some(p.get()) == frame.protocol(),
858 Frame::Sent(_) => false,
859 },
860 Protocol::All => true,
861 },
862 };
863 if should_deliver {
864 bindings_ctx.receive_frame(
865 external_state,
866 &device,
867 frame,
868 whole_frame,
869 )
870 }
871 should_deliver
872 },
873 );
874 if delivered {
875 core_ctx.increment_both(socket, |counters: &DeviceSocketCounters| {
876 &counters.rx_frames
877 });
878 }
879 }
880 })
881 })
882 }
883}
884
885#[derive(Debug, Default)]
889pub struct DeviceSocketCounters {
890 rx_frames: Counter,
896 tx_frames: Counter,
898 tx_err_queue_full: Counter,
900 tx_err_alloc: Counter,
902 tx_err_size_constraint: Counter,
904}
905
906impl Inspectable for DeviceSocketCounters {
907 fn record<I: Inspector>(&self, inspector: &mut I) {
908 let Self { rx_frames, tx_frames, tx_err_queue_full, tx_err_alloc, tx_err_size_constraint } =
909 self;
910 inspector.record_child("Rx", |inspector| {
911 inspector.record_counter("DeliveredFrames", rx_frames);
912 });
913 inspector.record_child("Tx", |inspector| {
914 inspector.record_counter("SentFrames", tx_frames);
915 inspector.record_counter("QueueFullError", tx_err_queue_full);
916 inspector.record_counter("AllocError", tx_err_alloc);
917 inspector.record_counter("SizeConstraintError", tx_err_size_constraint);
918 });
919 }
920}
921
922impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AnyDeviceSockets<D, BT>>
923 for Sockets<D, BT>
924{
925 type Lock = RwLock<AnyDeviceSockets<D, BT>>;
926 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
927 OrderedLockRef::new(&self.any_device_sockets)
928 }
929}
930
931impl<D: Send + Sync + Debug, BT: DeviceSocketTypes> OrderedLockAccess<AllSockets<D, BT>>
932 for Sockets<D, BT>
933{
934 type Lock = RwLock<AllSockets<D, BT>>;
935 fn ordered_lock_access(&self) -> OrderedLockRef<'_, Self::Lock> {
936 OrderedLockRef::new(&self.all_sockets)
937 }
938}
939
940#[cfg(any(test, feature = "testutils"))]
941mod testutil {
942 use alloc::vec::Vec;
943 use core::num::NonZeroU64;
944 use netstack3_base::testutil::{FakeBindingsCtx, MonotonicIdentifier};
945 use netstack3_base::StrongDeviceIdentifier;
946
947 use super::*;
948 use crate::internal::base::{
949 DeviceClassMatcher, DeviceIdAndNameMatcher, DeviceLayerStateTypes,
950 };
951
952 #[derive(Clone, Debug, PartialEq)]
953 pub struct ReceivedFrame<D> {
954 pub device: D,
955 pub frame: Frame<Vec<u8>>,
956 pub raw: Vec<u8>,
957 }
958
959 #[derive(Debug, Derivative)]
960 #[derivative(Default(bound = ""))]
961 pub struct ExternalSocketState<D>(pub Mutex<Vec<ReceivedFrame<D>>>);
962
963 impl<TimerId, Event: Debug, State> DeviceSocketTypes
964 for FakeBindingsCtx<TimerId, Event, State, ()>
965 {
966 type SocketState<D: Send + Sync + Debug> = ExternalSocketState<D>;
967 }
968
969 impl Frame<&[u8]> {
970 pub(crate) fn cloned(self) -> Frame<Vec<u8>> {
971 match self {
972 Self::Sent(SentFrame::Ethernet(frame)) => {
973 Frame::Sent(SentFrame::Ethernet(frame.cloned()))
974 }
975 Self::Received(super::ReceivedFrame::Ethernet { destination, frame }) => {
976 Frame::Received(super::ReceivedFrame::Ethernet {
977 destination,
978 frame: frame.cloned(),
979 })
980 }
981 Self::Sent(SentFrame::Ip(frame)) => Frame::Sent(SentFrame::Ip(frame.cloned())),
982 Self::Received(super::ReceivedFrame::Ip(frame)) => {
983 Frame::Received(super::ReceivedFrame::Ip(frame.cloned()))
984 }
985 }
986 }
987 }
988
989 impl EthernetFrame<&[u8]> {
990 fn cloned(self) -> EthernetFrame<Vec<u8>> {
991 let Self { src_mac, dst_mac, ethertype, body } = self;
992 EthernetFrame { src_mac, dst_mac, ethertype, body: Vec::from(body) }
993 }
994 }
995
996 impl IpFrame<&[u8]> {
997 fn cloned(self) -> IpFrame<Vec<u8>> {
998 let Self { ip_version, body } = self;
999 IpFrame { ip_version, body: Vec::from(body) }
1000 }
1001 }
1002
1003 impl<TimerId, Event: Debug, State, D: StrongDeviceIdentifier> DeviceSocketBindingsContext<D>
1004 for FakeBindingsCtx<TimerId, Event, State, ()>
1005 {
1006 fn receive_frame(
1007 &self,
1008 state: &ExternalSocketState<D::Weak>,
1009 device: &D,
1010 frame: Frame<&[u8]>,
1011 raw_frame: &[u8],
1012 ) {
1013 let ExternalSocketState(queue) = state;
1014 queue.lock().push(ReceivedFrame {
1015 device: device.downgrade(),
1016 frame: frame.cloned(),
1017 raw: raw_frame.into(),
1018 })
1019 }
1020 }
1021
1022 impl<
1023 TimerId: Debug + PartialEq + Clone + Send + Sync + 'static,
1024 Event: Debug + 'static,
1025 State: 'static,
1026 > DeviceLayerStateTypes for FakeBindingsCtx<TimerId, Event, State, ()>
1027 {
1028 type EthernetDeviceState = ();
1029 type LoopbackDeviceState = ();
1030 type PureIpDeviceState = ();
1031 type BlackholeDeviceState = ();
1032 type DeviceIdentifier = MonotonicIdentifier;
1033 }
1034
1035 impl DeviceClassMatcher<()> for () {
1036 fn device_class_matches(&self, (): &()) -> bool {
1037 unimplemented!()
1038 }
1039 }
1040
1041 impl DeviceIdAndNameMatcher for MonotonicIdentifier {
1042 fn id_matches(&self, _id: &NonZeroU64) -> bool {
1043 unimplemented!()
1044 }
1045
1046 fn name_matches(&self, _name: &str) -> bool {
1047 unimplemented!()
1048 }
1049 }
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054 use alloc::collections::HashMap;
1055 use alloc::vec;
1056 use alloc::vec::Vec;
1057 use core::marker::PhantomData;
1058
1059 use crate::internal::socket::testutil::{ExternalSocketState, ReceivedFrame};
1060 use netstack3_base::testutil::{
1061 FakeReferencyDeviceId, FakeStrongDeviceId, FakeWeakDeviceId, MultipleDevicesId,
1062 };
1063 use netstack3_base::{CounterContext, CtxPair, SendFrameError, SendableFrameMeta};
1064 use packet::ParsablePacket;
1065 use test_case::test_case;
1066
1067 use super::*;
1068
1069 type FakeCoreCtx<D> = netstack3_base::testutil::FakeCoreCtx<FakeSockets<D>, (), D>;
1070 type FakeBindingsCtx = netstack3_base::testutil::FakeBindingsCtx<(), (), (), ()>;
1071 type FakeCtx<D> = CtxPair<FakeCoreCtx<D>, FakeBindingsCtx>;
1072
1073 trait DeviceSocketApiExt: ContextPair + Sized {
1076 fn device_socket_api(&mut self) -> DeviceSocketApi<&mut Self> {
1077 DeviceSocketApi::new(self)
1078 }
1079 }
1080
1081 impl<O> DeviceSocketApiExt for O where O: ContextPair + Sized {}
1082
1083 #[derive(Derivative)]
1084 #[derivative(Default(bound = ""))]
1085 struct FakeSockets<D: FakeStrongDeviceId> {
1086 any_device_sockets: AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1087 device_sockets: HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1088 all_sockets: AllSockets<D::Weak, FakeBindingsCtx>,
1089 counters: DeviceSocketCounters,
1091 sent_frames: Vec<Vec<u8>>,
1092 }
1093
1094 pub struct FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>(
1096 &'m mut AnyDevice,
1097 &'m mut AllSockets,
1098 &'m mut Devices,
1099 PhantomData<Device>,
1100 &'m DeviceSocketCounters,
1101 );
1102
1103 pub trait AsFakeSocketsMutRefs {
1106 type AnyDevice: 'static;
1107 type AllSockets: 'static;
1108 type Devices: 'static;
1109 type Device: 'static;
1110 fn as_sockets_ref(
1111 &mut self,
1112 ) -> FakeSocketsMutRefs<'_, Self::AnyDevice, Self::AllSockets, Self::Devices, Self::Device>;
1113 }
1114
1115 impl<D: FakeStrongDeviceId> AsFakeSocketsMutRefs for FakeCoreCtx<D> {
1116 type AnyDevice = AnyDeviceSockets<D::Weak, FakeBindingsCtx>;
1117 type AllSockets = AllSockets<D::Weak, FakeBindingsCtx>;
1118 type Devices = HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>;
1119 type Device = D;
1120
1121 fn as_sockets_ref(
1122 &mut self,
1123 ) -> FakeSocketsMutRefs<
1124 '_,
1125 AnyDeviceSockets<D::Weak, FakeBindingsCtx>,
1126 AllSockets<D::Weak, FakeBindingsCtx>,
1127 HashMap<D, DeviceSockets<D::Weak, FakeBindingsCtx>>,
1128 D,
1129 > {
1130 let FakeSockets {
1131 any_device_sockets,
1132 device_sockets,
1133 all_sockets,
1134 counters,
1135 sent_frames: _,
1136 } = &mut self.state;
1137 FakeSocketsMutRefs(
1138 any_device_sockets,
1139 all_sockets,
1140 device_sockets,
1141 PhantomData,
1142 counters,
1143 )
1144 }
1145 }
1146
1147 impl<'m, AnyDevice: 'static, AllSockets: 'static, Devices: 'static, Device: 'static>
1148 AsFakeSocketsMutRefs for FakeSocketsMutRefs<'m, AnyDevice, AllSockets, Devices, Device>
1149 {
1150 type AnyDevice = AnyDevice;
1151 type AllSockets = AllSockets;
1152 type Devices = Devices;
1153 type Device = Device;
1154
1155 fn as_sockets_ref(
1156 &mut self,
1157 ) -> FakeSocketsMutRefs<'_, AnyDevice, AllSockets, Devices, Device> {
1158 let Self(any_device, all_sockets, devices, PhantomData, counters) = self;
1159 FakeSocketsMutRefs(any_device, all_sockets, devices, PhantomData, counters)
1160 }
1161 }
1162
1163 impl<D: Clone> TargetDevice<&D> {
1164 fn with_weak_id(&self) -> TargetDevice<FakeWeakDeviceId<D>> {
1165 match self {
1166 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1167 TargetDevice::SpecificDevice(d) => {
1168 TargetDevice::SpecificDevice(FakeWeakDeviceId((*d).clone()))
1169 }
1170 }
1171 }
1172 }
1173
1174 impl<D: Eq + Hash + FakeStrongDeviceId> FakeSockets<D> {
1175 fn new(devices: impl IntoIterator<Item = D>) -> Self {
1176 let device_sockets =
1177 devices.into_iter().map(|d| (d, DeviceSockets::default())).collect();
1178 Self {
1179 any_device_sockets: AnyDeviceSockets::default(),
1180 device_sockets,
1181 all_sockets: Default::default(),
1182 counters: Default::default(),
1183 sent_frames: Default::default(),
1184 }
1185 }
1186 }
1187
1188 impl<
1189 'm,
1190 DeviceId: FakeStrongDeviceId,
1191 As: AsFakeSocketsMutRefs
1192 + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1193 > SocketStateAccessor<FakeBindingsCtx> for As
1194 {
1195 fn with_socket_state<
1196 F: FnOnce(&ExternalSocketState<Self::WeakDeviceId>, &Target<Self::WeakDeviceId>) -> R,
1197 R,
1198 >(
1199 &mut self,
1200 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1201 cb: F,
1202 ) -> R {
1203 let DeviceSocketId(rc) = socket;
1204 let target = rc.target.lock();
1206 cb(&rc.external_state, &target)
1207 }
1208
1209 fn with_socket_state_mut<
1210 F: FnOnce(&ExternalSocketState<Self::WeakDeviceId>, &mut Target<Self::WeakDeviceId>) -> R,
1211 R,
1212 >(
1213 &mut self,
1214 socket: &DeviceSocketId<Self::WeakDeviceId, FakeBindingsCtx>,
1215 cb: F,
1216 ) -> R {
1217 let DeviceSocketId(rc) = socket;
1218 let mut target = rc.target.lock();
1220 cb(&rc.external_state, &mut target)
1221 }
1222 }
1223
1224 impl<
1225 'm,
1226 DeviceId: FakeStrongDeviceId,
1227 As: AsFakeSocketsMutRefs<
1228 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1229 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1230 > DeviceSocketAccessor<FakeBindingsCtx> for As
1231 {
1232 type DeviceSocketCoreCtx<'a> =
1233 FakeSocketsMutRefs<'a, As::AnyDevice, As::AllSockets, HashSet<DeviceId>, DeviceId>;
1234 fn with_device_sockets<
1235 F: FnOnce(
1236 &DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1237 &mut Self::DeviceSocketCoreCtx<'_>,
1238 ) -> R,
1239 R,
1240 >(
1241 &mut self,
1242 device: &Self::DeviceId,
1243 cb: F,
1244 ) -> R {
1245 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1246 self.as_sockets_ref();
1247 let mut devices = device_sockets.keys().cloned().collect();
1248 let device = device_sockets.get(device).unwrap();
1249 cb(
1250 device,
1251 &mut FakeSocketsMutRefs(
1252 any_device,
1253 all_sockets,
1254 &mut devices,
1255 PhantomData,
1256 counters,
1257 ),
1258 )
1259 }
1260 fn with_device_sockets_mut<
1261 F: FnOnce(
1262 &mut DeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1263 &mut Self::DeviceSocketCoreCtx<'_>,
1264 ) -> R,
1265 R,
1266 >(
1267 &mut self,
1268 device: &Self::DeviceId,
1269 cb: F,
1270 ) -> R {
1271 let FakeSocketsMutRefs(any_device, all_sockets, device_sockets, PhantomData, counters) =
1272 self.as_sockets_ref();
1273 let mut devices = device_sockets.keys().cloned().collect();
1274 let device = device_sockets.get_mut(device).unwrap();
1275 cb(
1276 device,
1277 &mut FakeSocketsMutRefs(
1278 any_device,
1279 all_sockets,
1280 &mut devices,
1281 PhantomData,
1282 counters,
1283 ),
1284 )
1285 }
1286 }
1287
1288 impl<
1289 'm,
1290 DeviceId: FakeStrongDeviceId,
1291 As: AsFakeSocketsMutRefs<
1292 AnyDevice = AnyDeviceSockets<DeviceId::Weak, FakeBindingsCtx>,
1293 AllSockets = AllSockets<DeviceId::Weak, FakeBindingsCtx>,
1294 Devices = HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1295 > + DeviceIdContext<AnyDevice, DeviceId = DeviceId, WeakDeviceId = DeviceId::Weak>,
1296 > DeviceSocketContext<FakeBindingsCtx> for As
1297 {
1298 type SocketTablesCoreCtx<'a> = FakeSocketsMutRefs<
1299 'a,
1300 (),
1301 (),
1302 HashMap<DeviceId, DeviceSockets<DeviceId::Weak, FakeBindingsCtx>>,
1303 DeviceId,
1304 >;
1305
1306 fn with_any_device_sockets<
1307 F: FnOnce(
1308 &AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1309 &mut Self::SocketTablesCoreCtx<'_>,
1310 ) -> R,
1311 R,
1312 >(
1313 &mut self,
1314 cb: F,
1315 ) -> R {
1316 let FakeSocketsMutRefs(
1317 any_device_sockets,
1318 _all_sockets,
1319 device_sockets,
1320 PhantomData,
1321 counters,
1322 ) = self.as_sockets_ref();
1323 cb(
1324 any_device_sockets,
1325 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1326 )
1327 }
1328 fn with_any_device_sockets_mut<
1329 F: FnOnce(
1330 &mut AnyDeviceSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1331 &mut Self::SocketTablesCoreCtx<'_>,
1332 ) -> R,
1333 R,
1334 >(
1335 &mut self,
1336 cb: F,
1337 ) -> R {
1338 let FakeSocketsMutRefs(
1339 any_device_sockets,
1340 _all_sockets,
1341 device_sockets,
1342 PhantomData,
1343 counters,
1344 ) = self.as_sockets_ref();
1345 cb(
1346 any_device_sockets,
1347 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1348 )
1349 }
1350
1351 fn with_all_device_sockets<
1352 F: FnOnce(
1353 &AllSockets<Self::WeakDeviceId, FakeBindingsCtx>,
1354 &mut Self::SocketTablesCoreCtx<'_>,
1355 ) -> R,
1356 R,
1357 >(
1358 &mut self,
1359 cb: F,
1360 ) -> R {
1361 let FakeSocketsMutRefs(
1362 _any_device_sockets,
1363 all_sockets,
1364 device_sockets,
1365 PhantomData,
1366 counters,
1367 ) = self.as_sockets_ref();
1368 cb(
1369 all_sockets,
1370 &mut FakeSocketsMutRefs(&mut (), &mut (), device_sockets, PhantomData, counters),
1371 )
1372 }
1373
1374 fn with_all_device_sockets_mut<
1375 F: FnOnce(&mut AllSockets<Self::WeakDeviceId, FakeBindingsCtx>) -> R,
1376 R,
1377 >(
1378 &mut self,
1379 cb: F,
1380 ) -> R {
1381 let FakeSocketsMutRefs(_, all_sockets, _, _, _) = self.as_sockets_ref();
1382 cb(all_sockets)
1383 }
1384 }
1385
1386 impl<'m, X, Y, Z, D: FakeStrongDeviceId> DeviceIdContext<AnyDevice>
1387 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1388 {
1389 type DeviceId = D;
1390 type WeakDeviceId = FakeWeakDeviceId<D>;
1391 }
1392
1393 impl<D: FakeStrongDeviceId> CounterContext<DeviceSocketCounters> for FakeCoreCtx<D> {
1394 fn counters(&self) -> &DeviceSocketCounters {
1395 &self.state.counters
1396 }
1397 }
1398
1399 impl<D: FakeStrongDeviceId>
1400 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1401 for FakeCoreCtx<D>
1402 {
1403 fn per_resource_counters<'a>(
1404 &'a self,
1405 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1406 ) -> &'a DeviceSocketCounters {
1407 socket.counters()
1408 }
1409 }
1410
1411 impl<'m, X, Y, Z, D> CounterContext<DeviceSocketCounters> for FakeSocketsMutRefs<'m, X, Y, Z, D> {
1412 fn counters(&self) -> &DeviceSocketCounters {
1413 let FakeSocketsMutRefs(_, _, _, _, counters) = self;
1414 counters
1415 }
1416 }
1417
1418 impl<'m, X, Y, Z, D: FakeStrongDeviceId>
1419 ResourceCounterContext<DeviceSocketId<D::Weak, FakeBindingsCtx>, DeviceSocketCounters>
1420 for FakeSocketsMutRefs<'m, X, Y, Z, D>
1421 {
1422 fn per_resource_counters<'a>(
1423 &'a self,
1424 socket: &'a DeviceSocketId<D::Weak, FakeBindingsCtx>,
1425 ) -> &'a DeviceSocketCounters {
1426 socket.counters()
1427 }
1428 }
1429
1430 const SOME_PROTOCOL: NonZeroU16 = NonZeroU16::new(2000).unwrap();
1431
1432 #[test]
1433 fn create_remove() {
1434 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1435 MultipleDevicesId::all(),
1436 )));
1437 let mut api = ctx.device_socket_api();
1438
1439 let bound = api.create(Default::default());
1440 assert_eq!(
1441 api.get_info(&bound),
1442 SocketInfo { device: TargetDevice::AnyDevice, protocol: None }
1443 );
1444
1445 let ExternalSocketState(_received_frames) = api.remove(bound).into_removed();
1446 }
1447
1448 #[test_case(TargetDevice::AnyDevice)]
1449 #[test_case(TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1450 fn test_set_device(device: TargetDevice<&MultipleDevicesId>) {
1451 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1452 MultipleDevicesId::all(),
1453 )));
1454 let mut api = ctx.device_socket_api();
1455
1456 let bound = api.create(Default::default());
1457 api.set_device(&bound, device.clone());
1458 assert_eq!(
1459 api.get_info(&bound),
1460 SocketInfo { device: device.with_weak_id(), protocol: None }
1461 );
1462
1463 let device_sockets = &api.core_ctx().state.device_sockets;
1464 if let TargetDevice::SpecificDevice(d) = device {
1465 let DeviceSockets(socket_ids) = device_sockets.get(&d).expect("device state exists");
1466 assert_eq!(socket_ids, &HashSet::from([bound]));
1467 }
1468 }
1469
1470 #[test]
1471 fn update_device() {
1472 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1473 MultipleDevicesId::all(),
1474 )));
1475 let mut api = ctx.device_socket_api();
1476 let bound = api.create(Default::default());
1477
1478 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::A));
1479
1480 api.set_device(&bound, TargetDevice::SpecificDevice(&MultipleDevicesId::B));
1483 assert_eq!(
1484 api.get_info(&bound),
1485 SocketInfo {
1486 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(MultipleDevicesId::B)),
1487 protocol: None
1488 }
1489 );
1490
1491 let device_sockets = &api.core_ctx().state.device_sockets;
1492 let device_socket_lists = device_sockets
1493 .iter()
1494 .map(|(d, DeviceSockets(indexes))| (d, indexes.iter().collect()))
1495 .collect::<HashMap<_, _>>();
1496
1497 assert_eq!(
1498 device_socket_lists,
1499 HashMap::from([
1500 (&MultipleDevicesId::A, vec![]),
1501 (&MultipleDevicesId::B, vec![&bound]),
1502 (&MultipleDevicesId::C, vec![])
1503 ])
1504 );
1505 }
1506
1507 #[test_case(Protocol::All, TargetDevice::AnyDevice)]
1508 #[test_case(Protocol::Specific(SOME_PROTOCOL), TargetDevice::AnyDevice)]
1509 #[test_case(Protocol::All, TargetDevice::SpecificDevice(&MultipleDevicesId::A))]
1510 #[test_case(
1511 Protocol::Specific(SOME_PROTOCOL),
1512 TargetDevice::SpecificDevice(&MultipleDevicesId::A)
1513 )]
1514 fn create_set_device_and_protocol_remove_multiple(
1515 protocol: Protocol,
1516 device: TargetDevice<&MultipleDevicesId>,
1517 ) {
1518 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1519 MultipleDevicesId::all(),
1520 )));
1521 let mut api = ctx.device_socket_api();
1522
1523 let mut sockets = [(); 3].map(|()| api.create(Default::default()));
1524 for socket in &mut sockets {
1525 api.set_device_and_protocol(socket, device.clone(), protocol);
1526 assert_eq!(
1527 api.get_info(socket),
1528 SocketInfo { device: device.with_weak_id(), protocol: Some(protocol) }
1529 );
1530 }
1531
1532 for socket in sockets {
1533 let ExternalSocketState(_received_frames) = api.remove(socket).into_removed();
1534 }
1535 }
1536
1537 #[test]
1538 fn change_device_after_removal() {
1539 let device_to_remove = FakeReferencyDeviceId::default();
1540 let device_to_maintain = FakeReferencyDeviceId::default();
1541 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new([
1542 device_to_remove.clone(),
1543 device_to_maintain.clone(),
1544 ])));
1545 let mut api = ctx.device_socket_api();
1546
1547 let bound = api.create(Default::default());
1548 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_remove));
1551
1552 device_to_remove.mark_removed();
1555
1556 api.set_device(&bound, TargetDevice::SpecificDevice(&device_to_maintain));
1559 assert_eq!(
1560 api.get_info(&bound),
1561 SocketInfo {
1562 device: TargetDevice::SpecificDevice(FakeWeakDeviceId(device_to_maintain.clone())),
1563 protocol: None,
1564 }
1565 );
1566
1567 let device_sockets = &api.core_ctx().state.device_sockets;
1568 let DeviceSockets(weak_sockets) =
1569 device_sockets.get(&device_to_maintain).expect("device state exists");
1570 assert_eq!(weak_sockets, &HashSet::from([bound]));
1571 }
1572
1573 struct TestData;
1574 impl TestData {
1575 const SRC_MAC: Mac = Mac::new([0, 1, 2, 3, 4, 5]);
1576 const DST_MAC: Mac = Mac::new([6, 7, 8, 9, 10, 11]);
1577 const PROTO: NonZeroU16 = NonZeroU16::new(0x08AB).unwrap();
1579 const BODY: &'static [u8] = b"some pig";
1580 const BUFFER: &'static [u8] = &[
1581 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 0x08, 0xAB, b's', b'o', b'm', b'e', b' ', b'p',
1582 b'i', b'g',
1583 ];
1584
1585 fn frame() -> packet_formats::ethernet::EthernetFrame<&'static [u8]> {
1587 let mut buffer_view = Self::BUFFER;
1588 packet_formats::ethernet::EthernetFrame::parse(
1589 &mut buffer_view,
1590 EthernetFrameLengthCheck::NoCheck,
1591 )
1592 .unwrap()
1593 }
1594 }
1595
1596 const WRONG_PROTO: NonZeroU16 = NonZeroU16::new(0x08ff).unwrap();
1597
1598 fn make_bound<D: FakeStrongDeviceId>(
1599 ctx: &mut FakeCtx<D>,
1600 device: TargetDevice<D>,
1601 protocol: Option<Protocol>,
1602 state: ExternalSocketState<D::Weak>,
1603 ) -> DeviceSocketId<D::Weak, FakeBindingsCtx> {
1604 let mut api = ctx.device_socket_api();
1605 let id = api.create(state);
1606 let device = match &device {
1607 TargetDevice::AnyDevice => TargetDevice::AnyDevice,
1608 TargetDevice::SpecificDevice(d) => TargetDevice::SpecificDevice(d),
1609 };
1610 match protocol {
1611 Some(protocol) => api.set_device_and_protocol(&id, device, protocol),
1612 None => api.set_device(&id, device),
1613 };
1614 id
1615 }
1616
1617 fn deliver_one_frame(
1620 delivered_frame: Frame<&[u8]>,
1621 FakeCtx { core_ctx, bindings_ctx }: &mut FakeCtx<MultipleDevicesId>,
1622 ) -> HashSet<DeviceSocketId<FakeWeakDeviceId<MultipleDevicesId>, FakeBindingsCtx>> {
1623 DeviceSocketHandler::handle_frame(
1624 core_ctx,
1625 bindings_ctx,
1626 &MultipleDevicesId::A,
1627 delivered_frame.clone(),
1628 TestData::BUFFER,
1629 );
1630
1631 let FakeSockets {
1632 all_sockets: AllSockets(all_sockets),
1633 any_device_sockets: _,
1634 device_sockets: _,
1635 counters: _,
1636 sent_frames: _,
1637 } = &core_ctx.state;
1638
1639 all_sockets
1640 .iter()
1641 .filter_map(|(id, _primary)| {
1642 let DeviceSocketId(rc) = &id;
1643 let ExternalSocketState(frames) = &rc.external_state;
1644 let frames = frames.lock();
1645 (!frames.is_empty()).then(|| {
1646 assert_eq!(
1647 &*frames,
1648 &[ReceivedFrame {
1649 device: FakeWeakDeviceId(MultipleDevicesId::A),
1650 frame: delivered_frame.cloned(),
1651 raw: TestData::BUFFER.into(),
1652 }]
1653 );
1654 id.clone()
1655 })
1656 })
1657 .collect()
1658 }
1659
1660 #[test]
1661 fn receive_frame_deliver_to_multiple() {
1662 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1663 MultipleDevicesId::all(),
1664 )));
1665
1666 use Protocol::*;
1667 use TargetDevice::*;
1668 let never_bound = {
1669 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1670 ctx.device_socket_api().create(state)
1671 };
1672
1673 let mut make_bound = |device, protocol| {
1674 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1675 make_bound(&mut ctx, device, protocol, state)
1676 };
1677 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1678 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1679 let bound_a_right_protocol =
1680 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1681 let bound_a_wrong_protocol =
1682 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1683 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1684 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1685 let bound_b_right_protocol =
1686 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1687 let bound_b_wrong_protocol =
1688 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1689 let bound_any_no_protocol = make_bound(AnyDevice, None);
1690 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1691 let bound_any_right_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1692 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1693
1694 let mut sockets_with_received_frames = deliver_one_frame(
1695 super::ReceivedFrame::from_ethernet(
1696 TestData::frame(),
1697 FrameDestination::Individual { local: true },
1698 )
1699 .into(),
1700 &mut ctx,
1701 );
1702
1703 let sockets_not_expecting_frames = [
1704 never_bound,
1705 bound_a_no_protocol,
1706 bound_a_wrong_protocol,
1707 bound_b_no_protocol,
1708 bound_b_all_protocols,
1709 bound_b_right_protocol,
1710 bound_b_wrong_protocol,
1711 bound_any_no_protocol,
1712 bound_any_wrong_protocol,
1713 ];
1714 let sockets_expecting_frames = [
1715 bound_a_all_protocols,
1716 bound_a_right_protocol,
1717 bound_any_all_protocols,
1718 bound_any_right_protocol,
1719 ];
1720
1721 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1722 assert!(
1723 sockets_with_received_frames.remove(&socket),
1724 "socket {n} didn't receive the frame"
1725 );
1726 }
1727 assert!(sockets_with_received_frames.is_empty());
1728
1729 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1731 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1732 }
1733 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1734 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1735 }
1736 }
1737
1738 #[test]
1739 fn sent_frame_deliver_to_multiple() {
1740 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1741 MultipleDevicesId::all(),
1742 )));
1743
1744 use Protocol::*;
1745 use TargetDevice::*;
1746 let never_bound = {
1747 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1748 ctx.device_socket_api().create(state)
1749 };
1750
1751 let mut make_bound = |device, protocol| {
1752 let state = ExternalSocketState::<FakeWeakDeviceId<MultipleDevicesId>>::default();
1753 make_bound(&mut ctx, device, protocol, state)
1754 };
1755 let bound_a_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::A), None);
1756 let bound_a_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::A), Some(All));
1757 let bound_a_same_protocol =
1758 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(TestData::PROTO)));
1759 let bound_a_wrong_protocol =
1760 make_bound(SpecificDevice(MultipleDevicesId::A), Some(Specific(WRONG_PROTO)));
1761 let bound_b_no_protocol = make_bound(SpecificDevice(MultipleDevicesId::B), None);
1762 let bound_b_all_protocols = make_bound(SpecificDevice(MultipleDevicesId::B), Some(All));
1763 let bound_b_same_protocol =
1764 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(TestData::PROTO)));
1765 let bound_b_wrong_protocol =
1766 make_bound(SpecificDevice(MultipleDevicesId::B), Some(Specific(WRONG_PROTO)));
1767 let bound_any_no_protocol = make_bound(AnyDevice, None);
1768 let bound_any_all_protocols = make_bound(AnyDevice, Some(All));
1769 let bound_any_same_protocol = make_bound(AnyDevice, Some(Specific(TestData::PROTO)));
1770 let bound_any_wrong_protocol = make_bound(AnyDevice, Some(Specific(WRONG_PROTO)));
1771
1772 let mut sockets_with_received_frames =
1773 deliver_one_frame(SentFrame::Ethernet(TestData::frame().into()).into(), &mut ctx);
1774
1775 let sockets_not_expecting_frames = [
1776 never_bound,
1777 bound_a_no_protocol,
1778 bound_a_same_protocol,
1779 bound_a_wrong_protocol,
1780 bound_b_no_protocol,
1781 bound_b_all_protocols,
1782 bound_b_same_protocol,
1783 bound_b_wrong_protocol,
1784 bound_any_no_protocol,
1785 bound_any_same_protocol,
1786 bound_any_wrong_protocol,
1787 ];
1788 let sockets_expecting_frames = [bound_a_all_protocols, bound_any_all_protocols];
1790
1791 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1792 assert!(
1793 sockets_with_received_frames.remove(&socket),
1794 "socket {n} didn't receive the frame"
1795 );
1796 }
1797 assert!(sockets_with_received_frames.is_empty());
1798
1799 for (n, socket) in sockets_expecting_frames.iter().enumerate() {
1801 assert_eq!(socket.counters().rx_frames.get(), 1, "socket {n} has wrong rx_frames");
1802 }
1803 for (n, socket) in sockets_not_expecting_frames.iter().enumerate() {
1804 assert_eq!(socket.counters().rx_frames.get(), 0, "socket {n} has wrong rx_frames");
1805 }
1806 }
1807
1808 #[test]
1809 fn deliver_multiple_frames() {
1810 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1811 MultipleDevicesId::all(),
1812 )));
1813 let socket = make_bound(
1814 &mut ctx,
1815 TargetDevice::AnyDevice,
1816 Some(Protocol::All),
1817 ExternalSocketState::default(),
1818 );
1819 let FakeCtx { mut core_ctx, mut bindings_ctx } = ctx;
1820
1821 const RECEIVE_COUNT: usize = 10;
1822 for _ in 0..RECEIVE_COUNT {
1823 DeviceSocketHandler::handle_frame(
1824 &mut core_ctx,
1825 &mut bindings_ctx,
1826 &MultipleDevicesId::A,
1827 super::ReceivedFrame::from_ethernet(
1828 TestData::frame(),
1829 FrameDestination::Individual { local: true },
1830 )
1831 .into(),
1832 TestData::BUFFER,
1833 );
1834 }
1835
1836 let FakeSockets {
1837 all_sockets: AllSockets(mut all_sockets),
1838 any_device_sockets: _,
1839 device_sockets: _,
1840 counters: _,
1841 sent_frames: _,
1842 } = core_ctx.into_state();
1843 let primary = all_sockets.remove(&socket).unwrap();
1844 let PrimaryDeviceSocketId(primary) = primary;
1845 assert!(all_sockets.is_empty());
1846 drop(socket);
1847 let SocketState { external_state: ExternalSocketState(received), counters, target: _ } =
1848 PrimaryRc::unwrap(primary);
1849 assert_eq!(
1850 received.into_inner(),
1851 vec![
1852 ReceivedFrame {
1853 device: FakeWeakDeviceId(MultipleDevicesId::A),
1854 frame: Frame::Received(super::ReceivedFrame::Ethernet {
1855 destination: FrameDestination::Individual { local: true },
1856 frame: EthernetFrame {
1857 src_mac: TestData::SRC_MAC,
1858 dst_mac: TestData::DST_MAC,
1859 ethertype: Some(TestData::PROTO.get().into()),
1860 body: Vec::from(TestData::BODY),
1861 }
1862 }),
1863 raw: TestData::BUFFER.into()
1864 };
1865 RECEIVE_COUNT
1866 ]
1867 );
1868 assert_eq!(counters.rx_frames.get(), u64::try_from(RECEIVE_COUNT).unwrap());
1869 }
1870
1871 pub struct FakeSendMetadata;
1872 impl DeviceSocketSendTypes for AnyDevice {
1873 type Metadata = FakeSendMetadata;
1874 }
1875 impl<BC, D: FakeStrongDeviceId> SendableFrameMeta<FakeCoreCtx<D>, BC>
1876 for DeviceSocketMetadata<AnyDevice, D>
1877 {
1878 fn send_meta<S>(
1879 self,
1880 core_ctx: &mut FakeCoreCtx<D>,
1881 _bindings_ctx: &mut BC,
1882 frame: S,
1883 ) -> Result<(), SendFrameError<S>>
1884 where
1885 S: packet::Serializer,
1886 S::Buffer: packet::BufferMut,
1887 {
1888 let frame = match frame.serialize_vec_outer() {
1889 Err(e) => {
1890 let _: (packet::SerializeError<core::convert::Infallible>, _) = e;
1891 unreachable!()
1892 }
1893 Ok(frame) => frame.unwrap_a().as_ref().to_vec(),
1894 };
1895 core_ctx.state.sent_frames.push(frame);
1896 Ok(())
1897 }
1898 }
1899
1900 #[test]
1901 fn send_multiple_frames() {
1902 let mut ctx = FakeCtx::with_core_ctx(FakeCoreCtx::with_state(FakeSockets::new(
1903 MultipleDevicesId::all(),
1904 )));
1905
1906 const DEVICE: MultipleDevicesId = MultipleDevicesId::A;
1907 let socket = make_bound(
1908 &mut ctx,
1909 TargetDevice::SpecificDevice(DEVICE),
1910 Some(Protocol::All),
1911 ExternalSocketState::default(),
1912 );
1913 let mut api = ctx.device_socket_api();
1914
1915 const SEND_COUNT: usize = 10;
1916 const PAYLOAD: &'static [u8] = &[1, 2, 3, 4, 5];
1917 for _ in 0..SEND_COUNT {
1918 let buf = packet::Buf::new(PAYLOAD.to_vec(), ..);
1919 api.send_frame(
1920 &socket,
1921 DeviceSocketMetadata { device_id: DEVICE, metadata: FakeSendMetadata },
1922 buf,
1923 )
1924 .expect("send failed");
1925 }
1926
1927 assert_eq!(ctx.core_ctx().state.sent_frames, vec![PAYLOAD.to_vec(); SEND_COUNT]);
1928
1929 assert_eq!(socket.counters().tx_frames.get(), u64::try_from(SEND_COUNT).unwrap());
1930 }
1931}