1use core::convert::Infallible as Never;
6use core::fmt::Debug;
7use core::num::NonZeroU32;
8
9use net_types::ip::{GenericOverIp, Ip, Ipv4, Ipv4SourceAddr, Ipv6, Ipv6SourceAddr, Mtu};
10use net_types::Witness;
11use packet_formats::icmp::{
12 IcmpDestUnreachable, Icmpv4DestUnreachableCode, Icmpv4ParameterProblemCode, Icmpv4RedirectCode,
13 Icmpv4TimeExceededCode, Icmpv6DestUnreachableCode, Icmpv6ParameterProblemCode,
14 Icmpv6TimeExceededCode,
15};
16use packet_formats::ip::IpProtoExt;
17use strum::{EnumCount as _, IntoEnumIterator as _};
18use strum_macros::{EnumCount, EnumIter};
19
20pub trait BroadcastIpExt: Ip {
22 type BroadcastMarker: Debug + Copy + Clone + PartialEq + Eq + Send + Sync + 'static;
25}
26
27impl BroadcastIpExt for Ipv4 {
28 type BroadcastMarker = ();
29}
30
31impl BroadcastIpExt for Ipv6 {
32 type BroadcastMarker = Never;
33}
34
35#[derive(GenericOverIp)]
38#[generic_over_ip(I, Ip)]
39pub struct WrapBroadcastMarker<I: BroadcastIpExt>(pub I::BroadcastMarker);
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45pub struct Mms(NonZeroU32);
46
47impl Mms {
48 pub fn from_mtu<I: IpExt>(mtu: Mtu, options_size: u32) -> Option<Self> {
50 NonZeroU32::new(mtu.get().saturating_sub(I::IP_HEADER_LENGTH.get() + options_size))
51 .map(|mms| Self(mms.min(I::IP_MAX_PAYLOAD_LENGTH)))
52 }
53
54 pub fn get(&self) -> NonZeroU32 {
56 let Self(mms) = *self;
57 mms
58 }
59}
60
61#[derive(Copy, Clone, Debug, PartialEq)]
66#[allow(missing_docs)]
67pub enum Icmpv4ErrorCode {
68 DestUnreachable(Icmpv4DestUnreachableCode, IcmpDestUnreachable),
69 Redirect(Icmpv4RedirectCode),
70 TimeExceeded(Icmpv4TimeExceededCode),
71 ParameterProblem(Icmpv4ParameterProblemCode),
72}
73
74impl<I: IcmpIpExt> GenericOverIp<I> for Icmpv4ErrorCode {
75 type Type = I::ErrorCode;
76}
77
78#[derive(Copy, Clone, Debug, PartialEq)]
83#[allow(missing_docs)]
84pub enum Icmpv6ErrorCode {
85 DestUnreachable(Icmpv6DestUnreachableCode),
86 PacketTooBig(Mtu),
87 TimeExceeded(Icmpv6TimeExceededCode),
88 ParameterProblem(Icmpv6ParameterProblemCode),
89}
90
91impl<I: IcmpIpExt> GenericOverIp<I> for Icmpv6ErrorCode {
92 type Type = I::ErrorCode;
93}
94
95#[derive(Debug, Clone, Copy)]
97pub enum IcmpErrorCode {
98 V4(Icmpv4ErrorCode),
100 V6(Icmpv6ErrorCode),
102}
103
104impl From<Icmpv4ErrorCode> for IcmpErrorCode {
105 fn from(v4_err: Icmpv4ErrorCode) -> Self {
106 IcmpErrorCode::V4(v4_err)
107 }
108}
109
110impl From<Icmpv6ErrorCode> for IcmpErrorCode {
111 fn from(v6_err: Icmpv6ErrorCode) -> Self {
112 IcmpErrorCode::V6(v6_err)
113 }
114}
115
116pub trait IcmpIpExt: packet_formats::ip::IpExt + packet_formats::icmp::IcmpIpExt {
118 type ErrorCode: Debug
121 + Copy
122 + PartialEq
123 + GenericOverIp<Self, Type = Self::ErrorCode>
124 + GenericOverIp<Ipv4, Type = Icmpv4ErrorCode>
125 + GenericOverIp<Ipv6, Type = Icmpv6ErrorCode>
126 + Into<IcmpErrorCode>;
127}
128
129impl IcmpIpExt for Ipv4 {
130 type ErrorCode = Icmpv4ErrorCode;
131}
132
133impl IcmpIpExt for Ipv6 {
134 type ErrorCode = Icmpv6ErrorCode;
135}
136
137pub trait IpTypesIpExt: packet_formats::ip::IpExt {
139 type BroadcastMarker: Debug + Copy + Clone + PartialEq + Eq;
142}
143
144impl IpTypesIpExt for Ipv4 {
145 type BroadcastMarker = ();
146}
147
148impl IpTypesIpExt for Ipv6 {
149 type BroadcastMarker = Never;
150}
151
152pub trait IpExt: packet_formats::ip::IpExt + IcmpIpExt + BroadcastIpExt + IpProtoExt {
154 type RecvSrcAddr: Witness<Self::Addr> + Copy + Clone;
159 const IP_HEADER_LENGTH: NonZeroU32;
161 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32;
163}
164
165impl IpExt for Ipv4 {
166 type RecvSrcAddr = Ipv4SourceAddr;
167 const IP_HEADER_LENGTH: NonZeroU32 =
168 NonZeroU32::new(packet_formats::ipv4::HDR_PREFIX_LEN as u32).unwrap();
169 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32 =
170 NonZeroU32::new(u16::MAX as u32 - Self::IP_HEADER_LENGTH.get()).unwrap();
171}
172
173impl IpExt for Ipv6 {
174 type RecvSrcAddr = Ipv6SourceAddr;
175 const IP_HEADER_LENGTH: NonZeroU32 =
176 NonZeroU32::new(packet_formats::ipv6::IPV6_FIXED_HDR_LEN as u32).unwrap();
177 const IP_MAX_PAYLOAD_LENGTH: NonZeroU32 = NonZeroU32::new(u16::MAX as u32).unwrap();
178}
179
180#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
185pub struct Mark(pub Option<u32>);
186
187impl From<Option<u32>> for Mark {
188 fn from(m: Option<u32>) -> Self {
189 Self(m)
190 }
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Eq, EnumCount, EnumIter)]
195pub enum MarkDomain {
196 Mark1,
198 Mark2,
200}
201
202const MARK_DOMAINS: usize = MarkDomain::COUNT;
203
204#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
206pub struct MarkStorage<T>([T; MARK_DOMAINS]);
207
208impl<T> MarkStorage<T> {
209 pub fn new<U, IntoIter>(iter: IntoIter) -> Self
216 where
217 IntoIter: IntoIterator<Item = (MarkDomain, U)>,
218 T: From<Option<U>> + Copy,
219 {
220 let mut storage = MarkStorage([None.into(); MARK_DOMAINS]);
221 for (domain, value) in iter.into_iter() {
222 *storage.get_mut(domain) = Some(value).into();
223 }
224 storage
225 }
226
227 fn domain_as_index(domain: MarkDomain) -> usize {
228 match domain {
229 MarkDomain::Mark1 => 0,
230 MarkDomain::Mark2 => 1,
231 }
232 }
233
234 pub fn get(&self, domain: MarkDomain) -> &T {
236 let Self(inner) = self;
237 &inner[Self::domain_as_index(domain)]
238 }
239
240 pub fn get_mut(&mut self, domain: MarkDomain) -> &mut T {
242 let Self(inner) = self;
243 &mut inner[Self::domain_as_index(domain)]
244 }
245
246 pub fn iter(&self) -> impl Iterator<Item = (MarkDomain, &T)> {
248 let Self(inner) = self;
249 MarkDomain::iter().map(move |domain| (domain, &inner[Self::domain_as_index(domain)]))
250 }
251
252 pub fn zip_with<'a, U>(
254 &'a self,
255 MarkStorage(other): &'a MarkStorage<U>,
256 ) -> impl Iterator<Item = (MarkDomain, &'a T, &'a U)> + 'a {
257 let Self(this) = self;
258 MarkDomain::iter().zip(this.iter().zip(other.iter())).map(|(d, (t, u))| (d, t, u))
259 }
260}
261
262pub type Marks = MarkStorage<Mark>;
264
265impl Marks {
266 pub const UNMARKED: Self = MarkStorage([Mark(None), Mark(None)]);
268}