1use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17 AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19
20use crate::data_structures::socketmap::{
21 Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
22};
23use crate::device::{
24 DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
25};
26use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
27use crate::ip::BroadcastIpExt;
28use crate::socket::address::{
29 AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
30};
31
32pub trait DualStackIpExt: Ip {
35 type OtherVersion: DualStackIpExt<OtherVersion = Self>;
37}
38
39impl DualStackIpExt for Ipv4 {
40 type OtherVersion = Ipv6;
41}
42
43impl DualStackIpExt for Ipv6 {
44 type OtherVersion = Ipv4;
45}
46
47pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
49 this_stack: <T as GenericOverIp<I>>::Type,
50 other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
51 _marker: IpVersionMarker<I>,
52}
53
54impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
55 pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
57 where
58 T: GenericOverIp<I, Type = T>,
59 {
60 Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
61 }
62
63 pub fn into_inner(
65 self,
66 ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
67 let Self { this_stack, other_stack, _marker } = self;
68 (this_stack, other_stack)
69 }
70
71 pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
73 self.this_stack
74 }
75
76 pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
78 &self.this_stack
79 }
80
81 pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
83 self.other_stack
84 }
85
86 pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
88 &self.other_stack
89 }
90
91 pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
93 let Self { this_stack, other_stack, _marker } = self;
94 DualStackTuple {
95 this_stack: other_stack,
96 other_stack: this_stack,
97 _marker: IpVersionMarker::new(),
98 }
99 }
100
101 pub fn cast<X>(self) -> DualStackTuple<X, T>
110 where
111 X: DualStackIpExt,
112 T: GenericOverIp<X>
113 + GenericOverIp<X::OtherVersion>
114 + GenericOverIp<Ipv4>
115 + GenericOverIp<Ipv6>,
116 {
117 I::map_ip_in(
118 self,
119 |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
120 |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
121 )
122 }
123}
124
125impl<
126 I: DualStackIpExt,
127 NewIp: DualStackIpExt,
128 T: GenericOverIp<NewIp>
129 + GenericOverIp<NewIp::OtherVersion>
130 + GenericOverIp<I>
131 + GenericOverIp<I::OtherVersion>,
132 > GenericOverIp<NewIp> for DualStackTuple<I, T>
133{
134 type Type = DualStackTuple<NewIp, T>;
135}
136
137pub trait SocketIpExt: Ip {
139 const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
141 SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
144 };
145}
146
147impl<I: Ip> SocketIpExt for I {}
148
149#[cfg(test)]
150mod socket_ip_ext_test {
151 use super::*;
152 use ip_test_macro::ip_test;
153
154 #[ip_test(I)]
155 fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
156 let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
161 .expect("loopback address should be a valid SocketIpAddr");
162 }
163}
164
165#[derive(Debug, PartialEq, Eq)]
173pub enum EitherStack<T, O> {
174 ThisStack(T),
176 OtherStack(O),
178}
179
180impl<T, O> Clone for EitherStack<T, O>
181where
182 T: Clone,
183 O: Clone,
184{
185 #[cfg_attr(feature = "instrumented", track_caller)]
186 fn clone(&self) -> Self {
187 match self {
188 Self::ThisStack(t) => Self::ThisStack(t.clone()),
189 Self::OtherStack(t) => Self::OtherStack(t.clone()),
190 }
191 }
192}
193
194#[derive(Debug)]
212#[allow(missing_docs)]
213pub enum MaybeDualStack<DS, NDS> {
214 DualStack(DS),
215 NotDualStack(NDS),
216}
217
218impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
221 for MaybeDualStack<DS, NDS>
222{
223 type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
224}
225
226#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
228#[generic_over_ip()]
229pub enum SetDualStackEnabledError {
230 SocketIsBound,
232 NotCapable,
235}
236
237#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
240#[generic_over_ip()]
241pub struct NotDualStackCapableError;
242
243#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
245pub struct Shutdown {
246 pub send: bool,
250 pub receive: bool,
254}
255
256#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
258#[generic_over_ip()]
259pub enum ShutdownType {
260 Send,
262 Receive,
264 SendAndReceive,
266}
267
268impl ShutdownType {
269 pub fn to_send_receive(&self) -> (bool, bool) {
271 match self {
272 Self::Send => (true, false),
273 Self::Receive => (false, true),
274 Self::SendAndReceive => (true, true),
275 }
276 }
277
278 pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
280 match (send, receive) {
281 (true, false) => Some(Self::Send),
282 (false, true) => Some(Self::Receive),
283 (true, true) => Some(Self::SendAndReceive),
284 (false, false) => None,
285 }
286 }
287}
288
289pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
291 fn must_have_zone(&self) -> bool
297 where
298 Self: Copy,
299 {
300 self.try_into_null_zoned().is_some()
301 }
302
303 fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
307 if self.get().is_loopback() {
308 return None;
309 }
310 AddrAndZone::new(self, ())
311 }
312}
313
314impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
315
316pub trait SocketZonedAddrExt<W, A, D> {
318 fn resolve_addr_with_device(
326 self,
327 device: Option<D::Weak>,
328 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
329 where
330 D: StrongDeviceIdentifier;
331}
332
333impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
334where
335 W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
336 A: IpAddress,
337{
338 fn resolve_addr_with_device(
339 self,
340 device: Option<D::Weak>,
341 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
342 where
343 D: StrongDeviceIdentifier,
344 {
345 let (addr, zone) = self.into_addr_zone();
346 let device = match (zone, device) {
347 (Some(zone), Some(device)) => {
348 if device != zone {
349 return Err(ZonedAddressError::DeviceZoneMismatch);
350 }
351 Some(EitherDeviceId::Strong(zone))
352 }
353 (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
354 (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
355 (None, None) => {
356 if addr.as_ref().must_have_zone() {
357 return Err(ZonedAddressError::RequiredZoneNotProvided);
358 } else {
359 None
360 }
361 }
362 };
363 Ok((addr, device))
364 }
365}
366
367pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
373 pub local_ip: Option<&'a SpecifiedAddr<A>>,
375 pub remote_ip: Option<&'a SpecifiedAddr<A>>,
377 pub old_device: Option<&'a D>,
379}
380
381impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
382 pub fn check_update<N>(
385 self,
386 new_device: Option<&N>,
387 ) -> Result<(), SocketDeviceUpdateNotAllowedError>
388 where
389 D: PartialEq<N>,
390 {
391 let Self { local_ip, remote_ip, old_device } = self;
392 let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
393 || remote_ip.is_some_and(|a| a.must_have_zone());
394
395 if !must_have_zone {
396 return Ok(());
397 }
398
399 let old_device = old_device.unwrap_or_else(|| {
400 panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
401 });
402
403 if new_device.is_some_and(|new_device| old_device == new_device) {
404 Ok(())
405 } else {
406 Err(SocketDeviceUpdateNotAllowedError)
407 }
408 }
409}
410
411pub struct SocketDeviceUpdateNotAllowedError;
413
414pub trait SocketMapAddrSpec {
419 type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
421 type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
423}
424
425pub struct ListenerAddrInfo {
427 pub has_device: bool,
429 pub specified_addr: bool,
432}
433
434impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
435 pub(crate) fn info(&self) -> ListenerAddrInfo {
436 let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
437 ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
438 }
439}
440
441pub trait SocketMapStateSpec {
443 type AddrVecTag: Eq + Copy + Debug + 'static;
448
449 fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
451
452 fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
454
455 type ListenerId: Clone + Debug;
457 type ConnId: Clone + Debug;
459
460 type ListenerSharingState: Clone + Debug;
463
464 type ConnSharingState: Clone + Debug;
467
468 type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
470 + Debug;
471
472 type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
474 + Debug;
475}
476
477#[derive(Copy, Clone, Debug, Eq, PartialEq)]
480pub struct IncompatibleError;
481
482pub trait Inserter<T> {
484 fn insert(self, item: T);
489}
490
491impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
492 fn insert(self, item: T) {
493 self.extend([item])
494 }
495}
496
497impl<T> Inserter<T> for Never {
498 fn insert(self, _: T) {
499 match self {}
500 }
501}
502
503pub trait SocketMapAddrStateSpec {
505 type Id;
507
508 type SharingState;
515
516 type Inserter<'a>: Inserter<Self::Id> + 'a
518 where
519 Self: 'a,
520 Self::Id: 'a;
521
522 fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
525
526 fn contains_id(&self, id: &Self::Id) -> bool;
528
529 fn try_get_inserter<'a, 'b>(
537 &'b mut self,
538 new_sharing_state: &'a Self::SharingState,
539 ) -> Result<Self::Inserter<'b>, IncompatibleError>;
540
541 fn could_insert(&self, new_sharing_state: &Self::SharingState)
546 -> Result<(), IncompatibleError>;
547
548 fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
552}
553
554pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
556 fn try_update_sharing(
559 &mut self,
560 id: Self::Id,
561 new_sharing_state: &Self::SharingState,
562 ) -> Result<(), IncompatibleError>;
563}
564
565pub trait SocketMapConflictPolicy<
567 Addr,
568 SharingState,
569 I: Ip,
570 D: DeviceIdentifier,
571 A: SocketMapAddrSpec,
572>: SocketMapStateSpec
573{
574 fn check_insert_conflicts(
583 new_sharing_state: &SharingState,
584 addr: &Addr,
585 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
586 ) -> Result<(), InsertError>;
587}
588
589pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
592 SocketMapConflictPolicy<Addr, SharingState, I, D, A>
593where
594 A: SocketMapAddrSpec,
595{
596 fn allows_sharing_update(
599 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
600 addr: &Addr,
601 old_sharing: &SharingState,
602 new_sharing: &SharingState,
603 ) -> Result<(), UpdateSharingError>;
604}
605
606#[derive(Derivative)]
608#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
609#[allow(missing_docs)]
610pub enum Bound<S: SocketMapStateSpec + ?Sized> {
611 Listen(S::ListenerAddrState),
612 Conn(S::ConnAddrState),
613}
614
615#[derive(Derivative)]
630#[derivative(
631 Debug(bound = "D: Debug"),
632 Clone(bound = "D: Clone"),
633 Eq(bound = "D: Eq"),
634 PartialEq(bound = "D: PartialEq"),
635 Hash(bound = "D: Hash")
636)]
637#[allow(missing_docs)]
638pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
639 Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
640 Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
641}
642
643impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
644 Tagged<AddrVec<I, D, A>> for Bound<S>
645{
646 type Tag = S::AddrVecTag;
647 fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
648 match (self, address) {
649 (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
650 (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
651 S::connected_tag(device.is_some(), c)
652 }
653 (Bound::Listen(_), AddrVec::Conn(_)) => {
654 unreachable!("found listen state for conn addr")
655 }
656 (Bound::Conn(_), AddrVec::Listen(_)) => {
657 unreachable!("found conn state for listen addr")
658 }
659 }
660 }
661}
662
663impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
664 type IterShadows = AddrVecIter<I, D, A>;
665
666 fn iter_shadows(&self) -> Self::IterShadows {
667 let (socket_ip_addr, device) = match self.clone() {
668 AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
669 AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
670 };
671 let mut iter = match device {
672 Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
673 None => AddrVecIter::without_device(socket_ip_addr),
674 };
675 assert_eq!(iter.next().as_ref(), Some(self));
677 iter
678 }
679}
680
681#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
683#[allow(missing_docs)]
684pub enum SocketAddrType {
685 AnyListener,
686 SpecificListener,
687 Connected,
688}
689
690impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
691 fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
692 match addr {
693 Some(_) => SocketAddrType::SpecificListener,
694 None => SocketAddrType::AnyListener,
695 }
696 }
697}
698
699impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
700 fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
701 SocketAddrType::Connected
702 }
703}
704
705pub enum RemoveResult {
707 Success,
709 IsLast,
712}
713
714#[derive(Derivative)]
715#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
716pub enum SocketId<S: SocketMapStateSpec> {
717 Listener(S::ListenerId),
718 Connection(S::ConnId),
719}
720
721#[derive(Derivative)]
735#[derivative(Default(bound = ""))]
736pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
737 addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
738}
739
740impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
741 BoundSocketMap<I, D, A, S>
742{
743 pub fn len(&self) -> usize {
745 self.addr_to_state.len()
746 }
747}
748
749pub enum Listener {}
751pub enum Connection {}
753
754pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
756
757impl<
758 'a,
759 I: Ip,
760 D: DeviceIdentifier,
761 SocketType: ConvertSocketMapState<I, D, A, S>,
762 A: SocketMapAddrSpec,
763 S: SocketMapStateSpec,
764 > Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
765where
766 S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
767{
768 pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
770 let Self(addr_to_state, _marker) = self;
771 addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
772 SocketType::from_bound_ref(state)
773 .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
774 })
775 }
776
777 pub fn could_insert(
783 self,
784 addr: &SocketType::Addr,
785 sharing: &SocketType::SharingState,
786 ) -> Result<(), InsertError> {
787 let Self(addr_to_state, _) = self;
788 match self.get_by_addr(addr) {
789 Some(state) => {
790 state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
791 }
792 None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
793 }
794 }
795}
796
797#[derive(Derivative)]
799#[derivative(Debug(bound = ""))]
800pub struct SocketStateEntry<
801 'a,
802 I: Ip,
803 D: DeviceIdentifier,
804 A: SocketMapAddrSpec,
805 S: SocketMapStateSpec,
806 SocketType,
807> {
808 id: SocketId<S>,
809 addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
810 _marker: PhantomData<SocketType>,
811}
812
813impl<
814 'a,
815 I: Ip,
816 D: DeviceIdentifier,
817 SocketType: ConvertSocketMapState<I, D, A, S>,
818 A: SocketMapAddrSpec,
819 S: SocketMapStateSpec
820 + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
821 > Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
822where
823 SocketType::SharingState: Clone,
824 SocketType::Id: Clone,
825{
826 pub fn try_insert(
829 self,
830 socket_addr: SocketType::Addr,
831 tag_state: SocketType::SharingState,
832 id: SocketType::Id,
833 ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
834 {
835 self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
836 .map(|(entry, ())| entry)
837 }
838
839 pub fn try_insert_with<R>(
844 self,
845 socket_addr: SocketType::Addr,
846 tag_state: SocketType::SharingState,
847 make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
848 ) -> Result<
849 (SocketStateEntry<'a, I, D, A, S, SocketType>, R),
850 (InsertError, SocketType::SharingState),
851 > {
852 let Self(addr_to_state, _) = self;
853 match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
854 Err(e) => return Err((e, tag_state)),
855 Ok(()) => (),
856 };
857
858 let addr = SocketType::to_addr_vec(&socket_addr);
859
860 match addr_to_state.entry(addr) {
861 Entry::Occupied(mut o) => {
862 let (id, ret) = o.map_mut(|bound| {
863 let bound = match SocketType::from_bound_mut(bound) {
864 Some(bound) => bound,
865 None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
866 };
867 match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
868 bound, &tag_state,
869 ) {
870 Ok(v) => {
871 let (id, ret) = make_id(socket_addr, tag_state);
872 v.insert(id.clone());
873 Ok((SocketType::to_socket_id(id), ret))
874 }
875 Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
876 }
877 })?;
878 Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
879 }
880 Entry::Vacant(v) => {
881 let (id, ret) = make_id(socket_addr, tag_state.clone());
882 let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
883 &tag_state,
884 id.clone(),
885 )));
886 let id = SocketType::to_socket_id(id);
887 Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
888 }
889 }
890 }
891
892 pub fn entry(
894 self,
895 id: &SocketType::Id,
896 addr: &SocketType::Addr,
897 ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
898 let Self(addr_to_state, _) = self;
899 let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
900 Entry::Vacant(_) => return None,
901 Entry::Occupied(o) => o,
902 };
903 let state = SocketType::from_bound_ref(addr_entry.get())?;
904
905 state.contains_id(id).then_some(SocketStateEntry {
906 id: SocketType::to_socket_id(id.clone()),
907 addr_entry,
908 _marker: PhantomData::default(),
909 })
910 }
911
912 pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
914 self.entry(id, addr)
915 .map(|entry| {
916 entry.remove();
917 })
918 .ok_or(NotFoundError)
919 }
920}
921
922#[derive(Debug)]
925pub struct UpdateSharingError;
926
927impl<
928 'a,
929 I: Ip,
930 D: DeviceIdentifier,
931 SocketType: ConvertSocketMapState<I, D, A, S>,
932 A: SocketMapAddrSpec,
933 S: SocketMapStateSpec,
934 > SocketStateEntry<'a, I, D, A, S, SocketType>
935where
936 SocketType::Id: Clone,
937{
938 pub fn get_addr(&self) -> &SocketType::Addr {
940 let Self { id: _, addr_entry, _marker } = self;
941 SocketType::from_addr_vec_ref(addr_entry.key())
942 }
943
944 pub fn id(&self) -> &SocketType::Id {
946 let Self { id, addr_entry: _, _marker } = self;
947 SocketType::from_socket_id_ref(id)
948 }
949
950 pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
952 let Self { id, addr_entry, _marker } = self;
953
954 let new_addrvec = SocketType::to_addr_vec(&new_addr);
955 let old_addr = addr_entry.key().clone();
956 let (addr_state, addr_to_state) = addr_entry.remove_from_map();
957 let addr_to_state = match addr_to_state.entry(new_addrvec) {
958 Entry::Occupied(o) => o.into_map(),
959 Entry::Vacant(v) => {
960 if v.descendant_counts().len() != 0 {
961 v.into_map()
962 } else {
963 let new_addr_entry = v.insert(addr_state);
964 return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
965 }
966 }
967 };
968 let to_restore = addr_state;
969 let addr_entry = match addr_to_state.entry(old_addr) {
971 Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
972 Entry::Vacant(v) => v.insert(to_restore),
973 };
974 return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
975 }
976
977 pub fn remove(self) {
979 let Self { id, mut addr_entry, _marker } = self;
980 let addr = addr_entry.key().clone();
981 match addr_entry.map_mut(|value| {
982 let value = match SocketType::from_bound_mut(value) {
983 Some(value) => value,
984 None => unreachable!("found {:?} for address {:?}", value, addr),
985 };
986 value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
987 }) {
988 RemoveResult::Success => (),
989 RemoveResult::IsLast => {
990 let _: Bound<S> = addr_entry.remove();
991 }
992 }
993 }
994
995 pub fn try_update_sharing(
997 &mut self,
998 old_sharing_state: &SocketType::SharingState,
999 new_sharing_state: SocketType::SharingState,
1000 ) -> Result<(), UpdateSharingError>
1001 where
1002 SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1003 S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1004 {
1005 let Self { id, addr_entry, _marker } = self;
1006 let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1007
1008 S::allows_sharing_update(
1009 addr_entry.get_map(),
1010 addr,
1011 old_sharing_state,
1012 &new_sharing_state,
1013 )?;
1014
1015 addr_entry
1016 .map_mut(|value| {
1017 let value = match SocketType::from_bound_mut(value) {
1018 Some(value) => value,
1019 None => unreachable!("found invalid state {:?}", value),
1023 };
1024
1025 value.try_update_sharing(
1026 SocketType::from_socket_id_ref(id).clone(),
1027 &new_sharing_state,
1028 )
1029 })
1030 .map_err(|IncompatibleError| UpdateSharingError)
1031 }
1032}
1033
1034impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1035where
1036 AddrVec<I, D, A>: IterShadows,
1037 S: SocketMapStateSpec,
1038{
1039 pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1041 where
1042 S: SocketMapConflictPolicy<
1043 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1044 <S as SocketMapStateSpec>::ListenerSharingState,
1045 I,
1046 D,
1047 A,
1048 >,
1049 S::ListenerAddrState:
1050 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1051 {
1052 let Self { addr_to_state } = self;
1053 Sockets(addr_to_state, Default::default())
1054 }
1055
1056 pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1058 where
1059 S: SocketMapConflictPolicy<
1060 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1061 <S as SocketMapStateSpec>::ListenerSharingState,
1062 I,
1063 D,
1064 A,
1065 >,
1066 S::ListenerAddrState:
1067 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1068 {
1069 let Self { addr_to_state } = self;
1070 Sockets(addr_to_state, Default::default())
1071 }
1072
1073 pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1075 where
1076 S: SocketMapConflictPolicy<
1077 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1078 <S as SocketMapStateSpec>::ConnSharingState,
1079 I,
1080 D,
1081 A,
1082 >,
1083 S::ConnAddrState:
1084 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1085 {
1086 let Self { addr_to_state } = self;
1087 Sockets(addr_to_state, Default::default())
1088 }
1089
1090 pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1092 where
1093 S: SocketMapConflictPolicy<
1094 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1095 <S as SocketMapStateSpec>::ConnSharingState,
1096 I,
1097 D,
1098 A,
1099 >,
1100 S::ConnAddrState:
1101 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1102 {
1103 let Self { addr_to_state } = self;
1104 Sockets(addr_to_state, Default::default())
1105 }
1106
1107 #[cfg(test)]
1108 pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1109 let Self { addr_to_state } = self;
1110 addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1111 }
1112
1113 pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1115 let Self { addr_to_state } = self;
1116 addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1117 }
1118}
1119
1120pub enum FoundSockets<A, It> {
1122 Single(A),
1124 Multicast(It),
1127}
1128
1129#[allow(missing_docs)]
1131#[derive(Debug)]
1132pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1133 Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1134 Conn(
1135 &'a S::ConnAddrState,
1136 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1137 ),
1138}
1139
1140impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1141where
1142 I: BroadcastIpExt<Addr: MulticastAddress>,
1143 D: DeviceIdentifier,
1144 A: SocketMapAddrSpec,
1145 S: SocketMapStateSpec
1146 + SocketMapConflictPolicy<
1147 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1148 <S as SocketMapStateSpec>::ListenerSharingState,
1149 I,
1150 D,
1151 A,
1152 > + SocketMapConflictPolicy<
1153 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1154 <S as SocketMapStateSpec>::ConnSharingState,
1155 I,
1156 D,
1157 A,
1158 >,
1159{
1160 pub fn iter_receivers(
1166 &self,
1167 (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1168 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1169 device: D,
1170 broadcast: Option<I::BroadcastMarker>,
1171 ) -> Option<
1172 FoundSockets<
1173 AddrEntry<'_, I, D, A, S>,
1174 impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1175 >,
1176 > {
1177 let mut matching_entries = AddrVecIter::with_device(
1178 match (src_ip, src_port) {
1179 (Some(specified_src_ip), Some(src_port)) => {
1180 ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1181 .into()
1182 }
1183 _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1184 },
1185 device,
1186 )
1187 .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1188 AddrVec::Listen(l) => {
1189 self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1190 }
1191 AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1192 });
1193
1194 if broadcast.is_some() || dst_ip.addr().is_multicast() {
1195 Some(FoundSockets::Multicast(matching_entries))
1196 } else {
1197 let single_entry: Option<_> = matching_entries.next();
1198 single_entry.map(FoundSockets::Single)
1199 }
1200 }
1201}
1202
1203#[derive(Debug, Eq, PartialEq)]
1205pub enum InsertError {
1206 ShadowAddrExists,
1208 Exists,
1210 ShadowerExists,
1212 IndirectConflict,
1214}
1215
1216pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1219 type Id;
1220 type SharingState;
1221 type Addr: Debug;
1222 type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1223
1224 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1225 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1226 fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1227 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1228 fn to_bound(state: Self::AddrState) -> Bound<S>;
1229 fn to_socket_id(id: Self::Id) -> SocketId<S>;
1230 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1231}
1232
1233impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1234 ConvertSocketMapState<I, D, A, S> for Listener
1235{
1236 type Id = S::ListenerId;
1237 type SharingState = S::ListenerSharingState;
1238 type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1239 type AddrState = S::ListenerAddrState;
1240 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1241 AddrVec::Listen(addr.clone())
1242 }
1243
1244 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1245 match addr {
1246 AddrVec::Listen(l) => l,
1247 AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1248 }
1249 }
1250
1251 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1252 match bound {
1253 Bound::Listen(l) => Some(l),
1254 Bound::Conn(_c) => None,
1255 }
1256 }
1257
1258 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1259 match bound {
1260 Bound::Listen(l) => Some(l),
1261 Bound::Conn(_c) => None,
1262 }
1263 }
1264
1265 fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1266 Bound::Listen(state)
1267 }
1268 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1269 match id {
1270 SocketId::Listener(id) => id,
1271 SocketId::Connection(_) => unreachable!("connection ID for listener"),
1272 }
1273 }
1274 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1275 SocketId::Listener(id)
1276 }
1277}
1278
1279impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1280 ConvertSocketMapState<I, D, A, S> for Connection
1281{
1282 type Id = S::ConnId;
1283 type SharingState = S::ConnSharingState;
1284 type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1285 type AddrState = S::ConnAddrState;
1286 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1287 AddrVec::Conn(addr.clone())
1288 }
1289
1290 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1291 match addr {
1292 AddrVec::Conn(c) => c,
1293 AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1294 }
1295 }
1296
1297 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1298 match bound {
1299 Bound::Listen(_l) => None,
1300 Bound::Conn(c) => Some(c),
1301 }
1302 }
1303
1304 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1305 match bound {
1306 Bound::Listen(_l) => None,
1307 Bound::Conn(c) => Some(c),
1308 }
1309 }
1310
1311 fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1312 Bound::Conn(state)
1313 }
1314
1315 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1316 match id {
1317 SocketId::Connection(id) => id,
1318 SocketId::Listener(_) => unreachable!("listener ID for connection"),
1319 }
1320 }
1321 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1322 SocketId::Connection(id)
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use alloc::collections::HashSet;
1329 use alloc::vec;
1330 use alloc::vec::Vec;
1331
1332 use assert_matches::assert_matches;
1333 use net_declare::{net_ip_v4, net_ip_v6};
1334 use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1335 use test_case::test_case;
1336
1337 use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1338 use crate::testutil::set_logger_for_test;
1339
1340 use super::*;
1341
1342 #[test_case(net_ip_v4!("8.8.8.8"))]
1343 #[test_case(net_ip_v4!("127.0.0.1"))]
1344 #[test_case(net_ip_v4!("127.0.8.9"))]
1345 #[test_case(net_ip_v4!("224.1.2.3"))]
1346 fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1347 let addr = SpecifiedAddr::new(addr).unwrap();
1349 assert_eq!(addr.must_have_zone(), false);
1350 }
1351
1352 #[test_case(net_ip_v6!("1::2:3"), false)]
1353 #[test_case(net_ip_v6!("::1"), false; "localhost")]
1354 #[test_case(net_ip_v6!("1::"), false)]
1355 #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1356 #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1357 #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1358 #[test_case(net_ip_v6!("fe80::1"), true)]
1359 fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1360 let addr = SpecifiedAddr::new(addr).unwrap();
1363 assert_eq!(addr.must_have_zone(), must_have);
1364 }
1365
1366 #[test]
1367 fn try_into_null_zoned_ipv6() {
1368 assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1369 let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1370 const ZONE: u32 = 5;
1371 assert_eq!(
1372 zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1373 Some(AddrAndZone::new(zoned, ZONE).unwrap())
1374 );
1375 }
1376
1377 enum FakeSpec {}
1378
1379 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1380 struct Listener(usize);
1381
1382 #[derive(PartialEq, Eq, Debug)]
1383 struct Multiple<T>(char, Vec<T>);
1384
1385 impl<T> Multiple<T> {
1386 fn tag(&self) -> char {
1387 let Multiple(c, _) = self;
1388 *c
1389 }
1390 }
1391
1392 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1393 struct Conn(usize);
1394
1395 enum FakeAddrSpec {}
1396
1397 impl SocketMapAddrSpec for FakeAddrSpec {
1398 type LocalIdentifier = NonZeroU16;
1399 type RemoteIdentifier = ();
1400 }
1401
1402 impl SocketMapStateSpec for FakeSpec {
1403 type AddrVecTag = char;
1404
1405 type ListenerId = Listener;
1406 type ConnId = Conn;
1407
1408 type ListenerSharingState = char;
1409 type ConnSharingState = char;
1410
1411 type ListenerAddrState = Multiple<Listener>;
1412 type ConnAddrState = Multiple<Conn>;
1413
1414 fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1415 state.tag()
1416 }
1417
1418 fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1419 state.tag()
1420 }
1421 }
1422
1423 type FakeBoundSocketMap =
1424 BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1425
1426 #[derive(Default)]
1430 struct FakeSocketIdGen {
1431 next_id: usize,
1432 }
1433
1434 impl FakeSocketIdGen {
1435 fn next(&mut self) -> usize {
1436 let next_next_id = self.next_id + 1;
1437 core::mem::replace(&mut self.next_id, next_next_id)
1438 }
1439 }
1440
1441 impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1442 type Id = I;
1443 type SharingState = char;
1444 type Inserter<'a>
1445 = &'a mut Vec<I>
1446 where
1447 I: 'a;
1448
1449 fn new(new_sharing_state: &char, id: I) -> Self {
1450 Self(*new_sharing_state, vec![id])
1451 }
1452
1453 fn contains_id(&self, id: &Self::Id) -> bool {
1454 self.1.contains(id)
1455 }
1456
1457 fn try_get_inserter<'a, 'b>(
1458 &'b mut self,
1459 new_state: &'a char,
1460 ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1461 let Self(c, v) = self;
1462 (new_state == c).then_some(v).ok_or(IncompatibleError)
1463 }
1464
1465 fn could_insert(
1466 &self,
1467 new_sharing_state: &Self::SharingState,
1468 ) -> Result<(), IncompatibleError> {
1469 let Self(c, _) = self;
1470 (new_sharing_state == c).then_some(()).ok_or(IncompatibleError)
1471 }
1472
1473 fn remove_by_id(&mut self, id: I) -> RemoveResult {
1474 let Self(_, v) = self;
1475 let index = v.iter().position(|i| i == &id).expect("did not find id");
1476 let _: I = v.swap_remove(index);
1477 if v.is_empty() {
1478 RemoveResult::IsLast
1479 } else {
1480 RemoveResult::Success
1481 }
1482 }
1483 }
1484
1485 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1486 SocketMapConflictPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1487 for FakeSpec
1488 {
1489 fn check_insert_conflicts(
1490 new_state: &char,
1491 addr: &A,
1492 socketmap: &SocketMap<
1493 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1494 Bound<FakeSpec>,
1495 >,
1496 ) -> Result<(), InsertError> {
1497 let dest = addr.clone().into();
1498 if dest.iter_shadows().any(|a| socketmap.get(&a).is_some()) {
1499 return Err(InsertError::ShadowAddrExists);
1500 }
1501 match socketmap.get(&dest) {
1502 Some(Bound::Listen(Multiple(c, _))) | Some(Bound::Conn(Multiple(c, _))) => {
1503 if c != new_state {
1506 return Err(InsertError::Exists);
1507 }
1508 }
1509 None => (),
1510 }
1511 if socketmap.descendant_counts(&dest).len() != 0 {
1512 Err(InsertError::ShadowerExists)
1513 } else {
1514 Ok(())
1515 }
1516 }
1517 }
1518
1519 impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1520 fn try_update_sharing(
1521 &mut self,
1522 id: Self::Id,
1523 new_sharing_state: &Self::SharingState,
1524 ) -> Result<(), IncompatibleError> {
1525 let Self(sharing, v) = self;
1526 if new_sharing_state == sharing {
1527 return Ok(());
1528 }
1529
1530 if v.len() != 1 {
1535 return Err(IncompatibleError);
1536 }
1537 assert!(v.contains(&id));
1538 *sharing = *new_sharing_state;
1539 Ok(())
1540 }
1541 }
1542
1543 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1544 SocketMapUpdateSharingPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1545 for FakeSpec
1546 {
1547 fn allows_sharing_update(
1548 _socketmap: &SocketMap<
1549 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1550 Bound<Self>,
1551 >,
1552 _addr: &A,
1553 _old_sharing: &char,
1554 _new_sharing_state: &char,
1555 ) -> Result<(), UpdateSharingError> {
1556 Ok(())
1557 }
1558 }
1559
1560 const LISTENER_ADDR: ListenerAddr<
1561 ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1562 FakeWeakDeviceId<FakeDeviceId>,
1563 > = ListenerAddr {
1564 ip: ListenerIpAddr {
1565 addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1566 identifier: NonZeroU16::new(1).unwrap(),
1567 },
1568 device: None,
1569 };
1570
1571 const CONN_ADDR: ConnAddr<
1572 ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1573 FakeWeakDeviceId<FakeDeviceId>,
1574 > = ConnAddr {
1575 ip: ConnIpAddr {
1576 local: (
1577 unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1578 NonZeroU16::new(1).unwrap(),
1579 ),
1580 remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1581 },
1582 device: None,
1583 };
1584
1585 #[test]
1586 fn bound_insert_get_remove_listener() {
1587 set_logger_for_test();
1588 let mut bound = FakeBoundSocketMap::default();
1589 let mut fake_id_gen = FakeSocketIdGen::default();
1590
1591 let addr = LISTENER_ADDR;
1592
1593 let id = {
1594 let entry =
1595 bound.listeners_mut().try_insert(addr, 'v', Listener(fake_id_gen.next())).unwrap();
1596 assert_eq!(entry.get_addr(), &addr);
1597 entry.id().clone()
1598 };
1599
1600 assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
1601
1602 assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1603 assert_eq!(bound.listeners().get_by_addr(&addr), None);
1604 }
1605
1606 #[test]
1607 fn bound_insert_get_remove_conn() {
1608 set_logger_for_test();
1609 let mut bound = FakeBoundSocketMap::default();
1610 let mut fake_id_gen = FakeSocketIdGen::default();
1611
1612 let addr = CONN_ADDR;
1613
1614 let id = {
1615 let entry = bound.conns_mut().try_insert(addr, 'v', Conn(fake_id_gen.next())).unwrap();
1616 assert_eq!(entry.get_addr(), &addr);
1617 entry.id().clone()
1618 };
1619
1620 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
1621
1622 assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1623 assert_eq!(bound.conns().get_by_addr(&addr), None);
1624 }
1625
1626 #[test]
1627 fn bound_iter_addrs() {
1628 set_logger_for_test();
1629 let mut bound = FakeBoundSocketMap::default();
1630 let mut fake_id_gen = FakeSocketIdGen::default();
1631
1632 let listener_addrs = [
1633 (Some(net_ip_v4!("1.1.1.1")), 1),
1634 (Some(net_ip_v4!("2.2.2.2")), 2),
1635 (Some(net_ip_v4!("1.1.1.1")), 3),
1636 (None, 4),
1637 ]
1638 .map(|(ip, identifier)| ListenerAddr {
1639 device: None,
1640 ip: ListenerIpAddr {
1641 addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1642 identifier: NonZeroU16::new(identifier).unwrap(),
1643 },
1644 });
1645 let conn_addrs = [
1646 (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1647 (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1648 ]
1649 .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1650 ip: ConnIpAddr {
1651 local: (
1652 SocketIpAddr::new(local_ip).unwrap(),
1653 NonZeroU16::new(local_identifier).unwrap(),
1654 ),
1655 remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1656 },
1657 device: None,
1658 });
1659
1660 for addr in listener_addrs.iter().cloned() {
1661 let _entry =
1662 bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap();
1663 }
1664 for addr in conn_addrs.iter().cloned() {
1665 let _entry = bound.conns_mut().try_insert(addr, 'a', Conn(fake_id_gen.next())).unwrap();
1666 }
1667 let expected_addrs = listener_addrs
1668 .into_iter()
1669 .map(Into::into)
1670 .chain(conn_addrs.into_iter().map(Into::into))
1671 .collect::<HashSet<_>>();
1672
1673 assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1674 }
1675
1676 #[test]
1677 fn try_insert_with_callback_not_called_on_error() {
1678 set_logger_for_test();
1681 let mut bound = FakeBoundSocketMap::default();
1682 let addr = LISTENER_ADDR;
1683
1684 let _: &Listener = bound.listeners_mut().try_insert(addr, 'a', Listener(0)).unwrap().id();
1686
1687 fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1691 panic!("should never be called");
1692 }
1693
1694 assert_matches!(
1695 bound.listeners_mut().try_insert_with(addr, 'b', is_never_called),
1696 Err((InsertError::Exists, _))
1697 );
1698 assert_matches!(
1699 bound.listeners_mut().try_insert_with(
1700 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1701 'b',
1702 is_never_called
1703 ),
1704 Err((InsertError::ShadowAddrExists, _))
1705 );
1706 assert_matches!(
1707 bound.conns_mut().try_insert_with(
1708 ConnAddr {
1709 device: None,
1710 ip: ConnIpAddr {
1711 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1712 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1713 },
1714 },
1715 'b',
1716 is_never_called,
1717 ),
1718 Err((InsertError::ShadowAddrExists, _))
1719 );
1720 }
1721
1722 #[test]
1723 fn insert_listener_conflict_with_listener() {
1724 set_logger_for_test();
1725 let mut bound = FakeBoundSocketMap::default();
1726 let mut fake_id_gen = FakeSocketIdGen::default();
1727 let addr = LISTENER_ADDR;
1728
1729 let _: &Listener =
1730 bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1731 assert_matches!(
1732 bound.listeners_mut().try_insert(addr, 'b', Listener(fake_id_gen.next())),
1733 Err((InsertError::Exists, 'b'))
1734 );
1735 }
1736
1737 #[test]
1738 fn insert_listener_conflict_with_shadower() {
1739 set_logger_for_test();
1740 let mut bound = FakeBoundSocketMap::default();
1741 let mut fake_id_gen = FakeSocketIdGen::default();
1742 let addr = LISTENER_ADDR;
1743 let shadows_addr = {
1744 assert_eq!(addr.device, None);
1745 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1746 };
1747
1748 let _: &Listener =
1749 bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1750 assert_matches!(
1751 bound.listeners_mut().try_insert(shadows_addr, 'b', Listener(fake_id_gen.next())),
1752 Err((InsertError::ShadowAddrExists, 'b'))
1753 );
1754 }
1755
1756 #[test]
1757 fn insert_conn_conflict_with_listener() {
1758 set_logger_for_test();
1759 let mut bound = FakeBoundSocketMap::default();
1760 let mut fake_id_gen = FakeSocketIdGen::default();
1761 let addr = LISTENER_ADDR;
1762 let shadows_addr = ConnAddr {
1763 device: None,
1764 ip: ConnIpAddr {
1765 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1766 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1767 },
1768 };
1769
1770 let _: &Listener =
1771 bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1772 assert_matches!(
1773 bound.conns_mut().try_insert(shadows_addr, 'b', Conn(fake_id_gen.next())),
1774 Err((InsertError::ShadowAddrExists, 'b'))
1775 );
1776 }
1777
1778 #[test]
1779 fn insert_and_remove_listener() {
1780 set_logger_for_test();
1781 let mut bound = FakeBoundSocketMap::default();
1782 let mut fake_id_gen = FakeSocketIdGen::default();
1783 let addr = LISTENER_ADDR;
1784
1785 let a = bound
1786 .listeners_mut()
1787 .try_insert(addr, 'x', Listener(fake_id_gen.next()))
1788 .unwrap()
1789 .id()
1790 .clone();
1791 let b = bound
1792 .listeners_mut()
1793 .try_insert(addr, 'x', Listener(fake_id_gen.next()))
1794 .unwrap()
1795 .id()
1796 .clone();
1797 assert_ne!(a, b);
1798
1799 assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1800 assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
1801 }
1802
1803 #[test]
1804 fn insert_and_remove_conn() {
1805 set_logger_for_test();
1806 let mut bound = FakeBoundSocketMap::default();
1807 let mut fake_id_gen = FakeSocketIdGen::default();
1808 let addr = CONN_ADDR;
1809
1810 let a =
1811 bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
1812 let b =
1813 bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
1814 assert_ne!(a, b);
1815
1816 assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1817 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
1818 }
1819
1820 #[test]
1821 fn update_listener_to_shadowed_addr_fails() {
1822 let mut bound = FakeBoundSocketMap::default();
1823 let mut fake_id_gen = FakeSocketIdGen::default();
1824
1825 let first_addr = LISTENER_ADDR;
1826 let second_addr = ListenerAddr {
1827 ip: ListenerIpAddr {
1828 addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1829 ..LISTENER_ADDR.ip
1830 },
1831 ..LISTENER_ADDR
1832 };
1833 let both_shadow = ListenerAddr {
1834 ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1835 device: None,
1836 };
1837
1838 let first = bound
1839 .listeners_mut()
1840 .try_insert(first_addr, 'a', Listener(fake_id_gen.next()))
1841 .unwrap()
1842 .id()
1843 .clone();
1844 let second = bound
1845 .listeners_mut()
1846 .try_insert(second_addr, 'b', Listener(fake_id_gen.next()))
1847 .unwrap()
1848 .id()
1849 .clone();
1850
1851 let (ExistsError, entry) = bound
1854 .listeners_mut()
1855 .entry(&second, &second_addr)
1856 .unwrap()
1857 .try_update_addr(both_shadow)
1858 .expect_err("update should fail");
1859
1860 assert_eq!(entry.id(), &second);
1862 drop(entry);
1863
1864 let (ExistsError, entry) = bound
1865 .listeners_mut()
1866 .entry(&first, &first_addr)
1867 .unwrap()
1868 .try_update_addr(both_shadow)
1869 .expect_err("update should fail");
1870 assert_eq!(entry.get_addr(), &first_addr);
1871 }
1872
1873 #[test]
1874 fn nonexistent_conn_entry() {
1875 let mut map = FakeBoundSocketMap::default();
1876 let mut fake_id_gen = FakeSocketIdGen::default();
1877 let addr = CONN_ADDR;
1878 let conn_id = map
1879 .conns_mut()
1880 .try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
1881 .expect("failed to insert")
1882 .id()
1883 .clone();
1884 assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
1885
1886 assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
1887 }
1888
1889 #[test]
1890 fn update_conn_sharing() {
1891 let mut map = FakeBoundSocketMap::default();
1892 let mut fake_id_gen = FakeSocketIdGen::default();
1893 let addr = CONN_ADDR;
1894 let mut entry = map
1895 .conns_mut()
1896 .try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
1897 .expect("failed to insert");
1898
1899 entry.try_update_sharing(&'a', 'd').expect("worked");
1900 let mut second_conn = map
1903 .conns_mut()
1904 .try_insert(addr.clone(), 'd', Conn(fake_id_gen.next()))
1905 .expect("can insert");
1906 assert_matches!(second_conn.try_update_sharing(&'d', 'e'), Err(UpdateSharingError));
1907 }
1908}