1use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17 AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::LocalAddressError;
22use crate::data_structures::socketmap::{
23 Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
24};
25use crate::device::{
26 DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
27};
28use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
29use crate::ip::BroadcastIpExt;
30use crate::socket::address::{
31 AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
32};
33
34pub trait DualStackIpExt: Ip {
37 type OtherVersion: DualStackIpExt<OtherVersion = Self>;
39}
40
41impl DualStackIpExt for Ipv4 {
42 type OtherVersion = Ipv6;
43}
44
45impl DualStackIpExt for Ipv6 {
46 type OtherVersion = Ipv4;
47}
48
49pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
51 this_stack: <T as GenericOverIp<I>>::Type,
52 other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
53 _marker: IpVersionMarker<I>,
54}
55
56impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
57 pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
59 where
60 T: GenericOverIp<I, Type = T>,
61 {
62 Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
63 }
64
65 pub fn into_inner(
67 self,
68 ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
69 let Self { this_stack, other_stack, _marker } = self;
70 (this_stack, other_stack)
71 }
72
73 pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
75 self.this_stack
76 }
77
78 pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
80 &self.this_stack
81 }
82
83 pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
85 self.other_stack
86 }
87
88 pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
90 &self.other_stack
91 }
92
93 pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
95 let Self { this_stack, other_stack, _marker } = self;
96 DualStackTuple {
97 this_stack: other_stack,
98 other_stack: this_stack,
99 _marker: IpVersionMarker::new(),
100 }
101 }
102
103 pub fn cast<X>(self) -> DualStackTuple<X, T>
112 where
113 X: DualStackIpExt,
114 T: GenericOverIp<X>
115 + GenericOverIp<X::OtherVersion>
116 + GenericOverIp<Ipv4>
117 + GenericOverIp<Ipv6>,
118 {
119 I::map_ip_in(
120 self,
121 |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
122 |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
123 )
124 }
125}
126
127impl<
128 I: DualStackIpExt,
129 NewIp: DualStackIpExt,
130 T: GenericOverIp<NewIp>
131 + GenericOverIp<NewIp::OtherVersion>
132 + GenericOverIp<I>
133 + GenericOverIp<I::OtherVersion>,
134> GenericOverIp<NewIp> for DualStackTuple<I, T>
135{
136 type Type = DualStackTuple<NewIp, T>;
137}
138
139pub trait SocketIpExt: Ip {
141 const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
143 SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
146 };
147}
148
149impl<I: Ip> SocketIpExt for I {}
150
151#[cfg(test)]
152mod socket_ip_ext_test {
153 use super::*;
154 use ip_test_macro::ip_test;
155
156 #[ip_test(I)]
157 fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
158 let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
163 .expect("loopback address should be a valid SocketIpAddr");
164 }
165}
166
167#[derive(Debug, PartialEq, Eq)]
175pub enum EitherStack<T, O> {
176 ThisStack(T),
178 OtherStack(O),
180}
181
182impl<T, O> Clone for EitherStack<T, O>
183where
184 T: Clone,
185 O: Clone,
186{
187 #[cfg_attr(feature = "instrumented", track_caller)]
188 fn clone(&self) -> Self {
189 match self {
190 Self::ThisStack(t) => Self::ThisStack(t.clone()),
191 Self::OtherStack(t) => Self::OtherStack(t.clone()),
192 }
193 }
194}
195
196#[derive(Debug)]
214#[allow(missing_docs)]
215pub enum MaybeDualStack<DS, NDS> {
216 DualStack(DS),
217 NotDualStack(NDS),
218}
219
220impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
223 for MaybeDualStack<DS, NDS>
224{
225 type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
226}
227
228#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
230#[generic_over_ip()]
231pub enum SetDualStackEnabledError {
232 #[error("a socket can only have dual stack enabled or disabled while unbound")]
234 SocketIsBound,
235 #[error(transparent)]
237 NotCapable(#[from] NotDualStackCapableError),
238}
239
240#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
243#[generic_over_ip()]
244#[error("socket's protocol is not dual-stack capable")]
245pub struct NotDualStackCapableError;
246
247#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
249pub struct Shutdown {
250 pub send: bool,
254 pub receive: bool,
258}
259
260#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
262#[generic_over_ip()]
263pub enum ShutdownType {
264 Send,
266 Receive,
268 SendAndReceive,
270}
271
272impl ShutdownType {
273 pub fn to_send_receive(&self) -> (bool, bool) {
275 match self {
276 Self::Send => (true, false),
277 Self::Receive => (false, true),
278 Self::SendAndReceive => (true, true),
279 }
280 }
281
282 pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
284 match (send, receive) {
285 (true, false) => Some(Self::Send),
286 (false, true) => Some(Self::Receive),
287 (true, true) => Some(Self::SendAndReceive),
288 (false, false) => None,
289 }
290 }
291}
292
293pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
295 fn must_have_zone(&self) -> bool
301 where
302 Self: Copy,
303 {
304 self.try_into_null_zoned().is_some()
305 }
306
307 fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
311 if self.get().is_loopback() {
312 return None;
313 }
314 AddrAndZone::new(self, ())
315 }
316}
317
318impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
319
320pub trait SocketZonedAddrExt<W, A, D> {
322 fn resolve_addr_with_device(
330 self,
331 device: Option<D::Weak>,
332 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
333 where
334 D: StrongDeviceIdentifier;
335}
336
337impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
338where
339 W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
340 A: IpAddress,
341{
342 fn resolve_addr_with_device(
343 self,
344 device: Option<D::Weak>,
345 ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
346 where
347 D: StrongDeviceIdentifier,
348 {
349 let (addr, zone) = self.into_addr_zone();
350 let device = match (zone, device) {
351 (Some(zone), Some(device)) => {
352 if device != zone {
353 return Err(ZonedAddressError::DeviceZoneMismatch);
354 }
355 Some(EitherDeviceId::Strong(zone))
356 }
357 (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
358 (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
359 (None, None) => {
360 if addr.as_ref().must_have_zone() {
361 return Err(ZonedAddressError::RequiredZoneNotProvided);
362 } else {
363 None
364 }
365 }
366 };
367 Ok((addr, device))
368 }
369}
370
371pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
377 pub local_ip: Option<&'a SpecifiedAddr<A>>,
379 pub remote_ip: Option<&'a SpecifiedAddr<A>>,
381 pub old_device: Option<&'a D>,
383}
384
385impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
386 pub fn check_update<N>(
389 self,
390 new_device: Option<&N>,
391 ) -> Result<(), SocketDeviceUpdateNotAllowedError>
392 where
393 D: PartialEq<N>,
394 {
395 let Self { local_ip, remote_ip, old_device } = self;
396 let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
397 || remote_ip.is_some_and(|a| a.must_have_zone());
398
399 if !must_have_zone {
400 return Ok(());
401 }
402
403 let old_device = old_device.unwrap_or_else(|| {
404 panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
405 });
406
407 if new_device.is_some_and(|new_device| old_device == new_device) {
408 Ok(())
409 } else {
410 Err(SocketDeviceUpdateNotAllowedError)
411 }
412 }
413}
414
415pub struct SocketDeviceUpdateNotAllowedError;
417
418pub trait SocketMapAddrSpec {
423 type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
425 type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
427}
428
429pub struct ListenerAddrInfo {
431 pub has_device: bool,
433 pub specified_addr: bool,
436}
437
438impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
439 pub(crate) fn info(&self) -> ListenerAddrInfo {
440 let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
441 ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
442 }
443}
444
445pub trait SocketMapStateSpec {
447 type AddrVecTag: Eq + Copy + Debug + 'static;
452
453 fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
455
456 fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
458
459 type ListenerId: Clone + Debug;
461 type ConnId: Clone + Debug;
463
464 type ListenerSharingState: Clone + Debug;
467
468 type ConnSharingState: Clone + Debug;
471
472 type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
474 + Debug;
475
476 type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
478 + Debug;
479}
480
481#[derive(Copy, Clone, Debug, Eq, PartialEq)]
484pub struct IncompatibleError;
485
486pub trait Inserter<T> {
488 fn insert(self, item: T);
493}
494
495impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
496 fn insert(self, item: T) {
497 self.extend([item])
498 }
499}
500
501impl<T> Inserter<T> for Never {
502 fn insert(self, _: T) {
503 match self {}
504 }
505}
506
507pub trait SocketMapAddrStateSpec {
509 type Id;
511
512 type SharingState;
519
520 type Inserter<'a>: Inserter<Self::Id> + 'a
522 where
523 Self: 'a,
524 Self::Id: 'a;
525
526 fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
529
530 fn contains_id(&self, id: &Self::Id) -> bool;
532
533 fn try_get_inserter<'a, 'b>(
541 &'b mut self,
542 new_sharing_state: &'a Self::SharingState,
543 ) -> Result<Self::Inserter<'b>, IncompatibleError>;
544
545 fn could_insert(&self, new_sharing_state: &Self::SharingState)
550 -> Result<(), IncompatibleError>;
551
552 fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
556}
557
558pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
560 fn try_update_sharing(
563 &mut self,
564 id: Self::Id,
565 new_sharing_state: &Self::SharingState,
566 ) -> Result<(), IncompatibleError>;
567}
568
569pub trait SocketMapConflictPolicy<
571 Addr,
572 SharingState,
573 I: Ip,
574 D: DeviceIdentifier,
575 A: SocketMapAddrSpec,
576>: SocketMapStateSpec
577{
578 fn check_insert_conflicts(
587 new_sharing_state: &SharingState,
588 addr: &Addr,
589 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
590 ) -> Result<(), InsertError>;
591}
592
593pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
596 SocketMapConflictPolicy<Addr, SharingState, I, D, A>
597where
598 A: SocketMapAddrSpec,
599{
600 fn allows_sharing_update(
603 socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
604 addr: &Addr,
605 old_sharing: &SharingState,
606 new_sharing: &SharingState,
607 ) -> Result<(), UpdateSharingError>;
608}
609
610#[derive(Derivative)]
612#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
613#[allow(missing_docs)]
614pub enum Bound<S: SocketMapStateSpec + ?Sized> {
615 Listen(S::ListenerAddrState),
616 Conn(S::ConnAddrState),
617}
618
619#[derive(Derivative)]
634#[derivative(
635 Debug(bound = "D: Debug"),
636 Clone(bound = "D: Clone"),
637 Eq(bound = "D: Eq"),
638 PartialEq(bound = "D: PartialEq"),
639 Hash(bound = "D: Hash")
640)]
641#[allow(missing_docs)]
642pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
643 Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
644 Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
645}
646
647impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
648 Tagged<AddrVec<I, D, A>> for Bound<S>
649{
650 type Tag = S::AddrVecTag;
651 fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
652 match (self, address) {
653 (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
654 (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
655 S::connected_tag(device.is_some(), c)
656 }
657 (Bound::Listen(_), AddrVec::Conn(_)) => {
658 unreachable!("found listen state for conn addr")
659 }
660 (Bound::Conn(_), AddrVec::Listen(_)) => {
661 unreachable!("found conn state for listen addr")
662 }
663 }
664 }
665}
666
667impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
668 type IterShadows = AddrVecIter<I, D, A>;
669
670 fn iter_shadows(&self) -> Self::IterShadows {
671 let (socket_ip_addr, device) = match self.clone() {
672 AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
673 AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
674 };
675 let mut iter = match device {
676 Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
677 None => AddrVecIter::without_device(socket_ip_addr),
678 };
679 assert_eq!(iter.next().as_ref(), Some(self));
681 iter
682 }
683}
684
685#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
687#[allow(missing_docs)]
688pub enum SocketAddrType {
689 AnyListener,
690 SpecificListener,
691 Connected,
692}
693
694impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
695 fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
696 match addr {
697 Some(_) => SocketAddrType::SpecificListener,
698 None => SocketAddrType::AnyListener,
699 }
700 }
701}
702
703impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
704 fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
705 SocketAddrType::Connected
706 }
707}
708
709pub enum RemoveResult {
711 Success,
713 IsLast,
716}
717
718#[derive(Derivative)]
719#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
720pub enum SocketId<S: SocketMapStateSpec> {
721 Listener(S::ListenerId),
722 Connection(S::ConnId),
723}
724
725#[derive(Derivative)]
739#[derivative(Default(bound = ""))]
740pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
741 addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
742}
743
744impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
745 BoundSocketMap<I, D, A, S>
746{
747 pub fn len(&self) -> usize {
749 self.addr_to_state.len()
750 }
751}
752
753pub enum Listener {}
755pub enum Connection {}
757
758pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
760
761impl<
762 'a,
763 I: Ip,
764 D: DeviceIdentifier,
765 SocketType: ConvertSocketMapState<I, D, A, S>,
766 A: SocketMapAddrSpec,
767 S: SocketMapStateSpec,
768> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
769where
770 S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
771{
772 pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
774 let Self(addr_to_state, _marker) = self;
775 addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
776 SocketType::from_bound_ref(state)
777 .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
778 })
779 }
780
781 pub fn could_insert(
787 self,
788 addr: &SocketType::Addr,
789 sharing: &SocketType::SharingState,
790 ) -> Result<(), InsertError> {
791 let Self(addr_to_state, _) = self;
792 match self.get_by_addr(addr) {
793 Some(state) => {
794 state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
795 }
796 None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
797 }
798 }
799}
800
801#[derive(Derivative)]
803#[derivative(Debug(bound = ""))]
804pub struct SocketStateEntry<
805 'a,
806 I: Ip,
807 D: DeviceIdentifier,
808 A: SocketMapAddrSpec,
809 S: SocketMapStateSpec,
810 SocketType,
811> {
812 id: SocketId<S>,
813 addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
814 _marker: PhantomData<SocketType>,
815}
816
817impl<
818 'a,
819 I: Ip,
820 D: DeviceIdentifier,
821 SocketType: ConvertSocketMapState<I, D, A, S>,
822 A: SocketMapAddrSpec,
823 S: SocketMapStateSpec
824 + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
825> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
826where
827 SocketType::SharingState: Clone,
828 SocketType::Id: Clone,
829{
830 pub fn try_insert(
833 self,
834 socket_addr: SocketType::Addr,
835 tag_state: SocketType::SharingState,
836 id: SocketType::Id,
837 ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, InsertError> {
838 self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
839 .map(|(entry, ())| entry)
840 }
841
842 pub fn try_insert_with<R>(
847 self,
848 socket_addr: SocketType::Addr,
849 tag_state: SocketType::SharingState,
850 make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
851 ) -> Result<(SocketStateEntry<'a, I, D, A, S, SocketType>, R), InsertError> {
852 let Self(addr_to_state, _) = self;
853 S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state)?;
854
855 let addr = SocketType::to_addr_vec(&socket_addr);
856
857 match addr_to_state.entry(addr) {
858 Entry::Occupied(mut o) => {
859 let (id, ret) = o.map_mut(|bound| {
860 let bound = match SocketType::from_bound_mut(bound) {
861 Some(bound) => bound,
862 None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
863 };
864 match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
865 bound, &tag_state,
866 ) {
867 Ok(v) => {
868 let (id, ret) = make_id(socket_addr, tag_state);
869 v.insert(id.clone());
870 Ok((SocketType::to_socket_id(id), ret))
871 }
872 Err(IncompatibleError) => Err(InsertError::Exists),
873 }
874 })?;
875 Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
876 }
877 Entry::Vacant(v) => {
878 let (id, ret) = make_id(socket_addr, tag_state.clone());
879 let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
880 &tag_state,
881 id.clone(),
882 )));
883 let id = SocketType::to_socket_id(id);
884 Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
885 }
886 }
887 }
888
889 pub fn entry(
891 self,
892 id: &SocketType::Id,
893 addr: &SocketType::Addr,
894 ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
895 let Self(addr_to_state, _) = self;
896 let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
897 Entry::Vacant(_) => return None,
898 Entry::Occupied(o) => o,
899 };
900 let state = SocketType::from_bound_ref(addr_entry.get())?;
901
902 state.contains_id(id).then_some(SocketStateEntry {
903 id: SocketType::to_socket_id(id.clone()),
904 addr_entry,
905 _marker: PhantomData::default(),
906 })
907 }
908
909 pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
911 self.entry(id, addr)
912 .map(|entry| {
913 entry.remove();
914 })
915 .ok_or(NotFoundError)
916 }
917}
918
919#[derive(Debug)]
922pub struct UpdateSharingError;
923
924impl<
925 'a,
926 I: Ip,
927 D: DeviceIdentifier,
928 SocketType: ConvertSocketMapState<I, D, A, S>,
929 A: SocketMapAddrSpec,
930 S: SocketMapStateSpec,
931> SocketStateEntry<'a, I, D, A, S, SocketType>
932where
933 SocketType::Id: Clone,
934{
935 pub fn get_addr(&self) -> &SocketType::Addr {
937 let Self { id: _, addr_entry, _marker } = self;
938 SocketType::from_addr_vec_ref(addr_entry.key())
939 }
940
941 pub fn id(&self) -> &SocketType::Id {
943 let Self { id, addr_entry: _, _marker } = self;
944 SocketType::from_socket_id_ref(id)
945 }
946
947 pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
949 let Self { id, addr_entry, _marker } = self;
950
951 let new_addrvec = SocketType::to_addr_vec(&new_addr);
952 let old_addr = addr_entry.key().clone();
953 let (addr_state, addr_to_state) = addr_entry.remove_from_map();
954 let addr_to_state = match addr_to_state.entry(new_addrvec) {
955 Entry::Occupied(o) => o.into_map(),
956 Entry::Vacant(v) => {
957 if v.descendant_counts().len() != 0 {
958 v.into_map()
959 } else {
960 let new_addr_entry = v.insert(addr_state);
961 return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
962 }
963 }
964 };
965 let to_restore = addr_state;
966 let addr_entry = match addr_to_state.entry(old_addr) {
968 Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
969 Entry::Vacant(v) => v.insert(to_restore),
970 };
971 return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
972 }
973
974 pub fn remove(self) {
976 let Self { id, mut addr_entry, _marker } = self;
977 let addr = addr_entry.key().clone();
978 match addr_entry.map_mut(|value| {
979 let value = match SocketType::from_bound_mut(value) {
980 Some(value) => value,
981 None => unreachable!("found {:?} for address {:?}", value, addr),
982 };
983 value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
984 }) {
985 RemoveResult::Success => (),
986 RemoveResult::IsLast => {
987 let _: Bound<S> = addr_entry.remove();
988 }
989 }
990 }
991
992 pub fn try_update_sharing(
994 &mut self,
995 old_sharing_state: &SocketType::SharingState,
996 new_sharing_state: SocketType::SharingState,
997 ) -> Result<(), UpdateSharingError>
998 where
999 SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1000 S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1001 {
1002 let Self { id, addr_entry, _marker } = self;
1003 let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1004
1005 S::allows_sharing_update(
1006 addr_entry.get_map(),
1007 addr,
1008 old_sharing_state,
1009 &new_sharing_state,
1010 )?;
1011
1012 addr_entry
1013 .map_mut(|value| {
1014 let value = match SocketType::from_bound_mut(value) {
1015 Some(value) => value,
1016 None => unreachable!("found invalid state {:?}", value),
1020 };
1021
1022 value.try_update_sharing(
1023 SocketType::from_socket_id_ref(id).clone(),
1024 &new_sharing_state,
1025 )
1026 })
1027 .map_err(|IncompatibleError| UpdateSharingError)
1028 }
1029}
1030
1031impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1032where
1033 AddrVec<I, D, A>: IterShadows,
1034 S: SocketMapStateSpec,
1035{
1036 pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1038 where
1039 S: SocketMapConflictPolicy<
1040 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1041 <S as SocketMapStateSpec>::ListenerSharingState,
1042 I,
1043 D,
1044 A,
1045 >,
1046 S::ListenerAddrState:
1047 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1048 {
1049 let Self { addr_to_state } = self;
1050 Sockets(addr_to_state, Default::default())
1051 }
1052
1053 pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1055 where
1056 S: SocketMapConflictPolicy<
1057 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1058 <S as SocketMapStateSpec>::ListenerSharingState,
1059 I,
1060 D,
1061 A,
1062 >,
1063 S::ListenerAddrState:
1064 SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1065 {
1066 let Self { addr_to_state } = self;
1067 Sockets(addr_to_state, Default::default())
1068 }
1069
1070 pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1072 where
1073 S: SocketMapConflictPolicy<
1074 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1075 <S as SocketMapStateSpec>::ConnSharingState,
1076 I,
1077 D,
1078 A,
1079 >,
1080 S::ConnAddrState:
1081 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1082 {
1083 let Self { addr_to_state } = self;
1084 Sockets(addr_to_state, Default::default())
1085 }
1086
1087 pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1089 where
1090 S: SocketMapConflictPolicy<
1091 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1092 <S as SocketMapStateSpec>::ConnSharingState,
1093 I,
1094 D,
1095 A,
1096 >,
1097 S::ConnAddrState:
1098 SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1099 {
1100 let Self { addr_to_state } = self;
1101 Sockets(addr_to_state, Default::default())
1102 }
1103
1104 #[cfg(test)]
1105 pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1106 let Self { addr_to_state } = self;
1107 addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1108 }
1109
1110 pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1112 let Self { addr_to_state } = self;
1113 addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1114 }
1115}
1116
1117pub enum FoundSockets<A, It> {
1119 Single(A),
1121 Multicast(It),
1124}
1125
1126#[allow(missing_docs)]
1128#[derive(Debug)]
1129pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1130 Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1131 Conn(
1132 &'a S::ConnAddrState,
1133 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1134 ),
1135}
1136
1137impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1138where
1139 I: BroadcastIpExt<Addr: MulticastAddress>,
1140 D: DeviceIdentifier,
1141 A: SocketMapAddrSpec,
1142 S: SocketMapStateSpec
1143 + SocketMapConflictPolicy<
1144 ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1145 <S as SocketMapStateSpec>::ListenerSharingState,
1146 I,
1147 D,
1148 A,
1149 > + SocketMapConflictPolicy<
1150 ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1151 <S as SocketMapStateSpec>::ConnSharingState,
1152 I,
1153 D,
1154 A,
1155 >,
1156{
1157 pub fn lookup_connected(
1163 &self,
1164 (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1165 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1166 device: D,
1167 ) -> Option<&'_ S::ConnAddrState> {
1168 let mut addr = ConnAddr {
1169 ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1170 device: Some(device),
1171 };
1172 let entry = self.conns().get_by_addr(&addr);
1173 if entry.is_some() {
1174 return entry;
1175 }
1176 addr.device = None;
1177 self.conns().get_by_addr(&addr)
1178 }
1179
1180 pub fn iter_receivers(
1186 &self,
1187 (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1188 (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1189 device: D,
1190 broadcast: Option<I::BroadcastMarker>,
1191 ) -> Option<
1192 FoundSockets<
1193 AddrEntry<'_, I, D, A, S>,
1194 impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1195 >,
1196 > {
1197 let mut matching_entries = AddrVecIter::with_device(
1198 match (src_ip, src_port) {
1199 (Some(specified_src_ip), Some(src_port)) => {
1200 ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1201 .into()
1202 }
1203 _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1204 },
1205 device,
1206 )
1207 .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1208 AddrVec::Listen(l) => {
1209 self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1210 }
1211 AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1212 });
1213
1214 if broadcast.is_some() || dst_ip.addr().is_multicast() {
1215 Some(FoundSockets::Multicast(matching_entries))
1216 } else {
1217 let single_entry: Option<_> = matching_entries.next();
1218 single_entry.map(FoundSockets::Single)
1219 }
1220 }
1221}
1222
1223#[derive(Debug, Eq, PartialEq)]
1225pub enum InsertError {
1226 ShadowAddrExists,
1228 Exists,
1230 ShadowerExists,
1232 IndirectConflict,
1234}
1235
1236impl From<InsertError> for LocalAddressError {
1237 fn from(value: InsertError) -> Self {
1238 match value {
1239 InsertError::ShadowAddrExists
1240 | InsertError::Exists
1241 | InsertError::IndirectConflict
1242 | InsertError::ShadowerExists => LocalAddressError::AddressInUse,
1243 }
1244 }
1245}
1246
1247pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1250 type Id;
1251 type SharingState;
1252 type Addr: Debug;
1253 type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1254
1255 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1256 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1257 fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1258 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1259 fn to_bound(state: Self::AddrState) -> Bound<S>;
1260 fn to_socket_id(id: Self::Id) -> SocketId<S>;
1261 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1262}
1263
1264impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1265 ConvertSocketMapState<I, D, A, S> for Listener
1266{
1267 type Id = S::ListenerId;
1268 type SharingState = S::ListenerSharingState;
1269 type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1270 type AddrState = S::ListenerAddrState;
1271 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1272 AddrVec::Listen(addr.clone())
1273 }
1274
1275 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1276 match addr {
1277 AddrVec::Listen(l) => l,
1278 AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1279 }
1280 }
1281
1282 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1283 match bound {
1284 Bound::Listen(l) => Some(l),
1285 Bound::Conn(_c) => None,
1286 }
1287 }
1288
1289 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1290 match bound {
1291 Bound::Listen(l) => Some(l),
1292 Bound::Conn(_c) => None,
1293 }
1294 }
1295
1296 fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1297 Bound::Listen(state)
1298 }
1299 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1300 match id {
1301 SocketId::Listener(id) => id,
1302 SocketId::Connection(_) => unreachable!("connection ID for listener"),
1303 }
1304 }
1305 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1306 SocketId::Listener(id)
1307 }
1308}
1309
1310impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1311 ConvertSocketMapState<I, D, A, S> for Connection
1312{
1313 type Id = S::ConnId;
1314 type SharingState = S::ConnSharingState;
1315 type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1316 type AddrState = S::ConnAddrState;
1317 fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1318 AddrVec::Conn(addr.clone())
1319 }
1320
1321 fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1322 match addr {
1323 AddrVec::Conn(c) => c,
1324 AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1325 }
1326 }
1327
1328 fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1329 match bound {
1330 Bound::Listen(_l) => None,
1331 Bound::Conn(c) => Some(c),
1332 }
1333 }
1334
1335 fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1336 match bound {
1337 Bound::Listen(_l) => None,
1338 Bound::Conn(c) => Some(c),
1339 }
1340 }
1341
1342 fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1343 Bound::Conn(state)
1344 }
1345
1346 fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1347 match id {
1348 SocketId::Connection(id) => id,
1349 SocketId::Listener(_) => unreachable!("listener ID for connection"),
1350 }
1351 }
1352 fn to_socket_id(id: Self::Id) -> SocketId<S> {
1353 SocketId::Connection(id)
1354 }
1355}
1356
1357#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1359pub struct SharingDomain(u64);
1360
1361impl SharingDomain {
1362 pub const fn new(id: u64) -> Self {
1366 SharingDomain(id)
1367 }
1368}
1369
1370#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1373pub enum ReusePortOption {
1374 #[default]
1376 Disabled,
1377
1378 Enabled(SharingDomain),
1381}
1382
1383impl ReusePortOption {
1384 pub fn is_enabled(&self) -> bool {
1386 matches!(self, ReusePortOption::Enabled(_))
1387 }
1388
1389 pub fn is_shareable_with(&self, other: &Self) -> bool {
1392 match (self, other) {
1393 (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1394 domain1 == domain2
1395 }
1396 _ => false,
1397 }
1398 }
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403 use alloc::vec;
1404 use alloc::vec::Vec;
1405
1406 use assert_matches::assert_matches;
1407 use net_declare::{net_ip_v4, net_ip_v6};
1408 use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1409 use netstack3_hashmap::HashSet;
1410 use test_case::test_case;
1411
1412 use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1413 use crate::testutil::set_logger_for_test;
1414
1415 use super::*;
1416
1417 #[test_case(net_ip_v4!("8.8.8.8"))]
1418 #[test_case(net_ip_v4!("127.0.0.1"))]
1419 #[test_case(net_ip_v4!("127.0.8.9"))]
1420 #[test_case(net_ip_v4!("224.1.2.3"))]
1421 fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1422 let addr = SpecifiedAddr::new(addr).unwrap();
1424 assert_eq!(addr.must_have_zone(), false);
1425 }
1426
1427 #[test_case(net_ip_v6!("1::2:3"), false)]
1428 #[test_case(net_ip_v6!("::1"), false; "localhost")]
1429 #[test_case(net_ip_v6!("1::"), false)]
1430 #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1431 #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1432 #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1433 #[test_case(net_ip_v6!("fe80::1"), true)]
1434 fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1435 let addr = SpecifiedAddr::new(addr).unwrap();
1438 assert_eq!(addr.must_have_zone(), must_have);
1439 }
1440
1441 #[test]
1442 fn try_into_null_zoned_ipv6() {
1443 assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1444 let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1445 const ZONE: u32 = 5;
1446 assert_eq!(
1447 zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1448 Some(AddrAndZone::new(zoned, ZONE).unwrap())
1449 );
1450 }
1451
1452 enum FakeSpec {}
1453
1454 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1455 struct Listener(usize);
1456
1457 #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1458 struct SharingState {
1459 tag: char,
1460 shared: bool,
1461 }
1462
1463 impl SharingState {
1464 fn exclusive(tag: char) -> Self {
1465 Self { tag, shared: false }
1466 }
1467
1468 fn shared(tag: char) -> Self {
1469 Self { tag, shared: true }
1470 }
1471 }
1472
1473 impl SharingState {
1474 fn can_share_with(&self, other: &Self) -> bool {
1475 self.tag == other.tag && self.shared && other.shared
1476 }
1477 }
1478
1479 #[derive(PartialEq, Eq, Debug)]
1480 struct Multiple<T> {
1481 sharing_state: SharingState,
1482 entries: Vec<T>,
1483 }
1484
1485 impl<T> Multiple<T> {
1486 fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1487 Self { sharing_state: SharingState { tag, shared: false }, entries }
1488 }
1489 }
1490
1491 #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1492 struct Conn(usize);
1493
1494 enum FakeAddrSpec {}
1495
1496 impl SocketMapAddrSpec for FakeAddrSpec {
1497 type LocalIdentifier = NonZeroU16;
1498 type RemoteIdentifier = ();
1499 }
1500
1501 impl SocketMapStateSpec for FakeSpec {
1502 type AddrVecTag = SharingState;
1503
1504 type ListenerId = Listener;
1505 type ConnId = Conn;
1506
1507 type ListenerSharingState = SharingState;
1508 type ConnSharingState = SharingState;
1509
1510 type ListenerAddrState = Multiple<Listener>;
1511 type ConnAddrState = Multiple<Conn>;
1512
1513 fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1514 state.sharing_state
1515 }
1516
1517 fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1518 state.sharing_state
1519 }
1520 }
1521
1522 type FakeBoundSocketMap =
1523 BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1524
1525 #[derive(Default)]
1529 struct FakeSocketIdGen {
1530 next_id: usize,
1531 }
1532
1533 impl FakeSocketIdGen {
1534 fn next(&mut self) -> usize {
1535 let next_next_id = self.next_id + 1;
1536 core::mem::replace(&mut self.next_id, next_next_id)
1537 }
1538 }
1539
1540 impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1541 type Id = I;
1542 type SharingState = SharingState;
1543 type Inserter<'a>
1544 = &'a mut Vec<I>
1545 where
1546 I: 'a;
1547
1548 fn new(sharing_state: &SharingState, id: I) -> Self {
1549 Self { sharing_state: *sharing_state, entries: vec![id] }
1550 }
1551
1552 fn contains_id(&self, id: &Self::Id) -> bool {
1553 self.entries.contains(id)
1554 }
1555
1556 fn try_get_inserter<'a, 'b>(
1557 &'b mut self,
1558 new_sharing_state: &'a SharingState,
1559 ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1560 (self.sharing_state == *new_sharing_state)
1561 .then_some(&mut self.entries)
1562 .ok_or(IncompatibleError)
1563 }
1564
1565 fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1566 (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1567 }
1568
1569 fn remove_by_id(&mut self, id: I) -> RemoveResult {
1570 let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1571 let _: I = self.entries.swap_remove(index);
1572 if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1573 }
1574 }
1575
1576 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1577 SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1578 for FakeSpec
1579 {
1580 fn check_insert_conflicts(
1581 new_sharing_state: &SharingState,
1582 addr: &A,
1583 socketmap: &SocketMap<
1584 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1585 Bound<FakeSpec>,
1586 >,
1587 ) -> Result<(), InsertError> {
1588 let dest: AddrVec<_, _, _> = addr.clone().into();
1589 if dest.iter_shadows().any(|a| {
1590 let entry = socketmap.get(&a);
1591 match entry {
1592 Some(Bound::Listen(Multiple { sharing_state, .. }))
1593 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1594 !sharing_state.can_share_with(new_sharing_state)
1595 }
1596 None => false,
1597 }
1598 }) {
1599 return Err(InsertError::ShadowAddrExists);
1600 }
1601
1602 match socketmap.get(&dest) {
1603 Some(Bound::Listen(Multiple { sharing_state, .. }))
1604 | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1605 if sharing_state != new_sharing_state {
1608 return Err(InsertError::Exists);
1609 }
1610 }
1611 None => (),
1612 }
1613
1614 if socketmap
1615 .descendant_counts(&dest)
1616 .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1617 {
1618 Err(InsertError::ShadowerExists)
1619 } else {
1620 Ok(())
1621 }
1622 }
1623 }
1624
1625 impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1626 fn try_update_sharing(
1627 &mut self,
1628 id: Self::Id,
1629 new_sharing_state: &Self::SharingState,
1630 ) -> Result<(), IncompatibleError> {
1631 if self.sharing_state == *new_sharing_state {
1632 return Ok(());
1633 }
1634
1635 if self.entries.len() != 1 {
1640 return Err(IncompatibleError);
1641 }
1642 assert!(self.entries.contains(&id));
1643 self.sharing_state = *new_sharing_state;
1644 Ok(())
1645 }
1646 }
1647
1648 impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1649 SocketMapUpdateSharingPolicy<
1650 A,
1651 SharingState,
1652 Ipv4,
1653 FakeWeakDeviceId<FakeDeviceId>,
1654 FakeAddrSpec,
1655 > for FakeSpec
1656 {
1657 fn allows_sharing_update(
1658 _socketmap: &SocketMap<
1659 AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1660 Bound<Self>,
1661 >,
1662 _addr: &A,
1663 _old_sharing: &SharingState,
1664 _new_sharing_state: &SharingState,
1665 ) -> Result<(), UpdateSharingError> {
1666 Ok(())
1667 }
1668 }
1669
1670 const LISTENER_ADDR: ListenerAddr<
1671 ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1672 FakeWeakDeviceId<FakeDeviceId>,
1673 > = ListenerAddr {
1674 ip: ListenerIpAddr {
1675 addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1676 identifier: NonZeroU16::new(1).unwrap(),
1677 },
1678 device: None,
1679 };
1680
1681 const CONN_ADDR: ConnAddr<
1682 ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1683 FakeWeakDeviceId<FakeDeviceId>,
1684 > = ConnAddr {
1685 ip: ConnIpAddr {
1686 local: (
1687 unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1688 NonZeroU16::new(1).unwrap(),
1689 ),
1690 remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1691 },
1692 device: None,
1693 };
1694
1695 #[test]
1696 fn bound_insert_get_remove_listener() {
1697 set_logger_for_test();
1698 let mut bound = FakeBoundSocketMap::default();
1699 let mut fake_id_gen = FakeSocketIdGen::default();
1700
1701 let addr = LISTENER_ADDR;
1702
1703 let id = {
1704 let entry = bound
1705 .listeners_mut()
1706 .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1707 .unwrap();
1708 assert_eq!(entry.get_addr(), &addr);
1709 entry.id().clone()
1710 };
1711
1712 assert_eq!(
1713 bound.listeners().get_by_addr(&addr),
1714 Some(&Multiple::new_exclusive('v', vec![id]))
1715 );
1716
1717 assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1718 assert_eq!(bound.listeners().get_by_addr(&addr), None);
1719 }
1720
1721 #[test]
1722 fn bound_insert_get_remove_conn() {
1723 set_logger_for_test();
1724 let mut bound = FakeBoundSocketMap::default();
1725 let mut fake_id_gen = FakeSocketIdGen::default();
1726
1727 let addr = CONN_ADDR;
1728
1729 let id = {
1730 let entry = bound
1731 .conns_mut()
1732 .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1733 .unwrap();
1734 assert_eq!(entry.get_addr(), &addr);
1735 entry.id().clone()
1736 };
1737
1738 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1739
1740 assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1741 assert_eq!(bound.conns().get_by_addr(&addr), None);
1742 }
1743
1744 #[test]
1745 fn bound_iter_addrs() {
1746 set_logger_for_test();
1747 let mut bound = FakeBoundSocketMap::default();
1748 let mut fake_id_gen = FakeSocketIdGen::default();
1749
1750 let listener_addrs = [
1751 (Some(net_ip_v4!("1.1.1.1")), 1),
1752 (Some(net_ip_v4!("2.2.2.2")), 2),
1753 (Some(net_ip_v4!("1.1.1.1")), 3),
1754 (None, 4),
1755 ]
1756 .map(|(ip, identifier)| ListenerAddr {
1757 device: None,
1758 ip: ListenerIpAddr {
1759 addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1760 identifier: NonZeroU16::new(identifier).unwrap(),
1761 },
1762 });
1763 let conn_addrs = [
1764 (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1765 (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1766 ]
1767 .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1768 ip: ConnIpAddr {
1769 local: (
1770 SocketIpAddr::new(local_ip).unwrap(),
1771 NonZeroU16::new(local_identifier).unwrap(),
1772 ),
1773 remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1774 },
1775 device: None,
1776 });
1777
1778 for addr in listener_addrs.iter().cloned() {
1779 let _entry = bound
1780 .listeners_mut()
1781 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1782 .unwrap();
1783 }
1784 for addr in conn_addrs.iter().cloned() {
1785 let _entry = bound
1786 .conns_mut()
1787 .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1788 .unwrap();
1789 }
1790 let expected_addrs = listener_addrs
1791 .into_iter()
1792 .map(Into::into)
1793 .chain(conn_addrs.into_iter().map(Into::into))
1794 .collect::<HashSet<_>>();
1795
1796 assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1797 }
1798
1799 #[test]
1800 fn try_insert_with_callback_not_called_on_error() {
1801 set_logger_for_test();
1804 let mut bound = FakeBoundSocketMap::default();
1805 let addr = LISTENER_ADDR;
1806
1807 let _: &Listener = bound
1809 .listeners_mut()
1810 .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1811 .unwrap()
1812 .id();
1813
1814 fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1818 panic!("should never be called");
1819 }
1820
1821 assert_matches!(
1822 bound.listeners_mut().try_insert_with(
1823 addr,
1824 SharingState::exclusive('b'),
1825 is_never_called
1826 ),
1827 Err(InsertError::Exists)
1828 );
1829 assert_matches!(
1830 bound.listeners_mut().try_insert_with(
1831 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1832 SharingState::exclusive('b'),
1833 is_never_called
1834 ),
1835 Err(InsertError::ShadowAddrExists)
1836 );
1837 assert_matches!(
1838 bound.conns_mut().try_insert_with(
1839 ConnAddr {
1840 device: None,
1841 ip: ConnIpAddr {
1842 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1843 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1844 },
1845 },
1846 SharingState::exclusive('b'),
1847 is_never_called,
1848 ),
1849 Err(InsertError::ShadowAddrExists)
1850 );
1851 }
1852
1853 #[test]
1854 fn insert_listener_conflict_with_listener() {
1855 set_logger_for_test();
1856 let mut bound = FakeBoundSocketMap::default();
1857 let mut fake_id_gen = FakeSocketIdGen::default();
1858 let addr = LISTENER_ADDR;
1859
1860 let _: &Listener = bound
1861 .listeners_mut()
1862 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1863 .unwrap()
1864 .id();
1865 assert_matches!(
1866 bound.listeners_mut().try_insert(
1867 addr,
1868 SharingState::exclusive('b'),
1869 Listener(fake_id_gen.next())
1870 ),
1871 Err(InsertError::Exists)
1872 );
1873 }
1874
1875 #[test]
1876 fn insert_listener_conflict_with_shadower() {
1877 set_logger_for_test();
1878 let mut bound = FakeBoundSocketMap::default();
1879 let mut fake_id_gen = FakeSocketIdGen::default();
1880 let addr = LISTENER_ADDR;
1881 let shadows_addr = {
1882 assert_eq!(addr.device, None);
1883 ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1884 };
1885
1886 let _: &Listener = bound
1887 .listeners_mut()
1888 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1889 .unwrap()
1890 .id();
1891 assert_matches!(
1892 bound.listeners_mut().try_insert(
1893 shadows_addr,
1894 SharingState::exclusive('b'),
1895 Listener(fake_id_gen.next())
1896 ),
1897 Err(InsertError::ShadowAddrExists)
1898 );
1899 }
1900
1901 #[test]
1902 fn insert_conn_conflict_with_listener() {
1903 set_logger_for_test();
1904 let mut bound = FakeBoundSocketMap::default();
1905 let mut fake_id_gen = FakeSocketIdGen::default();
1906 let addr = LISTENER_ADDR;
1907 let shadows_addr = ConnAddr {
1908 device: None,
1909 ip: ConnIpAddr {
1910 local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1911 remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1912 },
1913 };
1914
1915 let _: &Listener = bound
1916 .listeners_mut()
1917 .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1918 .unwrap()
1919 .id();
1920 assert_matches!(
1921 bound.conns_mut().try_insert(
1922 shadows_addr,
1923 SharingState::exclusive('b'),
1924 Conn(fake_id_gen.next())
1925 ),
1926 Err(InsertError::ShadowAddrExists)
1927 );
1928 }
1929
1930 #[test]
1931 fn insert_and_remove_listener() {
1932 set_logger_for_test();
1933 let mut bound = FakeBoundSocketMap::default();
1934 let mut fake_id_gen = FakeSocketIdGen::default();
1935 let addr = LISTENER_ADDR;
1936
1937 let a = bound
1938 .listeners_mut()
1939 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1940 .unwrap()
1941 .id()
1942 .clone();
1943 let b = bound
1944 .listeners_mut()
1945 .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1946 .unwrap()
1947 .id()
1948 .clone();
1949 assert_ne!(a, b);
1950
1951 assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1952 assert_eq!(
1953 bound.listeners().get_by_addr(&addr),
1954 Some(&Multiple::new_exclusive('x', vec![b]))
1955 );
1956 }
1957
1958 #[test]
1959 fn insert_and_remove_conn() {
1960 set_logger_for_test();
1961 let mut bound = FakeBoundSocketMap::default();
1962 let mut fake_id_gen = FakeSocketIdGen::default();
1963 let addr = CONN_ADDR;
1964
1965 let a = bound
1966 .conns_mut()
1967 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1968 .unwrap()
1969 .id()
1970 .clone();
1971 let b = bound
1972 .conns_mut()
1973 .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1974 .unwrap()
1975 .id()
1976 .clone();
1977 assert_ne!(a, b);
1978
1979 assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1980 assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
1981 }
1982
1983 #[test]
1984 fn update_listener_to_shadowed_addr_fails() {
1985 let mut bound = FakeBoundSocketMap::default();
1986 let mut fake_id_gen = FakeSocketIdGen::default();
1987
1988 let first_addr = LISTENER_ADDR;
1989 let second_addr = ListenerAddr {
1990 ip: ListenerIpAddr {
1991 addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1992 ..LISTENER_ADDR.ip
1993 },
1994 ..LISTENER_ADDR
1995 };
1996 let both_shadow = ListenerAddr {
1997 ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1998 device: None,
1999 };
2000
2001 let first = bound
2002 .listeners_mut()
2003 .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
2004 .unwrap()
2005 .id()
2006 .clone();
2007 let second = bound
2008 .listeners_mut()
2009 .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2010 .unwrap()
2011 .id()
2012 .clone();
2013
2014 let (ExistsError, entry) = bound
2017 .listeners_mut()
2018 .entry(&second, &second_addr)
2019 .unwrap()
2020 .try_update_addr(both_shadow)
2021 .expect_err("update should fail");
2022
2023 assert_eq!(entry.id(), &second);
2025 drop(entry);
2026
2027 let (ExistsError, entry) = bound
2028 .listeners_mut()
2029 .entry(&first, &first_addr)
2030 .unwrap()
2031 .try_update_addr(both_shadow)
2032 .expect_err("update should fail");
2033 assert_eq!(entry.get_addr(), &first_addr);
2034 }
2035
2036 #[test]
2037 fn nonexistent_conn_entry() {
2038 let mut map = FakeBoundSocketMap::default();
2039 let mut fake_id_gen = FakeSocketIdGen::default();
2040 let addr = CONN_ADDR;
2041 let conn_id = map
2042 .conns_mut()
2043 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2044 .expect("failed to insert")
2045 .id()
2046 .clone();
2047 assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2048
2049 assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2050 }
2051
2052 #[test]
2053 fn update_conn_sharing() {
2054 let mut map = FakeBoundSocketMap::default();
2055 let mut fake_id_gen = FakeSocketIdGen::default();
2056 let addr = CONN_ADDR;
2057 let mut entry = map
2058 .conns_mut()
2059 .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2060 .expect("failed to insert");
2061
2062 entry
2063 .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2064 .expect("worked");
2065 let mut second_conn = map
2068 .conns_mut()
2069 .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2070 .expect("can insert");
2071 assert_matches!(
2072 second_conn
2073 .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2074 Err(UpdateSharingError)
2075 );
2076 }
2077
2078 #[test]
2079 fn lookup_connected() {
2080 let mut map = FakeBoundSocketMap::default();
2081 let mut fake_id_gen = FakeSocketIdGen::default();
2082
2083 let sharing_state = SharingState::shared('a');
2084
2085 let device_id = FakeWeakDeviceId(FakeDeviceId);
2086 let entry1 = map
2087 .conns_mut()
2088 .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2089 .expect("failed to insert")
2090 .id()
2091 .clone();
2092 let conn = map
2093 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2094 .expect("lookup should succeed");
2095 assert!(conn.contains_id(&entry1));
2096
2097 let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2100 let entry2 = map
2101 .conns_mut()
2102 .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2103 .expect("failed to insert")
2104 .id()
2105 .clone();
2106 let conn = map
2107 .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2108 .expect("lookup should succeed");
2109 assert!(conn.contains_id(&entry2));
2110 }
2111}