netstack3_base/tcp/
segment.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The definition of a TCP segment.
6
7use crate::alloc::borrow::ToOwned;
8use core::borrow::Borrow;
9use core::convert::TryFrom as _;
10use core::fmt::Debug;
11use core::mem::MaybeUninit;
12use core::num::{NonZeroU16, TryFromIntError};
13use core::ops::Range;
14
15use arrayvec::ArrayVec;
16use log::info;
17use net_types::ip::IpAddress;
18use packet::records::options::OptionSequenceBuilder;
19use packet::InnerSerializer;
20use packet_formats::tcp::options::{TcpOption, TcpSackBlock};
21use packet_formats::tcp::{TcpSegment, TcpSegmentBuilder, TcpSegmentBuilderWithOptions};
22use thiserror::Error;
23
24use super::base::{Control, Mss};
25use super::seqnum::{SeqNum, UnscaledWindowSize, WindowScale, WindowSize};
26
27/// A TCP segment.
28#[derive(Debug, PartialEq, Eq, Clone)]
29pub struct Segment<P> {
30    /// The non-payload information of the segment.
31    header: SegmentHeader,
32    /// The data carried by the segment.
33    ///
34    /// It is guaranteed that data.len() plus the length of the control flag
35    /// (SYN or FIN) is <= MAX_PAYLOAD_AND_CONTROL_LEN
36    data: P,
37}
38
39/// All non-data portions of a TCP segment.
40#[derive(Debug, PartialEq, Eq, Clone)]
41pub struct SegmentHeader {
42    /// The sequence number of the segment.
43    pub seq: SeqNum,
44    /// The acknowledge number of the segment. [`None`] if not present.
45    pub ack: Option<SeqNum>,
46    /// The advertised window size.
47    pub wnd: UnscaledWindowSize,
48    /// The control flag of the segment.
49    pub control: Option<Control>,
50    /// Indicates whether the PSH bit is set.
51    pub push: bool,
52    /// Options carried by this segment.
53    pub options: Options,
54}
55
56/// Contains all supported TCP options.
57#[derive(Debug, PartialEq, Eq, Clone)]
58pub enum Options {
59    /// Options present in a handshake segment.
60    Handshake(HandshakeOptions),
61    /// Options present in a regular segment.
62    Segment(SegmentOptions),
63}
64
65impl Default for Options {
66    fn default() -> Self {
67        // Default to a non handshake options value, since those are more
68        // common.
69        Self::Segment(SegmentOptions::default())
70    }
71}
72
73impl From<HandshakeOptions> for Options {
74    fn from(value: HandshakeOptions) -> Self {
75        Self::Handshake(value)
76    }
77}
78
79impl From<SegmentOptions> for Options {
80    fn from(value: SegmentOptions) -> Self {
81        Self::Segment(value)
82    }
83}
84
85impl Options {
86    /// Returns an iterator over the contained options.
87    pub fn iter(&self) -> impl Iterator<Item = TcpOption<'_>> + Debug + Clone {
88        match self {
89            Options::Handshake(o) => either::Either::Left(o.iter()),
90            Options::Segment(o) => either::Either::Right(o.iter()),
91        }
92    }
93
94    fn as_handshake(&self) -> Option<&HandshakeOptions> {
95        match self {
96            Self::Handshake(h) => Some(h),
97            Self::Segment(_) => None,
98        }
99    }
100
101    fn as_handshake_mut(&mut self) -> Option<&mut HandshakeOptions> {
102        match self {
103            Self::Handshake(h) => Some(h),
104            Self::Segment(_) => None,
105        }
106    }
107
108    fn as_segment(&self) -> Option<&SegmentOptions> {
109        match self {
110            Self::Handshake(_) => None,
111            Self::Segment(s) => Some(s),
112        }
113    }
114
115    fn as_segment_mut(&mut self) -> Option<&mut SegmentOptions> {
116        match self {
117            Self::Handshake(_) => None,
118            Self::Segment(s) => Some(s),
119        }
120    }
121
122    /// Returns a new empty [`Options`] with the variant dictated by
123    /// `handshake`.
124    pub fn new_with_handshake(handshake: bool) -> Self {
125        if handshake {
126            Self::Handshake(Default::default())
127        } else {
128            Self::Segment(Default::default())
129        }
130    }
131
132    /// Creates a new [`Options`] from an iterator of TcpOption.
133    ///
134    /// If `handshake` is `true`, only the handshake options will be parsed.
135    /// Otherwise only the non-handshake options are parsed.
136    pub fn from_iter<'a>(handshake: bool, iter: impl IntoIterator<Item = TcpOption<'a>>) -> Self {
137        let mut options = Self::new_with_handshake(handshake);
138        for option in iter {
139            match option {
140                TcpOption::Mss(mss) => {
141                    if let Some(h) = options.as_handshake_mut() {
142                        h.mss = NonZeroU16::new(mss).map(Mss);
143                    }
144                }
145                TcpOption::WindowScale(ws) => {
146                    if let Some(h) = options.as_handshake_mut() {
147                        // Per RFC 7323 Section 2.3:
148                        //   If a Window Scale option is received with a shift.cnt
149                        //   value larger than 14, the TCP SHOULD log the error but
150                        //   MUST use 14 instead of the specified value.
151                        if ws > WindowScale::MAX.get() {
152                            info!(
153                                "received an out-of-range window scale: {}, want < {}",
154                                ws,
155                                WindowScale::MAX.get()
156                            );
157                        }
158                        h.window_scale = Some(WindowScale::new(ws).unwrap_or(WindowScale::MAX));
159                    }
160                }
161                TcpOption::SackPermitted => {
162                    if let Some(h) = options.as_handshake_mut() {
163                        h.sack_permitted = true;
164                    }
165                }
166                TcpOption::Sack(sack) => {
167                    if let Some(seg) = options.as_segment_mut() {
168                        seg.sack_blocks = SackBlocks::from_option(sack);
169                    }
170                }
171                // TODO(https://fxbug.dev/42072902): We don't support these yet.
172                TcpOption::Timestamp { ts_val: _, ts_echo_reply: _ } => {}
173            }
174        }
175        options
176    }
177
178    /// Reads the window scale if this is an [`Options::Handshake`].
179    pub fn window_scale(&self) -> Option<WindowScale> {
180        self.as_handshake().and_then(|h| h.window_scale)
181    }
182
183    /// Reads the mss option if this is an [`Options::Handshake`].
184    pub fn mss(&self) -> Option<Mss> {
185        self.as_handshake().and_then(|h| h.mss)
186    }
187
188    /// Returns true IFF this is an [`Options::Handshake`] and its
189    /// [`HandShakeOptions::sack_permitted`] is set.
190    pub fn sack_permitted(&self) -> bool {
191        self.as_handshake().is_some_and(|o| o.sack_permitted)
192    }
193
194    /// Returns the segment's selective ack blocks.
195    ///
196    /// Returns a reference to empty blocks if this is not [`Options::Segment`].
197    pub fn sack_blocks(&self) -> &SackBlocks {
198        const EMPTY_REF: &'static SackBlocks = &SackBlocks::EMPTY;
199        self.as_segment().map(|s| &s.sack_blocks).unwrap_or(EMPTY_REF)
200    }
201}
202
203/// Segment options only set on handshake.
204#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
205pub struct HandshakeOptions {
206    /// The MSS option.
207    pub mss: Option<Mss>,
208
209    /// The WS option.
210    pub window_scale: Option<WindowScale>,
211
212    /// The SACK permitted option.
213    pub sack_permitted: bool,
214}
215
216impl HandshakeOptions {
217    /// Returns an iterator over the contained options.
218    pub fn iter(&self) -> impl Iterator<Item = TcpOption<'_>> + Debug + Clone {
219        let Self { mss, window_scale, sack_permitted } = self;
220        mss.map(|mss| TcpOption::Mss(mss.get().get()))
221            .into_iter()
222            .chain(window_scale.map(|ws| TcpOption::WindowScale(ws.get())))
223            .chain((*sack_permitted).then_some(TcpOption::SackPermitted))
224    }
225}
226
227/// Segment options set on non-handshake segments.
228#[derive(Debug, Default, PartialEq, Eq, Clone)]
229pub struct SegmentOptions {
230    /// The SACK option.
231    pub sack_blocks: SackBlocks,
232}
233
234impl SegmentOptions {
235    /// Returns an iterator over the contained options.
236    pub fn iter(&self) -> impl Iterator<Item = TcpOption<'_>> + Debug + Clone {
237        let Self { sack_blocks } = self;
238        sack_blocks.as_option().into_iter()
239    }
240}
241
242const MAX_SACK_BLOCKS: usize = 4;
243/// Blocks of selective ACKs.
244#[derive(Debug, Default, PartialEq, Eq, Clone)]
245pub struct SackBlocks(ArrayVec<TcpSackBlock, MAX_SACK_BLOCKS>);
246
247impl SackBlocks {
248    /// A constant empty instance of SACK blocks.
249    pub const EMPTY: Self = SackBlocks(ArrayVec::new_const());
250
251    /// The maximum number of selective ack blocks that can be in a TCP segment.
252    ///
253    /// See [RFC 2018 section 3].
254    ///
255    /// [RFC 2018 section 3] https://www.rfc-editor.org/rfc/rfc2018#section-3
256    pub const MAX_BLOCKS: usize = MAX_SACK_BLOCKS;
257
258    /// Returns the contained selective ACKs as a TCP option.
259    ///
260    /// Returns `None` if this [`SackBlocks`] is empty.
261    pub fn as_option(&self) -> Option<TcpOption<'_>> {
262        let Self(inner) = self;
263        if inner.is_empty() {
264            return None;
265        }
266
267        Some(TcpOption::Sack(inner.as_slice()))
268    }
269
270    /// Returns an iterator over the *valid* [`SackBlock`]s contained in this
271    /// option.
272    pub fn iter_skip_invalid(&self) -> impl Iterator<Item = SackBlock> + '_ {
273        self.try_iter().filter_map(|r| match r {
274            Ok(s) => Some(s),
275            Err(InvalidSackBlockError(_, _)) => None,
276        })
277    }
278
279    /// Returns an iterator yielding the results of converting the blocks in
280    /// this option to valid [`SackBlock`]s.
281    pub fn try_iter(&self) -> impl Iterator<Item = Result<SackBlock, InvalidSackBlockError>> + '_ {
282        let Self(inner) = self;
283        inner.iter().map(|block| SackBlock::try_from(*block))
284    }
285
286    /// Creates a new [`SackBlocks`] option from a slice of blocks seen in a TCP
287    /// segment.
288    ///
289    /// Ignores any blocks past [`SackBlocks::MAX_BLOCKS`].
290    pub fn from_option(blocks: &[TcpSackBlock]) -> Self {
291        Self(blocks.iter().take(Self::MAX_BLOCKS).copied().collect())
292    }
293
294    /// Returns `true` if there are no blocks present.
295    pub fn is_empty(&self) -> bool {
296        let Self(inner) = self;
297        inner.is_empty()
298    }
299
300    /// Drops all blocks.
301    pub fn clear(&mut self) {
302        let Self(inner) = self;
303        inner.clear()
304    }
305}
306
307/// Creates a new [`SackBlocks`] option from an iterator of [`SackBlock`].
308///
309/// Ignores any blocks past [`SackBlocks::MAX_BLOCKS`].
310impl FromIterator<SackBlock> for SackBlocks {
311    fn from_iter<T: IntoIterator<Item = SackBlock>>(iter: T) -> Self {
312        Self(iter.into_iter().take(Self::MAX_BLOCKS).map(|b| b.into()).collect())
313    }
314}
315
316mod sack_block {
317    use super::*;
318
319    /// A selective ACK block.
320    ///
321    /// Contains the left and right markers for a received data segment. It is a
322    /// witness for a valid non empty open range of `SeqNum`.
323    #[derive(Debug, PartialEq, Eq, Clone, Copy)]
324    pub struct SackBlock {
325        // NB: We don't use core::ops::Range here because it doesn't implement Copy.
326        left: SeqNum,
327        right: SeqNum,
328    }
329
330    impl SackBlock {
331        /// Attempts to create a new [`SackBlock`] with the range `[left, right)`.
332        ///
333        /// Returns an error if `right` is at or before `left`.
334        pub fn try_new(left: SeqNum, right: SeqNum) -> Result<Self, InvalidSackBlockError> {
335            if right.after(left) {
336                Ok(Self { left, right })
337            } else {
338                Err(InvalidSackBlockError(left, right))
339            }
340        }
341
342        /// Creates a new [`SackBlock`] without checking that `right` is
343        /// strictly after `left`.
344        ///
345        /// # Safety
346        ///
347        /// Caller must guarantee that `right.after(left)`.
348        pub unsafe fn new_unchecked(left: SeqNum, right: SeqNum) -> Self {
349            Self { left, right }
350        }
351
352        /// Consumes this [`SackBlock`] returning a [`Range`] representation.
353        pub fn into_range(self) -> Range<SeqNum> {
354            let Self { left, right } = self;
355            Range { start: left, end: right }
356        }
357
358        /// Consumes this [`SackBlock`] returning a [`Range`] representation
359        /// unwrapping the [`SeqNum`] representation into `u32`.
360        pub fn into_range_u32(self) -> Range<u32> {
361            let Self { left, right } = self;
362            Range { start: left.into(), end: right.into() }
363        }
364
365        /// Returns the left (inclusive) edge of the block.
366        pub fn left(&self) -> SeqNum {
367            self.left
368        }
369
370        /// Returns the right (exclusive) edge of the block.
371        pub fn right(&self) -> SeqNum {
372            self.right
373        }
374
375        /// Returns a tuple of the left (inclusive) and right (exclusive) edges
376        /// of the block.
377        pub fn into_parts(self) -> (SeqNum, SeqNum) {
378            let Self { left, right } = self;
379            (left, right)
380        }
381    }
382
383    /// Error returned when attempting to create a [`SackBlock`] with an invalid
384    /// range (i.e. right edge <= left edge).
385    #[derive(Debug, Eq, PartialEq, Clone, Copy)]
386    pub struct InvalidSackBlockError(pub SeqNum, pub SeqNum);
387
388    impl From<SackBlock> for TcpSackBlock {
389        fn from(value: SackBlock) -> Self {
390            let SackBlock { left, right } = value;
391            TcpSackBlock::new(left.into(), right.into())
392        }
393    }
394
395    impl TryFrom<TcpSackBlock> for SackBlock {
396        type Error = InvalidSackBlockError;
397
398        fn try_from(value: TcpSackBlock) -> Result<Self, Self::Error> {
399            Self::try_new(value.left_edge().into(), value.right_edge().into())
400        }
401    }
402
403    impl From<SackBlock> for Range<SeqNum> {
404        fn from(value: SackBlock) -> Self {
405            value.into_range()
406        }
407    }
408
409    impl TryFrom<Range<SeqNum>> for SackBlock {
410        type Error = InvalidSackBlockError;
411
412        fn try_from(value: Range<SeqNum>) -> Result<Self, Self::Error> {
413            let Range { start, end } = value;
414            Self::try_new(start, end)
415        }
416    }
417}
418pub use sack_block::{InvalidSackBlockError, SackBlock};
419
420/// The maximum length that the sequence number doesn't wrap around.
421pub const MAX_PAYLOAD_AND_CONTROL_LEN: usize = 1 << 31;
422// The following `as` is sound because it is representable by `u32`.
423const MAX_PAYLOAD_AND_CONTROL_LEN_U32: u32 = MAX_PAYLOAD_AND_CONTROL_LEN as u32;
424
425impl<P: Payload> Segment<P> {
426    /// Creates a new segment with the provided header and data.
427    ///
428    /// Returns the segment along with how many bytes were removed to make sure
429    /// sequence numbers don't wrap around, i.e., `seq.before(seq + seg.len())`.
430    pub fn new(header: SegmentHeader, data: P) -> (Self, usize) {
431        let SegmentHeader { seq, ack, wnd, control, push, options } = header;
432        let has_control_len = control.map(Control::has_sequence_no).unwrap_or(false);
433
434        let data_len = data.len();
435        let discarded_len =
436            data_len.saturating_sub(MAX_PAYLOAD_AND_CONTROL_LEN - usize::from(has_control_len));
437
438        // Only keep the PSH bit if data is not empty.
439        let push = push && data_len != 0;
440
441        let (control, data) = if discarded_len > 0 {
442            // If we have to truncate the segment, the FIN flag must be removed
443            // because it is logically the last octet of the segment.
444            let (control, control_len) = if control == Some(Control::FIN) {
445                (None, 0)
446            } else {
447                (control, has_control_len.into())
448            };
449            // The following slice will not panic because `discarded_len > 0`,
450            // thus `data.len() > MAX_PAYLOAD_AND_CONTROL_LEN - control_len`.
451            (control, data.slice(0..MAX_PAYLOAD_AND_CONTROL_LEN_U32 - control_len))
452        } else {
453            (control, data)
454        };
455
456        (
457            Segment { header: SegmentHeader { seq, ack, wnd, control, push, options }, data: data },
458            discarded_len,
459        )
460    }
461
462    /// Returns a borrow of the segment's header.
463    pub fn header(&self) -> &SegmentHeader {
464        &self.header
465    }
466
467    /// Returns a borrow of the data payload in this segment.
468    pub fn data(&self) -> &P {
469        &self.data
470    }
471
472    /// Destructures self into its inner parts: The segment header and the data
473    /// payload.
474    pub fn into_parts(self) -> (SegmentHeader, P) {
475        let Self { header, data } = self;
476        (header, data)
477    }
478
479    /// Maps the payload in the segment with `f`.
480    pub fn map_payload<R, F: FnOnce(P) -> R>(self, f: F) -> Segment<R> {
481        let Segment { header, data } = self;
482        Segment { header, data: f(data) }
483    }
484
485    /// Returns the length of the segment in sequence number space.
486    ///
487    /// Per RFC 793 (https://tools.ietf.org/html/rfc793#page-25):
488    ///   SEG.LEN = the number of octets occupied by the data in the segment
489    ///   (counting SYN and FIN)
490    pub fn len(&self) -> u32 {
491        self.header.len(self.data.len())
492    }
493
494    /// Returns the part of the incoming segment within the receive window.
495    pub fn overlap(self, rnxt: SeqNum, rwnd: WindowSize) -> Option<Segment<P>> {
496        let len = self.len();
497        let Segment { header: SegmentHeader { seq, ack, wnd, control, options, push }, data } =
498            self;
499
500        // RFC 793 (https://tools.ietf.org/html/rfc793#page-69):
501        //   There are four cases for the acceptability test for an incoming
502        //   segment:
503        //       Segment Receive  Test
504        //       Length  Window
505        //       ------- -------  -------------------------------------------
506        //          0       0     SEG.SEQ = RCV.NXT
507        //          0      >0     RCV.NXT =< SEG.SEQ < RCV.NXT+RCV.WND
508        //         >0       0     not acceptable
509        //         >0      >0     RCV.NXT =< SEG.SEQ < RCV.NXT+RCV.WND
510        //                     or RCV.NXT =< SEG.SEQ+SEG.LEN-1 < RCV.NXT+RCV.WND
511        let overlap = match (len, rwnd) {
512            (0, WindowSize::ZERO) => seq == rnxt,
513            (0, rwnd) => !rnxt.after(seq) && seq.before(rnxt + rwnd),
514            (_len, WindowSize::ZERO) => false,
515            (len, rwnd) => {
516                (!rnxt.after(seq) && seq.before(rnxt + rwnd))
517                    // Note: here we use RCV.NXT <= SEG.SEQ+SEG.LEN instead of
518                    // the condition as quoted above because of the following
519                    // text immediately after the above table:
520                    //   One could tailor actual segments to fit this assumption by
521                    //   trimming off any portions that lie outside the window
522                    //   (including SYN and FIN), and only processing further if
523                    //   the segment then begins at RCV.NXT.
524                    // This is essential for TCP simultaneous open to work,
525                    // otherwise, the state machine would reject the SYN-ACK
526                    // sent by the peer.
527                    || (!(seq + len).before(rnxt) && !(seq + len).after(rnxt + rwnd))
528            }
529        };
530        overlap.then(move || {
531            // We deliberately don't define `PartialOrd` for `SeqNum`, so we use
532            // `cmp` below to utilize `cmp::{max,min}_by`.
533            let cmp = |lhs: &SeqNum, rhs: &SeqNum| (*lhs - *rhs).cmp(&0);
534            let new_seq = core::cmp::max_by(seq, rnxt, cmp);
535            let new_len = core::cmp::min_by(seq + len, rnxt + rwnd, cmp) - new_seq;
536            // The following unwrap won't panic because:
537            // 1. if `seq` is after `rnxt`, then `start` would be 0.
538            // 2. the interesting case is when `rnxt` is after `seq`, in that
539            // case, we have `rnxt - seq > 0`, thus `new_seq - seq > 0`.
540            let start = u32::try_from(new_seq - seq).unwrap();
541            // The following unwrap won't panic because:
542            // 1. The witness on `Segment` and `WindowSize` guarantees that
543            // `len <= 2^31` and `rwnd <= 2^30-1` thus
544            // `seq <= seq + len` and `rnxt <= rnxt + rwnd`.
545            // 2. We are in the closure because `overlap` is true which means
546            // `seq <= rnxt + rwnd` and `rnxt <= seq + len`.
547            // With these two conditions combined, `new_len` can't be negative
548            // so the unwrap can't panic.
549            let new_len = u32::try_from(new_len).unwrap();
550            let (new_control, new_data) = {
551                match control {
552                    Some(Control::SYN) => {
553                        if start == 0 {
554                            (Some(Control::SYN), data.slice(start..start + new_len - 1))
555                        } else {
556                            (None, data.slice(start - 1..start + new_len - 1))
557                        }
558                    }
559                    Some(Control::FIN) => {
560                        if len == start + new_len {
561                            if new_len > 0 {
562                                (Some(Control::FIN), data.slice(start..start + new_len - 1))
563                            } else {
564                                (None, data.slice(start - 1..start - 1))
565                            }
566                        } else {
567                            (None, data.slice(start..start + new_len))
568                        }
569                    }
570                    Some(Control::RST) | None => (control, data.slice(start..start + new_len)),
571                }
572            };
573            Segment {
574                header: SegmentHeader {
575                    seq: new_seq,
576                    ack,
577                    wnd,
578                    control: new_control,
579                    options,
580                    push,
581                },
582                data: new_data,
583            }
584        })
585    }
586
587    /// Creates a segment with no data.
588    pub fn new_empty(header: SegmentHeader) -> Self {
589        // All of the checks on lengths are optimized out:
590        // https://godbolt.org/z/KPd537G6Y
591        let (seg, truncated) = Self::new(header, P::new_empty());
592        debug_assert_eq!(truncated, 0);
593        seg
594    }
595
596    /// Creates an ACK segment.
597    pub fn ack(seq: SeqNum, ack: SeqNum, wnd: UnscaledWindowSize) -> Self {
598        Self::ack_with_options(seq, ack, wnd, Options::default())
599    }
600
601    /// Creates an ACK segment with options.
602    pub fn ack_with_options(
603        seq: SeqNum,
604        ack: SeqNum,
605        wnd: UnscaledWindowSize,
606        options: Options,
607    ) -> Self {
608        Segment::new_empty(SegmentHeader {
609            seq,
610            ack: Some(ack),
611            wnd,
612            control: None,
613            push: false,
614            options,
615        })
616    }
617
618    /// Creates a SYN segment.
619    pub fn syn(seq: SeqNum, wnd: UnscaledWindowSize, options: Options) -> Self {
620        Segment::new_empty(SegmentHeader {
621            seq,
622            ack: None,
623            wnd,
624            control: Some(Control::SYN),
625            push: false,
626            options,
627        })
628    }
629
630    /// Creates a SYN-ACK segment.
631    pub fn syn_ack(seq: SeqNum, ack: SeqNum, wnd: UnscaledWindowSize, options: Options) -> Self {
632        Segment::new_empty(SegmentHeader {
633            seq,
634            ack: Some(ack),
635            wnd,
636            control: Some(Control::SYN),
637            push: false,
638            options,
639        })
640    }
641
642    /// Creates a RST segment.
643    pub fn rst(seq: SeqNum) -> Self {
644        Segment::new_empty(SegmentHeader {
645            seq,
646            ack: None,
647            wnd: UnscaledWindowSize::from(0),
648            control: Some(Control::RST),
649            push: false,
650            options: Options::default(),
651        })
652    }
653
654    /// Creates a RST-ACK segment.
655    pub fn rst_ack(seq: SeqNum, ack: SeqNum) -> Self {
656        Segment::new_empty(SegmentHeader {
657            seq,
658            ack: Some(ack),
659            wnd: UnscaledWindowSize::from(0),
660            control: Some(Control::RST),
661            push: false,
662            options: Options::default(),
663        })
664    }
665}
666
667impl Segment<()> {
668    /// Converts this segment with `()` data into any `P` payload's `new_empty`
669    /// form.
670    pub fn into_empty<P: Payload>(self) -> Segment<P> {
671        self.map_payload(|()| P::new_empty())
672    }
673}
674
675impl SegmentHeader {
676    /// Returns the length of the segment in sequence number space.
677    ///
678    /// Per RFC 793 (https://tools.ietf.org/html/rfc793#page-25):
679    ///   SEG.LEN = the number of octets occupied by the data in the segment
680    ///   (counting SYN and FIN)
681    pub fn len(&self, payload_len: usize) -> u32 {
682        // The following unwrap and addition are fine because:
683        // - `u32::from(has_control_len)` is 0 or 1.
684        // - `self.data.len() <= 2^31`.
685        let has_control_len = self.control.map(Control::has_sequence_no).unwrap_or(false);
686        u32::try_from(payload_len).unwrap() + u32::from(has_control_len)
687    }
688
689    /// Create a `SegmentHeader` from the provided builder and data length.  The
690    /// options will be set to their default values.
691    pub fn from_builder<A: IpAddress>(
692        builder: &TcpSegmentBuilder<A>,
693    ) -> Result<Self, MalformedFlags> {
694        Self::from_builder_options(builder, Options::new_with_handshake(builder.syn_set()))
695    }
696
697    /// Create a `SegmentHeader` from the provided builder, options, and data length.
698    pub fn from_builder_options<A: IpAddress>(
699        builder: &TcpSegmentBuilder<A>,
700        options: Options,
701    ) -> Result<Self, MalformedFlags> {
702        Ok(SegmentHeader {
703            seq: SeqNum::new(builder.seq_num()),
704            ack: builder.ack_num().map(SeqNum::new),
705            control: Flags {
706                syn: builder.syn_set(),
707                fin: builder.fin_set(),
708                rst: builder.rst_set(),
709            }
710            .control()?,
711            wnd: UnscaledWindowSize::from(builder.window_size()),
712            push: builder.psh_set(),
713            options: options,
714        })
715    }
716}
717
718/// A TCP payload that only allows for getting the length of the payload.
719pub trait PayloadLen {
720    /// Returns the length of the payload.
721    fn len(&self) -> usize;
722}
723
724/// A TCP payload that operates around `u32` instead of `usize`.
725pub trait Payload: PayloadLen + Sized {
726    /// Creates a slice of the payload, reducing it to only the bytes within
727    /// `range`.
728    ///
729    /// # Panics
730    ///
731    /// Panics if the provided `range` is not within the bounds of this
732    /// `Payload`, or if the range is nonsensical (the end precedes
733    /// the start).
734    fn slice(self, range: Range<u32>) -> Self;
735
736    /// Copies part of the payload beginning at `offset` into `dst`.
737    ///
738    /// # Panics
739    ///
740    /// Panics if offset is too large or we couldn't fill the `dst` slice.
741    fn partial_copy(&self, offset: usize, dst: &mut [u8]);
742
743    /// Copies part of the payload beginning at `offset` into `dst`.
744    ///
745    /// # Panics
746    ///
747    /// Panics if offset is too large or we couldn't fill the `dst` slice.
748    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]);
749
750    /// Creates a new empty payload.
751    ///
752    /// An empty payload must report 0 as its length.
753    fn new_empty() -> Self;
754}
755
756impl PayloadLen for &[u8] {
757    fn len(&self) -> usize {
758        <[u8]>::len(self)
759    }
760}
761
762impl Payload for &[u8] {
763    fn slice(self, Range { start, end }: Range<u32>) -> Self {
764        // The following `unwrap`s are ok because:
765        // `usize::try_from(x)` fails when `x > usize::MAX`; given that
766        // `self.len() <= usize::MAX`, panic would be expected because `range`
767        // exceeds the bound of `self`.
768        let start = usize::try_from(start).unwrap_or_else(|TryFromIntError { .. }| {
769            panic!("range start index {} out of range for slice of length {}", start, self.len())
770        });
771        let end = usize::try_from(end).unwrap_or_else(|TryFromIntError { .. }| {
772            panic!("range end index {} out of range for slice of length {}", end, self.len())
773        });
774        &self[start..end]
775    }
776
777    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
778        dst.copy_from_slice(&self[offset..offset + dst.len()])
779    }
780
781    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
782        // TODO(https://github.com/rust-lang/rust/issues/79995): Replace unsafe
783        // with copy_from_slice when stabiliized.
784        let src = &self[offset..offset + dst.len()];
785        // SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
786        let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
787        dst.copy_from_slice(&uninit_src);
788    }
789
790    fn new_empty() -> Self {
791        &[]
792    }
793}
794
795impl PayloadLen for () {
796    fn len(&self) -> usize {
797        0
798    }
799}
800
801impl Payload for () {
802    fn slice(self, Range { start, end }: Range<u32>) -> Self {
803        if start != 0 {
804            panic!("range start index {} out of range for slice of length 0", start);
805        }
806        if end != 0 {
807            panic!("range end index {} out of range for slice of length 0", end);
808        }
809        ()
810    }
811
812    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
813        if dst.len() != 0 || offset != 0 {
814            panic!(
815                "source slice length (0) does not match destination slice length ({})",
816                dst.len()
817            );
818        }
819    }
820
821    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
822        if dst.len() != 0 || offset != 0 {
823            panic!(
824                "source slice length (0) does not match destination slice length ({})",
825                dst.len()
826            );
827        }
828    }
829
830    fn new_empty() -> Self {
831        ()
832    }
833}
834
835impl<I: PayloadLen, B> PayloadLen for InnerSerializer<I, B> {
836    fn len(&self) -> usize {
837        PayloadLen::len(self.inner())
838    }
839}
840
841#[derive(Error, Debug, PartialEq, Eq)]
842#[error("multiple mutually exclusive flags are set: syn: {syn}, fin: {fin}, rst: {rst}")]
843pub struct MalformedFlags {
844    syn: bool,
845    fin: bool,
846    rst: bool,
847}
848
849struct Flags {
850    syn: bool,
851    fin: bool,
852    rst: bool,
853}
854
855impl Flags {
856    fn control(&self) -> Result<Option<Control>, MalformedFlags> {
857        if usize::from(self.syn) + usize::from(self.fin) + usize::from(self.rst) > 1 {
858            return Err(MalformedFlags { syn: self.syn, fin: self.fin, rst: self.rst });
859        }
860
861        let syn = self.syn.then_some(Control::SYN);
862        let fin = self.fin.then_some(Control::FIN);
863        let rst = self.rst.then_some(Control::RST);
864
865        Ok(syn.or(fin).or(rst))
866    }
867}
868
869/// A TCP segment that has been verified to have valid flags. Can be converted
870/// to a `Segment`.
871pub struct VerifiedTcpSegment<'a> {
872    segment: TcpSegment<&'a [u8]>,
873    control: Option<Control>,
874}
875
876impl<'a> VerifiedTcpSegment<'a> {
877    /// Returns the underlying [`TcpSegment`].
878    pub fn tcp_segment(&self) -> &TcpSegment<&'a [u8]> {
879        &self.segment
880    }
881
882    /// Returns the control flag of the segment.
883    pub fn control(&self) -> Option<Control> {
884        self.control
885    }
886}
887
888impl<'a> TryFrom<TcpSegment<&'a [u8]>> for VerifiedTcpSegment<'a> {
889    type Error = MalformedFlags;
890
891    fn try_from(segment: TcpSegment<&'a [u8]>) -> Result<Self, Self::Error> {
892        let control =
893            Flags { syn: segment.syn(), fin: segment.fin(), rst: segment.rst() }.control()?;
894        Ok(VerifiedTcpSegment { segment, control })
895    }
896}
897
898impl<'a> From<&'a VerifiedTcpSegment<'a>> for Segment<&'a [u8]> {
899    fn from(from: &'a VerifiedTcpSegment<'a>) -> Segment<&'a [u8]> {
900        let options = Options::from_iter(from.segment.syn(), from.segment.iter_options());
901        let (to, discarded) = Segment::new(
902            SegmentHeader {
903                seq: from.segment.seq_num().into(),
904                ack: from.segment.ack_num().map(Into::into),
905                wnd: UnscaledWindowSize::from(from.segment.window_size()),
906                control: from.control,
907                push: from.segment.psh(),
908                options,
909            },
910            from.segment.body(),
911        );
912        debug_assert_eq!(discarded, 0);
913        to
914    }
915}
916
917impl<A> TryFrom<&TcpSegmentBuilder<A>> for SegmentHeader
918where
919    A: IpAddress,
920{
921    type Error = MalformedFlags;
922
923    fn try_from(from: &TcpSegmentBuilder<A>) -> Result<Self, Self::Error> {
924        SegmentHeader::from_builder(from)
925    }
926}
927
928impl<'a, A, I> TryFrom<&TcpSegmentBuilderWithOptions<A, OptionSequenceBuilder<TcpOption<'a>, I>>>
929    for SegmentHeader
930where
931    A: IpAddress,
932    I: Iterator + Clone,
933    I::Item: Borrow<TcpOption<'a>>,
934{
935    type Error = MalformedFlags;
936
937    fn try_from(
938        from: &TcpSegmentBuilderWithOptions<A, OptionSequenceBuilder<TcpOption<'a>, I>>,
939    ) -> Result<Self, Self::Error> {
940        let prefix_builder = from.prefix_builder();
941        let handshake = prefix_builder.syn_set();
942        Self::from_builder_options(
943            prefix_builder,
944            Options::from_iter(
945                handshake,
946                from.iter_options().map(|option| option.borrow().to_owned()),
947            ),
948        )
949    }
950}
951
952#[cfg(any(test, feature = "testutils"))]
953mod testutils {
954    use super::*;
955
956    /// Provide a handy default implementation for tests only.
957    impl Default for SegmentHeader {
958        fn default() -> Self {
959            Self {
960                seq: SeqNum::new(0),
961                ack: None,
962                control: None,
963                wnd: UnscaledWindowSize::from(0),
964                options: Options::default(),
965                push: false,
966            }
967        }
968    }
969
970    impl<P: Payload> Segment<P> {
971        /// Like [`Segment::new`] but asserts that no bytes were discarded from
972        /// `data`.
973        #[track_caller]
974        pub fn new_assert_no_discard(header: SegmentHeader, data: P) -> Self {
975            let (seg, discard) = Self::new(header, data);
976            assert_eq!(discard, 0);
977            seg
978        }
979    }
980
981    impl<'a> Segment<&'a [u8]> {
982        /// Create a new segment with the given seq, ack, and data.
983        pub fn with_fake_data(seq: SeqNum, ack: SeqNum, data: &'a [u8]) -> Self {
984            Self::new_assert_no_discard(
985                SegmentHeader {
986                    seq,
987                    ack: Some(ack),
988                    control: None,
989                    wnd: UnscaledWindowSize::from(u16::MAX),
990                    options: Options::default(),
991                    push: false,
992                },
993                data,
994            )
995        }
996    }
997
998    impl<P: Payload> Segment<P> {
999        /// Creates a new segment with the provided data.
1000        pub fn with_data(seq: SeqNum, ack: SeqNum, wnd: UnscaledWindowSize, data: P) -> Segment<P> {
1001            Segment::new_assert_no_discard(
1002                SegmentHeader {
1003                    seq,
1004                    ack: Some(ack),
1005                    control: None,
1006                    wnd,
1007                    push: false,
1008                    options: Options::default(),
1009                },
1010                data,
1011            )
1012        }
1013
1014        /// Creates a new FIN segment with the provided data.
1015        pub fn piggybacked_fin(
1016            seq: SeqNum,
1017            ack: SeqNum,
1018            wnd: UnscaledWindowSize,
1019            data: P,
1020        ) -> Segment<P> {
1021            Segment::new_assert_no_discard(
1022                SegmentHeader {
1023                    seq,
1024                    ack: Some(ack),
1025                    control: Some(Control::FIN),
1026                    wnd,
1027                    push: false,
1028                    options: Options::default(),
1029                },
1030                data,
1031            )
1032        }
1033
1034        /// Creates a new FIN segment.
1035        pub fn fin(seq: SeqNum, ack: SeqNum, wnd: UnscaledWindowSize) -> Self {
1036            Segment::new_empty(SegmentHeader {
1037                seq,
1038                ack: Some(ack),
1039                control: Some(Control::FIN),
1040                wnd,
1041                push: false,
1042                options: Options::default(),
1043            })
1044        }
1045    }
1046}
1047
1048#[cfg(test)]
1049mod test {
1050    use assert_matches::assert_matches;
1051    use ip_test_macro::ip_test;
1052    use net_declare::{net_ip_v4, net_ip_v6};
1053    use net_types::ip::{Ipv4, Ipv6};
1054    use packet_formats::ip::IpExt;
1055    use test_case::test_case;
1056
1057    use super::*;
1058
1059    #[test_case(None, &[][..] => (0, &[][..]); "empty")]
1060    #[test_case(None, &[1][..] => (1, &[1][..]); "no control")]
1061    #[test_case(Some(Control::SYN), &[][..] => (1, &[][..]); "empty slice with syn")]
1062    #[test_case(Some(Control::SYN), &[1][..] => (2, &[1][..]); "non-empty slice with syn")]
1063    #[test_case(Some(Control::FIN), &[][..] => (1, &[][..]); "empty slice with fin")]
1064    #[test_case(Some(Control::FIN), &[1][..] => (2, &[1][..]); "non-empty slice with fin")]
1065    #[test_case(Some(Control::RST), &[][..] => (0, &[][..]); "empty slice with rst")]
1066    #[test_case(Some(Control::RST), &[1][..] => (1, &[1][..]); "non-empty slice with rst")]
1067    fn segment_len(control: Option<Control>, data: &[u8]) -> (u32, &[u8]) {
1068        let (seg, truncated) = Segment::new(
1069            SegmentHeader {
1070                seq: SeqNum::new(1),
1071                ack: Some(SeqNum::new(1)),
1072                wnd: UnscaledWindowSize::from(0),
1073                control,
1074                push: false,
1075                options: Options::default(),
1076            },
1077            data,
1078        );
1079        assert_eq!(truncated, 0);
1080        (seg.len(), seg.data)
1081    }
1082
1083    #[test_case(&[1, 2, 3, 4, 5][..], 0..4 => [1, 2, 3, 4])]
1084    #[test_case((), 0..0 => [0, 0, 0, 0])]
1085    fn payload_slice_copy(data: impl Payload, range: Range<u32>) -> [u8; 4] {
1086        let sliced = data.slice(range);
1087        let mut buffer = [0; 4];
1088        sliced.partial_copy(0, &mut buffer[..sliced.len()]);
1089        buffer
1090    }
1091
1092    #[derive(Debug, PartialEq, Eq)]
1093    struct TestPayload(Range<u32>);
1094
1095    impl TestPayload {
1096        fn new(len: usize) -> Self {
1097            Self(0..u32::try_from(len).unwrap())
1098        }
1099    }
1100
1101    impl PayloadLen for TestPayload {
1102        fn len(&self) -> usize {
1103            self.0.len()
1104        }
1105    }
1106
1107    impl Payload for TestPayload {
1108        fn slice(self, range: Range<u32>) -> Self {
1109            let Self(this) = self;
1110            assert!(range.start >= this.start && range.end <= this.end);
1111            TestPayload(range)
1112        }
1113
1114        fn partial_copy(&self, _offset: usize, _dst: &mut [u8]) {
1115            unimplemented!("TestPayload doesn't carry any data");
1116        }
1117
1118        fn partial_copy_uninit(&self, _offset: usize, _dst: &mut [MaybeUninit<u8>]) {
1119            unimplemented!("TestPayload doesn't carry any data");
1120        }
1121
1122        fn new_empty() -> Self {
1123            Self(0..0)
1124        }
1125    }
1126
1127    #[test_case(100, Some(Control::SYN) => (100, Some(Control::SYN), 0))]
1128    #[test_case(100, Some(Control::FIN) => (100, Some(Control::FIN), 0))]
1129    #[test_case(100, Some(Control::RST) => (100, Some(Control::RST), 0))]
1130    #[test_case(100, None => (100, None, 0))]
1131    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::SYN)
1132    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::SYN), 0))]
1133    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::FIN)
1134    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::FIN), 0))]
1135    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::RST)
1136    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::RST), 0))]
1137    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN - 1, None
1138    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, None, 0))]
1139    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN, Some(Control::SYN)
1140    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::SYN), 1))]
1141    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN, Some(Control::FIN)
1142    => (MAX_PAYLOAD_AND_CONTROL_LEN, None, 1))]
1143    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN, Some(Control::RST)
1144    => (MAX_PAYLOAD_AND_CONTROL_LEN, Some(Control::RST), 0))]
1145    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN, None
1146    => (MAX_PAYLOAD_AND_CONTROL_LEN, None, 0))]
1147    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN + 1, Some(Control::SYN)
1148    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::SYN), 2))]
1149    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN + 1, Some(Control::FIN)
1150    => (MAX_PAYLOAD_AND_CONTROL_LEN, None, 2))]
1151    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN + 1, Some(Control::RST)
1152    => (MAX_PAYLOAD_AND_CONTROL_LEN, Some(Control::RST), 1))]
1153    #[test_case(MAX_PAYLOAD_AND_CONTROL_LEN + 1, None
1154    => (MAX_PAYLOAD_AND_CONTROL_LEN, None, 1))]
1155    #[test_case(u32::MAX as usize, Some(Control::SYN)
1156    => (MAX_PAYLOAD_AND_CONTROL_LEN - 1, Some(Control::SYN), 1 << 31))]
1157    fn segment_truncate(len: usize, control: Option<Control>) -> (usize, Option<Control>, usize) {
1158        let (seg, truncated) = Segment::new(
1159            SegmentHeader {
1160                seq: SeqNum::new(0),
1161                ack: None,
1162                wnd: UnscaledWindowSize::from(0),
1163                control,
1164                push: false,
1165                options: Options::default(),
1166            },
1167            TestPayload::new(len),
1168        );
1169        (seg.data.len(), seg.header.control, truncated)
1170    }
1171
1172    struct OverlapTestArgs {
1173        seg_seq: u32,
1174        control: Option<Control>,
1175        data_len: u32,
1176        rcv_nxt: u32,
1177        rcv_wnd: usize,
1178    }
1179    #[test_case(OverlapTestArgs{
1180        seg_seq: 1,
1181        control: None,
1182        data_len: 0,
1183        rcv_nxt: 0,
1184        rcv_wnd: 0,
1185    } => None)]
1186    #[test_case(OverlapTestArgs{
1187        seg_seq: 1,
1188        control: None,
1189        data_len: 0,
1190        rcv_nxt: 1,
1191        rcv_wnd: 0,
1192    } => Some((SeqNum::new(1), None, 0..0)))]
1193    #[test_case(OverlapTestArgs{
1194        seg_seq: 1,
1195        control: None,
1196        data_len: 0,
1197        rcv_nxt: 2,
1198        rcv_wnd: 0,
1199    } => None)]
1200    #[test_case(OverlapTestArgs{
1201        seg_seq: 1,
1202        control: Some(Control::SYN),
1203        data_len: 0,
1204        rcv_nxt: 2,
1205        rcv_wnd: 0,
1206    } => None)]
1207    #[test_case(OverlapTestArgs{
1208        seg_seq: 1,
1209        control: Some(Control::SYN),
1210        data_len: 0,
1211        rcv_nxt: 1,
1212        rcv_wnd: 0,
1213    } => None)]
1214    #[test_case(OverlapTestArgs{
1215        seg_seq: 1,
1216        control: Some(Control::SYN),
1217        data_len: 0,
1218        rcv_nxt: 0,
1219        rcv_wnd: 0,
1220    } => None)]
1221    #[test_case(OverlapTestArgs{
1222        seg_seq: 1,
1223        control: Some(Control::FIN),
1224        data_len: 0,
1225        rcv_nxt: 2,
1226        rcv_wnd: 0,
1227    } => None)]
1228    #[test_case(OverlapTestArgs{
1229        seg_seq: 1,
1230        control: Some(Control::FIN),
1231        data_len: 0,
1232        rcv_nxt: 1,
1233        rcv_wnd: 0,
1234    } => None)]
1235    #[test_case(OverlapTestArgs{
1236        seg_seq: 1,
1237        control: Some(Control::FIN),
1238        data_len: 0,
1239        rcv_nxt: 0,
1240        rcv_wnd: 0,
1241    } => None)]
1242    #[test_case(OverlapTestArgs{
1243        seg_seq: 0,
1244        control: None,
1245        data_len: 0,
1246        rcv_nxt: 1,
1247        rcv_wnd: 1,
1248    } => None)]
1249    #[test_case(OverlapTestArgs{
1250        seg_seq: 1,
1251        control: None,
1252        data_len: 0,
1253        rcv_nxt: 1,
1254        rcv_wnd: 1,
1255    } => Some((SeqNum::new(1), None, 0..0)))]
1256    #[test_case(OverlapTestArgs{
1257        seg_seq: 2,
1258        control: None,
1259        data_len: 0,
1260        rcv_nxt: 1,
1261        rcv_wnd: 1,
1262    } => None)]
1263    #[test_case(OverlapTestArgs{
1264        seg_seq: 0,
1265        control: None,
1266        data_len: 1,
1267        rcv_nxt: 1,
1268        rcv_wnd: 1,
1269    } => Some((SeqNum::new(1), None, 1..1)))]
1270    #[test_case(OverlapTestArgs{
1271        seg_seq: 0,
1272        control: Some(Control::SYN),
1273        data_len: 0,
1274        rcv_nxt: 1,
1275        rcv_wnd: 1,
1276    } => Some((SeqNum::new(1), None, 0..0)))]
1277    #[test_case(OverlapTestArgs{
1278        seg_seq: 2,
1279        control: None,
1280        data_len: 1,
1281        rcv_nxt: 1,
1282        rcv_wnd: 1,
1283    } => None)]
1284    #[test_case(OverlapTestArgs{
1285        seg_seq: 0,
1286        control: None,
1287        data_len: 2,
1288        rcv_nxt: 1,
1289        rcv_wnd: 1,
1290    } => Some((SeqNum::new(1), None, 1..2)))]
1291    #[test_case(OverlapTestArgs{
1292        seg_seq: 1,
1293        control: None,
1294        data_len: 2,
1295        rcv_nxt: 1,
1296        rcv_wnd: 1,
1297    } => Some((SeqNum::new(1), None, 0..1)))]
1298    #[test_case(OverlapTestArgs{
1299        seg_seq: 0,
1300        control: Some(Control::SYN),
1301        data_len: 1,
1302        rcv_nxt: 1,
1303        rcv_wnd: 1,
1304    } => Some((SeqNum::new(1), None, 0..1)))]
1305    #[test_case(OverlapTestArgs{
1306        seg_seq: 1,
1307        control: Some(Control::SYN),
1308        data_len: 1,
1309        rcv_nxt: 1,
1310        rcv_wnd: 1,
1311    } => Some((SeqNum::new(1), Some(Control::SYN), 0..0)))]
1312    #[test_case(OverlapTestArgs{
1313        seg_seq: 0,
1314        control: Some(Control::FIN),
1315        data_len: 1,
1316        rcv_nxt: 1,
1317        rcv_wnd: 1,
1318    } => Some((SeqNum::new(1), Some(Control::FIN), 1..1)))]
1319    #[test_case(OverlapTestArgs{
1320        seg_seq: 1,
1321        control: Some(Control::FIN),
1322        data_len: 1,
1323        rcv_nxt: 1,
1324        rcv_wnd: 1,
1325    } => Some((SeqNum::new(1), None, 0..1)))]
1326    #[test_case(OverlapTestArgs{
1327        seg_seq: 1,
1328        control: None,
1329        data_len: MAX_PAYLOAD_AND_CONTROL_LEN_U32,
1330        rcv_nxt: 1,
1331        rcv_wnd: 10,
1332    } => Some((SeqNum::new(1), None, 0..10)))]
1333    #[test_case(OverlapTestArgs{
1334        seg_seq: 10,
1335        control: None,
1336        data_len: MAX_PAYLOAD_AND_CONTROL_LEN_U32,
1337        rcv_nxt: 1,
1338        rcv_wnd: 10,
1339    } => Some((SeqNum::new(10), None, 0..1)))]
1340    #[test_case(OverlapTestArgs{
1341        seg_seq: 1,
1342        control: None,
1343        data_len: 10,
1344        rcv_nxt: 1,
1345        rcv_wnd: WindowSize::MAX.into(),
1346    } => Some((SeqNum::new(1), None, 0..10)))]
1347    #[test_case(OverlapTestArgs{
1348        seg_seq: 10,
1349        control: None,
1350        data_len: 10,
1351        rcv_nxt: 1,
1352        rcv_wnd: WindowSize::MAX.into(),
1353    } => Some((SeqNum::new(10), None, 0..10)))]
1354    #[test_case(OverlapTestArgs{
1355        seg_seq: 1,
1356        control: Some(Control::FIN),
1357        data_len: 1,
1358        rcv_nxt: 3,
1359        rcv_wnd: 10,
1360    } => Some((SeqNum::new(3), None, 1..1)); "regression test for https://fxbug.dev/42061750")]
1361    fn segment_overlap(
1362        OverlapTestArgs { seg_seq, control, data_len, rcv_nxt, rcv_wnd }: OverlapTestArgs,
1363    ) -> Option<(SeqNum, Option<Control>, Range<u32>)> {
1364        let (seg, discarded) = Segment::new(
1365            SegmentHeader {
1366                seq: SeqNum::new(seg_seq),
1367                ack: None,
1368                control,
1369                wnd: UnscaledWindowSize::from(0),
1370                push: false,
1371                options: Options::default(),
1372            },
1373            TestPayload(0..data_len),
1374        );
1375        assert_eq!(discarded, 0);
1376        seg.overlap(SeqNum::new(rcv_nxt), WindowSize::new(rcv_wnd).unwrap()).map(
1377            |Segment { header: SegmentHeader { seq, control, .. }, data: TestPayload(range) }| {
1378                (seq, control, range)
1379            },
1380        )
1381    }
1382
1383    pub trait TestIpExt: IpExt {
1384        const SRC_IP: Self::Addr;
1385        const DST_IP: Self::Addr;
1386    }
1387
1388    impl TestIpExt for Ipv4 {
1389        const SRC_IP: Self::Addr = net_ip_v4!("192.0.2.1");
1390        const DST_IP: Self::Addr = net_ip_v4!("192.0.2.2");
1391    }
1392
1393    impl TestIpExt for Ipv6 {
1394        const SRC_IP: Self::Addr = net_ip_v6!("2001:db8::1");
1395        const DST_IP: Self::Addr = net_ip_v6!("2001:db8::2");
1396    }
1397
1398    const SRC_PORT: NonZeroU16 = NonZeroU16::new(1234).unwrap();
1399    const DST_PORT: NonZeroU16 = NonZeroU16::new(9876).unwrap();
1400
1401    #[ip_test(I)]
1402    fn from_segment_builder<I: TestIpExt>() {
1403        let mut builder =
1404            TcpSegmentBuilder::new(I::SRC_IP, I::DST_IP, SRC_PORT, DST_PORT, 1, Some(2), 3);
1405        builder.syn(true);
1406
1407        let converted_header =
1408            SegmentHeader::try_from(&builder).expect("failed to convert serializer");
1409
1410        let expected_header = SegmentHeader {
1411            seq: SeqNum::new(1),
1412            ack: Some(SeqNum::new(2)),
1413            wnd: UnscaledWindowSize::from(3u16),
1414            control: Some(Control::SYN),
1415            options: HandshakeOptions::default().into(),
1416            push: false,
1417        };
1418
1419        assert_eq!(converted_header, expected_header);
1420    }
1421
1422    #[ip_test(I)]
1423    fn from_segment_builder_failure<I: TestIpExt>() {
1424        let mut builder =
1425            TcpSegmentBuilder::new(I::SRC_IP, I::DST_IP, SRC_PORT, DST_PORT, 1, Some(2), 3);
1426        builder.syn(true);
1427        builder.fin(true);
1428
1429        assert_matches!(
1430            SegmentHeader::try_from(&builder),
1431            Err(MalformedFlags { syn: true, fin: true, rst: false })
1432        );
1433    }
1434
1435    #[ip_test(I)]
1436    fn from_segment_builder_with_options_handshake<I: TestIpExt>() {
1437        let mut builder =
1438            TcpSegmentBuilder::new(I::SRC_IP, I::DST_IP, SRC_PORT, DST_PORT, 1, Some(2), 3);
1439        builder.syn(true);
1440
1441        let builder = TcpSegmentBuilderWithOptions::new(
1442            builder,
1443            [TcpOption::Mss(1024), TcpOption::WindowScale(10), TcpOption::SackPermitted],
1444        )
1445        .expect("failed to create tcp segment builder");
1446
1447        let converted_header =
1448            SegmentHeader::try_from(&builder).expect("failed to convert serializer");
1449
1450        let expected_header = SegmentHeader {
1451            seq: SeqNum::new(1),
1452            ack: Some(SeqNum::new(2)),
1453            wnd: UnscaledWindowSize::from(3u16),
1454            control: Some(Control::SYN),
1455            push: false,
1456            options: HandshakeOptions {
1457                mss: Some(Mss(NonZeroU16::new(1024).unwrap())),
1458                window_scale: Some(WindowScale::new(10).unwrap()),
1459                sack_permitted: true,
1460            }
1461            .into(),
1462        };
1463
1464        assert_eq!(converted_header, expected_header);
1465    }
1466
1467    #[ip_test(I)]
1468    fn from_segment_builder_with_options_segment<I: TestIpExt>() {
1469        let mut builder =
1470            TcpSegmentBuilder::new(I::SRC_IP, I::DST_IP, SRC_PORT, DST_PORT, 1, Some(2), 3);
1471        builder.psh(true);
1472
1473        let sack_blocks = [TcpSackBlock::new(1, 2), TcpSackBlock::new(4, 6)];
1474        let builder =
1475            TcpSegmentBuilderWithOptions::new(builder, [TcpOption::Sack(&sack_blocks[..])])
1476                .expect("failed to create tcp segment builder");
1477
1478        let converted_header =
1479            SegmentHeader::try_from(&builder).expect("failed to convert serializer");
1480
1481        let expected_header = SegmentHeader {
1482            seq: SeqNum::new(1),
1483            ack: Some(SeqNum::new(2)),
1484            wnd: UnscaledWindowSize::from(3u16),
1485            control: None,
1486            push: true,
1487            options: SegmentOptions {
1488                sack_blocks: SackBlocks::from_iter([
1489                    SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
1490                    SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
1491                ]),
1492            }
1493            .into(),
1494        };
1495
1496        assert_eq!(converted_header, expected_header);
1497    }
1498
1499    #[ip_test(I)]
1500    fn from_segment_builder_with_options_failure<I: TestIpExt>() {
1501        let mut builder =
1502            TcpSegmentBuilder::new(I::SRC_IP, I::DST_IP, SRC_PORT, DST_PORT, 1, Some(2), 3);
1503        builder.syn(true);
1504        builder.fin(true);
1505
1506        let builder = TcpSegmentBuilderWithOptions::new(
1507            builder,
1508            [TcpOption::Mss(1024), TcpOption::WindowScale(10)],
1509        )
1510        .expect("failed to create tcp segment builder");
1511
1512        assert_matches!(
1513            SegmentHeader::try_from(&builder),
1514            Err(MalformedFlags { syn: true, fin: true, rst: false })
1515        );
1516    }
1517
1518    #[test_case(Flags {
1519            syn: false,
1520            fin: false,
1521            rst: false,
1522        } => Ok(None))]
1523    #[test_case(Flags {
1524            syn: true,
1525            fin: false,
1526            rst: false,
1527        } => Ok(Some(Control::SYN)))]
1528    #[test_case(Flags {
1529            syn: false,
1530            fin: true,
1531            rst: false,
1532        } => Ok(Some(Control::FIN)))]
1533    #[test_case(Flags {
1534            syn: false,
1535            fin: false,
1536            rst: true,
1537        } => Ok(Some(Control::RST)))]
1538    #[test_case(Flags {
1539            syn: true,
1540            fin: true,
1541            rst: false,
1542        } => Err(MalformedFlags {
1543            syn: true,
1544            fin: true,
1545            rst: false,
1546        }))]
1547    #[test_case(Flags {
1548            syn: true,
1549            fin: false,
1550            rst: true,
1551        } => Err(MalformedFlags {
1552            syn: true,
1553            fin: false,
1554            rst: true,
1555        }))]
1556    #[test_case(Flags {
1557            syn: false,
1558            fin: true,
1559            rst: true,
1560        } => Err(MalformedFlags {
1561            syn: false,
1562            fin: true,
1563            rst: true,
1564        }))]
1565    #[test_case(Flags {
1566            syn: true,
1567            fin: true,
1568            rst: true,
1569        } => Err(MalformedFlags {
1570            syn: true,
1571            fin: true,
1572            rst: true,
1573        }))]
1574    fn flags_to_control(input: Flags) -> Result<Option<Control>, MalformedFlags> {
1575        input.control()
1576    }
1577
1578    #[test]
1579    fn sack_block_try_new() {
1580        assert_matches!(SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)), Ok(_));
1581        assert_matches!(
1582            SackBlock::try_new(SeqNum::new(0u32.wrapping_sub(1)), SeqNum::new(2)),
1583            Ok(_)
1584        );
1585        assert_eq!(
1586            SackBlock::try_new(SeqNum::new(1), SeqNum::new(1)),
1587            Err(InvalidSackBlockError(SeqNum::new(1), SeqNum::new(1)))
1588        );
1589        assert_eq!(
1590            SackBlock::try_new(SeqNum::new(2), SeqNum::new(1)),
1591            Err(InvalidSackBlockError(SeqNum::new(2), SeqNum::new(1)))
1592        );
1593        assert_eq!(
1594            SackBlock::try_new(SeqNum::new(0), SeqNum::new(0u32.wrapping_sub(1))),
1595            Err(InvalidSackBlockError(SeqNum::new(0), SeqNum::new(0u32.wrapping_sub(1))))
1596        );
1597    }
1598
1599    #[test]
1600    fn psh_bit_cleared_if_no_data() {
1601        let seg =
1602            Segment::new_assert_no_discard(SegmentHeader { push: true, ..Default::default() }, ());
1603        assert_eq!(seg.header().push, false);
1604        let seg = Segment::new_assert_no_discard(
1605            SegmentHeader { push: true, ..Default::default() },
1606            &[1u8, 2, 3, 4][..],
1607        );
1608        assert_eq!(seg.header().push, true);
1609    }
1610}