netstack3_base/socket/
base.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! General-purpose socket utilities common to device layer and IP layer
6//! sockets.
7
8use core::convert::Infallible as Never;
9use core::fmt::Debug;
10use core::hash::Hash;
11use core::marker::PhantomData;
12use core::num::NonZeroU16;
13
14use derivative::Derivative;
15use net_types::ip::{GenericOverIp, Ip, IpAddress, IpVersionMarker, Ipv4, Ipv6};
16use net_types::{
17    AddrAndZone, MulticastAddress, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
18};
19
20use crate::data_structures::socketmap::{
21    Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
22};
23use crate::device::{
24    DeviceIdentifier, EitherDeviceId, StrongDeviceIdentifier, WeakDeviceIdentifier,
25};
26use crate::error::{ExistsError, NotFoundError, ZonedAddressError};
27use crate::ip::BroadcastIpExt;
28use crate::socket::address::{
29    AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
30};
31
32/// A dual stack IP extention trait that provides the `OtherVersion` associated
33/// type.
34pub trait DualStackIpExt: Ip {
35    /// The "other" IP version, e.g. [`Ipv4`] for [`Ipv6`] and vice-versa.
36    type OtherVersion: DualStackIpExt<OtherVersion = Self>;
37}
38
39impl DualStackIpExt for Ipv4 {
40    type OtherVersion = Ipv6;
41}
42
43impl DualStackIpExt for Ipv6 {
44    type OtherVersion = Ipv4;
45}
46
47/// A tuple of values for `T` for both `I` and `I::OtherVersion`.
48pub struct DualStackTuple<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> {
49    this_stack: <T as GenericOverIp<I>>::Type,
50    other_stack: <T as GenericOverIp<I::OtherVersion>>::Type,
51    _marker: IpVersionMarker<I>,
52}
53
54impl<I: DualStackIpExt, T: GenericOverIp<I> + GenericOverIp<I::OtherVersion>> DualStackTuple<I, T> {
55    /// Creates a new tuple with `this_stack` and `other_stack` values.
56    pub fn new(this_stack: T, other_stack: <T as GenericOverIp<I::OtherVersion>>::Type) -> Self
57    where
58        T: GenericOverIp<I, Type = T>,
59    {
60        Self { this_stack, other_stack, _marker: IpVersionMarker::new() }
61    }
62
63    /// Retrieves `(this_stack, other_stack)` from the tuple.
64    pub fn into_inner(
65        self,
66    ) -> (<T as GenericOverIp<I>>::Type, <T as GenericOverIp<I::OtherVersion>>::Type) {
67        let Self { this_stack, other_stack, _marker } = self;
68        (this_stack, other_stack)
69    }
70
71    /// Retrieves `this_stack` from the tuple.
72    pub fn into_this_stack(self) -> <T as GenericOverIp<I>>::Type {
73        self.this_stack
74    }
75
76    /// Borrows `this_stack` from the tuple.
77    pub fn this_stack(&self) -> &<T as GenericOverIp<I>>::Type {
78        &self.this_stack
79    }
80
81    /// Retrieves `other_stack` from the tuple.
82    pub fn into_other_stack(self) -> <T as GenericOverIp<I::OtherVersion>>::Type {
83        self.other_stack
84    }
85
86    /// Borrows `other_stack` from the tuple.
87    pub fn other_stack(&self) -> &<T as GenericOverIp<I::OtherVersion>>::Type {
88        &self.other_stack
89    }
90
91    /// Flips the types, making `this_stack` `other_stack` and vice-versa.
92    pub fn flip(self) -> DualStackTuple<I::OtherVersion, T> {
93        let Self { this_stack, other_stack, _marker } = self;
94        DualStackTuple {
95            this_stack: other_stack,
96            other_stack: this_stack,
97            _marker: IpVersionMarker::new(),
98        }
99    }
100
101    /// Casts to IP version `X`.
102    ///
103    /// Given `DualStackTuple` contains complete information for both IP
104    /// versions, it can be easily cast into an arbitrary `X` IP version.
105    ///
106    /// This can be used to tie together type parameters when dealing with dual
107    /// stack sockets. For example, a `DualStackTuple` defined for `SockI` can
108    /// be cast to any `WireI`.
109    pub fn cast<X>(self) -> DualStackTuple<X, T>
110    where
111        X: DualStackIpExt,
112        T: GenericOverIp<X>
113            + GenericOverIp<X::OtherVersion>
114            + GenericOverIp<Ipv4>
115            + GenericOverIp<Ipv6>,
116    {
117        I::map_ip_in(
118            self,
119            |v4| X::map_ip_out(v4, |t| t, |t| t.flip()),
120            |v6| X::map_ip_out(v6, |t| t.flip(), |t| t),
121        )
122    }
123}
124
125impl<
126        I: DualStackIpExt,
127        NewIp: DualStackIpExt,
128        T: GenericOverIp<NewIp>
129            + GenericOverIp<NewIp::OtherVersion>
130            + GenericOverIp<I>
131            + GenericOverIp<I::OtherVersion>,
132    > GenericOverIp<NewIp> for DualStackTuple<I, T>
133{
134    type Type = DualStackTuple<NewIp, T>;
135}
136
137/// Extension trait for `Ip` providing socket-specific functionality.
138pub trait SocketIpExt: Ip {
139    /// `Self::LOOPBACK_ADDRESS`, but wrapped in the `SocketIpAddr` type.
140    const LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR: SocketIpAddr<Self::Addr> = unsafe {
141        // SAFETY: The loopback address is a valid SocketIpAddr, as verified
142        // in the `loopback_addr_is_valid_socket_addr` test.
143        SocketIpAddr::new_from_specified_unchecked(Self::LOOPBACK_ADDRESS)
144    };
145}
146
147impl<I: Ip> SocketIpExt for I {}
148
149#[cfg(test)]
150mod socket_ip_ext_test {
151    use super::*;
152    use ip_test_macro::ip_test;
153
154    #[ip_test(I)]
155    fn loopback_addr_is_valid_socket_addr<I: SocketIpExt>() {
156        // `LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR is defined with the "unchecked"
157        // constructor (which supports const construction). Verify here that the
158        // addr actually satisfies all the requirements (protecting against far
159        // away changes)
160        let _addr = SocketIpAddr::new(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.addr())
161            .expect("loopback address should be a valid SocketIpAddr");
162    }
163}
164
165/// State belonging to either IP stack.
166///
167/// Like `[either::Either]`, but with more helpful variant names.
168///
169/// Note that this type is not optimally type-safe, because `T` and `O` are not
170/// bound by `IP` and `IP::OtherVersion`, respectively. In many cases it may be
171/// more appropriate to define a one-off enum parameterized over `I: Ip`.
172#[derive(Debug, PartialEq, Eq)]
173pub enum EitherStack<T, O> {
174    /// In the current stack version.
175    ThisStack(T),
176    /// In the other version of the stack.
177    OtherStack(O),
178}
179
180impl<T, O> Clone for EitherStack<T, O>
181where
182    T: Clone,
183    O: Clone,
184{
185    #[cfg_attr(feature = "instrumented", track_caller)]
186    fn clone(&self) -> Self {
187        match self {
188            Self::ThisStack(t) => Self::ThisStack(t.clone()),
189            Self::OtherStack(t) => Self::OtherStack(t.clone()),
190        }
191    }
192}
193
194/// Control flow type containing either a dual-stack or non-dual-stack context.
195///
196/// This type exists to provide nice names to the result of
197/// [`BoundStateContext::dual_stack_context`], and to allow generic code to
198/// match on when checking whether a socket protocol and IP version support
199/// dual-stack operation. If dual-stack operation is supported, a
200/// [`MaybeDualStack::DualStack`] value will be held, otherwise a `NonDualStack`
201/// value.
202///
203/// Note that the templated types to not have trait bounds; those are provided
204/// by the trait with the `dual_stack_context` function.
205///
206/// In monomorphized code, this type frequently has exactly one template
207/// parameter that is uninstantiable (it contains an instance of
208/// [`core::convert::Infallible`] or some other empty enum, or a reference to
209/// the same)! That lets the compiler optimize it out completely, creating no
210/// actual runtime overhead.
211#[derive(Debug)]
212#[allow(missing_docs)]
213pub enum MaybeDualStack<DS, NDS> {
214    DualStack(DS),
215    NotDualStack(NDS),
216}
217
218// Implement `GenericOverIp` for a `MaybeDualStack` whose `DS` and `NDS` also
219// implement `GenericOverIp`.
220impl<I: DualStackIpExt, DS: GenericOverIp<I>, NDS: GenericOverIp<I>> GenericOverIp<I>
221    for MaybeDualStack<DS, NDS>
222{
223    type Type = MaybeDualStack<<DS as GenericOverIp<I>>::Type, <NDS as GenericOverIp<I>>::Type>;
224}
225
226/// An error encountered while enabling or disabling dual-stack operation.
227#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
228#[generic_over_ip()]
229pub enum SetDualStackEnabledError {
230    /// A socket can only have dual stack enabled or disabled while unbound.
231    SocketIsBound,
232    /// Similar to [`NotDualStackCapableError`]; the socket's protocol is not
233    /// dual stack capable.
234    NotCapable,
235}
236
237/// An error encountered when attempting to perform dual stack operations on
238/// socket with a non dual stack capable protocol.
239#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
240#[generic_over_ip()]
241pub struct NotDualStackCapableError;
242
243/// Describes which direction(s) of the data path should be shut down.
244#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
245pub struct Shutdown {
246    /// True if the send path is shut down for the owning socket.
247    ///
248    /// If this is true, the socket should not be able to send packets.
249    pub send: bool,
250    /// True if the receive path is shut down for the owning socket.
251    ///
252    /// If this is true, the socket should not be able to receive packets.
253    pub receive: bool,
254}
255
256/// Which direction(s) to shut down for a socket.
257#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
258#[generic_over_ip()]
259pub enum ShutdownType {
260    /// Prevent sending packets on the socket.
261    Send,
262    /// Prevent receiving packets on the socket.
263    Receive,
264    /// Prevent sending and receiving packets on the socket.
265    SendAndReceive,
266}
267
268impl ShutdownType {
269    /// Returns a tuple of booleans for `(shutdown_send, shutdown_receive)`.
270    pub fn to_send_receive(&self) -> (bool, bool) {
271        match self {
272            Self::Send => (true, false),
273            Self::Receive => (false, true),
274            Self::SendAndReceive => (true, true),
275        }
276    }
277
278    /// Creates a [`ShutdownType`] from a pair of bools for send and receive.
279    pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
280        match (send, receive) {
281            (true, false) => Some(Self::Send),
282            (false, true) => Some(Self::Receive),
283            (true, true) => Some(Self::SendAndReceive),
284            (false, false) => None,
285        }
286    }
287}
288
289/// Extensions to IP Address witnesses useful in the context of sockets.
290pub trait SocketIpAddrExt<A: IpAddress>: Witness<A> + ScopeableAddress {
291    /// Determines whether the provided address is underspecified by itself.
292    ///
293    /// Some addresses are ambiguous and so must have a zone identifier in order
294    /// to be used in a socket address. This function returns true for IPv6
295    /// link-local addresses and false for all others.
296    fn must_have_zone(&self) -> bool
297    where
298        Self: Copy,
299    {
300        self.try_into_null_zoned().is_some()
301    }
302
303    /// Converts into a [`AddrAndZone<A, ()>`] if the address requires a zone.
304    ///
305    /// Otherwise returns `None`.
306    fn try_into_null_zoned(self) -> Option<AddrAndZone<Self, ()>> {
307        if self.get().is_loopback() {
308            return None;
309        }
310        AddrAndZone::new(self, ())
311    }
312}
313
314impl<A: IpAddress, W: Witness<A> + ScopeableAddress> SocketIpAddrExt<A> for W {}
315
316/// An extention trait for [`ZonedAddr`].
317pub trait SocketZonedAddrExt<W, A, D> {
318    /// Returns the address and device that should be used for a socket.
319    ///
320    /// Given an address for a socket and an optional device that the socket is
321    /// already bound on, returns the address and device that should be used
322    /// for the socket. If `addr` and `device` require inconsistent devices,
323    /// or if `addr` requires a zone but there is none specified (by `addr` or
324    /// `device`), an error is returned.
325    fn resolve_addr_with_device(
326        self,
327        device: Option<D::Weak>,
328    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
329    where
330        D: StrongDeviceIdentifier;
331}
332
333impl<W, A, D> SocketZonedAddrExt<W, A, D> for ZonedAddr<W, D>
334where
335    W: ScopeableAddress + AsRef<SpecifiedAddr<A>>,
336    A: IpAddress,
337{
338    fn resolve_addr_with_device(
339        self,
340        device: Option<D::Weak>,
341    ) -> Result<(W, Option<EitherDeviceId<D, D::Weak>>), ZonedAddressError>
342    where
343        D: StrongDeviceIdentifier,
344    {
345        let (addr, zone) = self.into_addr_zone();
346        let device = match (zone, device) {
347            (Some(zone), Some(device)) => {
348                if device != zone {
349                    return Err(ZonedAddressError::DeviceZoneMismatch);
350                }
351                Some(EitherDeviceId::Strong(zone))
352            }
353            (Some(zone), None) => Some(EitherDeviceId::Strong(zone)),
354            (None, Some(device)) => Some(EitherDeviceId::Weak(device)),
355            (None, None) => {
356                if addr.as_ref().must_have_zone() {
357                    return Err(ZonedAddressError::RequiredZoneNotProvided);
358                } else {
359                    None
360                }
361            }
362        };
363        Ok((addr, device))
364    }
365}
366
367/// A helper type to verify if applying socket updates is allowed for a given
368/// current state.
369///
370/// The fields in `SocketDeviceUpdate` define the current state,
371/// [`SocketDeviceUpdate::try_update`] applies the verification logic.
372pub struct SocketDeviceUpdate<'a, A: IpAddress, D: WeakDeviceIdentifier> {
373    /// The current local IP address.
374    pub local_ip: Option<&'a SpecifiedAddr<A>>,
375    /// The current remote IP address.
376    pub remote_ip: Option<&'a SpecifiedAddr<A>>,
377    /// The currently bound device.
378    pub old_device: Option<&'a D>,
379}
380
381impl<'a, A: IpAddress, D: WeakDeviceIdentifier> SocketDeviceUpdate<'a, A, D> {
382    /// Checks if an update from `old_device` to `new_device` is allowed,
383    /// returning an error if not.
384    pub fn check_update<N>(
385        self,
386        new_device: Option<&N>,
387    ) -> Result<(), SocketDeviceUpdateNotAllowedError>
388    where
389        D: PartialEq<N>,
390    {
391        let Self { local_ip, remote_ip, old_device } = self;
392        let must_have_zone = local_ip.is_some_and(|a| a.must_have_zone())
393            || remote_ip.is_some_and(|a| a.must_have_zone());
394
395        if !must_have_zone {
396            return Ok(());
397        }
398
399        let old_device = old_device.unwrap_or_else(|| {
400            panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
401        });
402
403        if new_device.is_some_and(|new_device| old_device == new_device) {
404            Ok(())
405        } else {
406            Err(SocketDeviceUpdateNotAllowedError)
407        }
408    }
409}
410
411/// The device can't be updated on a socket.
412pub struct SocketDeviceUpdateNotAllowedError;
413
414/// Specification for the identifiers in an [`AddrVec`].
415///
416/// This is a convenience trait for bundling together the local and remote
417/// identifiers for a protocol.
418pub trait SocketMapAddrSpec {
419    /// The local identifier portion of a socket address.
420    type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
421    /// The remote identifier portion of a socket address.
422    type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
423}
424
425/// Information about the address in a [`ListenerAddr`].
426pub struct ListenerAddrInfo {
427    /// Whether the address has a device bound.
428    pub has_device: bool,
429    /// Whether the listener is on a specified address (as opposed to a blanket
430    /// listener).
431    pub specified_addr: bool,
432}
433
434impl<A: IpAddress, D: DeviceIdentifier, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
435    pub(crate) fn info(&self) -> ListenerAddrInfo {
436        let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
437        ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
438    }
439}
440
441/// Specifies the types parameters for [`BoundSocketMap`] state as a single bundle.
442pub trait SocketMapStateSpec {
443    /// The tag value of a socket address vector entry.
444    ///
445    /// These values are derived from [`Self::ListenerAddrState`] and
446    /// [`Self::ConnAddrState`].
447    type AddrVecTag: Eq + Copy + Debug + 'static;
448
449    /// Returns a the tag for a listener in the socket map.
450    fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
451
452    /// Returns a the tag for a connected socket in the socket map.
453    fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
454
455    /// An identifier for a listening socket.
456    type ListenerId: Clone + Debug;
457    /// An identifier for a connected socket.
458    type ConnId: Clone + Debug;
459
460    /// The state stored for a listening socket that is used to determine
461    /// whether sockets can share an address.
462    type ListenerSharingState: Clone + Debug;
463
464    /// The state stored for a connected socket that is used to determine
465    /// whether sockets can share an address.
466    type ConnSharingState: Clone + Debug;
467
468    /// The state stored for a listener socket address.
469    type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
470        + Debug;
471
472    /// The state stored for a connected socket address.
473    type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
474        + Debug;
475}
476
477/// Error returned by implementations of [`SocketMapAddrStateSpec`] to indicate
478/// incompatible changes to a socket map.
479#[derive(Copy, Clone, Debug, Eq, PartialEq)]
480pub struct IncompatibleError;
481
482/// An inserter into a [`SocketMap`].
483pub trait Inserter<T> {
484    /// Inserts the provided item and consumes `self`.
485    ///
486    /// Inserts a single item and consumes the inserter (thus preventing
487    /// additional insertions).
488    fn insert(self, item: T);
489}
490
491impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
492    fn insert(self, item: T) {
493        self.extend([item])
494    }
495}
496
497impl<T> Inserter<T> for Never {
498    fn insert(self, _: T) {
499        match self {}
500    }
501}
502
503/// Describes an entry in a [`SocketMap`] for a listener or connection address.
504pub trait SocketMapAddrStateSpec {
505    /// The type of ID that can be present at the address.
506    type Id;
507
508    /// The sharing state for the address.
509    ///
510    /// This can be used to determine whether a socket can be inserted at the
511    /// address. Every socket has its own sharing state associated with it,
512    /// though the sharing state is not necessarily stored in the address
513    /// entry.
514    type SharingState;
515
516    /// The type of inserter returned by [`SocketMapAddrStateSpec::try_get_inserter`].
517    type Inserter<'a>: Inserter<Self::Id> + 'a
518    where
519        Self: 'a,
520        Self::Id: 'a;
521
522    /// Creates a new `Self` holding the provided socket with the given new
523    /// sharing state at the specified address.
524    fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
525
526    /// Looks up the ID in self, returning `true` if it is present.
527    fn contains_id(&self, id: &Self::Id) -> bool;
528
529    /// Enables insertion in `self` for a new socket with the provided sharing
530    /// state.
531    ///
532    /// If the new state is incompatible with the existing socket(s),
533    /// implementations of this function should return `Err(IncompatibleError)`.
534    /// If `Ok(x)` is returned, calling `x.insert(y)` will insert `y` into
535    /// `self`.
536    fn try_get_inserter<'a, 'b>(
537        &'b mut self,
538        new_sharing_state: &'a Self::SharingState,
539    ) -> Result<Self::Inserter<'b>, IncompatibleError>;
540
541    /// Returns `Ok` if an entry with the given sharing state could be added
542    /// to `self`.
543    ///
544    /// If this returns `Ok`, `try_get_dest` should succeed.
545    fn could_insert(&self, new_sharing_state: &Self::SharingState)
546        -> Result<(), IncompatibleError>;
547
548    /// Removes the given socket from the existing state.
549    ///
550    /// Implementations should assume that `id` is contained in `self`.
551    fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
552}
553
554/// Provides behavior on updating the sharing state of a [`SocketMap`] entry.
555pub trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
556    /// Attempts to update the sharing state of the address state with id `id`
557    /// to `new_sharing_state`.
558    fn try_update_sharing(
559        &mut self,
560        id: Self::Id,
561        new_sharing_state: &Self::SharingState,
562    ) -> Result<(), IncompatibleError>;
563}
564
565/// Provides conflict detection for a [`SocketMapStateSpec`].
566pub trait SocketMapConflictPolicy<
567    Addr,
568    SharingState,
569    I: Ip,
570    D: DeviceIdentifier,
571    A: SocketMapAddrSpec,
572>: SocketMapStateSpec
573{
574    /// Checks whether a new socket with the provided state can be inserted at
575    /// the given address in the existing socket map, returning an error
576    /// otherwise.
577    ///
578    /// Implementations of this function should check for any potential
579    /// conflicts that would arise when inserting a socket with state
580    /// `new_sharing_state` into a new or existing entry at `addr` in
581    /// `socketmap`.
582    fn check_insert_conflicts(
583        new_sharing_state: &SharingState,
584        addr: &Addr,
585        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
586    ) -> Result<(), InsertError>;
587}
588
589/// Defines the policy for updating the sharing state of entries in the
590/// [`SocketMap`].
591pub trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: DeviceIdentifier, A>:
592    SocketMapConflictPolicy<Addr, SharingState, I, D, A>
593where
594    A: SocketMapAddrSpec,
595{
596    /// Returns whether the entry `addr` in `socketmap` allows the sharing state
597    /// to transition from `old_sharing` to `new_sharing`.
598    fn allows_sharing_update(
599        socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
600        addr: &Addr,
601        old_sharing: &SharingState,
602        new_sharing: &SharingState,
603    ) -> Result<(), UpdateSharingError>;
604}
605
606/// A bound socket state that is either a listener or a connection.
607#[derive(Derivative)]
608#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
609#[allow(missing_docs)]
610pub enum Bound<S: SocketMapStateSpec + ?Sized> {
611    Listen(S::ListenerAddrState),
612    Conn(S::ConnAddrState),
613}
614
615/// An "address vector" type that can hold any address in a [`SocketMap`].
616///
617/// This is a "vector" in the mathematical sense, in that it denotes an address
618/// in a space. Here, the space is the possible addresses to which a socket
619/// receiving IP packets can be bound.
620///
621/// `AddrVec`s are used as keys for the `SocketMap` type. Since an incoming
622/// packet can match more than one address, for each incoming packet there is a
623/// set of possible `AddrVec` keys whose entries (sockets) in a `SocketMap`
624/// might receive the packet.
625///
626/// This set of keys can be ordered by precedence as described in the
627/// documentation for [`AddrVecIter`]. Calling [`IterShadows::iter_shadows`] on
628/// an instance will produce the sequence of addresses it has precedence over.
629#[derive(Derivative)]
630#[derivative(
631    Debug(bound = "D: Debug"),
632    Clone(bound = "D: Clone"),
633    Eq(bound = "D: Eq"),
634    PartialEq(bound = "D: PartialEq"),
635    Hash(bound = "D: Hash")
636)]
637#[allow(missing_docs)]
638pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
639    Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
640    Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
641}
642
643impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
644    Tagged<AddrVec<I, D, A>> for Bound<S>
645{
646    type Tag = S::AddrVecTag;
647    fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
648        match (self, address) {
649            (Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
650            (Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
651                S::connected_tag(device.is_some(), c)
652            }
653            (Bound::Listen(_), AddrVec::Conn(_)) => {
654                unreachable!("found listen state for conn addr")
655            }
656            (Bound::Conn(_), AddrVec::Listen(_)) => {
657                unreachable!("found conn state for listen addr")
658            }
659        }
660    }
661}
662
663impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
664    type IterShadows = AddrVecIter<I, D, A>;
665
666    fn iter_shadows(&self) -> Self::IterShadows {
667        let (socket_ip_addr, device) = match self.clone() {
668            AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
669            AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
670        };
671        let mut iter = match device {
672            Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
673            None => AddrVecIter::without_device(socket_ip_addr),
674        };
675        // Skip the first element, which is always `*self`.
676        assert_eq!(iter.next().as_ref(), Some(self));
677        iter
678    }
679}
680
681/// How a socket is bound on the system.
682#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
683#[allow(missing_docs)]
684pub enum SocketAddrType {
685    AnyListener,
686    SpecificListener,
687    Connected,
688}
689
690impl<'a, A: IpAddress, LI> From<&'a ListenerIpAddr<A, LI>> for SocketAddrType {
691    fn from(ListenerIpAddr { addr, identifier: _ }: &'a ListenerIpAddr<A, LI>) -> Self {
692        match addr {
693            Some(_) => SocketAddrType::SpecificListener,
694            None => SocketAddrType::AnyListener,
695        }
696    }
697}
698
699impl<'a, A: IpAddress, LI, RI> From<&'a ConnIpAddr<A, LI, RI>> for SocketAddrType {
700    fn from(_: &'a ConnIpAddr<A, LI, RI>) -> Self {
701        SocketAddrType::Connected
702    }
703}
704
705/// The result of attempting to remove a socket from a collection of sockets.
706pub enum RemoveResult {
707    /// The value was removed successfully.
708    Success,
709    /// The value is the last value in the collection so the entire collection
710    /// should be removed.
711    IsLast,
712}
713
714#[derive(Derivative)]
715#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
716pub enum SocketId<S: SocketMapStateSpec> {
717    Listener(S::ListenerId),
718    Connection(S::ConnId),
719}
720
721/// A map from socket addresses to sockets.
722///
723/// The types of keys and IDs is determined by the [`SocketMapStateSpec`]
724/// parameter. Each listener and connected socket stores additional state.
725/// Listener and connected sockets are keyed independently, but share the same
726/// address vector space. Conflicts are detected on attempted insertion of new
727/// sockets.
728///
729/// Listener addresses map to listener-address-specific state, and likewise
730/// with connected addresses. Depending on protocol (determined by the
731/// `SocketMapStateSpec` protocol), these address states can hold one or more
732/// socket identifiers (e.g. UDP sockets with `SO_REUSEPORT` set can share an
733/// address).
734#[derive(Derivative)]
735#[derivative(Default(bound = ""))]
736pub struct BoundSocketMap<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
737    addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
738}
739
740impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
741    BoundSocketMap<I, D, A, S>
742{
743    /// Returns the number of entries in the map.
744    pub fn len(&self) -> usize {
745        self.addr_to_state.len()
746    }
747}
748
749/// Uninstantiable tag type for denoting listening sockets.
750pub enum Listener {}
751/// Uninstantiable tag type for denoting connected sockets.
752pub enum Connection {}
753
754/// View struct over one type of sockets in a [`BoundSocketMap`].
755pub struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
756
757impl<
758        'a,
759        I: Ip,
760        D: DeviceIdentifier,
761        SocketType: ConvertSocketMapState<I, D, A, S>,
762        A: SocketMapAddrSpec,
763        S: SocketMapStateSpec,
764    > Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
765where
766    S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
767{
768    /// Returns the state at an address, if there is any.
769    pub fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
770        let Self(addr_to_state, _marker) = self;
771        addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
772            SocketType::from_bound_ref(state)
773                .unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
774        })
775    }
776
777    /// Returns `Ok(())` if a socket could be inserted, otherwise an error.
778    ///
779    /// Goes through a dry run of inserting a socket at the given address and
780    /// with the given sharing state, returning `Ok(())` if the insertion would
781    /// succeed, otherwise the error that would be returned.
782    pub fn could_insert(
783        self,
784        addr: &SocketType::Addr,
785        sharing: &SocketType::SharingState,
786    ) -> Result<(), InsertError> {
787        let Self(addr_to_state, _) = self;
788        match self.get_by_addr(addr) {
789            Some(state) => {
790                state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
791            }
792            None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
793        }
794    }
795}
796
797/// A borrowed state entry in a [`SocketMap`].
798#[derive(Derivative)]
799#[derivative(Debug(bound = ""))]
800pub struct SocketStateEntry<
801    'a,
802    I: Ip,
803    D: DeviceIdentifier,
804    A: SocketMapAddrSpec,
805    S: SocketMapStateSpec,
806    SocketType,
807> {
808    id: SocketId<S>,
809    addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
810    _marker: PhantomData<SocketType>,
811}
812
813impl<
814        'a,
815        I: Ip,
816        D: DeviceIdentifier,
817        SocketType: ConvertSocketMapState<I, D, A, S>,
818        A: SocketMapAddrSpec,
819        S: SocketMapStateSpec
820            + SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
821    > Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
822where
823    SocketType::SharingState: Clone,
824    SocketType::Id: Clone,
825{
826    /// Attempts to insert a new entry into the [`SocketMap`] backing this
827    /// `Sockets`.
828    pub fn try_insert(
829        self,
830        socket_addr: SocketType::Addr,
831        tag_state: SocketType::SharingState,
832        id: SocketType::Id,
833    ) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
834    {
835        self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
836            .map(|(entry, ())| entry)
837    }
838
839    /// Like [`Sockets::try_insert`] but calls `make_id` to create a socket ID
840    /// before inserting into the map.
841    ///
842    /// `make_id` returns type `R` that is returned to the caller on success.
843    pub fn try_insert_with<R>(
844        self,
845        socket_addr: SocketType::Addr,
846        tag_state: SocketType::SharingState,
847        make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
848    ) -> Result<
849        (SocketStateEntry<'a, I, D, A, S, SocketType>, R),
850        (InsertError, SocketType::SharingState),
851    > {
852        let Self(addr_to_state, _) = self;
853        match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
854            Err(e) => return Err((e, tag_state)),
855            Ok(()) => (),
856        };
857
858        let addr = SocketType::to_addr_vec(&socket_addr);
859
860        match addr_to_state.entry(addr) {
861            Entry::Occupied(mut o) => {
862                let (id, ret) = o.map_mut(|bound| {
863                    let bound = match SocketType::from_bound_mut(bound) {
864                        Some(bound) => bound,
865                        None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
866                    };
867                    match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
868                        bound, &tag_state,
869                    ) {
870                        Ok(v) => {
871                            let (id, ret) = make_id(socket_addr, tag_state);
872                            v.insert(id.clone());
873                            Ok((SocketType::to_socket_id(id), ret))
874                        }
875                        Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
876                    }
877                })?;
878                Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
879            }
880            Entry::Vacant(v) => {
881                let (id, ret) = make_id(socket_addr, tag_state.clone());
882                let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
883                    &tag_state,
884                    id.clone(),
885                )));
886                let id = SocketType::to_socket_id(id);
887                Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
888            }
889        }
890    }
891
892    /// Returns a borrowed entry at `id` and `addr`.
893    pub fn entry(
894        self,
895        id: &SocketType::Id,
896        addr: &SocketType::Addr,
897    ) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
898        let Self(addr_to_state, _) = self;
899        let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
900            Entry::Vacant(_) => return None,
901            Entry::Occupied(o) => o,
902        };
903        let state = SocketType::from_bound_ref(addr_entry.get())?;
904
905        state.contains_id(id).then_some(SocketStateEntry {
906            id: SocketType::to_socket_id(id.clone()),
907            addr_entry,
908            _marker: PhantomData::default(),
909        })
910    }
911
912    /// Removes the entry with `id` and `addr`.
913    pub fn remove(self, id: &SocketType::Id, addr: &SocketType::Addr) -> Result<(), NotFoundError> {
914        self.entry(id, addr)
915            .map(|entry| {
916                entry.remove();
917            })
918            .ok_or(NotFoundError)
919    }
920}
921
922/// The error returned when updating the sharing state for a [`SocketMap`] entry
923/// fails.
924#[derive(Debug)]
925pub struct UpdateSharingError;
926
927impl<
928        'a,
929        I: Ip,
930        D: DeviceIdentifier,
931        SocketType: ConvertSocketMapState<I, D, A, S>,
932        A: SocketMapAddrSpec,
933        S: SocketMapStateSpec,
934    > SocketStateEntry<'a, I, D, A, S, SocketType>
935where
936    SocketType::Id: Clone,
937{
938    /// Returns this entry's address.
939    pub fn get_addr(&self) -> &SocketType::Addr {
940        let Self { id: _, addr_entry, _marker } = self;
941        SocketType::from_addr_vec_ref(addr_entry.key())
942    }
943
944    /// Returns this entry's identifier.
945    pub fn id(&self) -> &SocketType::Id {
946        let Self { id, addr_entry: _, _marker } = self;
947        SocketType::from_socket_id_ref(id)
948    }
949
950    /// Attempts to update the address for this entry.
951    pub fn try_update_addr(self, new_addr: SocketType::Addr) -> Result<Self, (ExistsError, Self)> {
952        let Self { id, addr_entry, _marker } = self;
953
954        let new_addrvec = SocketType::to_addr_vec(&new_addr);
955        let old_addr = addr_entry.key().clone();
956        let (addr_state, addr_to_state) = addr_entry.remove_from_map();
957        let addr_to_state = match addr_to_state.entry(new_addrvec) {
958            Entry::Occupied(o) => o.into_map(),
959            Entry::Vacant(v) => {
960                if v.descendant_counts().len() != 0 {
961                    v.into_map()
962                } else {
963                    let new_addr_entry = v.insert(addr_state);
964                    return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
965                }
966            }
967        };
968        let to_restore = addr_state;
969        // Restore the old state before returning an error.
970        let addr_entry = match addr_to_state.entry(old_addr) {
971            Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
972            Entry::Vacant(v) => v.insert(to_restore),
973        };
974        return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
975    }
976
977    /// Removes this entry from the map.
978    pub fn remove(self) {
979        let Self { id, mut addr_entry, _marker } = self;
980        let addr = addr_entry.key().clone();
981        match addr_entry.map_mut(|value| {
982            let value = match SocketType::from_bound_mut(value) {
983                Some(value) => value,
984                None => unreachable!("found {:?} for address {:?}", value, addr),
985            };
986            value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
987        }) {
988            RemoveResult::Success => (),
989            RemoveResult::IsLast => {
990                let _: Bound<S> = addr_entry.remove();
991            }
992        }
993    }
994
995    /// Attempts to update the sharing state for this entry.
996    pub fn try_update_sharing(
997        &mut self,
998        old_sharing_state: &SocketType::SharingState,
999        new_sharing_state: SocketType::SharingState,
1000    ) -> Result<(), UpdateSharingError>
1001    where
1002        SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
1003        S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
1004    {
1005        let Self { id, addr_entry, _marker } = self;
1006        let addr = SocketType::from_addr_vec_ref(addr_entry.key());
1007
1008        S::allows_sharing_update(
1009            addr_entry.get_map(),
1010            addr,
1011            old_sharing_state,
1012            &new_sharing_state,
1013        )?;
1014
1015        addr_entry
1016            .map_mut(|value| {
1017                let value = match SocketType::from_bound_mut(value) {
1018                    Some(value) => value,
1019                    // We shouldn't ever be storing listener state in a bound
1020                    // address, or bound state in a listener address. Doing so means
1021                    // we've got a serious bug.
1022                    None => unreachable!("found invalid state {:?}", value),
1023                };
1024
1025                value.try_update_sharing(
1026                    SocketType::from_socket_id_ref(id).clone(),
1027                    &new_sharing_state,
1028                )
1029            })
1030            .map_err(|IncompatibleError| UpdateSharingError)
1031    }
1032}
1033
1034impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
1035where
1036    AddrVec<I, D, A>: IterShadows,
1037    S: SocketMapStateSpec,
1038{
1039    /// Returns an iterator over the listeners on the socket map.
1040    pub fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1041    where
1042        S: SocketMapConflictPolicy<
1043            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1044            <S as SocketMapStateSpec>::ListenerSharingState,
1045            I,
1046            D,
1047            A,
1048        >,
1049        S::ListenerAddrState:
1050            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1051    {
1052        let Self { addr_to_state } = self;
1053        Sockets(addr_to_state, Default::default())
1054    }
1055
1056    /// Returns a mutable iterator over the listeners on the socket map.
1057    pub fn listeners_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
1058    where
1059        S: SocketMapConflictPolicy<
1060            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1061            <S as SocketMapStateSpec>::ListenerSharingState,
1062            I,
1063            D,
1064            A,
1065        >,
1066        S::ListenerAddrState:
1067            SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
1068    {
1069        let Self { addr_to_state } = self;
1070        Sockets(addr_to_state, Default::default())
1071    }
1072
1073    /// Returns an iterator over the connections on the socket map.
1074    pub fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1075    where
1076        S: SocketMapConflictPolicy<
1077            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1078            <S as SocketMapStateSpec>::ConnSharingState,
1079            I,
1080            D,
1081            A,
1082        >,
1083        S::ConnAddrState:
1084            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1085    {
1086        let Self { addr_to_state } = self;
1087        Sockets(addr_to_state, Default::default())
1088    }
1089
1090    /// Returns a mutable iterator over the connections on the socket map.
1091    pub fn conns_mut(&mut self) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
1092    where
1093        S: SocketMapConflictPolicy<
1094            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1095            <S as SocketMapStateSpec>::ConnSharingState,
1096            I,
1097            D,
1098            A,
1099        >,
1100        S::ConnAddrState:
1101            SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
1102    {
1103        let Self { addr_to_state } = self;
1104        Sockets(addr_to_state, Default::default())
1105    }
1106
1107    #[cfg(test)]
1108    pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
1109        let Self { addr_to_state } = self;
1110        addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
1111    }
1112
1113    /// Gets the number of shadower entries for `addr`.
1114    pub fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
1115        let Self { addr_to_state } = self;
1116        addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
1117    }
1118}
1119
1120/// The type returned by [`BoundSocketMap::iter_receivers`].
1121pub enum FoundSockets<A, It> {
1122    /// A single recipient was found for the address.
1123    Single(A),
1124    /// Indicates the looked-up address was multicast, and holds an iterator of
1125    /// the found receivers.
1126    Multicast(It),
1127}
1128
1129/// A borrowed entry in a [`BoundSocketMap`].
1130#[allow(missing_docs)]
1131#[derive(Debug)]
1132pub enum AddrEntry<'a, I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1133    Listen(&'a S::ListenerAddrState, ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
1134    Conn(
1135        &'a S::ConnAddrState,
1136        ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1137    ),
1138}
1139
1140impl<I, D, A, S> BoundSocketMap<I, D, A, S>
1141where
1142    I: BroadcastIpExt<Addr: MulticastAddress>,
1143    D: DeviceIdentifier,
1144    A: SocketMapAddrSpec,
1145    S: SocketMapStateSpec
1146        + SocketMapConflictPolicy<
1147            ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
1148            <S as SocketMapStateSpec>::ListenerSharingState,
1149            I,
1150            D,
1151            A,
1152        > + SocketMapConflictPolicy<
1153            ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
1154            <S as SocketMapStateSpec>::ConnSharingState,
1155            I,
1156            D,
1157            A,
1158        >,
1159{
1160    /// Finds the socket(s) that should receive an incoming packet.
1161    ///
1162    /// Uses the provided addresses and receiving device to look up sockets that
1163    /// should receive a matching incoming packet. Returns `None` if no sockets
1164    /// were found, or the results of the lookup.
1165    pub fn iter_receivers(
1166        &self,
1167        (src_ip, src_port): (Option<SocketIpAddr<I::Addr>>, Option<A::RemoteIdentifier>),
1168        (dst_ip, dst_port): (SocketIpAddr<I::Addr>, A::LocalIdentifier),
1169        device: D,
1170        broadcast: Option<I::BroadcastMarker>,
1171    ) -> Option<
1172        FoundSockets<
1173            AddrEntry<'_, I, D, A, S>,
1174            impl Iterator<Item = AddrEntry<'_, I, D, A, S>> + '_,
1175        >,
1176    > {
1177        let mut matching_entries = AddrVecIter::with_device(
1178            match (src_ip, src_port) {
1179                (Some(specified_src_ip), Some(src_port)) => {
1180                    ConnIpAddr { local: (dst_ip, dst_port), remote: (specified_src_ip, src_port) }
1181                        .into()
1182                }
1183                _ => ListenerIpAddr { addr: Some(dst_ip), identifier: dst_port }.into(),
1184            },
1185            device,
1186        )
1187        .filter_map(move |addr: AddrVec<I, D, A>| match addr {
1188            AddrVec::Listen(l) => {
1189                self.listeners().get_by_addr(&l).map(|state| AddrEntry::Listen(state, l))
1190            }
1191            AddrVec::Conn(c) => self.conns().get_by_addr(&c).map(|state| AddrEntry::Conn(state, c)),
1192        });
1193
1194        if broadcast.is_some() || dst_ip.addr().is_multicast() {
1195            Some(FoundSockets::Multicast(matching_entries))
1196        } else {
1197            let single_entry: Option<_> = matching_entries.next();
1198            single_entry.map(FoundSockets::Single)
1199        }
1200    }
1201}
1202
1203/// Errors observed by [`SocketMapConflictPolicy`].
1204#[derive(Debug, Eq, PartialEq)]
1205pub enum InsertError {
1206    /// A shadow address exists for the entry.
1207    ShadowAddrExists,
1208    /// Entry already exists.
1209    Exists,
1210    /// A shadower exists for the entry.
1211    ShadowerExists,
1212    /// An indirect conflict was detected.
1213    IndirectConflict,
1214}
1215
1216/// Helper trait for converting between [`AddrVec`] and [`Bound`] and their
1217/// variants.
1218pub trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
1219    type Id;
1220    type SharingState;
1221    type Addr: Debug;
1222    type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
1223
1224    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
1225    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
1226    fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
1227    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
1228    fn to_bound(state: Self::AddrState) -> Bound<S>;
1229    fn to_socket_id(id: Self::Id) -> SocketId<S>;
1230    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
1231}
1232
1233impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1234    ConvertSocketMapState<I, D, A, S> for Listener
1235{
1236    type Id = S::ListenerId;
1237    type SharingState = S::ListenerSharingState;
1238    type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
1239    type AddrState = S::ListenerAddrState;
1240    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1241        AddrVec::Listen(addr.clone())
1242    }
1243
1244    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1245        match addr {
1246            AddrVec::Listen(l) => l,
1247            AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
1248        }
1249    }
1250
1251    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
1252        match bound {
1253            Bound::Listen(l) => Some(l),
1254            Bound::Conn(_c) => None,
1255        }
1256    }
1257
1258    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
1259        match bound {
1260            Bound::Listen(l) => Some(l),
1261            Bound::Conn(_c) => None,
1262        }
1263    }
1264
1265    fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
1266        Bound::Listen(state)
1267    }
1268    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1269        match id {
1270            SocketId::Listener(id) => id,
1271            SocketId::Connection(_) => unreachable!("connection ID for listener"),
1272        }
1273    }
1274    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1275        SocketId::Listener(id)
1276    }
1277}
1278
1279impl<I: Ip, D: DeviceIdentifier, A: SocketMapAddrSpec, S: SocketMapStateSpec>
1280    ConvertSocketMapState<I, D, A, S> for Connection
1281{
1282    type Id = S::ConnId;
1283    type SharingState = S::ConnSharingState;
1284    type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
1285    type AddrState = S::ConnAddrState;
1286    fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
1287        AddrVec::Conn(addr.clone())
1288    }
1289
1290    fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
1291        match addr {
1292            AddrVec::Conn(c) => c,
1293            AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
1294        }
1295    }
1296
1297    fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
1298        match bound {
1299            Bound::Listen(_l) => None,
1300            Bound::Conn(c) => Some(c),
1301        }
1302    }
1303
1304    fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
1305        match bound {
1306            Bound::Listen(_l) => None,
1307            Bound::Conn(c) => Some(c),
1308        }
1309    }
1310
1311    fn to_bound(state: S::ConnAddrState) -> Bound<S> {
1312        Bound::Conn(state)
1313    }
1314
1315    fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
1316        match id {
1317            SocketId::Connection(id) => id,
1318            SocketId::Listener(_) => unreachable!("listener ID for connection"),
1319        }
1320    }
1321    fn to_socket_id(id: Self::Id) -> SocketId<S> {
1322        SocketId::Connection(id)
1323    }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328    use alloc::collections::HashSet;
1329    use alloc::vec;
1330    use alloc::vec::Vec;
1331
1332    use assert_matches::assert_matches;
1333    use net_declare::{net_ip_v4, net_ip_v6};
1334    use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
1335    use test_case::test_case;
1336
1337    use crate::device::testutil::{FakeDeviceId, FakeWeakDeviceId};
1338    use crate::testutil::set_logger_for_test;
1339
1340    use super::*;
1341
1342    #[test_case(net_ip_v4!("8.8.8.8"))]
1343    #[test_case(net_ip_v4!("127.0.0.1"))]
1344    #[test_case(net_ip_v4!("127.0.8.9"))]
1345    #[test_case(net_ip_v4!("224.1.2.3"))]
1346    fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
1347        // No IPv4 addresses are allowed to have a zone.
1348        let addr = SpecifiedAddr::new(addr).unwrap();
1349        assert_eq!(addr.must_have_zone(), false);
1350    }
1351
1352    #[test_case(net_ip_v6!("1::2:3"), false)]
1353    #[test_case(net_ip_v6!("::1"), false; "localhost")]
1354    #[test_case(net_ip_v6!("1::"), false)]
1355    #[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
1356    #[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
1357    #[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
1358    #[test_case(net_ip_v6!("fe80::1"), true)]
1359    fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
1360        // Only link-local unicast and multicast addresses are allowed to have
1361        // zones.
1362        let addr = SpecifiedAddr::new(addr).unwrap();
1363        assert_eq!(addr.must_have_zone(), must_have);
1364    }
1365
1366    #[test]
1367    fn try_into_null_zoned_ipv6() {
1368        assert_eq!(Ipv6::LOOPBACK_ADDRESS.try_into_null_zoned(), None);
1369        let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
1370        const ZONE: u32 = 5;
1371        assert_eq!(
1372            zoned.try_into_null_zoned().map(|a| a.map_zone(|()| ZONE)),
1373            Some(AddrAndZone::new(zoned, ZONE).unwrap())
1374        );
1375    }
1376
1377    enum FakeSpec {}
1378
1379    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1380    struct Listener(usize);
1381
1382    #[derive(PartialEq, Eq, Debug)]
1383    struct Multiple<T>(char, Vec<T>);
1384
1385    impl<T> Multiple<T> {
1386        fn tag(&self) -> char {
1387            let Multiple(c, _) = self;
1388            *c
1389        }
1390    }
1391
1392    #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
1393    struct Conn(usize);
1394
1395    enum FakeAddrSpec {}
1396
1397    impl SocketMapAddrSpec for FakeAddrSpec {
1398        type LocalIdentifier = NonZeroU16;
1399        type RemoteIdentifier = ();
1400    }
1401
1402    impl SocketMapStateSpec for FakeSpec {
1403        type AddrVecTag = char;
1404
1405        type ListenerId = Listener;
1406        type ConnId = Conn;
1407
1408        type ListenerSharingState = char;
1409        type ConnSharingState = char;
1410
1411        type ListenerAddrState = Multiple<Listener>;
1412        type ConnAddrState = Multiple<Conn>;
1413
1414        fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
1415            state.tag()
1416        }
1417
1418        fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
1419            state.tag()
1420        }
1421    }
1422
1423    type FakeBoundSocketMap =
1424        BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
1425
1426    /// Generator for unique socket IDs that don't have any state.
1427    ///
1428    /// Calling [`FakeSocketIdGen::next`] returns a unique ID.
1429    #[derive(Default)]
1430    struct FakeSocketIdGen {
1431        next_id: usize,
1432    }
1433
1434    impl FakeSocketIdGen {
1435        fn next(&mut self) -> usize {
1436            let next_next_id = self.next_id + 1;
1437            core::mem::replace(&mut self.next_id, next_next_id)
1438        }
1439    }
1440
1441    impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
1442        type Id = I;
1443        type SharingState = char;
1444        type Inserter<'a>
1445            = &'a mut Vec<I>
1446        where
1447            I: 'a;
1448
1449        fn new(new_sharing_state: &char, id: I) -> Self {
1450            Self(*new_sharing_state, vec![id])
1451        }
1452
1453        fn contains_id(&self, id: &Self::Id) -> bool {
1454            self.1.contains(id)
1455        }
1456
1457        fn try_get_inserter<'a, 'b>(
1458            &'b mut self,
1459            new_state: &'a char,
1460        ) -> Result<Self::Inserter<'b>, IncompatibleError> {
1461            let Self(c, v) = self;
1462            (new_state == c).then_some(v).ok_or(IncompatibleError)
1463        }
1464
1465        fn could_insert(
1466            &self,
1467            new_sharing_state: &Self::SharingState,
1468        ) -> Result<(), IncompatibleError> {
1469            let Self(c, _) = self;
1470            (new_sharing_state == c).then_some(()).ok_or(IncompatibleError)
1471        }
1472
1473        fn remove_by_id(&mut self, id: I) -> RemoveResult {
1474            let Self(_, v) = self;
1475            let index = v.iter().position(|i| i == &id).expect("did not find id");
1476            let _: I = v.swap_remove(index);
1477            if v.is_empty() {
1478                RemoveResult::IsLast
1479            } else {
1480                RemoveResult::Success
1481            }
1482        }
1483    }
1484
1485    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1486        SocketMapConflictPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1487        for FakeSpec
1488    {
1489        fn check_insert_conflicts(
1490            new_state: &char,
1491            addr: &A,
1492            socketmap: &SocketMap<
1493                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1494                Bound<FakeSpec>,
1495            >,
1496        ) -> Result<(), InsertError> {
1497            let dest = addr.clone().into();
1498            if dest.iter_shadows().any(|a| socketmap.get(&a).is_some()) {
1499                return Err(InsertError::ShadowAddrExists);
1500            }
1501            match socketmap.get(&dest) {
1502                Some(Bound::Listen(Multiple(c, _))) | Some(Bound::Conn(Multiple(c, _))) => {
1503                    // Require that all sockets inserted in a `Multiple` entry
1504                    // have the same sharing state.
1505                    if c != new_state {
1506                        return Err(InsertError::Exists);
1507                    }
1508                }
1509                None => (),
1510            }
1511            if socketmap.descendant_counts(&dest).len() != 0 {
1512                Err(InsertError::ShadowerExists)
1513            } else {
1514                Ok(())
1515            }
1516        }
1517    }
1518
1519    impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
1520        fn try_update_sharing(
1521            &mut self,
1522            id: Self::Id,
1523            new_sharing_state: &Self::SharingState,
1524        ) -> Result<(), IncompatibleError> {
1525            let Self(sharing, v) = self;
1526            if new_sharing_state == sharing {
1527                return Ok(());
1528            }
1529
1530            // Preserve the invariant that all sockets inserted in a `Multiple`
1531            // entry have the same sharing state. That means we can't change
1532            // the sharing state of all the sockets at the address unless there
1533            // is exactly one!
1534            if v.len() != 1 {
1535                return Err(IncompatibleError);
1536            }
1537            assert!(v.contains(&id));
1538            *sharing = *new_sharing_state;
1539            Ok(())
1540        }
1541    }
1542
1543    impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
1544        SocketMapUpdateSharingPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
1545        for FakeSpec
1546    {
1547        fn allows_sharing_update(
1548            _socketmap: &SocketMap<
1549                AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
1550                Bound<Self>,
1551            >,
1552            _addr: &A,
1553            _old_sharing: &char,
1554            _new_sharing_state: &char,
1555        ) -> Result<(), UpdateSharingError> {
1556            Ok(())
1557        }
1558    }
1559
1560    const LISTENER_ADDR: ListenerAddr<
1561        ListenerIpAddr<Ipv4Addr, NonZeroU16>,
1562        FakeWeakDeviceId<FakeDeviceId>,
1563    > = ListenerAddr {
1564        ip: ListenerIpAddr {
1565            addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
1566            identifier: NonZeroU16::new(1).unwrap(),
1567        },
1568        device: None,
1569    };
1570
1571    const CONN_ADDR: ConnAddr<
1572        ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
1573        FakeWeakDeviceId<FakeDeviceId>,
1574    > = ConnAddr {
1575        ip: ConnIpAddr {
1576            local: (
1577                unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
1578                NonZeroU16::new(1).unwrap(),
1579            ),
1580            remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
1581        },
1582        device: None,
1583    };
1584
1585    #[test]
1586    fn bound_insert_get_remove_listener() {
1587        set_logger_for_test();
1588        let mut bound = FakeBoundSocketMap::default();
1589        let mut fake_id_gen = FakeSocketIdGen::default();
1590
1591        let addr = LISTENER_ADDR;
1592
1593        let id = {
1594            let entry =
1595                bound.listeners_mut().try_insert(addr, 'v', Listener(fake_id_gen.next())).unwrap();
1596            assert_eq!(entry.get_addr(), &addr);
1597            entry.id().clone()
1598        };
1599
1600        assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
1601
1602        assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
1603        assert_eq!(bound.listeners().get_by_addr(&addr), None);
1604    }
1605
1606    #[test]
1607    fn bound_insert_get_remove_conn() {
1608        set_logger_for_test();
1609        let mut bound = FakeBoundSocketMap::default();
1610        let mut fake_id_gen = FakeSocketIdGen::default();
1611
1612        let addr = CONN_ADDR;
1613
1614        let id = {
1615            let entry = bound.conns_mut().try_insert(addr, 'v', Conn(fake_id_gen.next())).unwrap();
1616            assert_eq!(entry.get_addr(), &addr);
1617            entry.id().clone()
1618        };
1619
1620        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
1621
1622        assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
1623        assert_eq!(bound.conns().get_by_addr(&addr), None);
1624    }
1625
1626    #[test]
1627    fn bound_iter_addrs() {
1628        set_logger_for_test();
1629        let mut bound = FakeBoundSocketMap::default();
1630        let mut fake_id_gen = FakeSocketIdGen::default();
1631
1632        let listener_addrs = [
1633            (Some(net_ip_v4!("1.1.1.1")), 1),
1634            (Some(net_ip_v4!("2.2.2.2")), 2),
1635            (Some(net_ip_v4!("1.1.1.1")), 3),
1636            (None, 4),
1637        ]
1638        .map(|(ip, identifier)| ListenerAddr {
1639            device: None,
1640            ip: ListenerIpAddr {
1641                addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
1642                identifier: NonZeroU16::new(identifier).unwrap(),
1643            },
1644        });
1645        let conn_addrs = [
1646            (net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
1647            (net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
1648        ]
1649        .map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
1650            ip: ConnIpAddr {
1651                local: (
1652                    SocketIpAddr::new(local_ip).unwrap(),
1653                    NonZeroU16::new(local_identifier).unwrap(),
1654                ),
1655                remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
1656            },
1657            device: None,
1658        });
1659
1660        for addr in listener_addrs.iter().cloned() {
1661            let _entry =
1662                bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap();
1663        }
1664        for addr in conn_addrs.iter().cloned() {
1665            let _entry = bound.conns_mut().try_insert(addr, 'a', Conn(fake_id_gen.next())).unwrap();
1666        }
1667        let expected_addrs = listener_addrs
1668            .into_iter()
1669            .map(Into::into)
1670            .chain(conn_addrs.into_iter().map(Into::into))
1671            .collect::<HashSet<_>>();
1672
1673        assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
1674    }
1675
1676    #[test]
1677    fn try_insert_with_callback_not_called_on_error() {
1678        // TODO(https://fxbug.dev/42076891): remove this test along with
1679        // try_insert_with.
1680        set_logger_for_test();
1681        let mut bound = FakeBoundSocketMap::default();
1682        let addr = LISTENER_ADDR;
1683
1684        // Insert a listener so that future calls can conflict.
1685        let _: &Listener = bound.listeners_mut().try_insert(addr, 'a', Listener(0)).unwrap().id();
1686
1687        // All of the below try_insert_with calls should fail, but more
1688        // importantly, they should not call the `make_id` callback (because it
1689        // is only called once success is certain).
1690        fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
1691            panic!("should never be called");
1692        }
1693
1694        assert_matches!(
1695            bound.listeners_mut().try_insert_with(addr, 'b', is_never_called),
1696            Err((InsertError::Exists, _))
1697        );
1698        assert_matches!(
1699            bound.listeners_mut().try_insert_with(
1700                ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
1701                'b',
1702                is_never_called
1703            ),
1704            Err((InsertError::ShadowAddrExists, _))
1705        );
1706        assert_matches!(
1707            bound.conns_mut().try_insert_with(
1708                ConnAddr {
1709                    device: None,
1710                    ip: ConnIpAddr {
1711                        local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1712                        remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1713                    },
1714                },
1715                'b',
1716                is_never_called,
1717            ),
1718            Err((InsertError::ShadowAddrExists, _))
1719        );
1720    }
1721
1722    #[test]
1723    fn insert_listener_conflict_with_listener() {
1724        set_logger_for_test();
1725        let mut bound = FakeBoundSocketMap::default();
1726        let mut fake_id_gen = FakeSocketIdGen::default();
1727        let addr = LISTENER_ADDR;
1728
1729        let _: &Listener =
1730            bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1731        assert_matches!(
1732            bound.listeners_mut().try_insert(addr, 'b', Listener(fake_id_gen.next())),
1733            Err((InsertError::Exists, 'b'))
1734        );
1735    }
1736
1737    #[test]
1738    fn insert_listener_conflict_with_shadower() {
1739        set_logger_for_test();
1740        let mut bound = FakeBoundSocketMap::default();
1741        let mut fake_id_gen = FakeSocketIdGen::default();
1742        let addr = LISTENER_ADDR;
1743        let shadows_addr = {
1744            assert_eq!(addr.device, None);
1745            ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
1746        };
1747
1748        let _: &Listener =
1749            bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1750        assert_matches!(
1751            bound.listeners_mut().try_insert(shadows_addr, 'b', Listener(fake_id_gen.next())),
1752            Err((InsertError::ShadowAddrExists, 'b'))
1753        );
1754    }
1755
1756    #[test]
1757    fn insert_conn_conflict_with_listener() {
1758        set_logger_for_test();
1759        let mut bound = FakeBoundSocketMap::default();
1760        let mut fake_id_gen = FakeSocketIdGen::default();
1761        let addr = LISTENER_ADDR;
1762        let shadows_addr = ConnAddr {
1763            device: None,
1764            ip: ConnIpAddr {
1765                local: (addr.ip.addr.unwrap(), addr.ip.identifier),
1766                remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
1767            },
1768        };
1769
1770        let _: &Listener =
1771            bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
1772        assert_matches!(
1773            bound.conns_mut().try_insert(shadows_addr, 'b', Conn(fake_id_gen.next())),
1774            Err((InsertError::ShadowAddrExists, 'b'))
1775        );
1776    }
1777
1778    #[test]
1779    fn insert_and_remove_listener() {
1780        set_logger_for_test();
1781        let mut bound = FakeBoundSocketMap::default();
1782        let mut fake_id_gen = FakeSocketIdGen::default();
1783        let addr = LISTENER_ADDR;
1784
1785        let a = bound
1786            .listeners_mut()
1787            .try_insert(addr, 'x', Listener(fake_id_gen.next()))
1788            .unwrap()
1789            .id()
1790            .clone();
1791        let b = bound
1792            .listeners_mut()
1793            .try_insert(addr, 'x', Listener(fake_id_gen.next()))
1794            .unwrap()
1795            .id()
1796            .clone();
1797        assert_ne!(a, b);
1798
1799        assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
1800        assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
1801    }
1802
1803    #[test]
1804    fn insert_and_remove_conn() {
1805        set_logger_for_test();
1806        let mut bound = FakeBoundSocketMap::default();
1807        let mut fake_id_gen = FakeSocketIdGen::default();
1808        let addr = CONN_ADDR;
1809
1810        let a =
1811            bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
1812        let b =
1813            bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
1814        assert_ne!(a, b);
1815
1816        assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
1817        assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
1818    }
1819
1820    #[test]
1821    fn update_listener_to_shadowed_addr_fails() {
1822        let mut bound = FakeBoundSocketMap::default();
1823        let mut fake_id_gen = FakeSocketIdGen::default();
1824
1825        let first_addr = LISTENER_ADDR;
1826        let second_addr = ListenerAddr {
1827            ip: ListenerIpAddr {
1828                addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
1829                ..LISTENER_ADDR.ip
1830            },
1831            ..LISTENER_ADDR
1832        };
1833        let both_shadow = ListenerAddr {
1834            ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
1835            device: None,
1836        };
1837
1838        let first = bound
1839            .listeners_mut()
1840            .try_insert(first_addr, 'a', Listener(fake_id_gen.next()))
1841            .unwrap()
1842            .id()
1843            .clone();
1844        let second = bound
1845            .listeners_mut()
1846            .try_insert(second_addr, 'b', Listener(fake_id_gen.next()))
1847            .unwrap()
1848            .id()
1849            .clone();
1850
1851        // Moving from (1, "aaa") to (1, None) should fail since it is shadowed
1852        // by (1, "yyy"), and vise versa.
1853        let (ExistsError, entry) = bound
1854            .listeners_mut()
1855            .entry(&second, &second_addr)
1856            .unwrap()
1857            .try_update_addr(both_shadow)
1858            .expect_err("update should fail");
1859
1860        // The entry should correspond to `second`.
1861        assert_eq!(entry.id(), &second);
1862        drop(entry);
1863
1864        let (ExistsError, entry) = bound
1865            .listeners_mut()
1866            .entry(&first, &first_addr)
1867            .unwrap()
1868            .try_update_addr(both_shadow)
1869            .expect_err("update should fail");
1870        assert_eq!(entry.get_addr(), &first_addr);
1871    }
1872
1873    #[test]
1874    fn nonexistent_conn_entry() {
1875        let mut map = FakeBoundSocketMap::default();
1876        let mut fake_id_gen = FakeSocketIdGen::default();
1877        let addr = CONN_ADDR;
1878        let conn_id = map
1879            .conns_mut()
1880            .try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
1881            .expect("failed to insert")
1882            .id()
1883            .clone();
1884        assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
1885
1886        assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
1887    }
1888
1889    #[test]
1890    fn update_conn_sharing() {
1891        let mut map = FakeBoundSocketMap::default();
1892        let mut fake_id_gen = FakeSocketIdGen::default();
1893        let addr = CONN_ADDR;
1894        let mut entry = map
1895            .conns_mut()
1896            .try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
1897            .expect("failed to insert");
1898
1899        entry.try_update_sharing(&'a', 'd').expect("worked");
1900        // Updating sharing is only allowed if there are no other occupants at
1901        // the address.
1902        let mut second_conn = map
1903            .conns_mut()
1904            .try_insert(addr.clone(), 'd', Conn(fake_id_gen.next()))
1905            .expect("can insert");
1906        assert_matches!(second_conn.try_update_sharing(&'d', 'e'), Err(UpdateSharingError));
1907    }
1908}