netstack3_base/tcp/
base.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! The Transmission Control Protocol (TCP).
6
7use core::iter::FromIterator;
8use core::ops::Range;
9
10use alloc::vec::Vec;
11use core::mem::MaybeUninit;
12use core::num::NonZeroU16;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15use static_assertions::const_assert;
16
17use crate::ip::Mms;
18use crate::tcp::segment::{Payload, PayloadLen, SegmentOptions};
19
20/// Control flags that can alter the state of a TCP control block.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Control {
23    /// Corresponds to the SYN bit in a TCP segment.
24    SYN,
25    /// Corresponds to the FIN bit in a TCP segment.
26    FIN,
27    /// Corresponds to the RST bit in a TCP segment.
28    RST,
29}
30
31impl Control {
32    /// Returns whether the control flag consumes one byte from the sequence
33    /// number space.
34    pub fn has_sequence_no(self) -> bool {
35        match self {
36            Control::SYN | Control::FIN => true,
37            Control::RST => false,
38        }
39    }
40}
41
42const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
43
44/// Maximum segment size, that is the maximum TCP payload one segment can carry.
45///
46/// `Mss` also acts as a witness that the contained value is >= `Mss::MIN`.
47#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
48pub struct Mss(u16);
49
50const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV4.get());
51const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV6.get());
52const_assert!(Mss::MIN.get() as usize >= packet_formats::tcp::MAX_OPTIONS_LEN);
53
54impl Mss {
55    /// The minimum MSS allowed by TCP.
56    ///
57    /// Although enforcing a minimum MSS is outside the recommendations of any
58    /// RFC, it is a common practice on other platforms and has multiple
59    /// benefits:
60    ///   1) Ensures there is enough space to transmit TCP Options & IP Options.
61    ///      See RFC 6691 section 2, which clarifies that
62    ///          The TCP MSS OPTION [...] SHOULD NOT be decreased to account for
63    ///          any possible IP or TCP options; conversely, the sender MUST
64    ///          reduce the TCP data length to account for any IP or TCP options
65    ///          that it is including in the packets that it sends.
66    ///   2) Protects against DOS attacks in which the attacker initiates TCP
67    ///      connections with an intentionally small MSS to incur additional
68    ///      packet processing overhead on the victim. See
69    ///      * FreeBSD: https://www.cve.org/CVERecord?id=CVE-2004-0002
70    ///      * Linux: https://www.cve.org/CVERecord?id=CVE-2019-11479
71    ///
72    /// Here, the value 216 is inspired by FreeBSD. It's large enough to satisfy
73    /// points 1 & 2 from above, while remaining small enough to support all
74    /// link-layer technologies on the open Internet.
75    pub const MIN: Mss = Mss(216);
76
77    /// Per RFC 9293 Section 3.7.1:
78    ///  If an MSS Option is not received at connection setup, TCP
79    ///  implementations MUST assume a default send MSS of 536 (576 - 40) for
80    ///  IPv4.
81    pub const DEFAULT_IPV4: Mss = Mss(536);
82
83    /// Per RFC 9293 Section 3.7.1:
84    ///  If an MSS Option is not received at connection setup, TCP
85    ///  implementations MUST assume a default send MSS of [...] 1220
86    /// (1280 - 60) for IPv6 (MUST-15).
87    pub const DEFAULT_IPV6: Mss = Mss(1220);
88
89    /// Creates `Mss`, provided the given value satisfies the requirements.
90    pub const fn new(mss: u16) -> Option<Self> {
91        if mss < Self::MIN.get() { None } else { Some(Mss(mss)) }
92    }
93
94    /// Creates MSS from the maximum message size of the IP layer.
95    pub fn from_mms(mms: Mms) -> Option<Self> {
96        let mss = u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX);
97        Self::new(mss)
98    }
99
100    /// Create a new [`Mss`] with the IP-version default value, as defined by RFC 9293.
101    pub const fn default<I: Ip>() -> Self {
102        match I::VERSION {
103            IpVersion::V4 => Self::DEFAULT_IPV4,
104            IpVersion::V6 => Self::DEFAULT_IPV6,
105        }
106    }
107
108    /// Gets the numeric value of the MSS.
109    pub const fn get(&self) -> u16 {
110        let Self(mss) = *self;
111        mss
112    }
113}
114
115/// Like [`Mss`], but smaller to account for fixed-size TCP Options.
116///
117/// This corresponds to the "effective send MSS" as defined in RFC 9293 section
118/// 3.7.1:
119///   Eff.snd.MSS = min(SendMSS+20, MMS_S) - TCPhdrsize - IPoptionsize
120///   where:
121///     [...]
122///     * TCPhdrsize is the size of the fixed TCP header and any options.
123///
124/// Both [`Mss`] and [`EffectiveMss`] have their place in TCP. For example,
125/// the TCP MSS option has [`Mss`] semantics, while the MSS used to calculate
126/// receive windows & congestion windows has [`EffectiveMss`] semantics. When
127/// implementing a TCP feature, refer to the feature's RFC to determine which
128/// MSS semantics are appropriate to use.
129///
130/// Note: this implementation accounts for all fixed-sized TCP Options that are
131/// part of [`SegmentOptions`]. SACK blocks are ignored, because they are
132/// variable sized. Variable sized options pose a problem when calculating the
133/// [`EffectiveMss`] because they vary from segment to segment, whereas the
134/// [`EffectiveMss`] should be stable throughout the lifetime of the connection.
135/// While, no RFC explicitly states how to account for variable sized options,
136/// we take inspiration from Linux's TCP implementation and choose to ignore
137/// them until it comes time to actually calculate payload sizes for a given
138/// segment.
139// TODO(https://fxbug.dev/441271979): Account for fixed-size IP Options.
140#[derive(Clone, Copy, PartialEq, Eq, Debug)]
141pub struct EffectiveMss {
142    mss: Mss,
143    fixed_tcp_options_size: u16,
144}
145
146impl EffectiveMss {
147    /// Per RFC 7323 Section 3.2, the TCP Timestamp option has a length of
148    /// 10 bytes:
149    ///   +-------+-------+---------------------+---------------------+
150    ///   |Kind=8 |  10   |   TS Value (TSval)  |TS Echo Reply (TSecr)|
151    ///   +-------+-------+---------------------+---------------------+
152    ///      1       1              4                     4
153    ///
154    /// However, once aligned, it will occupy 12 bytes.
155    const ALIGNED_TIMESTAMP_OPTION_LENGTH: u16 = 12;
156
157    /// Constructs an [`EffectiveMss`] from an [`Mss`]
158    pub const fn from_mss(mss: Mss, size_limits: MssSizeLimiters) -> Self {
159        let MssSizeLimiters { timestamp_enabled } = size_limits;
160        // NB: When adding additional fixed size options in the future, authors
161        // should take care to account for the alignment only once.
162        let fixed_tcp_options_size =
163            if timestamp_enabled { Self::ALIGNED_TIMESTAMP_OPTION_LENGTH } else { 0 };
164        EffectiveMss { mss, fixed_tcp_options_size }
165    }
166
167    /// Computes the amount of payload data to include in a segment.
168    ///
169    /// Accounts for the size of any variable-sized options present in the
170    /// segment.
171    pub fn payload_size(&self, options: &SegmentOptions) -> NonZeroU16 {
172        // NB: Ignore the fixed TCP options size, it will be accounted for by
173        // `options`.
174        let Self { mss, fixed_tcp_options_size: _ } = self;
175        // NB: Safe to unwrap here because TCP options have a fixed maximum
176        // size < u16::MAX.
177        let tcp_options_len =
178            u16::try_from(packet_formats::tcp::aligned_options_length(options.iter())).unwrap();
179        // NB: Safe to unwrap here because MSS has a minimum value large enough
180        // to fit all TCP options.
181        NonZeroU16::new(mss.get() - tcp_options_len).unwrap()
182    }
183
184    /// Returns the original [`Mss`] used to compute this [`EffectiveMss`].
185    pub fn mss(&self) -> &Mss {
186        &self.mss
187    }
188
189    /// Replaces the held [`Mss`] with a new value.
190    pub fn update_mss(&mut self, new: Mss) {
191        self.mss = new
192    }
193
194    /// Gets the numeric value of the MSS.
195    pub const fn get(&self) -> u16 {
196        let Self { mss, fixed_tcp_options_size } = *self;
197        mss.get() - fixed_tcp_options_size
198    }
199}
200
201/// Factors that may limit the space available from the MSS.
202pub struct MssSizeLimiters {
203    /// True if the TCP Timestamp Option is enabled.
204    pub timestamp_enabled: bool,
205}
206
207impl From<EffectiveMss> for u32 {
208    fn from(mss: EffectiveMss) -> Self {
209        u32::from(mss.get())
210    }
211}
212
213impl From<EffectiveMss> for usize {
214    fn from(mss: EffectiveMss) -> Self {
215        usize::from(mss.get())
216    }
217}
218
219/// An implementation of [`Payload`] backed by up to `N` byte slices.
220#[derive(Copy, Clone, Debug, PartialEq)]
221pub struct FragmentedPayload<'a, const N: usize> {
222    storage: [&'a [u8]; N],
223    // NB: Not using `Range` because it is not `Copy`.
224    //
225    // Start is inclusive, end is exclusive; so this is equivalent to
226    // `start..end` ranges.
227    start: usize,
228    end: usize,
229}
230
231/// Creates a new `FragmentedPayload` possibly without using the entire
232/// storage capacity `N`.
233///
234/// # Panics
235///
236/// Panics if the iterator contains more than `N` items.
237impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
238    fn from_iter<T>(iter: T) -> Self
239    where
240        T: IntoIterator<Item = &'a [u8]>,
241    {
242        let Self { storage, start, end } = Self::new_empty();
243        let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
244            storage[end] = sl;
245            (storage, end + 1)
246        });
247        Self { storage, start, end }
248    }
249}
250
251impl<'a, const N: usize> FragmentedPayload<'a, N> {
252    /// Creates a new `FragmentedPayload` with the slices in `values`.
253    pub fn new(values: [&'a [u8]; N]) -> Self {
254        Self { storage: values, start: 0, end: N }
255    }
256
257    /// Creates a new `FragmentedPayload` with a single contiguous slice.
258    pub fn new_contiguous(value: &'a [u8]) -> Self {
259        core::iter::once(value).collect()
260    }
261
262    /// Converts this [`FragmentedPayload`] into an owned `Vec`.
263    pub fn to_vec(self) -> Vec<u8> {
264        self.slices().concat()
265    }
266
267    fn slices(&self) -> &[&'a [u8]] {
268        let Self { storage, start, end } = self;
269        &storage[*start..*end]
270    }
271
272    /// Extracted function to implement [`Payload::partial_copy`] and
273    /// [`Payload::partial_copy_uninit`].
274    fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
275        &self,
276        mut offset: usize,
277        mut dst: &mut [T],
278        apply: F,
279    ) {
280        let mut slices = self.slices().into_iter();
281        while let Some(sl) = slices.next() {
282            let l = sl.len();
283            if offset >= l {
284                offset -= l;
285                continue;
286            }
287            let sl = &sl[offset..];
288            let cp = sl.len().min(dst.len());
289            let (target, new_dst) = dst.split_at_mut(cp);
290            apply(&sl[..cp], target);
291
292            // We're done.
293            if new_dst.len() == 0 {
294                return;
295            }
296
297            dst = new_dst;
298            offset = 0;
299        }
300        assert_eq!(dst.len(), 0, "failed to fill dst");
301    }
302}
303
304impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
305    fn len(&self) -> usize {
306        self.slices().into_iter().map(|s| s.len()).sum()
307    }
308}
309
310impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
311    fn slice(self, byte_range: Range<u32>) -> Self {
312        let Self { mut storage, start: mut self_start, end: mut self_end } = self;
313        let Range { start: byte_start, end: byte_end } = byte_range;
314        let byte_start =
315            usize::try_from(byte_start).expect("range start index out of range for usize");
316        let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
317        assert!(byte_end >= byte_start);
318        let mut storage_iter =
319            (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
320                let slice_len = slice.len();
321                let item = Some((*total_len, slice));
322                *total_len += slice_len;
323                item
324            });
325
326        // Keep track of whether the start was inside the range, we should panic
327        // even on an empty range out of start bounds.
328        let mut start_offset = None;
329        let mut final_len = 0;
330        while let Some((sl_offset, sl)) = storage_iter.next() {
331            let orig_len = sl.len();
332
333            // Advance until the start of the specified range, discarding unused
334            // slices.
335            if sl_offset + orig_len < byte_start {
336                *sl = &[];
337                self_start += 1;
338                continue;
339            }
340            // Discard any empty slices at the end.
341            if sl_offset >= byte_end {
342                *sl = &[];
343                self_end -= 1;
344                continue;
345            }
346
347            let sl_start = byte_start.saturating_sub(sl_offset);
348            let sl_end = sl.len().min(byte_end - sl_offset);
349            *sl = &sl[sl_start..sl_end];
350
351            match start_offset {
352                Some(_) => (),
353                None => {
354                    // Keep track of the start offset of the first slice.
355                    start_offset = Some(sl_offset + sl_start);
356                    // Avoid producing an empty slice if we haven't added
357                    // anything yet.
358                    if sl.len() == 0 {
359                        self_start += 1;
360                    }
361                }
362            }
363            final_len += sl.len();
364        }
365        // Verify that the entire range was consumed.
366        assert_eq!(
367            // If we didn't use start_offset the only valid value for
368            // `byte_start` is zero.
369            start_offset.unwrap_or(0),
370            byte_start,
371            "range start index out of range {byte_range:?}"
372        );
373        assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
374
375        // Canonicalize an empty payload.
376        if self_start == self_end {
377            self_start = 0;
378            self_end = 0;
379        }
380        Self { storage, start: self_start, end: self_end }
381    }
382
383    fn new_empty() -> Self {
384        Self { storage: [&[]; N], start: 0, end: 0 }
385    }
386
387    fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
388        self.apply_copy(offset, dst, |src, dst| {
389            dst.copy_from_slice(src);
390        });
391    }
392
393    fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
394        self.apply_copy(offset, dst, |src, dst| {
395            // TODO(https://github.com/rust-lang/rust/issues/79995): Replace unsafe
396            // with copy_from_slice when stabiliized.
397            // SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
398            let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
399            dst.copy_from_slice(&uninit_src);
400        });
401    }
402}
403
404impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
405    fn bytes_len(&self) -> usize {
406        self.len()
407    }
408
409    fn serialize(&self, buffer: &mut [u8]) {
410        self.partial_copy(0, buffer);
411    }
412}
413
414#[cfg(any(test, feature = "testutils"))]
415mod testutil {
416    use super::*;
417
418    impl From<Mss> for u32 {
419        fn from(Mss(mss): Mss) -> Self {
420            u32::from(mss)
421        }
422    }
423
424    impl From<Mss> for usize {
425        fn from(Mss(mss): Mss) -> Self {
426            usize::from(mss)
427        }
428    }
429}
430
431#[cfg(test)]
432mod test {
433    use super::*;
434    use alloc::format;
435
436    use packet::Serializer as _;
437    use proptest::test_runner::Config;
438    use proptest::{prop_assert_eq, proptest};
439    use proptest_support::failed_seeds_no_std;
440    use test_case::test_case;
441
442    use crate::{SackBlock, SackBlocks, SeqNum, Timestamp, TimestampOption};
443
444    const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
445    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
446    #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
447    #[test_case(FragmentedPayload::new([
448        &EXAMPLE_DATA[0..2],
449        &EXAMPLE_DATA[2..5],
450        &EXAMPLE_DATA[5..],
451    ]); "split twice")]
452    #[test_case(FragmentedPayload::<4>::from_iter([
453        &EXAMPLE_DATA[0..2],
454        &EXAMPLE_DATA[2..5],
455        &EXAMPLE_DATA[5..],
456    ]); "partial twice")]
457    fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
458        let serialized = payload
459            .into_serializer()
460            .serialize_vec_outer()
461            .expect("should serialize")
462            .unwrap_b()
463            .into_inner();
464        assert_eq!(&serialized[..], EXAMPLE_DATA);
465    }
466
467    #[test]
468    #[should_panic(expected = "range start index out of range")]
469    fn slice_start_out_of_bounds() {
470        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
471        let bad_len = len + 1;
472        // Like for standard slices, this shouldn't succeed if the start length
473        // is out of bounds, even if the total range is empty.
474        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
475    }
476
477    #[test]
478    #[should_panic(expected = "range end index out of range")]
479    fn slice_end_out_of_bounds() {
480        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
481        let bad_len = len + 1;
482        let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
483    }
484
485    #[test]
486    fn canon_empty_payload() {
487        let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
488        assert_eq!(
489            FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
490            FragmentedPayload::new_empty()
491        );
492        assert_eq!(
493            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
494            FragmentedPayload::new_empty()
495        );
496        assert_eq!(
497            FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
498            FragmentedPayload::new_empty()
499        );
500    }
501
502    const TEST_BYTES: &'static [u8] = b"Hello World!";
503    proptest! {
504        #![proptest_config(Config {
505            // Add all failed seeds here.
506            failure_persistence: failed_seeds_no_std!(),
507            ..Config::default()
508        })]
509
510        #[test]
511        fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
512            prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
513        }
514
515        #[test]
516        fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
517            prop_assert_eq!(payload.len(), TEST_BYTES.len())
518        }
519
520        #[test]
521        fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
522            let want = &TEST_BYTES[start..end];
523            let start = u32::try_from(start).unwrap();
524            let end = u32::try_from(end).unwrap();
525            prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
526        }
527
528        #[test]
529        fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
530            let mut buffer = [0; TEST_BYTES.len()];
531            let buffer = &mut buffer[0..(end-start)];
532            payload.partial_copy(start, buffer);
533            prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
534        }
535    }
536
537    mod fragmented_payload {
538        use super::*;
539
540        use proptest::strategy::{Just, Strategy};
541        use rand::Rng as _;
542
543        const TEST_STORAGE: usize = 5;
544        type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
545        pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
546            (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
547                (0..slices)
548                    .scan(0, |st, slice| {
549                        let len = if slice == slices - 1 {
550                            TEST_BYTES.len() - *st
551                        } else {
552                            rng.random_range(0..=(TEST_BYTES.len() - *st))
553                        };
554                        let start = *st;
555                        *st += len;
556                        Some(&TEST_BYTES[start..*st])
557                    })
558                    .collect()
559            })
560        }
561
562        pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
563        {
564            (
565                with_payload(),
566                (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
567            )
568        }
569    }
570
571    #[test_case(true; "timestamp_enabled")]
572    #[test_case(false; "timestamp_disabled")]
573    fn effective_mss_accounts_for_fixed_size_tcp_options(timestamp_enabled: bool) {
574        const SIZE: u16 = 1000;
575        let mss =
576            EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
577        if timestamp_enabled {
578            assert_eq!(mss.get(), SIZE - EffectiveMss::ALIGNED_TIMESTAMP_OPTION_LENGTH)
579        } else {
580            assert_eq!(mss.get(), SIZE);
581        }
582    }
583
584    #[test_case(SegmentOptions {sack_blocks: SackBlocks::EMPTY, timestamp: None}; "empty")]
585    #[test_case(SegmentOptions {
586        sack_blocks: SackBlocks::from_iter([
587            SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
588            SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
589        ]),
590        timestamp: None
591    }; "sack_blocks")]
592    #[test_case(SegmentOptions {
593        sack_blocks: SackBlocks::EMPTY,
594        timestamp: Some(TimestampOption {
595            ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
596        }),
597    }; "timestamp")]
598    #[test_case(SegmentOptions {
599        sack_blocks: SackBlocks::from_iter([
600            SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
601            SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
602        ]),
603        timestamp: Some(TimestampOption {
604            ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
605        }),
606    }; "sack_blocks_and_timestamp")]
607
608    fn effective_mss_accounts_for_variable_size_tcp_options(options: SegmentOptions) {
609        const SIZE: u16 = 1000;
610        let timestamp_enabled = options.timestamp.is_some();
611        let mss =
612            EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
613        let options_len =
614            u16::try_from(packet_formats::tcp::aligned_options_length(options.iter())).unwrap();
615        assert_eq!(mss.payload_size(&options).get(), SIZE - options_len);
616    }
617}