Skip to main content

netstack3_base/socket/
base.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! General-purpose socket utilities common to device layer and IP layer
6//! sockets.
7
8use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17    AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19use thiserror::Error;
20
21use crate::LocalAddressError;
22use crate::data_structures::socketmap::{
23    Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
24};
25use crate::device::{
26    DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
27};
28use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
29use crate::ip::BroadcastIpExt;
30use crate::socket::address::{
31    AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
32};
33
34/// A dual stack IP extention trait that provides the `OtherVersion` associated
35/// type.
36pub trait DualStackIpExt: Ip {
37    /// The "other" IP version, e.g. [`Ipv4`] for [`Ipv6`] and vice-versa.
38    type OtherVersion: DualStackIpExt<OtherVersion = Self>;
39}
40
41impl DualStackIpExt for Ipv4 {
42    type OtherVersion = Ipv6;
43}
44
45impl DualStackIpExt for Ipv6 {
46    type OtherVersion = Ipv4;
47}
48
49/// A tuple of values for `T` for both `I` and `I::OtherVersion`.
50pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
51    this_stack: <T as GenericOverIp<I>>::Type,
52    other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
53    _marker: IpVersionMarker<I>,
54}
55
56impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
57    /// Creates a new tuple with `this_stack` and `other_stack` values.
58    pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
59    where
60        T: GenericOverIp<I, Type = T>,
61    {
62        Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
63    }
64
65    /// Retrieves `(this_stack, other_stack)` from the tuple.
66    pub fn into_inner(
67        self,
68    ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
69        let Self { this_stack, other_stack, _marker } = self;
70        (this_stack, other_stack)
71    }
72
73    /// Retrieves `this_stack` from the tuple.
74    pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
75        self.this_stack
76    }
77
78    /// Borrows `this_stack` from the tuple.
79    pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
80        &self.this_stack
81    }
82
83    /// Retrieves `other_stack` from the tuple.
84    pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
85        self.other_stack
86    }
87
88    /// Borrows `other_stack` from the tuple.
89    pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
90        &self.other_stack
91    }
92
93    /// Flips the types, making `this_stack` `other_stack` and vice-versa.
94    pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
95        let Self { this_stack, other_stack, _marker } = self;
96        DualStackTuple {
97            this_stack: other_stack,
98            other_stack: this_stack,
99            _marker: IpVersionMarker::new(),
100        }
101    }
102
103    /// Casts to IP version `X`.
104    ///
105    /// Given `DualStackTuple` contains complete information for both IP
106    /// versions, it can be easily cast into an arbitrary `X` IP version.
107    ///
108    /// This can be used to tie together type parameters when dealing with dual
109    /// stack sockets. For example, a `DualStackTuple` defined for `SockI` can
110    /// be cast to any `WireI`.
111    pub fn cast<X>(self) -> DualStackTuple<X, T>
112    where
113        X: DualStackIpExt,
114        T: GenericOverIp<X>
115            + GenericOverIp<X::OtherVersion>
116            + GenericOverIp<Ipv4>
117            + GenericOverIp<Ipv6>,
118    {
119        I::map_ip_in(
120            self,
121            |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
122            |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
123        )
124    }
125}
126
127impl<
128    I: DualStackIpExt,
129    NewIp: DualStackIpExt,
130    T: GenericOverIp<NewIp>
131        + GenericOverIp<NewIp::OtherVersion>
132        + GenericOverIp<I>
133        + GenericOverIp<I::OtherVersion>,
134> GenericOverIp<NewIp> for DualStackTuple<I, T>
135{
136    type Type = DualStackTuple<NewIp, T>;
137}
138
139/// Extension trait for `Ip` providing socket-specific functionality.
140pub trait SocketIpExt: Ip {
141    /// `Self::LOOPBACK_ADDRESS`, but wrapped in the `SocketIpAddr` type.
142    const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
143        // SAFETY: The loopback address is a valid SocketIpAddr, as verified
144        // in the `loopback_addr_is_valid_socket_addr` test.
145        SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
146    };
147}
148
149impl<I: Ip> SocketIpExt for I {}
150
151#[cfg(test)]
152mod socket_ip_ext_test {
153    use super::*;
154    use ip_test_macro::ip_test;
155
156    #[ip_test(I)]
157    fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
158        // `LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR is defined with the "unchecked"
159        // constructor (which supports const construction). Verify here that the
160        // addr actually satisfies all the requirements (protecting against far
161        // away changes)
162        let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
163            .expect("loopback address should be a valid SocketIpAddr");
164    }
165}
166
167/// State belonging to either IP stack.
168///
169/// Like `[either::Either]`, but with more helpful variant names.
170///
171/// Note that this type is not optimally type-safe, because `T` and `O` are not
172/// bound by `IP` and `IP::OtherVersion`, respectively. In many cases it may be
173/// more appropriate to define a one-off enum parameterized over `I: Ip`.
174#[derive(Debug, PartialEq, Eq)]
175pub enum EitherStack<T, O> {
176    /// In the current stack version.
177    ThisStack(T),
178    /// In the other version of the stack.
179    OtherStack(O),
180}
181
182impl<T, O> Clone for EitherStack<T, O>
183where
184    T: Clone,
185    O: Clone,
186{
187    #[cfg_attr(feature = "instrumented", track_caller)]
188    fn clone(&self) -> Self {
189        match self {
190            Self::ThisStack(t) => Self::ThisStack(t.clone()),
191            Self::OtherStack(t) => Self::OtherStack(t.clone()),
192        }
193    }
194}
195
196/// Control flow type containing either a dual-stack or non-dual-stack context.
197///
198/// This type exists to provide nice names to the result of
199/// [`BoundStateContext::dual_stack_context`], and to allow generic code to
200/// match on when checking whether a socket protocol and IP version support
201/// dual-stack operation. If dual-stack operation is supported, a
202/// [`MaybeDualStack::DualStack`] value will be held, otherwise a `NonDualStack`
203/// value.
204///
205/// Note that the templated types to not have trait bounds; those are provided
206/// by the trait with the `dual_stack_context` function.
207///
208/// In monomorphized code, this type frequently has exactly one template
209/// parameter that is uninstantiable (it contains an instance of
210/// [`core::convert::Infallible`] or some other empty enum, or a reference to
211/// the same)! That lets the compiler optimize it out completely, creating no
212/// actual runtime overhead.
213#[derive(Debug)]
214#[allow(missing_docs)]
215pub enum MaybeDualStack<DS, NDS> {
216    DualStack(DS),
217    NotDualStack(NDS),
218}
219
220// Implement `GenericOverIp` for a `MaybeDualStack` whose `DS` and `NDS` also
221// implement `GenericOverIp`.
222impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
223    for MaybeDualStack<DS, NDS>
224{
225    type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
226}
227
228/// An error encountered while enabling or disabling dual-stack operation.
229#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
230#[generic_over_ip()]
231pub enum SetDualStackEnabledError {
232    /// A socket can only have dual stack enabled or disabled while unbound.
233    #[error("a socket can only have dual stack enabled or disabled while unbound")]
234    SocketIsBound,
235    /// The socket's protocol is not dual stack capable.
236    #[error(transparent)]
237    NotCapable(#[from] NotDualStackCapableError),
238}
239
240/// An error encountered when attempting to perform dual stack operations on
241/// socket with a non dual stack capable protocol.
242#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq, Error)]
243#[generic_over_ip()]
244#[error("socket's protocol is not dual-stack capable")]
245pub struct NotDualStackCapableError;
246
247/// Describes which direction(s) of the data path should be shut down.
248#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
249pub struct Shutdown {
250    /// True if the send path is shut down for the owning socket.
251    ///
252    /// If this is true, the socket should not be able to send packets.
253    pub send: bool,
254    /// True if the receive path is shut down for the owning socket.
255    ///
256    /// If this is true, the socket should not be able to receive packets.
257    pub receive: bool,
258}
259
260/// Which direction(s) to shut down for a socket.
261#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
262#[generic_over_ip()]
263pub enum ShutdownType {
264    /// Prevent sending packets on the socket.
265    Send,
266    /// Prevent receiving packets on the socket.
267    Receive,
268    /// Prevent sending and receiving packets on the socket.
269    SendAndReceive,
270}
271
272impl ShutdownType {
273    /// Returns a tuple of booleans for `(shutdown_send, shutdown_receive)`.
274    pub fn to_send_receive(&self) -> (bool, bool) {
275        match self {
276            Self::Send => (true, false),
277            Self::Receive => (false, true),
278            Self::SendAndReceive => (true, true),
279        }
280    }
281
282    /// Creates a [`ShutdownType`] from a pair of bools for send and receive.
283    pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
284        match (send, receive) {
285            (true, false) => Some(Self::Send),
286            (false, true) => Some(Self::Receive),
287            (true, true) => Some(Self::SendAndReceive),
288            (false, false) => None,
289        }
290    }
291}
292
293/// Extensions to IP Address witnesses useful in the context of sockets.
294pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
295    /// Determines whether the provided address is underspecified by itself.
296    ///
297    /// Some addresses are ambiguous and so must have a zone identifier in order
298    /// to be used in a socket address. This function returns true for IPv6
299    /// link-local addresses and false for all others.
300    fn must_have_zone(&self) -> bool
301    where
302        Self: Copy,
303    {
304        self.try_into_null_zoned().is_some()
305    }
306
307    /// Converts into a [`AddrAndZone<A, ()>`] if the address requires a zone.
308    ///
309    /// Otherwise returns `None`.
310    fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
311        if self.get().is_loopback() {
312            return None;
313        }
314        AddrAndZone::new(self, ())
315    }
316}
317
318impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
319
320/// An extention trait for [`ZonedAddr`].
321pub trait SocketZonedAddrExt<W, A, D> {
322    /// Returns the address and device that should be used for a socket.
323    ///
324    /// Given an address for a socket and an optional device that the socket is
325    /// already bound on, returns the address and device that should be used
326    /// for the socket. If `addr` and `device` require inconsistent devices,
327    /// or if `addr` requires a zone but there is none specified (by `addr` or
328    /// `device`), an error is returned.
329    fn resolve_addr_with_device(
330        self,
331        device: Option<D::Weak>,
332    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
333    where
334        D: StrongDeviceIdentifier;
335}
336
337impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
338where
339    W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
340    A: IpAddress,
341{
342    fn resolve_addr_with_device(
343        self,
344        device: Option<D::Weak>,
345    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
346    where
347        D: StrongDeviceIdentifier,
348    {
349        let (addr, zone) = self.into_addr_zone();
350        let device = match (zone, device) {
351            (Some(zone), Some(device)) => {
352                if device != zone {
353                    return Err(ZonedAddressError::DeviceZoneMismatch);
354                }
355                Some(EitherDeviceId::Strong(zone))
356            }
357            (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
358            (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
359            (None, None) => {
360                if addr.as_ref().must_have_zone() {
361                    return Err(ZonedAddressError::RequiredZoneNotProvided);
362                } else {
363                    None
364                }
365            }
366        };
367        Ok((addr, device))
368    }
369}
370
371/// A helper type to verify if applying socket updates is allowed for a given
372/// current state.
373///
374/// The fields in `SocketDeviceUpdate` define the current state,
375/// [`SocketDeviceUpdate::try_update`] applies the verification logic.
376pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
377    /// The current local IP address.
378    pub local_ip: Option<&'a SpecifiedAddr<A>>,
379    /// The current remote IP address.
380    pub remote_ip: Option<&'a SpecifiedAddr<A>>,
381    /// The currently bound device.
382    pub old_device: Option<&'a D>,
383}
384
385impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
386    /// Checks if an update from `old_device` to `new_device` is allowed,
387    /// returning an error if not.
388    pub fn check_update<N>(
389        self,
390        new_device: Option<&N>,
391    ) -> Result<(), SocketDeviceUpdateNotAllowedError>
392    where
393        D: PartialEq<N>,
394    {
395        let Self { local_ip, remote_ip, old_device } = self;
396        let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
397            || remote_ip.is_some_and(|a| a.must_have_zone());
398
399        if !must_have_zone {
400            return Ok(());
401        }
402
403        let old_device = old_device.unwrap_or_else(|| {
404            panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
405        });
406
407        if new_device.is_some_and(|new_device| old_device == new_device) {
408            Ok(())
409        } else {
410            Err(SocketDeviceUpdateNotAllowedError)
411        }
412    }
413}
414
415/// The device can't be updated on a socket.
416pub struct SocketDeviceUpdateNotAllowedError;
417
418/// Specification for the identifiers in an [`AddrVec`].
419///
420/// This is a convenience trait for bundling together the local and remote
421/// identifiers for a protocol.
422pub trait SocketMapAddrSpec {
423    /// The local identifier portion of a socket address.
424    type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
425    /// The remote identifier portion of a socket address.
426    type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
427}
428
429/// Information about the address in a [`ListenerAddr`].
430pub struct ListenerAddrInfo {
431    /// Whether the address has a device bound.
432    pub has_device: bool,
433    /// Whether the listener is on a specified address (as opposed to a blanket
434    /// listener).
435    pub specified_addr: bool,
436}
437
438impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
439    pub(crate) fn info(&self) -> ListenerAddrInfo {
440        let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
441        ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
442    }
443}
444
445/// Specifies the types parameters for [`BoundSocketMap`] state as a single bundle.
446pub trait SocketMapStateSpec {
447    /// The tag value of a socket address vector entry.
448    ///
449    /// These values are derived from [`Self::ListenerAddrState`] and
450    /// [`Self::ConnAddrState`].
451    type AddrVecTag: Eq + Copy + Debug + 'static;
452
453    /// Returns a the tag for a listener in the socket map.
454    fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
455
456    /// Returns a the tag for a connected socket in the socket map.
457    fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
458
459    /// An identifier for a listening socket.
460    type ListenerId: Clone + Debug;
461    /// An identifier for a connected socket.
462    type ConnId: Clone + Debug;
463
464    /// The state stored for a listening socket that is used to determine
465    /// whether sockets can share an address.
466    type ListenerSharingState: Clone + Debug;
467
468    /// The state stored for a connected socket that is used to determine
469    /// whether sockets can share an address.
470    type ConnSharingState: Clone + Debug;
471
472    /// The state stored for a listener socket address.
473    type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
474        + Debug;
475
476    /// The state stored for a connected socket address.
477    type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
478        + Debug;
479}
480
481/// Error returned by implementations of [`SocketMapAddrStateSpec`] to indicate
482/// incompatible changes to a socket map.
483#[derive(Copy, Clone, Debug, Eq, PartialEq)]
484pub struct IncompatibleError;
485
486/// An inserter into a [`SocketMap`].
487pub trait Inserter<T> {
488    /// Inserts the provided item and consumes `self`.
489    ///
490    /// Inserts a single item and consumes the inserter (thus preventing
491    /// additional insertions).
492    fn insert(self, item: T);
493}
494
495impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
496    fn insert(self, item: T) {
497        self.extend([item])
498    }
499}
500
501impl<T> Inserter<T> for Never {
502    fn insert(self, _: T) {
503        match self {}
504    }
505}
506
507/// Describes an entry in a [`SocketMap`] for a listener or connection address.
508pub trait SocketMapAddrStateSpec {
509    /// The type of ID that can be present at the address.
510    type Id;
511
512    /// The sharing state for the address.
513    ///
514    /// This can be used to determine whether a socket can be inserted at the
515    /// address. Every socket has its own sharing state associated with it,
516    /// though the sharing state is not necessarily stored in the address
517    /// entry.
518    type SharingState;
519
520    /// The type of inserter returned by [`SocketMapAddrStateSpec::try_get_inserter`].
521    type Inserter<'a>: Inserter<Self::Id> + 'a
522    where
523        Self: 'a,
524        Self::Id: 'a;
525
526    /// Creates a new `Self` holding the provided socket with the given new
527    /// sharing state at the specified address.
528    fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
529
530    /// Looks up the ID in self, returning `true` if it is present.
531    fn contains_id(&self, id: &Self::Id) -> bool;
532
533    /// Enables insertion in `self` for a new socket with the provided sharing
534    /// state.
535    ///
536    /// If the new state is incompatible with the existing socket(s),
537    /// implementations of this function should return `Err(IncompatibleError)`.
538    /// If `Ok(x)` is returned, calling `x.insert(y)` will insert `y` into
539    /// `self`.
540    fn try_get_inserter<'a, 'b>(
541        &'b mut self,
542        new_sharing_state: &'a Self::SharingState,
543    ) -> Result<Self::Inserter<'b>, IncompatibleError>;
544
545    /// Returns `Ok` if an entry with the given sharing state could be added
546    /// to `self`.
547    ///
548    /// If this returns `Ok`, `try_get_dest` should succeed.
549    fn could_insert(&self, new_sharing_state: &Self::SharingState)
550    -> Result<(), IncompatibleError>;
551
552    /// Removes the given socket from the existing state.
553    ///
554    /// Implementations should assume that `id` is contained in `self`.
555    fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
556}
557
558/// Provides behavior on updating the sharing state of a [`SocketMap`] entry.
559pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
560    /// Attempts to update the sharing state of the address state with id `id`
561    /// to `new_sharing_state`.
562    fn try_update_sharing(
563        &mut self,
564        id: Self::Id,
565        new_sharing_state: &Self::SharingState,
566    ) -> Result<(), IncompatibleError>;
567}
568
569/// Provides conflict detection for a [`SocketMapStateSpec`].
570pub trait SocketMapConflictPolicy<
571    Addr,
572    SharingState,
573    I: Ip,
574    D: DeviceIdentifier,
575    A: SocketMapAddrSpec,
576>: SocketMapStateSpec
577{
578    /// Checks whether a new socket with the provided state can be inserted at
579    /// the given address in the existing socket map, returning an error
580    /// otherwise.
581    ///
582    /// Implementations of this function should check for any potential
583    /// conflicts that would arise when inserting a socket with state
584    /// `new_sharing_state` into a new or existing entry at `addr` in
585    /// `socketmap`.
586    fn check_insert_conflicts(
587        new_sharing_state: &SharingState,
588        addr: &Addr,
589        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
590    ) -> Result<(), InsertError>;
591}
592
593/// Defines the policy for updating the sharing state of entries in the
594/// [`SocketMap`].
595pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
596    SocketMapConflictPolicy<Addr, SharingState, I, D, A>
597where
598    A: SocketMapAddrSpec,
599{
600    /// Returns whether the entry `addr` in `socketmap` allows the sharing state
601    /// to transition from `old_sharing` to `new_sharing`.
602    fn allows_sharing_update(
603        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
604        addr: &Addr,
605        old_sharing: &SharingState,
606        new_sharing: &SharingState,
607    ) -> Result<(), UpdateSharingError>;
608}
609
610/// A bound socket state that is either a listener or a connection.
611#[derive(Derivative)]
612#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
613#[allow(missing_docs)]
614pub enum Bound<S: SocketMapStateSpec + ?Sized> {
615    Listen(S::ListenerAddrState),
616    Conn(S::ConnAddrState),
617}
618
619/// An "address vector" type that can hold any address in a [`SocketMap`].
620///
621/// This is a "vector" in the mathematical sense, in that it denotes an address
622/// in a space. Here, the space is the possible addresses to which a socket
623/// receiving IP packets can be bound.
624///
625/// `AddrVec`s are used as keys for the `SocketMap` type. Since an incoming
626/// packet can match more than one address, for each incoming packet there is a
627/// set of possible `AddrVec` keys whose entries (sockets) in a `SocketMap`
628/// might receive the packet.
629///
630/// This set of keys can be ordered by precedence as described in the
631/// documentation for [`AddrVecIter`]. Calling [`IterShadows::iter_shadows`] on
632/// an instance will produce the sequence of addresses it has precedence over.
633#[derive(Derivative)]
634#[derivative(
635    Debug(bound = "D: Debug"),
636    Clone(bound = "D: Clone"),
637    Eq(bound = "D: Eq"),
638    PartialEq(bound = "D: PartialEq"),
639    Hash(bound = "D: Hash")
640)]
641#[allow(missing_docs)]
642pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
643    Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
644    Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
645}
646
647impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
648    Tagged<AddrVec<I, D, A>> for Bound<S>
649{
650    type Tag = S::AddrVecTag;
651    fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
652        match (self, address) {
653            (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
654            (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
655                S::connected_tag(device.is_some(), c)
656            }
657            (Bound::Listen(_), AddrVec::Conn(_)) => {
658                unreachable!("found listen state for conn addr")
659            }
660            (Bound::Conn(_), AddrVec::Listen(_)) => {
661                unreachable!("found conn state for listen addr")
662            }
663        }
664    }
665}
666
667impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
668    type IterShadows = AddrVecIter<I, D, A>;
669
670    fn iter_shadows(&self) -> Self::IterShadows {
671        let (socket_ip_addr, device) = match self.clone() {
672            AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
673            AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
674        };
675        let mut iter = match device {
676            Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
677            None => AddrVecIter::without_device(socket_ip_addr),
678        };
679        // Skip the first element, which is always `*self`.
680        assert_eq!(iter.next().as_ref(), Some(self));
681        iter
682    }
683}
684
685/// How a socket is bound on the system.
686#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
687#[allow(missing_docs)]
688pub enum SocketAddrType {
689    AnyListener,
690    SpecificListener,
691    Connected,
692}
693
694impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
695    fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
696        match addr {
697            Some(_) => SocketAddrType::SpecificListener,
698            None => SocketAddrType::AnyListener,
699        }
700    }
701}
702
703impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
704    fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
705        SocketAddrType::Connected
706    }
707}
708
709/// The result of attempting to remove a socket from a collection of sockets.
710pub enum RemoveResult {
711    /// The value was removed successfully.
712    Success,
713    /// The value is the last value in the collection so the entire collection
714    /// should be removed.
715    IsLast,
716}
717
718#[derive(Derivative)]
719#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
720pub enum SocketId<S: SocketMapStateSpec> {
721    Listener(S::ListenerId),
722    Connection(S::ConnId),
723}
724
725/// A map from socket addresses to sockets.
726///
727/// The types of keys and IDs is determined by the [`SocketMapStateSpec`]
728/// parameter. Each listener and connected socket stores additional state.
729/// Listener and connected sockets are keyed independently, but share the same
730/// address vector space. Conflicts are detected on attempted insertion of new
731/// sockets.
732///
733/// Listener addresses map to listener-address-specific state, and likewise
734/// with connected addresses. Depending on protocol (determined by the
735/// `SocketMapStateSpec` protocol), these address states can hold one or more
736/// socket identifiers (e.g. UDP sockets with `SO_REUSEPORT` set can share an
737/// address).
738#[derive(Derivative)]
739#[derivative(Default(bound = ""))]
740pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
741    addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
742}
743
744impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
745    BoundSocketMap<I, D, A, S>
746{
747    /// Returns the number of entries in the map.
748    pub fn len(&self) -> usize {
749        self.addr_to_state.len()
750    }
751}
752
753/// Uninstantiable tag type for denoting listening sockets.
754pub enum Listener {}
755/// Uninstantiable tag type for denoting connected sockets.
756pub enum Connection {}
757
758/// View struct over one type of sockets in a [`BoundSocketMap`].
759pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
760
761impl<
762    'a,
763    I: Ip,
764    D: DeviceIdentifier,
765    SocketType: ConvertSocketMapState<I, D, A, S>,
766    A: SocketMapAddrSpec,
767    S: SocketMapStateSpec,
768> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
769where
770    S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
771{
772    /// Returns the state at an address, if there is any.
773    pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
774        let Self(addr_to_state, _marker) = self;
775        addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
776            SocketType::from_bound_ref(state)
777                .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
778        })
779    }
780
781    /// Returns `Ok(())` if a socket could be inserted, otherwise an error.
782    ///
783    /// Goes through a dry run of inserting a socket at the given address and
784    /// with the given sharing state, returning `Ok(())` if the insertion would
785    /// succeed, otherwise the error that would be returned.
786    pub fn could_insert(
787        self,
788        addr: &SocketType::Addr,
789        sharing: &SocketType::SharingState,
790    ) -> Result<(), InsertError> {
791        let Self(addr_to_state, _) = self;
792        match self.get_by_addr(addr) {
793            Some(state) => {
794                state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
795            }
796            None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
797        }
798    }
799}
800
801/// A borrowed state entry in a [`SocketMap`].
802#[derive(Derivative)]
803#[derivative(Debug(bound = ""))]
804pub struct SocketStateEntry<
805    'a,
806    I: Ip,
807    D: DeviceIdentifier,
808    A: SocketMapAddrSpec,
809    S: SocketMapStateSpec,
810    SocketType,
811> {
812    id: SocketId<S>,
813    addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
814    _marker: PhantomData<SocketType>,
815}
816
817impl<
818    'a,
819    I: Ip,
820    D: DeviceIdentifier,
821    SocketType: ConvertSocketMapState<I, D, A, S>,
822    A: SocketMapAddrSpec,
823    S: SocketMapStateSpec
824        + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
825> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
826where
827    SocketType::SharingState: Clone,
828    SocketType::Id: Clone,
829{
830    /// Attempts to insert a new entry into the [`SocketMap`] backing this
831    /// `Sockets`.
832    pub fn try_insert(
833        self,
834        socket_addr: SocketType::Addr,
835        tag_state: SocketType::SharingState,
836        id: SocketType::Id,
837    ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, InsertError> {
838        self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
839            .map(|(entry, ())| entry)
840    }
841
842    /// Like [`Sockets::try_insert`] but calls `make_id` to create a socket ID
843    /// before inserting into the map.
844    ///
845    /// `make_id` returns type `R` that is returned to the caller on success.
846    pub fn try_insert_with<R>(
847        self,
848        socket_addr: SocketType::Addr,
849        tag_state: SocketType::SharingState,
850        make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
851    ) -> Result<(SocketStateEntry<'a, I, D, A, S, SocketType>, R), InsertError> {
852        let Self(addr_to_state, _) = self;
853        S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state)?;
854
855        let addr = SocketType::to_addr_vec(&socket_addr);
856
857        match addr_to_state.entry(addr) {
858            Entry::Occupied(mut o) => {
859                let (id, ret) = o.map_mut(|bound| {
860                    let bound = match SocketType::from_bound_mut(bound) {
861                        Some(bound) => bound,
862                        None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
863                    };
864                    match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
865                        bound, &tag_state,
866                    ) {
867                        Ok(v) => {
868                            let (id, ret) = make_id(socket_addr, tag_state);
869                            v.insert(id.clone());
870                            Ok((SocketType::to_socket_id(id), ret))
871                        }
872                        Err(IncompatibleError) => Err(InsertError::Exists),
873                    }
874                })?;
875                Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
876            }
877            Entry::Vacant(v) => {
878                let (id, ret) = make_id(socket_addr, tag_state.clone());
879                let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
880                    &tag_state,
881                    id.clone(),
882                )));
883                let id = SocketType::to_socket_id(id);
884                Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
885            }
886        }
887    }
888
889    /// Returns a borrowed entry at `id` and `addr`.
890    pub fn entry(
891        self,
892        id: &SocketType::Id,
893        addr: &SocketType::Addr,
894    ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
895        let Self(addr_to_state, _) = self;
896        let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
897            Entry::Vacant(_) => return None,
898            Entry::Occupied(o) => o,
899        };
900        let state = SocketType::from_bound_ref(addr_entry.get())?;
901
902        state.contains_id(id).then_some(SocketStateEntry {
903            id: SocketType::to_socket_id(id.clone()),
904            addr_entry,
905            _marker: PhantomData::default(),
906        })
907    }
908
909    /// Removes the entry with `id` and `addr`.
910    pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
911        self.entry(id, addr)
912            .map(|entry| {
913                entry.remove();
914            })
915            .ok_or(NotFoundError)
916    }
917}
918
919/// The error returned when updating the sharing state for a [`SocketMap`] entry
920/// fails.
921#[derive(Debug)]
922pub struct UpdateSharingError;
923
924impl<
925    'a,
926    I: Ip,
927    D: DeviceIdentifier,
928    SocketType: ConvertSocketMapState<I, D, A, S>,
929    A: SocketMapAddrSpec,
930    S: SocketMapStateSpec,
931> SocketStateEntry<'a, I, D, A, S, SocketType>
932where
933    SocketType::Id: Clone,
934{
935    /// Returns this entry's address.
936    pub fn get_addr(&self) -> &SocketType::Addr {
937        let Self { id: _, addr_entry, _marker } = self;
938        SocketType::from_addr_vec_ref(addr_entry.key())
939    }
940
941    /// Returns this entry's identifier.
942    pub fn id(&self) -> &SocketType::Id {
943        let Self { id, addr_entry: _, _marker } = self;
944        SocketType::from_socket_id_ref(id)
945    }
946
947    /// Attempts to update the address for this entry.
948    pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
949        let Self { id, addr_entry, _marker } = self;
950
951        let new_addrvec = SocketType::to_addr_vec(&new_addr);
952        let old_addr = addr_entry.key().clone();
953        let (addr_state, addr_to_state) = addr_entry.remove_from_map();
954        let addr_to_state = match addr_to_state.entry(new_addrvec) {
955            Entry::Occupied(o) => o.into_map(),
956            Entry::Vacant(v) => {
957                if v.descendant_counts().len() != 0 {
958                    v.into_map()
959                } else {
960                    let new_addr_entry = v.insert(addr_state);
961                    return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
962                }
963            }
964        };
965        let to_restore = addr_state;
966        // Restore the old state before returning an error.
967        let addr_entry = match addr_to_state.entry(old_addr) {
968            Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
969            Entry::Vacant(v) => v.insert(to_restore),
970        };
971        return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
972    }
973
974    /// Removes this entry from the map.
975    pub fn remove(self) {
976        let Self { id, mut addr_entry, _marker } = self;
977        let addr = addr_entry.key().clone();
978        match addr_entry.map_mut(|value| {
979            let value = match SocketType::from_bound_mut(value) {
980                Some(value) => value,
981                None => unreachable!("found {:?} for address {:?}", value, addr),
982            };
983            value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
984        }) {
985            RemoveResult::Success => (),
986            RemoveResult::IsLast => {
987                let _: Bound<S> = addr_entry.remove();
988            }
989        }
990    }
991
992    /// Attempts to update the sharing state for this entry.
993    pub fn try_update_sharing(
994        &mut self,
995        old_sharing_state: &SocketType::SharingState,
996        new_sharing_state: SocketType::SharingState,
997    ) -> Result<(), UpdateSharingError>
998    where
999        SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1000        S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1001    {
1002        let Self { id, addr_entry, _marker } = self;
1003        let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1004
1005        S::allows_sharing_update(
1006            addr_entry.get_map(),
1007            addr,
1008            old_sharing_state,
1009            &new_sharing_state,
1010        )?;
1011
1012        addr_entry
1013            .map_mut(|value| {
1014                let value = match SocketType::from_bound_mut(value) {
1015                    Some(value) => value,
1016                    // We shouldn't ever be storing listener state in a bound
1017                    // address, or bound state in a listener address. Doing so means
1018                    // we've got a serious bug.
1019                    None => unreachable!("found invalid state {:?}", value),
1020                };
1021
1022                value.try_update_sharing(
1023                    SocketType::from_socket_id_ref(id).clone(),
1024                    &new_sharing_state,
1025                )
1026            })
1027            .map_err(|IncompatibleError| UpdateSharingError)
1028    }
1029}
1030
1031impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1032where
1033    AddrVec<I, D, A>: IterShadows,
1034    S: SocketMapStateSpec,
1035{
1036    /// Returns an iterator over the listeners on the socket map.
1037    pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1038    where
1039        S: SocketMapConflictPolicy<
1040                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1041                <S as SocketMapStateSpec>::ListenerSharingState,
1042                I,
1043                D,
1044                A,
1045            >,
1046        S::ListenerAddrState:
1047            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1048    {
1049        let Self { addr_to_state } = self;
1050        Sockets(addr_to_state, Default::default())
1051    }
1052
1053    /// Returns a mutable iterator over the listeners on the socket map.
1054    pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1055    where
1056        S: SocketMapConflictPolicy<
1057                ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1058                <S as SocketMapStateSpec>::ListenerSharingState,
1059                I,
1060                D,
1061                A,
1062            >,
1063        S::ListenerAddrState:
1064            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1065    {
1066        let Self { addr_to_state } = self;
1067        Sockets(addr_to_state, Default::default())
1068    }
1069
1070    /// Returns an iterator over the connections on the socket map.
1071    pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1072    where
1073        S: SocketMapConflictPolicy<
1074                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1075                <S as SocketMapStateSpec>::ConnSharingState,
1076                I,
1077                D,
1078                A,
1079            >,
1080        S::ConnAddrState:
1081            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1082    {
1083        let Self { addr_to_state } = self;
1084        Sockets(addr_to_state, Default::default())
1085    }
1086
1087    /// Returns a mutable iterator over the connections on the socket map.
1088    pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1089    where
1090        S: SocketMapConflictPolicy<
1091                ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1092                <S as SocketMapStateSpec>::ConnSharingState,
1093                I,
1094                D,
1095                A,
1096            >,
1097        S::ConnAddrState:
1098            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1099    {
1100        let Self { addr_to_state } = self;
1101        Sockets(addr_to_state, Default::default())
1102    }
1103
1104    #[cfg(test)]
1105    pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1106        let Self { addr_to_state } = self;
1107        addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1108    }
1109
1110    /// Gets the number of shadower entries for `addr`.
1111    pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1112        let Self { addr_to_state } = self;
1113        addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1114    }
1115}
1116
1117/// The type returned by [`BoundSocketMap::iter_receivers`].
1118pub enum FoundSockets<A, It> {
1119    /// A single recipient was found for the address.
1120    Single(A),
1121    /// Indicates the looked-up address was multicast, and holds an iterator of
1122    /// the found receivers.
1123    Multicast(It),
1124}
1125
1126/// A borrowed entry in a [`BoundSocketMap`].
1127#[allow(missing_docs)]
1128#[derive(Debug)]
1129pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1130    Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1131    Conn(
1132        &'a S::ConnAddrState,
1133        ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1134    ),
1135}
1136
1137impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1138where
1139    I: BroadcastIpExt<Addr: MulticastAddress>,
1140    D: DeviceIdentifier,
1141    A: SocketMapAddrSpec,
1142    S: SocketMapStateSpec
1143        + SocketMapConflictPolicy<
1144            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1145            <S as SocketMapStateSpec>::ListenerSharingState,
1146            I,
1147            D,
1148            A,
1149        > + SocketMapConflictPolicy<
1150            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1151            <S as SocketMapStateSpec>::ConnSharingState,
1152            I,
1153            D,
1154            A,
1155        >,
1156{
1157    /// Looks up a connected socket.
1158    ///
1159    /// This is a lightweight version of `iter_receivers()` that doesn't try to
1160    /// lookup listening sockets. It is used for early demux, which applies only
1161    /// to connected sockets.
1162    pub fn lookup_connected(
1163        &self,
1164        (src_ip, src_port): (SocketIpAddr<I::Addr>, A::RemoteIdentifier),
1165        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1166        device: D,
1167    ) -> Option<&'_ S::ConnAddrState> {
1168        let mut addr = ConnAddr {
1169            ip: ConnIpAddr { local: (dst_ip, dst_port), remote: (src_ip, src_port) },
1170            device: Some(device),
1171        };
1172        let entry = self.conns().get_by_addr(&addr);
1173        if entry.is_some() {
1174            return entry;
1175        }
1176        addr.device = None;
1177        self.conns().get_by_addr(&addr)
1178    }
1179
1180    /// Finds the socket(s) that should receive an incoming packet.
1181    ///
1182    /// Uses the provided addresses and receiving device to look up sockets that
1183    /// should receive a matching incoming packet. Returns `None` if no sockets
1184    /// were found, or the results of the lookup.
1185    pub fn iter_receivers(
1186        &self,
1187        (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1188        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1189        device: D,
1190        broadcast: Option<I::BroadcastMarker>,
1191    ) -> Option<
1192        FoundSockets<
1193            AddrEntry<'_, I, D, A, S>,
1194            impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1195        >,
1196    > {
1197        let mut matching_entries = AddrVecIter::with_device(
1198            match (src_ip, src_port) {
1199                (Some(specified_src_ip), Some(src_port)) => {
1200                    ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1201                        .into()
1202                }
1203                _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1204            },
1205            device,
1206        )
1207        .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1208            AddrVec::Listen(l) => {
1209                self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1210            }
1211            AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1212        });
1213
1214        if broadcast.is_some() || dst_ip.addr().is_multicast() {
1215            Some(FoundSockets::Multicast(matching_entries))
1216        } else {
1217            let single_entry: Option<_> = matching_entries.next();
1218            single_entry.map(FoundSockets::Single)
1219        }
1220    }
1221}
1222
1223/// Errors observed by [`SocketMapConflictPolicy`].
1224#[derive(Debug, Eq, PartialEq)]
1225pub enum InsertError {
1226    /// A shadow address exists for the entry.
1227    ShadowAddrExists,
1228    /// Entry already exists.
1229    Exists,
1230    /// A shadower exists for the entry.
1231    ShadowerExists,
1232    /// An indirect conflict was detected.
1233    IndirectConflict,
1234}
1235
1236impl From<InsertError> for LocalAddressError {
1237    fn from(value: InsertError) -> Self {
1238        match value {
1239            InsertError::ShadowAddrExists
1240            | InsertError::Exists
1241            | InsertError::IndirectConflict
1242            | InsertError::ShadowerExists => LocalAddressError::AddressInUse,
1243        }
1244    }
1245}
1246
1247/// Helper trait for converting between [`AddrVec`] and [`Bound`] and their
1248/// variants.
1249pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1250    type Id;
1251    type SharingState;
1252    type Addr: Debug;
1253    type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1254
1255    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1256    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1257    fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1258    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1259    fn to_bound(state: Self::AddrState) -> Bound<S>;
1260    fn to_socket_id(id: Self::Id) -> SocketId<S>;
1261    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1262}
1263
1264impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1265    ConvertSocketMapState<I, D, A, S> for Listener
1266{
1267    type Id = S::ListenerId;
1268    type SharingState = S::ListenerSharingState;
1269    type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1270    type AddrState = S::ListenerAddrState;
1271    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1272        AddrVec::Listen(addr.clone())
1273    }
1274
1275    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1276        match addr {
1277            AddrVec::Listen(l) => l,
1278            AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1279        }
1280    }
1281
1282    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1283        match bound {
1284            Bound::Listen(l) => Some(l),
1285            Bound::Conn(_c) => None,
1286        }
1287    }
1288
1289    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1290        match bound {
1291            Bound::Listen(l) => Some(l),
1292            Bound::Conn(_c) => None,
1293        }
1294    }
1295
1296    fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1297        Bound::Listen(state)
1298    }
1299    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1300        match id {
1301            SocketId::Listener(id) => id,
1302            SocketId::Connection(_) => unreachable!("connection ID for listener"),
1303        }
1304    }
1305    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1306        SocketId::Listener(id)
1307    }
1308}
1309
1310impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1311    ConvertSocketMapState<I, D, A, S> for Connection
1312{
1313    type Id = S::ConnId;
1314    type SharingState = S::ConnSharingState;
1315    type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1316    type AddrState = S::ConnAddrState;
1317    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1318        AddrVec::Conn(addr.clone())
1319    }
1320
1321    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1322        match addr {
1323            AddrVec::Conn(c) => c,
1324            AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1325        }
1326    }
1327
1328    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1329        match bound {
1330            Bound::Listen(_l) => None,
1331            Bound::Conn(c) => Some(c),
1332        }
1333    }
1334
1335    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1336        match bound {
1337            Bound::Listen(_l) => None,
1338            Bound::Conn(c) => Some(c),
1339        }
1340    }
1341
1342    fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1343        Bound::Conn(state)
1344    }
1345
1346    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1347        match id {
1348            SocketId::Connection(id) => id,
1349            SocketId::Listener(_) => unreachable!("listener ID for connection"),
1350        }
1351    }
1352    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1353        SocketId::Connection(id)
1354    }
1355}
1356
1357/// An identifier of a sharing domain used for SO_REUSEPORT.
1358#[derive(Debug, Eq, PartialEq, Clone, Copy, Hash)]
1359pub struct SharingDomain(u64);
1360
1361impl SharingDomain {
1362    /// Creates a new instance with the specified ID. Caller must ensure that the `id`
1363    /// uniquely identifies the sharing domain and that the client is authorized to use it,
1364    /// e.g. on Fuchsia the ID is the KOID of a handle provided by the client.
1365    pub const fn new(id: u64) -> Self {
1366        SharingDomain(id)
1367    }
1368}
1369
1370/// A value of the SO_REUSEPORT option. Also encodes the sharing domain, which allows
1371/// to ensure that only sockets in the same domain can share ports.
1372#[derive(Default, Debug, Eq, PartialEq, Clone, Copy, Hash)]
1373pub enum ReusePortOption {
1374    /// The option is disabled.
1375    #[default]
1376    Disabled,
1377
1378    /// The option is enabled: the port is shareable with other sockets in the
1379    /// same sharing domain.
1380    Enabled(SharingDomain),
1381}
1382
1383impl ReusePortOption {
1384    /// Returns `true` if the option is enabled.
1385    pub fn is_enabled(&self) -> bool {
1386        matches!(self, ReusePortOption::Enabled(_))
1387    }
1388
1389    /// Returns `true` if the socket is shareable with a socket with the
1390    /// specified value of the SO_REUSEPORT option.
1391    pub fn is_shareable_with(&self, other: &Self) -> bool {
1392        match (self, other) {
1393            (ReusePortOption::Enabled(domain1), ReusePortOption::Enabled(domain2)) => {
1394                domain1 == domain2
1395            }
1396            _ => false,
1397        }
1398    }
1399}
1400
1401#[cfg(test)]
1402mod tests {
1403    use alloc::vec;
1404    use alloc::vec::Vec;
1405
1406    use assert_matches::assert_matches;
1407    use net_declare::{net_ip_v4, net_ip_v6};
1408    use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1409    use netstack3_hashmap::HashSet;
1410    use test_case::test_case;
1411
1412    use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1413    use crate::testutil::set_logger_for_test;
1414
1415    use super::*;
1416
1417    #[test_case(net_ip_v4!("8.8.8.8"))]
1418    #[test_case(net_ip_v4!("127.0.0.1"))]
1419    #[test_case(net_ip_v4!("127.0.8.9"))]
1420    #[test_case(net_ip_v4!("224.1.2.3"))]
1421    fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1422        // No IPv4 addresses are allowed to have a zone.
1423        let addr = SpecifiedAddr::new(addr).unwrap();
1424        assert_eq!(addr.must_have_zone(), false);
1425    }
1426
1427    #[test_case(net_ip_v6!("1::2:3"), false)]
1428    #[test_case(net_ip_v6!("::1"), false; "localhost")]
1429    #[test_case(net_ip_v6!("1::"), false)]
1430    #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1431    #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1432    #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1433    #[test_case(net_ip_v6!("fe80::1"), true)]
1434    fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1435        // Only link-local unicast and multicast addresses are allowed to have
1436        // zones.
1437        let addr = SpecifiedAddr::new(addr).unwrap();
1438        assert_eq!(addr.must_have_zone(), must_have);
1439    }
1440
1441    #[test]
1442    fn try_into_null_zoned_ipv6() {
1443        assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1444        let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1445        const ZONE: u32 = 5;
1446        assert_eq!(
1447            zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1448            Some(AddrAndZone::new(zoned, ZONE).unwrap())
1449        );
1450    }
1451
1452    enum FakeSpec {}
1453
1454    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1455    struct Listener(usize);
1456
1457    #[derive(PartialEq, Eq, Debug, Copy, Clone)]
1458    struct SharingState {
1459        tag: char,
1460        shared: bool,
1461    }
1462
1463    impl SharingState {
1464        fn exclusive(tag: char) -> Self {
1465            Self { tag, shared: false }
1466        }
1467
1468        fn shared(tag: char) -> Self {
1469            Self { tag, shared: true }
1470        }
1471    }
1472
1473    impl SharingState {
1474        fn can_share_with(&self, other: &Self) -> bool {
1475            self.tag == other.tag && self.shared && other.shared
1476        }
1477    }
1478
1479    #[derive(PartialEq, Eq, Debug)]
1480    struct Multiple<T> {
1481        sharing_state: SharingState,
1482        entries: Vec<T>,
1483    }
1484
1485    impl<T> Multiple<T> {
1486        fn new_exclusive(tag: char, entries: Vec<T>) -> Self {
1487            Self { sharing_state: SharingState { tag, shared: false }, entries }
1488        }
1489    }
1490
1491    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1492    struct Conn(usize);
1493
1494    enum FakeAddrSpec {}
1495
1496    impl SocketMapAddrSpec for FakeAddrSpec {
1497        type LocalIdentifier = NonZeroU16;
1498        type RemoteIdentifier = ();
1499    }
1500
1501    impl SocketMapStateSpec for FakeSpec {
1502        type AddrVecTag = SharingState;
1503
1504        type ListenerId = Listener;
1505        type ConnId = Conn;
1506
1507        type ListenerSharingState = SharingState;
1508        type ConnSharingState = SharingState;
1509
1510        type ListenerAddrState = Multiple<Listener>;
1511        type ConnAddrState = Multiple<Conn>;
1512
1513        fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1514            state.sharing_state
1515        }
1516
1517        fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1518            state.sharing_state
1519        }
1520    }
1521
1522    type FakeBoundSocketMap =
1523        BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1524
1525    /// Generator for unique socket IDs that don't have any state.
1526    ///
1527    /// Calling [`FakeSocketIdGen::next`] returns a unique ID.
1528    #[derive(Default)]
1529    struct FakeSocketIdGen {
1530        next_id: usize,
1531    }
1532
1533    impl FakeSocketIdGen {
1534        fn next(&mut self) -> usize {
1535            let next_next_id = self.next_id + 1;
1536            core::mem::replace(&mut self.next_id, next_next_id)
1537        }
1538    }
1539
1540    impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1541        type Id = I;
1542        type SharingState = SharingState;
1543        type Inserter<'a>
1544            = &'a mut Vec<I>
1545        where
1546            I: 'a;
1547
1548        fn new(sharing_state: &SharingState, id: I) -> Self {
1549            Self { sharing_state: *sharing_state, entries: vec![id] }
1550        }
1551
1552        fn contains_id(&self, id: &Self::Id) -> bool {
1553            self.entries.contains(id)
1554        }
1555
1556        fn try_get_inserter<'a, 'b>(
1557            &'b mut self,
1558            new_sharing_state: &'a SharingState,
1559        ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1560            (self.sharing_state == *new_sharing_state)
1561                .then_some(&mut self.entries)
1562                .ok_or(IncompatibleError)
1563        }
1564
1565        fn could_insert(&self, new_sharing_state: &SharingState) -> Result<(), IncompatibleError> {
1566            (self.sharing_state == *new_sharing_state).then_some(()).ok_or(IncompatibleError)
1567        }
1568
1569        fn remove_by_id(&mut self, id: I) -> RemoveResult {
1570            let index = self.entries.iter().position(|i| i == &id).expect("did not find id");
1571            let _: I = self.entries.swap_remove(index);
1572            if self.entries.is_empty() { RemoveResult::IsLast } else { RemoveResult::Success }
1573        }
1574    }
1575
1576    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1577        SocketMapConflictPolicy<A, SharingState, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1578        for FakeSpec
1579    {
1580        fn check_insert_conflicts(
1581            new_sharing_state: &SharingState,
1582            addr: &A,
1583            socketmap: &SocketMap<
1584                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1585                Bound<FakeSpec>,
1586            >,
1587        ) -> Result<(), InsertError> {
1588            let dest: AddrVec<_, _, _> = addr.clone().into();
1589            if dest.iter_shadows().any(|a| {
1590                let entry = socketmap.get(&a);
1591                match entry {
1592                    Some(Bound::Listen(Multiple { sharing_state, .. }))
1593                    | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1594                        !sharing_state.can_share_with(new_sharing_state)
1595                    }
1596                    None => false,
1597                }
1598            }) {
1599                return Err(InsertError::ShadowAddrExists);
1600            }
1601
1602            match socketmap.get(&dest) {
1603                Some(Bound::Listen(Multiple { sharing_state, .. }))
1604                | Some(Bound::Conn(Multiple { sharing_state, .. })) => {
1605                    // Require that all sockets inserted in a `Multiple` entry
1606                    // have the same sharing state.
1607                    if sharing_state != new_sharing_state {
1608                        return Err(InsertError::Exists);
1609                    }
1610                }
1611                None => (),
1612            }
1613
1614            if socketmap
1615                .descendant_counts(&dest)
1616                .any(|(sharing_state, _count)| !sharing_state.can_share_with(new_sharing_state))
1617            {
1618                Err(InsertError::ShadowerExists)
1619            } else {
1620                Ok(())
1621            }
1622        }
1623    }
1624
1625    impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1626        fn try_update_sharing(
1627            &mut self,
1628            id: Self::Id,
1629            new_sharing_state: &Self::SharingState,
1630        ) -> Result<(), IncompatibleError> {
1631            if self.sharing_state == *new_sharing_state {
1632                return Ok(());
1633            }
1634
1635            // Preserve the invariant that all sockets inserted in a `Multiple`
1636            // entry have the same sharing state. That means we can't change
1637            // the sharing state of all the sockets at the address unless there
1638            // is exactly one!
1639            if self.entries.len() != 1 {
1640                return Err(IncompatibleError);
1641            }
1642            assert!(self.entries.contains(&id));
1643            self.sharing_state = *new_sharing_state;
1644            Ok(())
1645        }
1646    }
1647
1648    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1649        SocketMapUpdateSharingPolicy<
1650            A,
1651            SharingState,
1652            Ipv4,
1653            FakeWeakDeviceId<FakeDeviceId>,
1654            FakeAddrSpec,
1655        > for FakeSpec
1656    {
1657        fn allows_sharing_update(
1658            _socketmap: &SocketMap<
1659                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1660                Bound<Self>,
1661            >,
1662            _addr: &A,
1663            _old_sharing: &SharingState,
1664            _new_sharing_state: &SharingState,
1665        ) -> Result<(), UpdateSharingError> {
1666            Ok(())
1667        }
1668    }
1669
1670    const LISTENER_ADDR: ListenerAddr<
1671        ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1672        FakeWeakDeviceId<FakeDeviceId>,
1673    > = ListenerAddr {
1674        ip: ListenerIpAddr {
1675            addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1676            identifier: NonZeroU16::new(1).unwrap(),
1677        },
1678        device: None,
1679    };
1680
1681    const CONN_ADDR: ConnAddr<
1682        ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1683        FakeWeakDeviceId<FakeDeviceId>,
1684    > = ConnAddr {
1685        ip: ConnIpAddr {
1686            local: (
1687                unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1688                NonZeroU16::new(1).unwrap(),
1689            ),
1690            remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1691        },
1692        device: None,
1693    };
1694
1695    #[test]
1696    fn bound_insert_get_remove_listener() {
1697        set_logger_for_test();
1698        let mut bound = FakeBoundSocketMap::default();
1699        let mut fake_id_gen = FakeSocketIdGen::default();
1700
1701        let addr = LISTENER_ADDR;
1702
1703        let id = {
1704            let entry = bound
1705                .listeners_mut()
1706                .try_insert(addr, SharingState::exclusive('v'), Listener(fake_id_gen.next()))
1707                .unwrap();
1708            assert_eq!(entry.get_addr(), &addr);
1709            entry.id().clone()
1710        };
1711
1712        assert_eq!(
1713            bound.listeners().get_by_addr(&addr),
1714            Some(&Multiple::new_exclusive('v', vec![id]))
1715        );
1716
1717        assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1718        assert_eq!(bound.listeners().get_by_addr(&addr), None);
1719    }
1720
1721    #[test]
1722    fn bound_insert_get_remove_conn() {
1723        set_logger_for_test();
1724        let mut bound = FakeBoundSocketMap::default();
1725        let mut fake_id_gen = FakeSocketIdGen::default();
1726
1727        let addr = CONN_ADDR;
1728
1729        let id = {
1730            let entry = bound
1731                .conns_mut()
1732                .try_insert(addr, SharingState::exclusive('v'), Conn(fake_id_gen.next()))
1733                .unwrap();
1734            assert_eq!(entry.get_addr(), &addr);
1735            entry.id().clone()
1736        };
1737
1738        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('v', vec![id])));
1739
1740        assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1741        assert_eq!(bound.conns().get_by_addr(&addr), None);
1742    }
1743
1744    #[test]
1745    fn bound_iter_addrs() {
1746        set_logger_for_test();
1747        let mut bound = FakeBoundSocketMap::default();
1748        let mut fake_id_gen = FakeSocketIdGen::default();
1749
1750        let listener_addrs = [
1751            (Some(net_ip_v4!("1.1.1.1")), 1),
1752            (Some(net_ip_v4!("2.2.2.2")), 2),
1753            (Some(net_ip_v4!("1.1.1.1")), 3),
1754            (None, 4),
1755        ]
1756        .map(|(ip, identifier)| ListenerAddr {
1757            device: None,
1758            ip: ListenerIpAddr {
1759                addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1760                identifier: NonZeroU16::new(identifier).unwrap(),
1761            },
1762        });
1763        let conn_addrs = [
1764            (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1765            (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1766        ]
1767        .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1768            ip: ConnIpAddr {
1769                local: (
1770                    SocketIpAddr::new(local_ip).unwrap(),
1771                    NonZeroU16::new(local_identifier).unwrap(),
1772                ),
1773                remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1774            },
1775            device: None,
1776        });
1777
1778        for addr in listener_addrs.iter().cloned() {
1779            let _entry = bound
1780                .listeners_mut()
1781                .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1782                .unwrap();
1783        }
1784        for addr in conn_addrs.iter().cloned() {
1785            let _entry = bound
1786                .conns_mut()
1787                .try_insert(addr, SharingState::exclusive('a'), Conn(fake_id_gen.next()))
1788                .unwrap();
1789        }
1790        let expected_addrs = listener_addrs
1791            .into_iter()
1792            .map(Into::into)
1793            .chain(conn_addrs.into_iter().map(Into::into))
1794            .collect::<HashSet<_>>();
1795
1796        assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1797    }
1798
1799    #[test]
1800    fn try_insert_with_callback_not_called_on_error() {
1801        // TODO(https://fxbug.dev/42076891): remove this test along with
1802        // try_insert_with.
1803        set_logger_for_test();
1804        let mut bound = FakeBoundSocketMap::default();
1805        let addr = LISTENER_ADDR;
1806
1807        // Insert a listener so that future calls can conflict.
1808        let _: &Listener = bound
1809            .listeners_mut()
1810            .try_insert(addr, SharingState::exclusive('a'), Listener(0))
1811            .unwrap()
1812            .id();
1813
1814        // All of the below try_insert_with calls should fail, but more
1815        // importantly, they should not call the `make_id` callback (because it
1816        // is only called once success is certain).
1817        fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1818            panic!("should never be called");
1819        }
1820
1821        assert_matches!(
1822            bound.listeners_mut().try_insert_with(
1823                addr,
1824                SharingState::exclusive('b'),
1825                is_never_called
1826            ),
1827            Err(InsertError::Exists)
1828        );
1829        assert_matches!(
1830            bound.listeners_mut().try_insert_with(
1831                ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1832                SharingState::exclusive('b'),
1833                is_never_called
1834            ),
1835            Err(InsertError::ShadowAddrExists)
1836        );
1837        assert_matches!(
1838            bound.conns_mut().try_insert_with(
1839                ConnAddr {
1840                    device: None,
1841                    ip: ConnIpAddr {
1842                        local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1843                        remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1844                    },
1845                },
1846                SharingState::exclusive('b'),
1847                is_never_called,
1848            ),
1849            Err(InsertError::ShadowAddrExists)
1850        );
1851    }
1852
1853    #[test]
1854    fn insert_listener_conflict_with_listener() {
1855        set_logger_for_test();
1856        let mut bound = FakeBoundSocketMap::default();
1857        let mut fake_id_gen = FakeSocketIdGen::default();
1858        let addr = LISTENER_ADDR;
1859
1860        let _: &Listener = bound
1861            .listeners_mut()
1862            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1863            .unwrap()
1864            .id();
1865        assert_matches!(
1866            bound.listeners_mut().try_insert(
1867                addr,
1868                SharingState::exclusive('b'),
1869                Listener(fake_id_gen.next())
1870            ),
1871            Err(InsertError::Exists)
1872        );
1873    }
1874
1875    #[test]
1876    fn insert_listener_conflict_with_shadower() {
1877        set_logger_for_test();
1878        let mut bound = FakeBoundSocketMap::default();
1879        let mut fake_id_gen = FakeSocketIdGen::default();
1880        let addr = LISTENER_ADDR;
1881        let shadows_addr = {
1882            assert_eq!(addr.device, None);
1883            ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1884        };
1885
1886        let _: &Listener = bound
1887            .listeners_mut()
1888            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1889            .unwrap()
1890            .id();
1891        assert_matches!(
1892            bound.listeners_mut().try_insert(
1893                shadows_addr,
1894                SharingState::exclusive('b'),
1895                Listener(fake_id_gen.next())
1896            ),
1897            Err(InsertError::ShadowAddrExists)
1898        );
1899    }
1900
1901    #[test]
1902    fn insert_conn_conflict_with_listener() {
1903        set_logger_for_test();
1904        let mut bound = FakeBoundSocketMap::default();
1905        let mut fake_id_gen = FakeSocketIdGen::default();
1906        let addr = LISTENER_ADDR;
1907        let shadows_addr = ConnAddr {
1908            device: None,
1909            ip: ConnIpAddr {
1910                local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1911                remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1912            },
1913        };
1914
1915        let _: &Listener = bound
1916            .listeners_mut()
1917            .try_insert(addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
1918            .unwrap()
1919            .id();
1920        assert_matches!(
1921            bound.conns_mut().try_insert(
1922                shadows_addr,
1923                SharingState::exclusive('b'),
1924                Conn(fake_id_gen.next())
1925            ),
1926            Err(InsertError::ShadowAddrExists)
1927        );
1928    }
1929
1930    #[test]
1931    fn insert_and_remove_listener() {
1932        set_logger_for_test();
1933        let mut bound = FakeBoundSocketMap::default();
1934        let mut fake_id_gen = FakeSocketIdGen::default();
1935        let addr = LISTENER_ADDR;
1936
1937        let a = bound
1938            .listeners_mut()
1939            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1940            .unwrap()
1941            .id()
1942            .clone();
1943        let b = bound
1944            .listeners_mut()
1945            .try_insert(addr, SharingState::exclusive('x'), Listener(fake_id_gen.next()))
1946            .unwrap()
1947            .id()
1948            .clone();
1949        assert_ne!(a, b);
1950
1951        assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1952        assert_eq!(
1953            bound.listeners().get_by_addr(&addr),
1954            Some(&Multiple::new_exclusive('x', vec![b]))
1955        );
1956    }
1957
1958    #[test]
1959    fn insert_and_remove_conn() {
1960        set_logger_for_test();
1961        let mut bound = FakeBoundSocketMap::default();
1962        let mut fake_id_gen = FakeSocketIdGen::default();
1963        let addr = CONN_ADDR;
1964
1965        let a = bound
1966            .conns_mut()
1967            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1968            .unwrap()
1969            .id()
1970            .clone();
1971        let b = bound
1972            .conns_mut()
1973            .try_insert(addr, SharingState::exclusive('x'), Conn(fake_id_gen.next()))
1974            .unwrap()
1975            .id()
1976            .clone();
1977        assert_ne!(a, b);
1978
1979        assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1980        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple::new_exclusive('x', vec![b])));
1981    }
1982
1983    #[test]
1984    fn update_listener_to_shadowed_addr_fails() {
1985        let mut bound = FakeBoundSocketMap::default();
1986        let mut fake_id_gen = FakeSocketIdGen::default();
1987
1988        let first_addr = LISTENER_ADDR;
1989        let second_addr = ListenerAddr {
1990            ip: ListenerIpAddr {
1991                addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1992                ..LISTENER_ADDR.ip
1993            },
1994            ..LISTENER_ADDR
1995        };
1996        let both_shadow = ListenerAddr {
1997            ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1998            device: None,
1999        };
2000
2001        let first = bound
2002            .listeners_mut()
2003            .try_insert(first_addr, SharingState::exclusive('a'), Listener(fake_id_gen.next()))
2004            .unwrap()
2005            .id()
2006            .clone();
2007        let second = bound
2008            .listeners_mut()
2009            .try_insert(second_addr, SharingState::exclusive('b'), Listener(fake_id_gen.next()))
2010            .unwrap()
2011            .id()
2012            .clone();
2013
2014        // Moving from (1, "aaa") to (1, None) should fail since it is shadowed
2015        // by (1, "yyy"), and vise versa.
2016        let (ExistsError, entry) = bound
2017            .listeners_mut()
2018            .entry(&second, &second_addr)
2019            .unwrap()
2020            .try_update_addr(both_shadow)
2021            .expect_err("update should fail");
2022
2023        // The entry should correspond to `second`.
2024        assert_eq!(entry.id(), &second);
2025        drop(entry);
2026
2027        let (ExistsError, entry) = bound
2028            .listeners_mut()
2029            .entry(&first, &first_addr)
2030            .unwrap()
2031            .try_update_addr(both_shadow)
2032            .expect_err("update should fail");
2033        assert_eq!(entry.get_addr(), &first_addr);
2034    }
2035
2036    #[test]
2037    fn nonexistent_conn_entry() {
2038        let mut map = FakeBoundSocketMap::default();
2039        let mut fake_id_gen = FakeSocketIdGen::default();
2040        let addr = CONN_ADDR;
2041        let conn_id = map
2042            .conns_mut()
2043            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2044            .expect("failed to insert")
2045            .id()
2046            .clone();
2047        assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
2048
2049        assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
2050    }
2051
2052    #[test]
2053    fn update_conn_sharing() {
2054        let mut map = FakeBoundSocketMap::default();
2055        let mut fake_id_gen = FakeSocketIdGen::default();
2056        let addr = CONN_ADDR;
2057        let mut entry = map
2058            .conns_mut()
2059            .try_insert(addr.clone(), SharingState::exclusive('a'), Conn(fake_id_gen.next()))
2060            .expect("failed to insert");
2061
2062        entry
2063            .try_update_sharing(&SharingState::exclusive('a'), SharingState::exclusive('d'))
2064            .expect("worked");
2065        // Updating sharing is only allowed if there are no other occupants at
2066        // the address.
2067        let mut second_conn = map
2068            .conns_mut()
2069            .try_insert(addr.clone(), SharingState::exclusive('d'), Conn(fake_id_gen.next()))
2070            .expect("can insert");
2071        assert_matches!(
2072            second_conn
2073                .try_update_sharing(&SharingState::exclusive('d'), SharingState::exclusive('e')),
2074            Err(UpdateSharingError)
2075        );
2076    }
2077
2078    #[test]
2079    fn lookup_connected() {
2080        let mut map = FakeBoundSocketMap::default();
2081        let mut fake_id_gen = FakeSocketIdGen::default();
2082
2083        let sharing_state = SharingState::shared('a');
2084
2085        let device_id = FakeWeakDeviceId(FakeDeviceId);
2086        let entry1 = map
2087            .conns_mut()
2088            .try_insert(CONN_ADDR, sharing_state, Conn(fake_id_gen.next()))
2089            .expect("failed to insert")
2090            .id()
2091            .clone();
2092        let conn = map
2093            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2094            .expect("lookup should succeed");
2095        assert!(conn.contains_id(&entry1));
2096
2097        // Add a second entry with a device ID. This one should be preferred
2098        // over the first one.
2099        let addr_with_device = ConnAddr { device: Some(device_id), ..CONN_ADDR };
2100        let entry2 = map
2101            .conns_mut()
2102            .try_insert(addr_with_device, sharing_state, Conn(fake_id_gen.next()))
2103            .expect("failed to insert")
2104            .id()
2105            .clone();
2106        let conn = map
2107            .lookup_connected(CONN_ADDR.ip.remote, CONN_ADDR.ip.local, device_id)
2108            .expect("lookup should succeed");
2109        assert!(conn.contains_id(&entry2));
2110    }
2111}