1use core::iter::FromIterator;
8use core::ops::Range;
9
10use alloc::vec::Vec;
11use core::mem::MaybeUninit;
12use core::num::NonZeroU16;
13use net_types::ip::{Ip, IpVersion};
14use packet::InnerPacketBuilder;
15use static_assertions::const_assert;
16
17use crate::ip::Mms;
18use crate::tcp::segment::{Payload, PayloadLen, SegmentOptions};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum Control {
23 SYN,
25 FIN,
27 RST,
29}
30
31impl Control {
32 pub fn has_sequence_no(self) -> bool {
35 match self {
36 Control::SYN | Control::FIN => true,
37 Control::RST => false,
38 }
39 }
40}
41
42const TCP_HEADER_LEN: u32 = packet_formats::tcp::HDR_PREFIX_LEN as u32;
43
44#[derive(Clone, Copy, PartialEq, Eq, Debug, PartialOrd, Ord)]
48pub struct Mss(u16);
49
50const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV4.get());
51const_assert!(Mss::MIN.get() <= Mss::DEFAULT_IPV6.get());
52const_assert!(Mss::MIN.get() as usize >= packet_formats::tcp::MAX_OPTIONS_LEN);
53
54impl Mss {
55 pub const MIN: Mss = Mss(216);
76
77 pub const DEFAULT_IPV4: Mss = Mss(536);
82
83 pub const DEFAULT_IPV6: Mss = Mss(1220);
88
89 pub const fn new(mss: u16) -> Option<Self> {
91 if mss < Self::MIN.get() { None } else { Some(Mss(mss)) }
92 }
93
94 pub fn from_mms(mms: Mms) -> Option<Self> {
96 let mss = u16::try_from(mms.get().get().saturating_sub(TCP_HEADER_LEN)).unwrap_or(u16::MAX);
97 Self::new(mss)
98 }
99
100 pub const fn default<I: Ip>() -> Self {
102 match I::VERSION {
103 IpVersion::V4 => Self::DEFAULT_IPV4,
104 IpVersion::V6 => Self::DEFAULT_IPV6,
105 }
106 }
107
108 pub const fn get(&self) -> u16 {
110 let Self(mss) = *self;
111 mss
112 }
113}
114
115#[derive(Clone, Copy, PartialEq, Eq, Debug)]
141pub struct EffectiveMss {
142 mss: Mss,
143 fixed_tcp_options_size: u16,
144}
145
146impl EffectiveMss {
147 const ALIGNED_TIMESTAMP_OPTION_LENGTH: u16 = 12;
156
157 pub const fn from_mss(mss: Mss, size_limits: MssSizeLimiters) -> Self {
159 let MssSizeLimiters { timestamp_enabled } = size_limits;
160 let fixed_tcp_options_size =
163 if timestamp_enabled { Self::ALIGNED_TIMESTAMP_OPTION_LENGTH } else { 0 };
164 EffectiveMss { mss, fixed_tcp_options_size }
165 }
166
167 pub fn payload_size(&self, options: &SegmentOptions) -> NonZeroU16 {
172 let Self { mss, fixed_tcp_options_size: _ } = self;
175 let tcp_options_len =
178 u16::try_from(packet_formats::tcp::aligned_options_length(options.iter())).unwrap();
179 NonZeroU16::new(mss.get() - tcp_options_len).unwrap()
182 }
183
184 pub fn mss(&self) -> &Mss {
186 &self.mss
187 }
188
189 pub fn update_mss(&mut self, new: Mss) {
191 self.mss = new
192 }
193
194 pub const fn get(&self) -> u16 {
196 let Self { mss, fixed_tcp_options_size } = *self;
197 mss.get() - fixed_tcp_options_size
198 }
199}
200
201pub struct MssSizeLimiters {
203 pub timestamp_enabled: bool,
205}
206
207impl From<EffectiveMss> for u32 {
208 fn from(mss: EffectiveMss) -> Self {
209 u32::from(mss.get())
210 }
211}
212
213impl From<EffectiveMss> for usize {
214 fn from(mss: EffectiveMss) -> Self {
215 usize::from(mss.get())
216 }
217}
218
219#[derive(Copy, Clone, Debug, PartialEq)]
221pub struct FragmentedPayload<'a, const N: usize> {
222 storage: [&'a [u8]; N],
223 start: usize,
228 end: usize,
229}
230
231impl<'a, const N: usize> FromIterator<&'a [u8]> for FragmentedPayload<'a, N> {
238 fn from_iter<T>(iter: T) -> Self
239 where
240 T: IntoIterator<Item = &'a [u8]>,
241 {
242 let Self { storage, start, end } = Self::new_empty();
243 let (storage, end) = iter.into_iter().fold((storage, end), |(mut storage, end), sl| {
244 storage[end] = sl;
245 (storage, end + 1)
246 });
247 Self { storage, start, end }
248 }
249}
250
251impl<'a, const N: usize> FragmentedPayload<'a, N> {
252 pub fn new(values: [&'a [u8]; N]) -> Self {
254 Self { storage: values, start: 0, end: N }
255 }
256
257 pub fn new_contiguous(value: &'a [u8]) -> Self {
259 core::iter::once(value).collect()
260 }
261
262 pub fn to_vec(self) -> Vec<u8> {
264 self.slices().concat()
265 }
266
267 fn slices(&self) -> &[&'a [u8]] {
268 let Self { storage, start, end } = self;
269 &storage[*start..*end]
270 }
271
272 fn apply_copy<T, F: Fn(&[u8], &mut [T])>(
275 &self,
276 mut offset: usize,
277 mut dst: &mut [T],
278 apply: F,
279 ) {
280 let mut slices = self.slices().into_iter();
281 while let Some(sl) = slices.next() {
282 let l = sl.len();
283 if offset >= l {
284 offset -= l;
285 continue;
286 }
287 let sl = &sl[offset..];
288 let cp = sl.len().min(dst.len());
289 let (target, new_dst) = dst.split_at_mut(cp);
290 apply(&sl[..cp], target);
291
292 if new_dst.len() == 0 {
294 return;
295 }
296
297 dst = new_dst;
298 offset = 0;
299 }
300 assert_eq!(dst.len(), 0, "failed to fill dst");
301 }
302}
303
304impl<'a, const N: usize> PayloadLen for FragmentedPayload<'a, N> {
305 fn len(&self) -> usize {
306 self.slices().into_iter().map(|s| s.len()).sum()
307 }
308}
309
310impl<'a, const N: usize> Payload for FragmentedPayload<'a, N> {
311 fn slice(self, byte_range: Range<u32>) -> Self {
312 let Self { mut storage, start: mut self_start, end: mut self_end } = self;
313 let Range { start: byte_start, end: byte_end } = byte_range;
314 let byte_start =
315 usize::try_from(byte_start).expect("range start index out of range for usize");
316 let byte_end = usize::try_from(byte_end).expect("range end index out of range for usize");
317 assert!(byte_end >= byte_start);
318 let mut storage_iter =
319 (&mut storage[self_start..self_end]).into_iter().scan(0, |total_len, slice| {
320 let slice_len = slice.len();
321 let item = Some((*total_len, slice));
322 *total_len += slice_len;
323 item
324 });
325
326 let mut start_offset = None;
329 let mut final_len = 0;
330 while let Some((sl_offset, sl)) = storage_iter.next() {
331 let orig_len = sl.len();
332
333 if sl_offset + orig_len < byte_start {
336 *sl = &[];
337 self_start += 1;
338 continue;
339 }
340 if sl_offset >= byte_end {
342 *sl = &[];
343 self_end -= 1;
344 continue;
345 }
346
347 let sl_start = byte_start.saturating_sub(sl_offset);
348 let sl_end = sl.len().min(byte_end - sl_offset);
349 *sl = &sl[sl_start..sl_end];
350
351 match start_offset {
352 Some(_) => (),
353 None => {
354 start_offset = Some(sl_offset + sl_start);
356 if sl.len() == 0 {
359 self_start += 1;
360 }
361 }
362 }
363 final_len += sl.len();
364 }
365 assert_eq!(
367 start_offset.unwrap_or(0),
370 byte_start,
371 "range start index out of range {byte_range:?}"
372 );
373 assert_eq!(byte_start + final_len, byte_end, "range end index out of range {byte_range:?}");
374
375 if self_start == self_end {
377 self_start = 0;
378 self_end = 0;
379 }
380 Self { storage, start: self_start, end: self_end }
381 }
382
383 fn new_empty() -> Self {
384 Self { storage: [&[]; N], start: 0, end: 0 }
385 }
386
387 fn partial_copy(&self, offset: usize, dst: &mut [u8]) {
388 self.apply_copy(offset, dst, |src, dst| {
389 dst.copy_from_slice(src);
390 });
391 }
392
393 fn partial_copy_uninit(&self, offset: usize, dst: &mut [MaybeUninit<u8>]) {
394 self.apply_copy(offset, dst, |src, dst| {
395 let uninit_src: &[MaybeUninit<u8>] = unsafe { core::mem::transmute(src) };
399 dst.copy_from_slice(&uninit_src);
400 });
401 }
402}
403
404impl<'a, const N: usize> InnerPacketBuilder for FragmentedPayload<'a, N> {
405 fn bytes_len(&self) -> usize {
406 self.len()
407 }
408
409 fn serialize(&self, buffer: &mut [u8]) {
410 self.partial_copy(0, buffer);
411 }
412}
413
414#[cfg(any(test, feature = "testutils"))]
415mod testutil {
416 use super::*;
417
418 impl From<Mss> for u32 {
419 fn from(Mss(mss): Mss) -> Self {
420 u32::from(mss)
421 }
422 }
423
424 impl From<Mss> for usize {
425 fn from(Mss(mss): Mss) -> Self {
426 usize::from(mss)
427 }
428 }
429}
430
431#[cfg(test)]
432mod test {
433 use super::*;
434 use alloc::format;
435
436 use packet::Serializer as _;
437 use proptest::test_runner::Config;
438 use proptest::{prop_assert_eq, proptest};
439 use proptest_support::failed_seeds_no_std;
440 use test_case::test_case;
441
442 use crate::{SackBlock, SackBlocks, SeqNum, Timestamp, TimestampOption};
443
444 const EXAMPLE_DATA: [u8; 10] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
445 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[..]]); "contiguous")]
446 #[test_case(FragmentedPayload::new([&EXAMPLE_DATA[0..2], &EXAMPLE_DATA[2..]]); "split once")]
447 #[test_case(FragmentedPayload::new([
448 &EXAMPLE_DATA[0..2],
449 &EXAMPLE_DATA[2..5],
450 &EXAMPLE_DATA[5..],
451 ]); "split twice")]
452 #[test_case(FragmentedPayload::<4>::from_iter([
453 &EXAMPLE_DATA[0..2],
454 &EXAMPLE_DATA[2..5],
455 &EXAMPLE_DATA[5..],
456 ]); "partial twice")]
457 fn fragmented_payload_serializer_data<const N: usize>(payload: FragmentedPayload<'_, N>) {
458 let serialized = payload
459 .into_serializer()
460 .serialize_vec_outer()
461 .expect("should serialize")
462 .unwrap_b()
463 .into_inner();
464 assert_eq!(&serialized[..], EXAMPLE_DATA);
465 }
466
467 #[test]
468 #[should_panic(expected = "range start index out of range")]
469 fn slice_start_out_of_bounds() {
470 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
471 let bad_len = len + 1;
472 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(bad_len..bad_len);
475 }
476
477 #[test]
478 #[should_panic(expected = "range end index out of range")]
479 fn slice_end_out_of_bounds() {
480 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
481 let bad_len = len + 1;
482 let _ = FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(0..bad_len);
483 }
484
485 #[test]
486 fn canon_empty_payload() {
487 let len = u32::try_from(EXAMPLE_DATA.len()).unwrap();
488 assert_eq!(
489 FragmentedPayload::<1>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
490 FragmentedPayload::new_empty()
491 );
492 assert_eq!(
493 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(len..len),
494 FragmentedPayload::new_empty()
495 );
496 assert_eq!(
497 FragmentedPayload::<2>::new_contiguous(&EXAMPLE_DATA).slice(2..2),
498 FragmentedPayload::new_empty()
499 );
500 }
501
502 const TEST_BYTES: &'static [u8] = b"Hello World!";
503 proptest! {
504 #![proptest_config(Config {
505 failure_persistence: failed_seeds_no_std!(),
507 ..Config::default()
508 })]
509
510 #[test]
511 fn fragmented_payload_to_vec(payload in fragmented_payload::with_payload()) {
512 prop_assert_eq!(payload.to_vec(), &TEST_BYTES[..]);
513 }
514
515 #[test]
516 fn fragmented_payload_len(payload in fragmented_payload::with_payload()) {
517 prop_assert_eq!(payload.len(), TEST_BYTES.len())
518 }
519
520 #[test]
521 fn fragmented_payload_slice((payload, (start, end)) in fragmented_payload::with_range()) {
522 let want = &TEST_BYTES[start..end];
523 let start = u32::try_from(start).unwrap();
524 let end = u32::try_from(end).unwrap();
525 prop_assert_eq!(payload.clone().slice(start..end).to_vec(), want);
526 }
527
528 #[test]
529 fn fragmented_payload_partial_copy((payload, (start, end)) in fragmented_payload::with_range()) {
530 let mut buffer = [0; TEST_BYTES.len()];
531 let buffer = &mut buffer[0..(end-start)];
532 payload.partial_copy(start, buffer);
533 prop_assert_eq!(buffer, &TEST_BYTES[start..end]);
534 }
535 }
536
537 mod fragmented_payload {
538 use super::*;
539
540 use proptest::strategy::{Just, Strategy};
541 use rand::Rng as _;
542
543 const TEST_STORAGE: usize = 5;
544 type TestFragmentedPayload = FragmentedPayload<'static, TEST_STORAGE>;
545 pub(super) fn with_payload() -> impl Strategy<Value = TestFragmentedPayload> {
546 (1..=TEST_STORAGE).prop_perturb(|slices, mut rng| {
547 (0..slices)
548 .scan(0, |st, slice| {
549 let len = if slice == slices - 1 {
550 TEST_BYTES.len() - *st
551 } else {
552 rng.random_range(0..=(TEST_BYTES.len() - *st))
553 };
554 let start = *st;
555 *st += len;
556 Some(&TEST_BYTES[start..*st])
557 })
558 .collect()
559 })
560 }
561
562 pub(super) fn with_range() -> impl Strategy<Value = (TestFragmentedPayload, (usize, usize))>
563 {
564 (
565 with_payload(),
566 (0..TEST_BYTES.len()).prop_flat_map(|start| (Just(start), start..TEST_BYTES.len())),
567 )
568 }
569 }
570
571 #[test_case(true; "timestamp_enabled")]
572 #[test_case(false; "timestamp_disabled")]
573 fn effective_mss_accounts_for_fixed_size_tcp_options(timestamp_enabled: bool) {
574 const SIZE: u16 = 1000;
575 let mss =
576 EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
577 if timestamp_enabled {
578 assert_eq!(mss.get(), SIZE - EffectiveMss::ALIGNED_TIMESTAMP_OPTION_LENGTH)
579 } else {
580 assert_eq!(mss.get(), SIZE);
581 }
582 }
583
584 #[test_case(SegmentOptions {sack_blocks: SackBlocks::EMPTY, timestamp: None}; "empty")]
585 #[test_case(SegmentOptions {
586 sack_blocks: SackBlocks::from_iter([
587 SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
588 SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
589 ]),
590 timestamp: None
591 }; "sack_blocks")]
592 #[test_case(SegmentOptions {
593 sack_blocks: SackBlocks::EMPTY,
594 timestamp: Some(TimestampOption {
595 ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
596 }),
597 }; "timestamp")]
598 #[test_case(SegmentOptions {
599 sack_blocks: SackBlocks::from_iter([
600 SackBlock::try_new(SeqNum::new(1), SeqNum::new(2)).unwrap(),
601 SackBlock::try_new(SeqNum::new(4), SeqNum::new(6)).unwrap(),
602 ]),
603 timestamp: Some(TimestampOption {
604 ts_val: Timestamp::new(12345), ts_echo_reply: Timestamp::new(54321)
605 }),
606 }; "sack_blocks_and_timestamp")]
607
608 fn effective_mss_accounts_for_variable_size_tcp_options(options: SegmentOptions) {
609 const SIZE: u16 = 1000;
610 let timestamp_enabled = options.timestamp.is_some();
611 let mss =
612 EffectiveMss::from_mss(Mss::new(SIZE).unwrap(), MssSizeLimiters { timestamp_enabled });
613 let options_len =
614 u16::try_from(packet_formats::tcp::aligned_options_length(options.iter())).unwrap();
615 assert_eq!(mss.payload_size(&options).get(), SIZE - options_len);
616 }
617}