netstack3_ip/
fragmentation.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! IP fragmentation support.
6
7use core::borrow::Borrow;
8use core::fmt::Debug;
9
10use alloc::vec::Vec;
11
12use explicit::UnreachableExt;
13use net_types::ip::{GenericOverIp, Ip, IpInvariant, Ipv4, Ipv6, Mtu};
14use netstack3_base::{Counter, RngContext, Uninstantiable};
15use netstack3_filter::ForwardedPacket;
16use packet::{
17    Buf, BufferMut, EmptyBuf, FragmentedBuffer as _, InnerPacketBuilder as _, Nested,
18    PacketBuilder, PacketConstraints, ParsablePacket, SerializeError, Serializer,
19};
20use packet_formats::ip::FragmentOffset;
21use packet_formats::ipv4::options::Ipv4Option;
22use packet_formats::ipv4::{
23    Ipv4Header as _, Ipv4PacketBuilder, Ipv4PacketBuilderWithOptions, Ipv4PacketRaw,
24};
25use packet_formats::ipv6::{
26    Ipv6PacketBuilder, Ipv6PacketBuilderBeforeFragment, Ipv6PacketBuilderWithFragmentHeader,
27};
28use rand::Rng;
29
30/// The maximum fragment offset that can be expressed in both IPv4 and IPv6
31/// headers. The maximum transmissible body is this value plus the maximum bytes
32/// transmitted in the last fragment.
33// We have 13 bits to express an 8-byte multiple offset.
34const MAX_FRAGMENT_OFFSET: usize = ((1 << 13) - 1) * 8;
35
36pub trait FragmentationIpExt:
37    packet_formats::ip::IpExt<PacketBuilder: AsFragmentableIpPacketBuilder<Self>>
38{
39    /// The IP packet builder for a forwarded packet.
40    type ForwardedFragmentBuilder: FragmentableIpPacketBuilder<Self>;
41    /// An identifier generated at fragmentation time.
42    type FragmentationId: Copy + Debug;
43}
44
45impl FragmentationIpExt for Ipv4 {
46    type ForwardedFragmentBuilder = ForwardedIpv4PacketBuilder;
47    type FragmentationId = ();
48}
49
50impl FragmentationIpExt for Ipv6 {
51    // IPv6 never fragments forwarded packets, only the source node may
52    // fragment.
53    type ForwardedFragmentBuilder = Uninstantiable;
54    type FragmentationId = u32;
55}
56
57/// Fragmentation errors
58#[derive(Debug, Eq, PartialEq, GenericOverIp)]
59#[generic_over_ip()]
60pub enum FragmentationError {
61    /// Fragmentation not allowed.
62    NotAllowed,
63    /// MTU is too small, headers don't fit.
64    MtuTooSmall,
65    /// Body is too long to be fragmented.
66    BodyTooLong,
67    /// Inner serializer reported a size limited exceeded.
68    SizeLimitExceeded,
69}
70
71/// A [`Serializer`] capable of splitting itself into a packet builder and a
72/// pre-serialized body for fragmentation.
73// TODO(https://fxbug.dev/42148826): Ideally we'd be able to generate fragments
74// without requiring the IP body to be dumped into a Vec first. Update this when
75// that support is available in packet and packet_formats.
76pub trait FragmentableIpSerializer<I: FragmentationIpExt>: Serializer {
77    /// The builder for each fragment.
78    type Builder<'a>: FragmentableIpPacketBuilder<I>
79    where
80        Self: 'a;
81    /// The body to be fragmented.
82    ///
83    /// Note that this API is not attempting to reuse buffers in any way. There
84    /// are improvements that can be made here to perhaps avoid allocations and
85    /// yield out reusable bodies, but we're constrained to taking references to
86    /// the serializers here to avoid changing the body which could interfere
87    /// with the higher layers on errors.
88    type Body<'a>: AsRef<[u8]>
89    where
90        Self: 'a;
91
92    /// Returns the inner packet builder for this IP version and a serialized
93    /// body.
94    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError>;
95}
96
97impl<I, S, B> FragmentableIpSerializer<I> for Nested<S, B>
98where
99    I: FragmentationIpExt,
100    S: Serializer,
101    B: AsFragmentableIpPacketBuilder<I> + PacketBuilder,
102{
103    type Builder<'a>
104        = B::Builder<'a>
105    where
106        Self: 'a;
107
108    type Body<'a>
109        = Buf<Vec<u8>>
110    where
111        Self: 'a;
112
113    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
114        let builder = self.outer().try_as_fragmentable()?;
115        let body = self
116            .inner()
117            .serialize_new_buf(PacketConstraints::UNCONSTRAINED, packet::new_buf_vec)
118            .map_err(|e| match e {
119                SerializeError::SizeLimitExceeded => FragmentationError::SizeLimitExceeded,
120            })?;
121        Ok((builder, body))
122    }
123}
124
125#[derive(Debug, Eq, PartialEq, Copy, Clone)]
126pub enum FragmentPosition {
127    First,
128    Middle,
129    Last,
130}
131
132/// The header size constraints for `FragmentableIpPacketBuilder`
133/// implementations.
134pub struct HeaderSizes {
135    first: usize,
136    remaining: usize,
137}
138
139/// A type that may be transformed into a fragmentable ip packet builder.
140pub trait AsFragmentableIpPacketBuilder<I: FragmentationIpExt> {
141    /// The fragmentable packet builder that can be constructed from this type.
142    type Builder<'a>: FragmentableIpPacketBuilder<I>
143    where
144        Self: 'a;
145
146    /// Attempts to extract a `FragmentableIpPacketBuilder` implementation from
147    /// this type, returning an error if it can't be fragmented.
148    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError>;
149}
150
151/// An IP packet builder that can create IP fragments.
152pub trait FragmentableIpPacketBuilder<I: FragmentationIpExt> {
153    /// Returns the portion of the MTU occupied by IP headers.
154    fn header_sizes(&self) -> HeaderSizes;
155
156    /// Returns a builder for fragment at offset `offset`.
157    ///
158    /// `position` carries information if this is the first or last segment, which
159    /// require special logic.
160    fn builder_at(
161        &self,
162        offset: FragmentOffset,
163        position: FragmentPosition,
164        identifier: I::FragmentationId,
165    ) -> impl PacketBuilder + '_;
166}
167
168/// Blanket impl for everything that has a shape to fit in `Ipv4FragmentBuilder`
169/// as a provider for fragmentation.
170impl<B> AsFragmentableIpPacketBuilder<Ipv4> for B
171where
172    B: InnerIpv4FragmentBuilder,
173{
174    type Builder<'a>
175        = Ipv4FragmentBuilder<'a, Self>
176    where
177        Self: 'a;
178
179    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
180        can_fragment_ipv4(self.prefix())?;
181        Ok(Ipv4FragmentBuilder { builder: self })
182    }
183}
184
185/// A trait marking all the IPv4 builder types that can be fragmented with
186/// [`Ipv4FragmentBuilder`].
187trait InnerIpv4FragmentBuilder: PacketBuilder {
188    fn prefix(&self) -> &Ipv4PacketBuilder;
189    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder;
190    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder;
191    fn header_sizes(&self) -> HeaderSizes;
192}
193
194impl InnerIpv4FragmentBuilder for Ipv4PacketBuilder {
195    fn prefix(&self) -> &Ipv4PacketBuilder {
196        self
197    }
198
199    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
200        self
201    }
202
203    fn clone_for_fragment(&self, _position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
204        self.clone()
205    }
206
207    fn header_sizes(&self) -> HeaderSizes {
208        let size = self.constraints().header_len();
209        HeaderSizes { first: size, remaining: size }
210    }
211}
212
213impl<'a, I> InnerIpv4FragmentBuilder for Ipv4PacketBuilderWithOptions<'a, I>
214where
215    I: Iterator<Item: Borrow<Ipv4Option<'a>>> + Clone,
216{
217    fn prefix(&self) -> &Ipv4PacketBuilder {
218        self.prefix_builder()
219    }
220
221    fn prefix_mut(&mut self) -> &mut Ipv4PacketBuilder {
222        self.prefix_builder_mut()
223    }
224
225    fn clone_for_fragment(&self, position: FragmentPosition) -> impl InnerIpv4FragmentBuilder {
226        self.clone().with_fragment_options(position == FragmentPosition::First)
227    }
228
229    fn header_sizes(&self) -> HeaderSizes {
230        let first = self.constraints().header_len();
231        let remaining = self.clone().with_fragment_options(false).constraints().header_len();
232        HeaderSizes { first, remaining }
233    }
234}
235
236pub struct Ipv4FragmentBuilder<'a, B> {
237    builder: &'a B,
238}
239
240impl<'a, B> FragmentableIpPacketBuilder<Ipv4> for Ipv4FragmentBuilder<'a, B>
241where
242    B: InnerIpv4FragmentBuilder,
243{
244    fn header_sizes(&self) -> HeaderSizes {
245        self.builder.header_sizes()
246    }
247
248    fn builder_at(
249        &self,
250        offset: FragmentOffset,
251        position: FragmentPosition,
252        (): (),
253    ) -> impl PacketBuilder + '_ {
254        let mut builder = self.builder.clone_for_fragment(position);
255        set_ipv4_fragment(builder.prefix_mut(), offset, position);
256        builder
257    }
258}
259
260impl<B> AsFragmentableIpPacketBuilder<Ipv6> for B
261where
262    for<'a> &'a B: Ipv6PacketBuilderBeforeFragment,
263{
264    type Builder<'a>
265        = Ipv6FragmentBuilder<'a, Self>
266    where
267        Self: 'a;
268
269    fn try_as_fragmentable(&self) -> Result<Self::Builder<'_>, FragmentationError> {
270        Ok(Ipv6FragmentBuilder { builder: self })
271    }
272}
273
274pub struct Ipv6FragmentBuilder<'a, B> {
275    builder: &'a B,
276}
277
278impl<'a, B> FragmentableIpPacketBuilder<Ipv6> for Ipv6FragmentBuilder<'a, B>
279where
280    &'a B: Ipv6PacketBuilderBeforeFragment,
281{
282    fn header_sizes(&self) -> HeaderSizes {
283        // NB: We currently only support headers that need to be in all
284        // fragments, so we only need to calculate once. We might need to change
285        // the trait shape if that changes.
286        let header_len =
287            Ipv6PacketBuilderWithFragmentHeader::new(self.builder, FragmentOffset::ZERO, false, 0)
288                .constraints()
289                .header_len();
290        HeaderSizes { first: header_len, remaining: header_len }
291    }
292
293    fn builder_at(
294        &self,
295        offset: FragmentOffset,
296        position: FragmentPosition,
297        identifier: u32,
298    ) -> impl PacketBuilder + '_ {
299        Ipv6PacketBuilderWithFragmentHeader::new(
300            self.builder,
301            offset,
302            position != FragmentPosition::Last,
303            identifier,
304        )
305    }
306}
307
308impl<I, B> FragmentableIpSerializer<I> for ForwardedPacket<I, B>
309where
310    I: FragmentationIpExt,
311    B: BufferMut,
312{
313    type Builder<'a>
314        = I::ForwardedFragmentBuilder
315    where
316        Self: 'a;
317    type Body<'a>
318        = Buf<&'a [u8]>
319    where
320        Self: 'a;
321
322    fn builder_and_body(&self) -> Result<(Self::Builder<'_>, Self::Body<'_>), FragmentationError> {
323        #[derive(GenericOverIp)]
324        #[generic_over_ip(I, Ip)]
325        struct Out<I: FragmentationIpExt>(I::ForwardedFragmentBuilder);
326        I::map_ip::<_, Result<(Out<I>, IpInvariant<Buf<&[u8]>>), FragmentationError>>(
327            self,
328            |forwarded| {
329                // Parse an IPv4 packet from the forwarded packet. We can assert
330                // strongly on all of the parsing here because ForwardedPacket
331                // is guaranteed to have been parsed by the IP stack already.
332                let mut buffer = forwarded.buffer().as_ref();
333                let packet = Ipv4PacketRaw::parse(&mut buffer, ())
334                    .expect("ForwardedPacket must be parseable");
335                let builder = packet.builder();
336                can_fragment_ipv4(&builder)?;
337                let raw_options_bytes = packet
338                    .options()
339                    .as_ref()
340                    .complete()
341                    .expect("unexpected incomplete IP header")
342                    .bytes();
343
344                let mut raw_options = Buf::new(
345                    [0u8; packet_formats::ipv4::MAX_OPTIONS_LEN],
346                    ..raw_options_bytes.len(),
347                );
348                raw_options.as_mut().copy_from_slice(raw_options_bytes);
349                let body = Buf::new(
350                    packet.into_body().complete().expect("unexpected incomplete IP body"),
351                    ..,
352                );
353                Ok((Out(ForwardedIpv4PacketBuilder { builder, raw_options }), IpInvariant(body)))
354            },
355            |_forwarded| Err(FragmentationError::NotAllowed),
356        )
357        .map(|(Out(builder), IpInvariant(body))| (builder, body))
358    }
359}
360
361pub struct ForwardedIpv4PacketBuilder {
362    builder: Ipv4PacketBuilder,
363    raw_options: Buf<[u8; packet_formats::ipv4::MAX_OPTIONS_LEN]>,
364}
365
366impl FragmentableIpPacketBuilder<Ipv4> for ForwardedIpv4PacketBuilder {
367    fn header_sizes(&self) -> HeaderSizes {
368        let Self { builder, raw_options } = self;
369        if raw_options.is_empty() {
370            builder.header_sizes()
371        } else {
372            let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
373                .expect("must hold valid options");
374            Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.iter())
375                .header_sizes()
376        }
377    }
378
379    fn builder_at(
380        &self,
381        offset: FragmentOffset,
382        position: FragmentPosition,
383        (): (),
384    ) -> impl PacketBuilder + '_ {
385        let Self { builder, raw_options } = self;
386        let mut builder = builder.clone();
387        set_ipv4_fragment(&mut builder, offset, position);
388        let options = packet_formats::ipv4::Options::parse(raw_options.as_ref())
389            .expect("must hold valid options");
390        Ipv4PacketBuilderWithOptions::new_with_records_iter(builder.clone(), options.into_iter())
391            .with_fragment_options(position == FragmentPosition::First)
392    }
393}
394
395impl<I: FragmentationIpExt> FragmentableIpPacketBuilder<I> for Uninstantiable {
396    fn header_sizes(&self) -> HeaderSizes {
397        self.uninstantiable_unreachable()
398    }
399
400    fn builder_at(
401        &self,
402        _offset: FragmentOffset,
403        _position: FragmentPosition,
404        _identifier: I::FragmentationId,
405    ) -> impl PacketBuilder + '_ {
406        self.uninstantiable_unreachable::<Ipv6PacketBuilder>()
407    }
408}
409
410/// Abstracts fragment ID generation for [`IpFragmenter`].
411///
412/// A blanket impl is provided for [`RngContext`] implementers, so the bindings
413/// context can be used to generate random IDs for IPv6.
414pub(crate) trait FragmentationIdGenContext {
415    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId;
416}
417
418#[derive(GenericOverIp)]
419#[generic_over_ip(I, Ip)]
420struct WrapFragmentationId<I: FragmentationIpExt>(I::FragmentationId);
421
422impl<BC> FragmentationIdGenContext for BC
423where
424    BC: RngContext,
425{
426    fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
427        let WrapFragmentationId(identifier) = I::map_ip_out(
428            self,
429            |_| WrapFragmentationId(()),
430            |rng| {
431                // TODO(https://fxbug.dev/373428005): Perhaps we can do better
432                // than a simple RNG. This is currently copying what netstack2
433                // does. RFC 7739 calls out different strategies for fragment
434                // IDs in IPv6. We currently pick an option that is not doing a
435                // best effort to avoid collisions, but it guarantees that
436                // fragment IDs can't be tracked as an attack vector.
437                // We avoid a zero fragment ID like netstack2 does.
438                WrapFragmentationId(rng.rng().gen_range(1..=u32::MAX))
439            },
440        );
441        identifier
442    }
443}
444
445pub(crate) struct IpFragmenter<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I> + 'a> {
446    builder: S::Builder<'a>,
447    body: S::Body<'a>,
448    consumed: usize,
449    max_fragment_body_first: usize,
450    max_fragment_body_remaining: usize,
451    identifier: I::FragmentationId,
452}
453
454/// Trait to allow [`IpFragmenter::next`] to capture all the required lifetimes.
455// TODO(https://github.com/rust-lang/rust/issues/123432): Replace with `impl
456// use<'a, 'b>` when available in tree.
457pub trait Capture<'a, 'b> {}
458impl<'a, 'b, O> Capture<'a, 'b> for O
459where
460    O: 'b,
461    'a: 'b,
462{
463}
464
465/// Returns the biggest fragment body that can fit in `mtu` with a given IP
466/// `header` size.
467///
468/// The returned body size is rounded down to the nearest multiple of 8 to fit
469/// the IP header representation of fragment offsets.
470fn maximum_fragment_body_with_header_and_mtu(
471    mtu: Mtu,
472    header: usize,
473) -> Result<usize, FragmentationError> {
474    let v = usize::from(mtu).checked_sub(header).ok_or(FragmentationError::MtuTooSmall)?;
475    // Mask the final 8 bits since fragment offset is expressed in units
476    // of 8 octets for both IP versions.
477    let v = v & !0x07usize;
478
479    if v == 0 {
480        // Can't fragment if we don't have at least a single 8 octet
481        // of space.
482        return Err(FragmentationError::MtuTooSmall);
483    }
484    Ok(v)
485}
486
487impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
488    /// Creates a new `IpFragmenter` with some `serializer` respecting a maximum
489    /// IP layer `mtu`.
490    pub(crate) fn new<C: FragmentationIdGenContext>(
491        id_ctx: &mut C,
492        serializer: &'a S,
493        mtu: Mtu,
494    ) -> Result<Self, FragmentationError> {
495        let (builder, body) = serializer.builder_and_body()?;
496        let HeaderSizes { first, remaining } = builder.header_sizes();
497        let max_fragment_body_first = maximum_fragment_body_with_header_and_mtu(mtu, first)?;
498        let max_fragment_body_remaining =
499            maximum_fragment_body_with_header_and_mtu(mtu, remaining)?;
500
501        if body.as_ref().len() > MAX_FRAGMENT_OFFSET + max_fragment_body_remaining {
502            return Err(FragmentationError::BodyTooLong);
503        }
504
505        let identifier = id_ctx.generate_id::<I>();
506
507        Ok(Self {
508            builder,
509            body,
510            consumed: 0,
511            max_fragment_body_first,
512            max_fragment_body_remaining,
513            identifier,
514        })
515    }
516
517    /// Returns the serializer for the next segment and a boolean indicating
518    /// whether more fragments are pending, or `None` if all segments have been
519    /// produced.
520    ///
521    /// # Panics
522    ///
523    /// Panics if fragmentation is not necessary for the `serializer` that
524    /// created this `IpFragmenter`.
525    pub(crate) fn next(
526        &mut self,
527    ) -> Option<(impl Serializer<Buffer = EmptyBuf> + Capture<'a, '_>, bool)> {
528        let Self {
529            builder,
530            body,
531            consumed,
532            max_fragment_body_first,
533            max_fragment_body_remaining,
534            identifier,
535        } = self;
536        let body = &AsRef::as_ref(body)[*consumed..];
537        if body.is_empty() {
538            return None;
539        }
540        let first = *consumed == 0;
541        let max_fragment_body =
542            if first { max_fragment_body_first } else { max_fragment_body_remaining };
543        let take = body.len().min(*max_fragment_body);
544        let last = take == body.len();
545        let position = match (first, last) {
546            (true, true) => {
547                panic!("unnecessary fragmentation");
548            }
549            (true, false) => FragmentPosition::First,
550            (false, false) => FragmentPosition::Middle,
551            (false, true) => FragmentPosition::Last,
552        };
553        // Upon construction IpFragmenter verifies that we won't go over the
554        // maximum offset since the body length is known.
555        let fragment_offset = u16::try_from(*consumed).expect("fragment offset too large");
556        // Care is taken above to always take 8-byte multiples to be added to
557        // consumed, so we should always have a good representation for
558        // FragmentOffset.
559        let fragment_offset =
560            FragmentOffset::new_with_bytes(fragment_offset).expect("invalid offset");
561        let fragment_builder = builder.builder_at(fragment_offset, position, *identifier);
562        let end = *consumed + take;
563        let has_more = body.len() > take;
564        let fragment_body = &body[..take];
565        *consumed = end;
566        Some((fragment_body.into_serializer().encapsulate(fragment_builder), has_more))
567    }
568}
569
570fn can_fragment_ipv4(builder: &Ipv4PacketBuilder) -> Result<(), FragmentationError> {
571    if builder.read_df_flag() {
572        return Err(FragmentationError::NotAllowed);
573    }
574    Ok(())
575}
576
577fn set_ipv4_fragment(
578    builder: &mut Ipv4PacketBuilder,
579    offset: FragmentOffset,
580    position: FragmentPosition,
581) {
582    builder.mf_flag(position != FragmentPosition::Last);
583    builder.fragment_offset(offset);
584}
585
586/// Counters kept by the IP stack pertaining to fragmentation.
587#[derive(Default, Debug)]
588#[cfg_attr(any(test, feature = "testutils"), derive(PartialEq))]
589pub struct FragmentationCounters<C = Counter> {
590    /// The number of IP frames requiring fragmentation on egress.
591    pub fragmentation_required: C,
592    /// The total number of fragments sent.
593    pub fragments: C,
594    /// The number of `NotAllowed` errors encountered.
595    pub error_not_allowed: C,
596    /// The number of `MtuTooSmall` errors encountered.
597    pub error_mtu_too_small: C,
598    /// The number of `BodyTooLong` errors encountered.
599    pub error_body_too_long: C,
600    /// The number of `SizeLimitExceeded` errors encountered.
601    pub error_inner_size_limit_exceeded: C,
602    /// Counts the number of times fragmentation was short-circuited due to a
603    /// fragment serialization error.
604    pub error_fragmented_serializer: C,
605}
606
607impl FragmentationCounters {
608    pub(crate) fn error_counter(&self, error: &FragmentationError) -> &Counter {
609        match error {
610            FragmentationError::NotAllowed => &self.error_not_allowed,
611            FragmentationError::MtuTooSmall => &self.error_mtu_too_small,
612            FragmentationError::BodyTooLong => &self.error_body_too_long,
613            FragmentationError::SizeLimitExceeded => &self.error_inner_size_limit_exceeded,
614        }
615    }
616}
617
618#[cfg(any(test, feature = "testutils"))]
619impl From<&FragmentationCounters> for FragmentationCounters<u64> {
620    fn from(counters: &FragmentationCounters) -> FragmentationCounters<u64> {
621        let FragmentationCounters {
622            fragmentation_required,
623            fragments,
624            error_not_allowed,
625            error_mtu_too_small,
626            error_body_too_long,
627            error_inner_size_limit_exceeded,
628            error_fragmented_serializer,
629        } = counters;
630        FragmentationCounters {
631            fragmentation_required: fragmentation_required.get(),
632            fragments: fragments.get(),
633            error_not_allowed: error_not_allowed.get(),
634            error_mtu_too_small: error_mtu_too_small.get(),
635            error_body_too_long: error_body_too_long.get(),
636            error_inner_size_limit_exceeded: error_inner_size_limit_exceeded.get(),
637            error_fragmented_serializer: error_fragmented_serializer.get(),
638        }
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    use assert_matches::assert_matches;
647    use net_types::Witness as _;
648    use netstack3_base::testutil::{TEST_ADDRS_V4, TEST_ADDRS_V6};
649    use netstack3_filter::FilterIpExt;
650    use packet::{Buffer, BufferView, GrowBuffer};
651    use packet_formats::ip::IpProto;
652    use packet_formats::ipv4::Ipv4Packet;
653    use packet_formats::ipv6::ext_hdrs::Ipv6ExtensionHeaderData;
654    use packet_formats::ipv6::{Ipv6Header, Ipv6Packet};
655    use test_case::test_case;
656
657    const TEST_MTU: Mtu = Ipv6::MINIMUM_LINK_MTU;
658
659    fn gen_body(len: usize) -> Vec<u8> {
660        // Cycle bytes until 251 which is the largest prime that can fit in a
661        // u8. Unlikely this aligns poorly and hides fragmentation bugs.
662        (0u8..=251).cycle().take(len).collect::<Vec<u8>>()
663    }
664
665    impl<'a, I: FragmentationIpExt, S: FragmentableIpSerializer<I>> IpFragmenter<'a, I, S> {
666        fn next_serialized(&mut self) -> Buf<Vec<u8>> {
667            self.next()
668                .expect("no more fragments")
669                .0
670                .serialize_vec_outer()
671                .map_err(|(err, _serializer)| err)
672                .unwrap()
673                .unwrap_b()
674        }
675    }
676
677    trait FragmentationTestEnv<I: FragmentationIpExt> {
678        fn new_serializer<'a>(
679            &self,
680            body: &'a [u8],
681        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a;
682        fn check_fragment(
683            &self,
684            fragment: &mut Buf<Vec<u8>>,
685            position: FragmentPosition,
686            offset: usize,
687        );
688    }
689
690    #[derive(Default)]
691    struct Ipv4TestEnv {
692        dont_frag: bool,
693    }
694
695    impl Ipv4TestEnv {
696        const fn dont_frag() -> Self {
697            Self { dont_frag: true }
698        }
699    }
700
701    const IPV4_ID: u16 = 0x1234;
702    fn new_ipv4_packet_builder(dont_frag: bool) -> Ipv4PacketBuilder {
703        let mut builder = Ipv4PacketBuilder::new(
704            TEST_ADDRS_V4.local_ip,
705            TEST_ADDRS_V4.remote_ip,
706            1,
707            IpProto::Udp.into(),
708        );
709        builder.id(IPV4_ID);
710        builder.df_flag(dont_frag);
711        builder
712    }
713
714    fn parse_and_check_ipv4_packet(
715        fragment: &mut Buf<Vec<u8>>,
716        position: FragmentPosition,
717        offset: usize,
718    ) -> Ipv4Packet<&[u8]> {
719        let packet = Ipv4Packet::parse(fragment.buffer_view(), ()).expect("parse fragment");
720        assert_eq!(packet.src_ip(), TEST_ADDRS_V4.local_ip.get());
721        assert_eq!(packet.dst_ip(), TEST_ADDRS_V4.remote_ip.get());
722        assert_eq!(packet.ttl(), 1);
723        assert_eq!(packet.id(), IPV4_ID);
724        assert_eq!(packet.proto(), IpProto::Udp.into());
725        assert_eq!(packet.mf_flag(), position != FragmentPosition::Last);
726        assert_eq!(usize::from(packet.fragment_offset().into_bytes()), offset);
727        packet
728    }
729
730    impl FragmentationTestEnv<Ipv4> for Ipv4TestEnv {
731        fn new_serializer<'a>(
732            &self,
733            body: &'a [u8],
734        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
735            let Self { dont_frag } = self;
736            body.into_serializer().encapsulate(new_ipv4_packet_builder(*dont_frag))
737        }
738
739        fn check_fragment(
740            &self,
741            fragment: &mut Buf<Vec<u8>>,
742            position: FragmentPosition,
743            offset: usize,
744        ) {
745            let _ = parse_and_check_ipv4_packet(fragment, position, offset);
746        }
747    }
748
749    #[derive(Default)]
750    struct Ipv4WithOptionsTestEnv(Ipv4TestEnv);
751
752    // The MSB of an option kind determines if it should be copied.
753    const FAKE_OPTION_COPIED_KIND: u8 = 255;
754    const FAKE_OPTION_COPIED: [u8; 1] = [255];
755    const FAKE_OPTION_NOT_COPIED_KIND: u8 = 127;
756    const FAKE_OPTION_NOT_COPIED: [u8; 1] = [127];
757
758    impl FragmentationTestEnv<Ipv4> for Ipv4WithOptionsTestEnv {
759        fn new_serializer<'a>(
760            &self,
761            body: &'a [u8],
762        ) -> impl FragmentableIpSerializer<Ipv4, Buffer: Buffer> + 'a {
763            let Self(Ipv4TestEnv { dont_frag }) = self;
764            body.into_serializer().encapsulate(
765                Ipv4PacketBuilderWithOptions::new(
766                    new_ipv4_packet_builder(*dont_frag),
767                    [
768                        Ipv4Option::Unrecognized {
769                            kind: FAKE_OPTION_COPIED_KIND,
770                            data: &FAKE_OPTION_COPIED[..],
771                        },
772                        Ipv4Option::Unrecognized {
773                            kind: FAKE_OPTION_NOT_COPIED_KIND,
774                            data: &FAKE_OPTION_NOT_COPIED[..],
775                        },
776                    ],
777                )
778                .unwrap(),
779            )
780        }
781
782        fn check_fragment(
783            &self,
784            fragment: &mut Buf<Vec<u8>>,
785            position: FragmentPosition,
786            offset: usize,
787        ) {
788            let packet = parse_and_check_ipv4_packet(fragment, position, offset);
789            let (copied, not_copied) = packet.iter_options().fold(
790                (false, false),
791                |(mut copied, mut not_copied), option| {
792                    let (kind, data) = assert_matches!(option,
793                        Ipv4Option::Unrecognized{ kind, data } => (kind, data)
794                    );
795                    assert_eq!(data.len(), 1);
796                    assert_eq!(data[0], kind);
797                    let seen = match kind {
798                        FAKE_OPTION_COPIED_KIND => &mut copied,
799                        FAKE_OPTION_NOT_COPIED_KIND => &mut not_copied,
800                        k => panic!("unexpected option {k}"),
801                    };
802                    assert_eq!(core::mem::replace(seen, true), false);
803                    (copied, not_copied)
804                },
805            );
806            assert_eq!(copied, true, "must be copied on all fragments {position:?}");
807            assert_eq!(
808                not_copied,
809                position == FragmentPosition::First,
810                "must only be in first fragment {position:?}"
811            );
812        }
813    }
814
815    struct ForwardingTestEnv<E>(E);
816    impl<I: FragmentationIpExt + FilterIpExt, E: FragmentationTestEnv<I>> FragmentationTestEnv<I>
817        for ForwardingTestEnv<E>
818    {
819        fn new_serializer<'a>(
820            &self,
821            body: &'a [u8],
822        ) -> impl FragmentableIpSerializer<I, Buffer: Buffer> + 'a {
823            use packet_formats::ip::IpPacket as _;
824            let Self(inner) = self;
825            let mut buffer = inner
826                .new_serializer(body)
827                .serialize_outer(packet::NoReuseBufferProvider(packet::new_buf_vec))
828                .map_err(|(err, _)| err)
829                .unwrap();
830            let packet =
831                <I::Packet<_> as ParsablePacket<_, _>>::parse(buffer.buffer_view(), ()).unwrap();
832            let src_addr = packet.src_ip();
833            let dst_addr = packet.dst_ip();
834            let proto = packet.proto();
835            let meta = packet.parse_metadata();
836            drop(packet);
837            ForwardedPacket::new(src_addr, dst_addr, proto, meta, buffer)
838        }
839        fn check_fragment(
840            &self,
841            fragment: &mut Buf<Vec<u8>>,
842            position: FragmentPosition,
843            offset: usize,
844        ) {
845            let Self(inner) = self;
846            inner.check_fragment(fragment, position, offset)
847        }
848    }
849
850    struct Ipv6TestEnv;
851
852    const IPV6_ID: u32 = 0x1234ABCD;
853
854    impl FragmentationTestEnv<Ipv6> for Ipv6TestEnv {
855        fn new_serializer<'a>(
856            &self,
857            body: &'a [u8],
858        ) -> impl FragmentableIpSerializer<Ipv6, Buffer: Buffer> + 'a {
859            body.into_serializer().encapsulate(Ipv6PacketBuilder::new(
860                TEST_ADDRS_V6.local_ip,
861                TEST_ADDRS_V6.remote_ip,
862                1,
863                IpProto::Udp.into(),
864            ))
865        }
866
867        fn check_fragment(
868            &self,
869            fragment: &mut Buf<Vec<u8>>,
870            position: FragmentPosition,
871            offset: usize,
872        ) {
873            let packet = Ipv6Packet::parse(fragment.buffer_view(), ()).unwrap();
874            assert_eq!(packet.src_ip(), TEST_ADDRS_V6.local_ip.get());
875            assert_eq!(packet.dst_ip(), TEST_ADDRS_V6.remote_ip.get());
876            assert_eq!(packet.hop_limit(), 1);
877            assert_eq!(packet.proto(), IpProto::Udp.into());
878            let fragment = packet
879                .iter_extension_hdrs()
880                .find_map(|h| match h.into_data() {
881                    Ipv6ExtensionHeaderData::Fragment { fragment_data } => Some(fragment_data),
882                    _ => None,
883                })
884                .expect("no fragment header");
885            assert_eq!(fragment.identification(), IPV6_ID);
886            assert_eq!(usize::from(fragment.fragment_offset().into_bytes()), offset);
887            assert_eq!(fragment.m_flag(), position != FragmentPosition::Last);
888        }
889    }
890
891    struct FixedIdContext;
892    impl FragmentationIdGenContext for FixedIdContext {
893        fn generate_id<I: FragmentationIpExt>(&mut self) -> I::FragmentationId {
894            let WrapFragmentationId(id) =
895                I::map_ip_out((), |()| WrapFragmentationId(()), |()| WrapFragmentationId(IPV6_ID));
896            id
897        }
898    }
899
900    #[test_case::test_matrix(
901        [
902            Ipv4TestEnv::default(),
903            Ipv4WithOptionsTestEnv::default(),
904            ForwardingTestEnv(Ipv4TestEnv::default()),
905            ForwardingTestEnv(Ipv4WithOptionsTestEnv::default()),
906            Ipv6TestEnv,
907        ],
908        0..=2
909    )]
910    fn fragment<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(
911        env: E,
912        middle_fragments: usize,
913    ) {
914        // NB: We're using the fact that MTU is larger than the header sizes
915        // here to end up obtaining the right number of middle fragments as
916        // expected. This makes this test sensitive to the relation between the
917        // picked MTU and the header sizes for the multiple serializers.
918        let full_body = gen_body(usize::from(TEST_MTU) * (1 + middle_fragments));
919        let mut body_view = Buf::new(&full_body[..], ..);
920        let serializer = env.new_serializer(&full_body[..]);
921        let mut fragmenter = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU)
922            .expect("create fragmenter");
923
924        let mut frag = fragmenter.next_serialized();
925        env.check_fragment(&mut frag, FragmentPosition::First, body_view.prefix_len());
926        assert_eq!(
927            frag.as_ref(),
928            body_view.buffer_view().take_front(fragmenter.max_fragment_body_first).unwrap()
929        );
930
931        for _ in 0..middle_fragments {
932            let mut frag = fragmenter.next_serialized();
933            env.check_fragment(&mut frag, FragmentPosition::Middle, body_view.prefix_len());
934            assert_eq!(
935                frag.as_ref(),
936                body_view.buffer_view().take_front(fragmenter.max_fragment_body_remaining).unwrap()
937            );
938        }
939
940        let mut frag = fragmenter.next_serialized();
941        env.check_fragment(&mut frag, FragmentPosition::Last, body_view.prefix_len());
942        assert_eq!(frag.as_ref(), body_view.buffer_view().into_rest());
943
944        // No more fragments.
945        assert!(fragmenter.next().is_none());
946    }
947
948    #[test_case(Ipv4TestEnv::dont_frag())]
949    #[test_case(Ipv4WithOptionsTestEnv(Ipv4TestEnv::dont_frag()))]
950    #[test_case(ForwardingTestEnv(Ipv4TestEnv::dont_frag()))]
951    #[test_case(ForwardingTestEnv(Ipv6TestEnv))]
952    fn not_allowed<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
953        let body = gen_body(usize::from(TEST_MTU));
954        let serializer = env.new_serializer(&body[..]);
955        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
956        assert_eq!(result, Err(FragmentationError::NotAllowed))
957    }
958
959    #[test_case(Ipv4TestEnv::default())]
960    #[test_case(Ipv4WithOptionsTestEnv::default())]
961    #[test_case(ForwardingTestEnv(Ipv4TestEnv::default()))]
962    #[test_case(Ipv6TestEnv)]
963    fn mtu_too_small<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
964        let body = gen_body(usize::from(TEST_MTU));
965        let serializer = env.new_serializer(&body[..]);
966        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, Mtu::new(10)).map(|_| ());
967        assert_eq!(result, Err(FragmentationError::MtuTooSmall));
968    }
969
970    #[test_case(Ipv4TestEnv::default())]
971    #[test_case(Ipv4WithOptionsTestEnv::default())]
972    #[test_case(Ipv6TestEnv)]
973    fn body_too_long<I: FragmentationIpExt, E: FragmentationTestEnv<I>>(env: E) {
974        let body = gen_body(MAX_FRAGMENT_OFFSET + usize::from(TEST_MTU));
975        let serializer = env.new_serializer(&body[..]);
976        let result = IpFragmenter::new(&mut FixedIdContext, &serializer, TEST_MTU).map(|_| ());
977        assert_eq!(result, Err(FragmentationError::BodyTooLong));
978    }
979}