netstack3_tcp/
congestion.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU32, NonZeroU8};
17use core::time::Duration;
18
19use netstack3_base::{Instant, Mss, SackBlocks, SeqNum, WindowSize};
20
21use crate::internal::sack_scoreboard::SackScoreboard;
22
23// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
24///   The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
25///   as an indication that a segment has been lost.
26pub(crate) const DUP_ACK_THRESHOLD: u8 = 3;
27
28/// Holds the parameters of congestion control that are common to algorithms.
29#[derive(Debug)]
30struct CongestionControlParams {
31    /// Slow start threshold.
32    ssthresh: u32,
33    /// Congestion control window size, in bytes.
34    cwnd: u32,
35    /// Sender MSS.
36    mss: Mss,
37}
38
39impl CongestionControlParams {
40    fn with_mss(mss: Mss) -> Self {
41        let mss_u32 = u32::from(mss);
42        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
43        //   IW, the initial value of cwnd, MUST be set using the following
44        //   guidelines as an upper bound.
45        //   If SMSS > 2190 bytes:
46        //       IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
47        //   If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
48        //       IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
49        //   if SMSS <= 1095 bytes:
50        //       IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
51        let cwnd = if mss_u32 > 2190 {
52            mss_u32 * 2
53        } else if mss_u32 > 1095 {
54            mss_u32 * 3
55        } else {
56            mss_u32 * 4
57        };
58        Self { cwnd, ssthresh: u32::MAX, mss }
59    }
60
61    fn rounded_cwnd(&self) -> CongestionWindow {
62        CongestionWindow::new(self.cwnd, self.mss)
63    }
64}
65
66mod cwnd {
67    use super::*;
68    /// A witness type for a congestion window that is rounded to a multiple of
69    /// MSS.
70    ///
71    /// This type carries around the mss that was used to calculate it.
72    #[derive(Debug, Copy, Clone)]
73    #[cfg_attr(test, derive(Eq, PartialEq))]
74    pub(crate) struct CongestionWindow {
75        cwnd: u32,
76        mss: Mss,
77    }
78
79    impl CongestionWindow {
80        pub(super) fn new(cwnd: u32, mss: Mss) -> Self {
81            let mss_u32 = u32::from(mss);
82            Self { cwnd: cwnd / mss_u32 * mss_u32, mss }
83        }
84
85        pub(crate) fn cwnd(&self) -> u32 {
86            self.cwnd
87        }
88
89        pub(crate) fn mss(&self) -> Mss {
90            self.mss
91        }
92    }
93}
94pub(crate) use cwnd::CongestionWindow;
95
96/// Congestion control with five intertwined algorithms.
97///
98/// - Slow start
99/// - Congestion avoidance from a loss-based algorithm
100/// - Fast retransmit
101/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
102/// - SACK recovery: https://datatracker.ietf.org/doc/html/rfc6675
103#[derive(Debug)]
104pub(crate) struct CongestionControl<I> {
105    params: CongestionControlParams,
106    sack_scoreboard: SackScoreboard,
107    algorithm: LossBasedAlgorithm<I>,
108    /// The connection is in loss recovery when this field is a [`Some`].
109    loss_recovery: Option<LossRecovery>,
110}
111
112/// Available congestion control algorithms.
113#[derive(Debug)]
114enum LossBasedAlgorithm<I> {
115    Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
116}
117
118impl<I: Instant> LossBasedAlgorithm<I> {
119    /// Called when there is a loss detected.
120    ///
121    /// Specifically, packet loss means
122    /// - either when the retransmission timer fired;
123    /// - or when we have received a certain amount of duplicate acks.
124    fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
125        match self {
126            LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
127        }
128    }
129
130    fn on_ack(
131        &mut self,
132        params: &mut CongestionControlParams,
133        bytes_acked: NonZeroU32,
134        now: I,
135        rtt: Duration,
136    ) {
137        match self {
138            LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
139        }
140    }
141
142    fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
143        match self {
144            LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
145        }
146    }
147}
148
149impl<I: Instant> CongestionControl<I> {
150    /// Preprocesses an ACK that may contain selective ack blocks.
151    ///
152    /// Returns `Some(true)` if this should be considered a duplicate ACK
153    /// according to the rules in [RFC 6675 section 2]. Returns `Some(false)`
154    /// otherwise.
155    ///
156    /// If the incoming ACK does not have SACK information, `None` is returned
157    /// and the caller should use the classic algorithm to determine if this is
158    /// a duplicate ACk.
159    ///
160    /// [RFC 6675 section 2]:
161    ///     https://datatracker.ietf.org/doc/html/rfc6675#section-2
162    pub(super) fn preprocess_ack(
163        &mut self,
164        seg_ack: SeqNum,
165        snd_nxt: SeqNum,
166        seg_sack_blocks: &SackBlocks,
167    ) -> Option<bool> {
168        let Self { params, algorithm: _, loss_recovery, sack_scoreboard } = self;
169        let high_rxt = loss_recovery.as_ref().and_then(|lr| match lr {
170            LossRecovery::FastRecovery(_) => None,
171            LossRecovery::SackRecovery(sack_recovery) => sack_recovery.high_rxt(),
172        });
173        let is_dup_ack =
174            sack_scoreboard.process_ack(seg_ack, snd_nxt, high_rxt, seg_sack_blocks, params.mss);
175        (!seg_sack_blocks.is_empty()).then_some(is_dup_ack)
176    }
177
178    /// Informs the congestion control algorithm that a segment of length
179    /// `seg_len` is being sent on the wire.
180    ///
181    /// This allows congestion control to keep the correct estimate of how many
182    /// bytes are in flight.
183    pub(super) fn on_will_send_segment(&mut self, seg_len: u32) {
184        let Self { params: _, sack_scoreboard, algorithm: _, loss_recovery: _ } = self;
185        // From RFC 6675:
186        //
187        //  (C.4) The estimate of the amount of data outstanding in the
188        //  network must be updated by incrementing pipe by the number of
189        //  octets transmitted in (C.1).
190        sack_scoreboard.increment_pipe(seg_len);
191    }
192
193    /// Called when there are previously unacknowledged bytes being acked.
194    ///
195    /// If a round-trip-time estimation is not available, `rtt` can be `None`,
196    /// but the loss-based algorithm is not updated in that case.
197    ///
198    /// Returns `true` if this ack signals a loss recovery.
199    pub(super) fn on_ack(
200        &mut self,
201        seg_ack: SeqNum,
202        bytes_acked: NonZeroU32,
203        now: I,
204        rtt: Option<Duration>,
205    ) -> bool {
206        let Self { params, algorithm, loss_recovery, sack_scoreboard: _ } = self;
207        // Exit fast recovery since there is an ACK that acknowledges new data.
208        let outcome = match loss_recovery {
209            None => LossRecoveryOnAckOutcome::None,
210            Some(LossRecovery::FastRecovery(fast_recovery)) => fast_recovery.on_ack(params),
211            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery.on_ack(seg_ack),
212        };
213
214        let recovered = match outcome {
215            LossRecoveryOnAckOutcome::None => false,
216            LossRecoveryOnAckOutcome::Discard { recovered } => {
217                *loss_recovery = None;
218                recovered
219            }
220        };
221
222        // It is possible, however unlikely, that we get here without an RTT
223        // estimation - in case the first data segment that we send out gets
224        // retransmitted. In that case, simply don't update the congestion
225        // parameters with the loss based algorithm which at worst causes slow
226        // start to take one extra step.
227        if let Some(rtt) = rtt {
228            algorithm.on_ack(params, bytes_acked, now, rtt);
229        }
230        recovered
231    }
232
233    /// Called when a duplicate ack is arrived.
234    ///
235    /// Returns `Some` if loss recovery was initiated as a result of this ACK,
236    /// informing which mode was triggered.
237    pub(super) fn on_dup_ack(
238        &mut self,
239        seg_ack: SeqNum,
240        snd_nxt: SeqNum,
241    ) -> Option<LossRecoveryMode> {
242        let Self { params, algorithm, loss_recovery, sack_scoreboard } = self;
243        match loss_recovery {
244            None => {
245                // If we have SACK information, prefer SACK recovery.
246                if sack_scoreboard.has_sack_info() {
247                    let mut sack_recovery = SackRecovery::new();
248                    let started_loss_recovery = sack_recovery
249                        .on_dup_ack(seg_ack, snd_nxt, sack_scoreboard)
250                        .apply(params, algorithm);
251                    *loss_recovery = Some(LossRecovery::SackRecovery(sack_recovery));
252                    started_loss_recovery.then_some(LossRecoveryMode::SackRecovery)
253                } else {
254                    *loss_recovery = Some(LossRecovery::FastRecovery(FastRecovery::new()));
255                    None
256                }
257            }
258            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery
259                .on_dup_ack(seg_ack, snd_nxt, sack_scoreboard)
260                .apply(params, algorithm)
261                .then_some(LossRecoveryMode::SackRecovery),
262            Some(LossRecovery::FastRecovery(fast_recovery)) => fast_recovery
263                .on_dup_ack(params, algorithm, seg_ack)
264                .then_some(LossRecoveryMode::FastRecovery),
265        }
266    }
267
268    /// Called upon a retransmission timeout.
269    ///
270    /// `snd_nxt` is the value of SND.NXT _before_ it is rewound to SND.UNA as
271    /// part of an RTO.
272    pub(super) fn on_retransmission_timeout(&mut self, snd_nxt: SeqNum) {
273        let Self { params, algorithm, loss_recovery, sack_scoreboard } = self;
274        sack_scoreboard.on_retransmission_timeout();
275        let discard_loss_recovery = match loss_recovery {
276            None | Some(LossRecovery::FastRecovery(_)) => true,
277            Some(LossRecovery::SackRecovery(sack_recovery)) => {
278                sack_recovery.on_retransmission_timeout(snd_nxt)
279            }
280        };
281        if discard_loss_recovery {
282            *loss_recovery = None;
283        }
284        algorithm.on_retransmission_timeout(params);
285    }
286
287    pub(super) fn slow_start_threshold(&self) -> u32 {
288        self.params.ssthresh
289    }
290
291    #[cfg(test)]
292    pub(super) fn pipe(&self) -> u32 {
293        self.sack_scoreboard.pipe()
294    }
295
296    /// Inflates the congestion window by `value` to facilitate testing.
297    #[cfg(test)]
298    pub(super) fn inflate_cwnd(&mut self, inflation: u32) {
299        self.params.cwnd += inflation;
300    }
301
302    pub(super) fn cubic_with_mss(mss: Mss) -> Self {
303        Self {
304            params: CongestionControlParams::with_mss(mss),
305            algorithm: LossBasedAlgorithm::Cubic(Default::default()),
306            loss_recovery: None,
307            sack_scoreboard: SackScoreboard::default(),
308        }
309    }
310
311    pub(super) fn mss(&self) -> Mss {
312        self.params.mss
313    }
314
315    pub(super) fn update_mss(&mut self, mss: Mss, snd_una: SeqNum, snd_nxt: SeqNum) {
316        let Self { params, sack_scoreboard, algorithm: _, loss_recovery } = self;
317        // From [RFC 5681 section 3.1]:
318        //
319        //    When initial congestion windows of more than one segment are
320        //    implemented along with Path MTU Discovery [RFC1191], and the MSS
321        //    being used is found to be too large, the congestion window cwnd
322        //    SHOULD be reduced to prevent large bursts of smaller segments.
323        //    Specifically, cwnd SHOULD be reduced by the ratio of the old segment
324        //    size to the new segment size.
325        //
326        // [RFC 5681 section 3.1]: https://datatracker.ietf.org/doc/html/rfc5681#section-3.1
327        if params.ssthresh == u32::MAX {
328            params.cwnd =
329                params.cwnd.saturating_div(u32::from(params.mss)).saturating_mul(u32::from(mss));
330        }
331        params.mss = mss;
332
333        // Given we'll retransmit after receiving this, we need to update the
334        // SACK scoreboard so pipe is recalculated based on this value of
335        // snd_nxt and mss.
336        let high_rxt = loss_recovery.as_ref().and_then(|lr| match lr {
337            LossRecovery::FastRecovery(_) => None,
338            LossRecovery::SackRecovery(sack_recovery) => sack_recovery.high_rxt(),
339        });
340        sack_scoreboard.on_mss_update(snd_una, snd_nxt, high_rxt, mss);
341    }
342
343    /// Returns the rounded unmodified by loss recovery window size.
344    ///
345    /// This is meant to be used for inspection only. Congestion calculation for
346    /// sending should use [`CongestionControl::poll_send`].
347    pub(super) fn inspect_cwnd(&self) -> CongestionWindow {
348        self.params.rounded_cwnd()
349    }
350
351    /// Returns the current loss recovery mode, if any.
352    ///
353    /// This method returns the current loss recovery mode.
354    ///
355    /// *NOTE* It's possible for [`CongestionControl`] to return a
356    /// [`LossRecoveryMode`] here even if there was no congestion events. Rely
357    /// on the return values from [`CongestionControl::poll_send`],
358    /// [`CongestionControl::on_dup_ack`] to catch entering into loss recovery
359    /// mode or determining if segments originate from a specific algorithm.
360    pub(super) fn inspect_loss_recovery_mode(&self) -> Option<LossRecoveryMode> {
361        self.loss_recovery.as_ref().map(|lr| lr.mode())
362    }
363
364    /// Returns true if this [`CongestionControl`] is in slow start.
365    pub(super) fn in_slow_start(&self) -> bool {
366        self.params.cwnd < self.params.ssthresh
367    }
368
369    /// Polls congestion control for the next segment to be sent out.
370    ///
371    /// Receives pertinent parameters from the sender state machine to allow for
372    /// this decision:
373    ///
374    /// - `snd_una` is SND.UNA the highest unacknowledged sequence number.
375    /// - `snd_nxt` is SND.NXT, the next sequence number the sender would send
376    ///   without loss recovery.
377    /// - `snd_wnd` is the total send window, i.e. the allowable receiver window
378    ///   after `snd_una`.
379    /// - `available_bytes` is the total number of bytes in the send buffer,
380    ///   starting at `snd_una`.
381    ///
382    /// Returns `None` if no segment should be sent right now.
383    pub(super) fn poll_send(
384        &mut self,
385        snd_una: SeqNum,
386        snd_nxt: SeqNum,
387        snd_wnd: WindowSize,
388        available_bytes: usize,
389    ) -> Option<CongestionControlSendOutcome> {
390        let Self { params, algorithm: _, loss_recovery, sack_scoreboard } = self;
391        let cwnd = params.rounded_cwnd();
392
393        match loss_recovery {
394            None => {
395                let pipe = sack_scoreboard.pipe();
396                let congestion_window = cwnd.cwnd();
397                let available_window = congestion_window.saturating_sub(pipe);
398                let congestion_limit = available_window.min(cwnd.mss().into());
399                Some(CongestionControlSendOutcome {
400                    next_seg: snd_nxt,
401                    congestion_limit,
402                    congestion_window,
403                    loss_recovery: LossRecoverySegment::No,
404                })
405            }
406            Some(LossRecovery::FastRecovery(fast_recovery)) => {
407                Some(fast_recovery.poll_send(cwnd, sack_scoreboard.pipe(), snd_nxt))
408            }
409            Some(LossRecovery::SackRecovery(sack_recovery)) => sack_recovery.poll_send(
410                cwnd,
411                snd_una,
412                snd_nxt,
413                snd_wnd,
414                available_bytes,
415                sack_scoreboard,
416            ),
417        }
418    }
419}
420
421/// Indicates whether the segment yielded in [`CongestionControlSendOutcome`] is
422/// a loss recovery segment.
423#[derive(Debug)]
424#[cfg_attr(test, derive(Copy, Clone, Eq, PartialEq))]
425pub(super) enum LossRecoverySegment {
426    /// Indicates the segment is a loss recovery segment.
427    Yes {
428        /// If true, the retransmit timer should be rearmed due to this loss
429        /// recovery segment.
430        ///
431        /// This is used in SACK recovery to prevent RTOs during retransmission,
432        /// from [RFC 6675 section 6]:
433        ///
434        /// > Therefore, we give implementers the latitude to use the standard
435        /// > [RFC6298]-style RTO management or, optionally, a more careful
436        /// > variant that re-arms the RTO timer on each retransmission that is
437        /// > sent during recovery MAY be used.  This provides a more
438        /// > conservative timer than specified in [RFC6298], and so may not
439        /// > always be an attractive alternative.  However, in some cases it
440        /// > may prevent needless retransmissions, go-back-N transmission, and
441        /// > further reduction of the congestion window.
442        ///
443        /// [RFC 6675 section 6]: https://datatracker.ietf.org/doc/html/rfc6675#section-6
444        rearm_retransmit: bool,
445        /// The recovery mode that caused this loss recovery segment.
446        mode: LossRecoveryMode,
447    },
448    /// Indicates the segment is *not* a loss recovery segment.
449    No,
450}
451
452/// The outcome of [`CongestionControl::poll_send`].
453#[derive(Debug)]
454#[cfg_attr(test, derive(Eq, PartialEq))]
455pub(super) struct CongestionControlSendOutcome {
456    /// The next segment to be sent out.
457    pub next_seg: SeqNum,
458    /// The maximum number of bytes post next_seg that can be sent.
459    ///
460    /// This limit does not account for the unused/open window on the receiver.
461    ///
462    /// This is limited by the current congestion limit and the sender MSS.
463    pub congestion_limit: u32,
464    /// The congestion window used to calculate `congestion limit`.
465    ///
466    /// This is the estimated total congestion window, including loss
467    /// recovery-based inflation.
468    pub congestion_window: u32,
469    /// Whether this is a loss recovery segment.
470    pub loss_recovery: LossRecoverySegment,
471}
472
473/// The current loss recovery mode.
474#[derive(Debug)]
475pub enum LossRecovery {
476    FastRecovery(FastRecovery),
477    SackRecovery(SackRecovery),
478}
479
480impl LossRecovery {
481    fn mode(&self) -> LossRecoveryMode {
482        match self {
483            LossRecovery::FastRecovery(_) => LossRecoveryMode::FastRecovery,
484            LossRecovery::SackRecovery(_) => LossRecoveryMode::SackRecovery,
485        }
486    }
487}
488
489/// An equivalent to [`LossRecovery`] that simply informs the loss recovery
490/// mode, without carrying state.
491#[derive(Debug)]
492#[cfg_attr(test, derive(Copy, Clone, Eq, PartialEq))]
493pub enum LossRecoveryMode {
494    FastRecovery,
495    SackRecovery,
496}
497
498#[derive(Debug)]
499#[cfg_attr(test, derive(Eq, PartialEq))]
500enum LossRecoveryOnAckOutcome {
501    None,
502    Discard { recovered: bool },
503}
504
505/// Reno style Fast Recovery algorithm as described in
506/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
507#[derive(Debug)]
508pub struct FastRecovery {
509    /// Holds the sequence number of the segment to fast retransmit, if any.
510    fast_retransmit: Option<SeqNum>,
511    /// The running count of consecutive duplicate ACKs we have received so far.
512    ///
513    /// Here we limit the maximum number of duplicate ACKS we track to 255, as
514    /// per a note in the RFC:
515    ///
516    /// Note: [SCWA99] discusses a receiver-based attack whereby many
517    /// bogus duplicate ACKs are sent to the data sender in order to
518    /// artificially inflate cwnd and cause a higher than appropriate
519    /// sending rate to be used.  A TCP MAY therefore limit the number of
520    /// times cwnd is artificially inflated during loss recovery to the
521    /// number of outstanding segments (or, an approximation thereof).
522    ///
523    /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
524    dup_acks: NonZeroU8,
525}
526
527impl FastRecovery {
528    fn new() -> Self {
529        Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
530    }
531
532    fn poll_send(
533        &mut self,
534        cwnd: CongestionWindow,
535        used_congestion_window: u32,
536        snd_nxt: SeqNum,
537    ) -> CongestionControlSendOutcome {
538        let Self { fast_retransmit, dup_acks } = self;
539        // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2): ...
540        // the Limited Transmit algorithm, which calls for a TCP sender to
541        //   transmit new data upon the arrival of the first two consecutive
542        //   duplicate ACKs ... The amount of outstanding data would remain less
543        //   than or equal to the congestion window plus 2 segments.  In other
544        //   words, the sender can only send two segments beyond the congestion
545        //   window (cwnd).
546        //
547        // Note: We don't directly change cwnd in the loss-based algorithm
548        // because the RFC says one MUST NOT do that. We follow the requirement
549        // here by not changing the cwnd of the algorithm - if a new ACK is
550        // received after the two dup acks, the loss-based algorithm will
551        // continue to operate the same way as if the 2 SMSS is never added to
552        // cwnd.
553        let congestion_window = if dup_acks.get() < DUP_ACK_THRESHOLD {
554            cwnd.cwnd().saturating_add(u32::from(dup_acks.get()) * u32::from(cwnd.mss()))
555        } else {
556            cwnd.cwnd()
557        };
558
559        // Elect fast retransmit sequence number or snd_nxt if we don't have
560        // one.
561        let (next_seg, loss_recovery, congestion_limit) = match fast_retransmit.take() {
562            // From RFC 5861:
563            //
564            //  3. The lost segment starting at SND.UNA MUST be retransmitted
565            //     [...].
566            //
567            // So we always set the congestion limit to be just the mss.
568            Some(f) => (
569                f,
570                LossRecoverySegment::Yes {
571                    rearm_retransmit: false,
572                    mode: LossRecoveryMode::FastRecovery,
573                },
574                cwnd.mss().into(),
575            ),
576            // There's no fast retransmit pending, use snd_nxt applying the used
577            // congestion window.
578            None => (
579                snd_nxt,
580                LossRecoverySegment::No,
581                congestion_window.saturating_sub(used_congestion_window).min(cwnd.mss().into()),
582            ),
583        };
584        CongestionControlSendOutcome {
585            next_seg,
586            congestion_limit,
587            congestion_window,
588            loss_recovery,
589        }
590    }
591
592    fn on_ack(&mut self, params: &mut CongestionControlParams) -> LossRecoveryOnAckOutcome {
593        let recovered = self.dup_acks.get() >= DUP_ACK_THRESHOLD;
594        if recovered {
595            // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
596            //   When the next ACK arrives that acknowledges previously
597            //   unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
598            //   set in step 2).  This is termed "deflating" the window.
599            params.cwnd = params.ssthresh;
600        }
601        LossRecoveryOnAckOutcome::Discard { recovered }
602    }
603
604    /// Processes a duplicate ack with sequence number `seg_ack`.
605    ///
606    /// Returns `true` if loss recovery is triggered.
607    fn on_dup_ack<I: Instant>(
608        &mut self,
609        params: &mut CongestionControlParams,
610        loss_based: &mut LossBasedAlgorithm<I>,
611        seg_ack: SeqNum,
612    ) -> bool {
613        self.dup_acks = self.dup_acks.saturating_add(1);
614
615        match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
616            Ordering::Less => false,
617            Ordering::Equal => {
618                loss_based.on_loss_detected(params);
619                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
620                //   The lost segment starting at SND.UNA MUST be retransmitted
621                //   and cwnd set to ssthresh plus 3*SMSS.  This artificially
622                //   "inflates" the congestion window by the number of segments
623                //   (three) that have left the network and which the receiver
624                //   has buffered.
625                self.fast_retransmit = Some(seg_ack);
626                params.cwnd =
627                    params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
628                true
629            }
630            Ordering::Greater => {
631                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
632                //   For each additional duplicate ACK received (after the third),
633                //   cwnd MUST be incremented by SMSS. This artificially inflates
634                //   the congestion window in order to reflect the additional
635                //   segment that has left the network.
636                params.cwnd = params.cwnd.saturating_add(u32::from(params.mss));
637                false
638            }
639        }
640    }
641}
642
643/// The state kept by [`SackRecovery`] indicating the recovery state.
644#[derive(Debug)]
645#[cfg_attr(test, derive(Eq, PartialEq, Copy, Clone))]
646enum SackRecoveryState {
647    /// SACK is currently in active recovery.
648    InRecovery(SackInRecoveryState),
649    /// SACK is holding off starting new recovery after an RTO.
650    PostRto { recovery_point: SeqNum },
651    /// SACK is not in active recovery.
652    NotInRecovery,
653}
654
655/// The state kept by [`SackInRecoveryState::InRecovery`].
656#[derive(Debug)]
657#[cfg_attr(test, derive(Eq, PartialEq, Copy, Clone))]
658struct SackInRecoveryState {
659    /// The sequence number that marks the end of the current loss recovery
660    /// phase.
661    recovery_point: SeqNum,
662    /// The highest retransmitted sequence number during the current loss
663    /// recovery phase.
664    ///
665    /// Tracks the "HighRxt" variable defined in [RFC 6675 section 2].
666    ///
667    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
668    high_rxt: SeqNum,
669    /// The highest sequence number that has been optimistically retransmitted.
670    ///
671    /// Tracks the "RescureRxt" variable defined in [RFC 6675 section 2].
672    ///
673    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
674    rescue_rxt: Option<SeqNum>,
675}
676
677/// Implements the SACK based recovery from [RFC 6675].
678///
679/// [RFC 6675]: https://datatracker.ietf.org/doc/html/rfc6675
680#[derive(Debug)]
681pub(crate) struct SackRecovery {
682    /// Keeps track of the number of duplicate ACKs received during SACK
683    /// recovery.
684    ///
685    /// Tracks the "DupAcks" variable defined in [RFC 6675 section 2].
686    ///
687    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
688    dup_acks: u8,
689    /// Statekeeping for loss recovery.
690    ///
691    /// Set to `Some` when we're in recovery state.
692    recovery: SackRecoveryState,
693}
694
695impl SackRecovery {
696    fn new() -> Self {
697        Self {
698            // Unlike FastRecovery, we start with zero duplicate ACKs,
699            // congestion control calls on_dup_ack after creation.
700            dup_acks: 0,
701            recovery: SackRecoveryState::NotInRecovery,
702        }
703    }
704
705    fn high_rxt(&self) -> Option<SeqNum> {
706        match &self.recovery {
707            SackRecoveryState::InRecovery(SackInRecoveryState {
708                recovery_point: _,
709                high_rxt,
710                rescue_rxt: _,
711            }) => Some(*high_rxt),
712            SackRecoveryState::PostRto { recovery_point: _ } | SackRecoveryState::NotInRecovery => {
713                None
714            }
715        }
716    }
717
718    fn on_ack(&mut self, seg_ack: SeqNum) -> LossRecoveryOnAckOutcome {
719        let Self { dup_acks, recovery } = self;
720        match recovery {
721            SackRecoveryState::InRecovery(SackInRecoveryState {
722                recovery_point,
723                high_rxt: _,
724                rescue_rxt: _,
725            })
726            | SackRecoveryState::PostRto { recovery_point } => {
727                // From RFC 6675:
728                //  An incoming cumulative ACK for a sequence number greater than
729                //  RecoveryPoint signals the end of loss recovery, and the loss
730                //  recovery phase MUST be terminated.
731                if seg_ack.after_or_eq(*recovery_point) {
732                    LossRecoveryOnAckOutcome::Discard {
733                        recovered: matches!(recovery, SackRecoveryState::InRecovery(_)),
734                    }
735                } else {
736                    // From RFC 6675:
737                    //  If the incoming ACK is a cumulative acknowledgment, the
738                    //  TCP MUST reset DupAcks to zero.
739                    *dup_acks = 0;
740                    LossRecoveryOnAckOutcome::None
741                }
742            }
743            SackRecoveryState::NotInRecovery => {
744                // We're not in loss recovery, we seem to have moved things
745                // forward. Discard loss recovery information.
746                LossRecoveryOnAckOutcome::Discard { recovered: false }
747            }
748        }
749    }
750
751    /// Processes a duplicate acknowledgement.
752    fn on_dup_ack(
753        &mut self,
754        seq_ack: SeqNum,
755        snd_nxt: SeqNum,
756        sack_scoreboard: &SackScoreboard,
757    ) -> SackDupAckOutcome {
758        let Self { dup_acks, recovery } = self;
759        match recovery {
760            SackRecoveryState::InRecovery(_) | SackRecoveryState::PostRto { .. } => {
761                // Already in recovery mode, nothing to do.
762                return SackDupAckOutcome(false);
763            }
764            SackRecoveryState::NotInRecovery => (),
765        }
766        *dup_acks += 1;
767        // From RFC 6675:
768        //  (1) If DupAcks >= DupThresh, [...].
769        //  (2) If DupAcks < DupThresh but IsLost (HighACK + 1) returns true
770        //  [...]
771        if *dup_acks >= DUP_ACK_THRESHOLD || sack_scoreboard.is_first_hole_lost() {
772            // Enter loss recovery:
773            //  (4.1) RecoveryPoint = HighData
774            //  When the TCP sender receives a cumulative ACK for this data
775            //  octet, the loss recovery phase is terminated.
776            *recovery = SackRecoveryState::InRecovery(SackInRecoveryState {
777                recovery_point: snd_nxt,
778                high_rxt: seq_ack,
779                rescue_rxt: None,
780            });
781            SackDupAckOutcome(true)
782        } else {
783            SackDupAckOutcome(false)
784        }
785    }
786
787    /// Updates SACK recovery to account for a retransmission timeout during
788    /// recovery.
789    ///
790    /// From [RFC 6675 section 5.1]:
791    ///
792    /// > If an RTO occurs during loss recovery as specified in this document,
793    /// > RecoveryPoint MUST be set to HighData.  Further, the new value of
794    /// > RecoveryPoint MUST be preserved and the loss recovery algorithm
795    /// > outlined in this document MUST be terminated.  In addition, a new
796    /// > recovery phase (as described in Section 5) MUST NOT be initiated until
797    /// > HighACK is greater than or equal to the new value of RecoveryPoint.
798    ///
799    /// [RFC 6675 section 5.1]: https://datatracker.ietf.org/doc/html/rfc6675#section-5.1
800    ///
801    /// Returns `true` iff we can clear all recovery state due to the timeout.
802    pub(crate) fn on_retransmission_timeout(&mut self, snd_nxt: SeqNum) -> bool {
803        let Self { dup_acks: _, recovery } = self;
804        match recovery {
805            SackRecoveryState::InRecovery(SackInRecoveryState { .. }) => {
806                *recovery = SackRecoveryState::PostRto { recovery_point: snd_nxt };
807                false
808            }
809            SackRecoveryState::PostRto { recovery_point: _ } => {
810                // NB: The RFC is not exactly clear on what to do here, but the
811                // best interpretation is that we should maintain the old
812                // recovery point until we've hit that point and don't update to
813                // the new (assumedly rewound) snd_nxt.
814                false
815            }
816            SackRecoveryState::NotInRecovery => {
817                // Not in recovery we can reset our state.
818                true
819            }
820        }
821    }
822
823    /// SACK recovery based congestion control next segment selection.
824    ///
825    /// Argument semantics are the same as [`CongestionControl::poll_send`].
826    fn poll_send(
827        &mut self,
828        cwnd: CongestionWindow,
829        snd_una: SeqNum,
830        snd_nxt: SeqNum,
831        snd_wnd: WindowSize,
832        available_bytes: usize,
833        sack_scoreboard: &SackScoreboard,
834    ) -> Option<CongestionControlSendOutcome> {
835        let Self { dup_acks: _, recovery } = self;
836
837        let pipe = sack_scoreboard.pipe();
838        let congestion_window = cwnd.cwnd();
839        let available_window = congestion_window.saturating_sub(pipe);
840        // Don't send anything if we can't send at least full MSS, following the
841        // RFC. All outcomes require at least one MSS of available window:
842        //
843        // (3.3) If (cwnd - pipe) >= 1 SMSS [...]
844        // (C) If cwnd - pipe >= 1 SMSS [...]
845        if available_window < cwnd.mss().into() {
846            return None;
847        }
848        let congestion_limit = available_window.min(cwnd.mss().into());
849
850        // If we're not in recovery, use the regular congestion calculation,
851        // adjusting the congestion window with the pipe value.
852        //
853        // From RFC 6675:
854        //
855        //  (3.3) If (cwnd - pipe) >= 1 SMSS, there exists previously unsent
856        //  data, and the receiver's advertised window allows, transmit up
857        //  to 1 SMSS of data starting with the octet HighData+1 and update
858        //  HighData to reflect this transmission, then return to (3.2).
859        let SackInRecoveryState { recovery_point, high_rxt, rescue_rxt } = match recovery {
860            SackRecoveryState::InRecovery(sack_in_recovery_state) => sack_in_recovery_state,
861            SackRecoveryState::PostRto { recovery_point: _ } | SackRecoveryState::NotInRecovery => {
862                return Some(CongestionControlSendOutcome {
863                    next_seg: snd_nxt,
864                    congestion_limit,
865                    congestion_window,
866                    loss_recovery: LossRecoverySegment::No,
867                });
868            }
869        };
870
871        // From RFC 6675 section 6:
872        //
873        //  we give implementers the latitude to use the standard
874        //  [RFC6298]-style RTO management or, optionally, a more careful
875        //  variant that re-arms the RTO timer on each retransmission that is
876        //  sent during recovery MAY be used.  This provides a more conservative
877        //  timer than specified in [RFC6298].
878        //
879        // As a local decision, we only rearm the retransmit timer for rules 1
880        // and 3 (regular retransmissions) when the next segment trying to be
881        // sent out is _before_ the recovery point that initiated this loss
882        // recovery. Given the recovery algorithm greedily keeps sending more
883        // data as long as it's available to keep the ACK clock running, there's
884        // a catastrophic scenario where the data sent past the recovery point
885        // creates new holes in the sack scoreboard that are filled by rules 1
886        // and 3 and rearm the RTO, even if the retransmissions from holes
887        // before RecoveryPoint might be lost themselves. Hence, once the
888        // algorithm has moved past trying to fix things past the RecoveryPoint
889        // we stop rearming the RTO in case the ACK for RecoveryPoint never
890        // arrives.
891        //
892        // Note that rule 4 always rearms the retransmission timer because it
893        // sents only a single segment per entry into recovery.¡
894        let rearm_retransmit = |next_seg: SeqNum| next_seg.before(*recovery_point);
895
896        // run NextSeg() as defined in RFC 6675.
897
898        // (1) If there exists a smallest unSACKed sequence number 'S2' that
899        //   meets the following three criteria for determining loss, the
900        //   sequence range of one segment of up to SMSS octets starting
901        //   with S2 MUST be returned.
902        //
903        //   (1.a) S2 is greater than HighRxt.
904        //   (1.b) S2 is less than the highest octet covered by any received
905        //         SACK.
906        //   (1.c) IsLost (S2) returns true.
907
908        let first_unsacked_range =
909            sack_scoreboard.first_unsacked_range_from(snd_una.latest(*high_rxt));
910
911        if let Some(first_hole) = &first_unsacked_range {
912            // Meta is the IsLost value.
913            if *first_hole.meta() {
914                let hole_size = first_hole.len();
915                let congestion_limit = congestion_limit.min(hole_size);
916                *high_rxt = first_hole.start() + congestion_limit;
917
918                // If we haven't set RescueRxt yet, set it to prevent eager
919                // rescue. From RFC 6675:
920                //
921                //  Retransmit the first data segment presumed dropped --
922                //  the segment starting with sequence number HighACK + 1.
923                //  To prevent repeated retransmission of the same data or a
924                //  premature rescue retransmission, set both HighRxt and
925                //  RescueRxt to the highest sequence number in the
926                //  retransmitted segment.
927                if rescue_rxt.is_none() {
928                    *rescue_rxt = Some(*high_rxt);
929                }
930
931                return Some(CongestionControlSendOutcome {
932                    next_seg: first_hole.start(),
933                    congestion_limit,
934                    congestion_window,
935                    loss_recovery: LossRecoverySegment::Yes {
936                        rearm_retransmit: rearm_retransmit(first_hole.start()),
937                        mode: LossRecoveryMode::SackRecovery,
938                    },
939                });
940            }
941        }
942
943        // Run next rule, from RFC 6675:
944        //
945        // (2) If no sequence number 'S2' per rule (1) exists but there
946        // exists available unsent data and the receiver's advertised window
947        // allows, the sequence range of one segment of up to SMSS octets of
948        // previously unsent data starting with sequence number HighData+1
949        // MUST be returned.
950        let total_sent = u32::try_from(snd_nxt - snd_una).unwrap();
951        if available_bytes > usize::try_from(total_sent).unwrap() && u32::from(snd_wnd) > total_sent
952        {
953            return Some(CongestionControlSendOutcome {
954                next_seg: snd_nxt,
955                // We only need to send out the congestion limit, the window
956                // limit is applied by the sender state machine.
957                congestion_limit,
958                congestion_window,
959                // NB: even though we're sending new bytes, we're still
960                // signaling that we're in loss recovery. Our goal here is
961                // to keep the ACK clock running and prevent an RTO, so we
962                // don't want this segment to be delayed by anything.
963                loss_recovery: LossRecoverySegment::Yes {
964                    rearm_retransmit: false,
965                    mode: LossRecoveryMode::SackRecovery,
966                },
967            });
968        }
969
970        // Run next rule, from RFC 6675:
971        //
972        //  (3) If the conditions for rules (1) and (2) fail, but there
973        //  exists an unSACKed sequence number 'S3' that meets the criteria
974        //  for detecting loss given in steps (1.a) and (1.b) above
975        //  (specifically excluding step (1.c)), then one segment of up to
976        //  SMSS octets starting with S3 SHOULD be returned.
977        if let Some(first_hole) = first_unsacked_range {
978            let hole_size = first_hole.len();
979            let congestion_limit = congestion_limit.min(hole_size);
980            *high_rxt = first_hole.start() + congestion_limit;
981
982            return Some(CongestionControlSendOutcome {
983                next_seg: first_hole.start(),
984                congestion_limit,
985                congestion_window,
986                loss_recovery: LossRecoverySegment::Yes {
987                    rearm_retransmit: rearm_retransmit(first_hole.start()),
988                    mode: LossRecoveryMode::SackRecovery,
989                },
990            });
991        }
992
993        // Run next rule, from RFC 6675:
994        //
995        //  (4) If the conditions for (1), (2), and (3) fail, but there
996        //  exists outstanding unSACKed data, we provide the opportunity for
997        //  a single "rescue" retransmission per entry into loss recovery.
998        //  If HighACK is greater than RescueRxt (or RescueRxt is
999        //  undefined), then one segment of up to SMSS octets that MUST
1000        //  include the highest outstanding unSACKed sequence number SHOULD
1001        //  be returned, and RescueRxt set to RecoveryPoint. HighRxt MUST
1002        //  NOT be updated.
1003        if rescue_rxt.is_none_or(|rescue_rxt| snd_una.after_or_eq(rescue_rxt)) {
1004            if let Some(right_edge) = sack_scoreboard.right_edge() {
1005                let left = right_edge.latest(snd_nxt - congestion_limit);
1006                // This can't send any new data, so figure out how much space we
1007                // have left. If SND.NXT got rewound and is now before the right
1008                // edge, unwrap the calculation to zero to avoid sending the
1009                // rescue segment.
1010                let congestion_limit = u32::try_from(snd_nxt - left).unwrap_or(0);
1011                if congestion_limit > 0 {
1012                    *rescue_rxt = Some(*recovery_point);
1013                    return Some(CongestionControlSendOutcome {
1014                        next_seg: left,
1015                        congestion_limit,
1016                        congestion_window,
1017                        // NB: Rescue retransmissions can only happen once in
1018                        // every recovery enter, so always rearm the RTO.
1019                        loss_recovery: LossRecoverySegment::Yes {
1020                            rearm_retransmit: true,
1021                            mode: LossRecoveryMode::SackRecovery,
1022                        },
1023                    });
1024                }
1025            }
1026        }
1027
1028        None
1029    }
1030}
1031
1032/// The value returned by [`SackRecovery::on_dup_ack`].
1033///
1034/// It contains a boolean indicating whether loss recovery started due to a
1035/// duplicate ACK. [`SackDupAckOutcome::apply`] is used to retrieve the boolean
1036/// and notify loss recovery algorithm as needed and update the congestion
1037/// parameters.
1038///
1039/// This is its own type so [`SackRecovery::on_dup_ack`] can be tested in
1040/// isolation from [`LossBasedAlgorithm`].
1041#[derive(Debug)]
1042#[cfg_attr(test, derive(Eq, PartialEq))]
1043struct SackDupAckOutcome(bool);
1044
1045impl SackDupAckOutcome {
1046    /// Consumes this outcome, notifying `algorithm` that loss was detected if
1047    /// needed.
1048    ///
1049    /// Returns the inner boolean indicating whether loss recovery started.
1050    fn apply<I: Instant>(
1051        self,
1052        params: &mut CongestionControlParams,
1053        algorithm: &mut LossBasedAlgorithm<I>,
1054    ) -> bool {
1055        let Self(loss_recovery) = self;
1056        if loss_recovery {
1057            algorithm.on_loss_detected(params);
1058        }
1059        loss_recovery
1060    }
1061}
1062
1063#[cfg(test)]
1064mod test {
1065    use core::ops::Range;
1066
1067    use assert_matches::assert_matches;
1068    use netstack3_base::testutil::FakeInstant;
1069    use netstack3_base::SackBlock;
1070    use test_case::{test_case, test_matrix};
1071
1072    use super::*;
1073    use crate::internal::testutil::{
1074        self, DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE, DEFAULT_IPV6_MAXIMUM_SEGMENT_SIZE,
1075    };
1076
1077    const MSS_1: Mss = DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
1078    const MSS_2: Mss = DEFAULT_IPV6_MAXIMUM_SEGMENT_SIZE;
1079
1080    enum StartingAck {
1081        One,
1082        Wraparound,
1083        WraparoundAfter(u32),
1084    }
1085
1086    impl StartingAck {
1087        fn into_seqnum(self, mss: Mss) -> SeqNum {
1088            let mss = u32::from(mss);
1089            match self {
1090                StartingAck::One => SeqNum::new(1),
1091                StartingAck::Wraparound => SeqNum::new((mss / 2).wrapping_sub(mss)),
1092                StartingAck::WraparoundAfter(n) => SeqNum::new((mss / 2).wrapping_sub(n * mss)),
1093            }
1094        }
1095    }
1096
1097    impl SackRecovery {
1098        #[track_caller]
1099        fn assert_in_recovery(&mut self) -> &mut SackInRecoveryState {
1100            assert_matches!(&mut self.recovery, SackRecoveryState::InRecovery(s) => s)
1101        }
1102    }
1103
1104    impl<I> CongestionControl<I> {
1105        #[track_caller]
1106        fn assert_sack_recovery(&mut self) -> &mut SackRecovery {
1107            assert_matches!(&mut self.loss_recovery, Some(LossRecovery::SackRecovery(s)) => s)
1108        }
1109    }
1110
1111    fn nth_segment_from(base: SeqNum, mss: Mss, n: u32) -> Range<SeqNum> {
1112        let mss = u32::from(mss);
1113        let start = base + n * mss;
1114        Range { start, end: start + mss }
1115    }
1116
1117    fn nth_range(base: SeqNum, mss: Mss, range: Range<u32>) -> Range<SeqNum> {
1118        let mss = u32::from(mss);
1119        let Range { start, end } = range;
1120        let start = base + start * mss;
1121        let end = base + end * mss;
1122        Range { start, end }
1123    }
1124
1125    #[test]
1126    fn no_recovery_before_reaching_threshold() {
1127        let mut congestion_control =
1128            CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
1129        let old_cwnd = congestion_control.params.cwnd;
1130        assert_eq!(congestion_control.params.ssthresh, u32::MAX);
1131        assert_eq!(congestion_control.on_dup_ack(SeqNum::new(0), SeqNum::new(1)), None);
1132        assert!(!congestion_control.on_ack(
1133            SeqNum::new(1),
1134            NonZeroU32::new(1).unwrap(),
1135            FakeInstant::from(Duration::from_secs(0)),
1136            Some(Duration::from_secs(1)),
1137        ));
1138        // We have only received one duplicate ack, receiving a new ACK should
1139        // not mean "loss recovery" - we should not bump our cwnd to initial
1140        // ssthresh (u32::MAX) and then overflow.
1141        assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
1142    }
1143
1144    #[test]
1145    fn preprocess_ack_result() {
1146        let ack = SeqNum::new(1);
1147        let snd_nxt = SeqNum::new(100);
1148        let mut congestion_control =
1149            CongestionControl::<FakeInstant>::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
1150        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1151        assert_eq!(
1152            congestion_control.preprocess_ack(ack, snd_nxt, &testutil::sack_blocks([10..20])),
1153            Some(true)
1154        );
1155        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1156        assert_eq!(
1157            congestion_control.preprocess_ack(ack, snd_nxt, &testutil::sack_blocks([10..20])),
1158            Some(false)
1159        );
1160        assert_eq!(
1161            congestion_control.preprocess_ack(
1162                ack,
1163                snd_nxt,
1164                &testutil::sack_blocks([10..20, 20..30])
1165            ),
1166            Some(true)
1167        );
1168    }
1169
1170    #[test_case(DUP_ACK_THRESHOLD-1; "no loss")]
1171    #[test_case(DUP_ACK_THRESHOLD; "exact threshold")]
1172    #[test_case(DUP_ACK_THRESHOLD+1; "over threshold")]
1173    fn sack_recovery_enter_exit_loss_dupacks(dup_acks: u8) {
1174        let mut congestion_control =
1175            CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
1176        let mss = congestion_control.mss();
1177
1178        let ack = SeqNum::new(1);
1179        let snd_nxt = nth_segment_from(ack, mss, 10).end;
1180
1181        let expect_recovery =
1182            SackInRecoveryState { recovery_point: snd_nxt, high_rxt: ack, rescue_rxt: None };
1183
1184        let mut sack = SackBlock::try_from(nth_segment_from(ack, mss, 1)).unwrap();
1185        for n in 1..=dup_acks {
1186            assert_eq!(
1187                congestion_control.preprocess_ack(ack, snd_nxt, &[sack].into_iter().collect()),
1188                Some(true)
1189            );
1190            assert_eq!(
1191                congestion_control.on_dup_ack(ack, snd_nxt),
1192                (n == DUP_ACK_THRESHOLD).then_some(LossRecoveryMode::SackRecovery)
1193            );
1194            let sack_recovery = congestion_control.assert_sack_recovery();
1195            // We stop counting duplicate acks after the threshold.
1196            assert_eq!(sack_recovery.dup_acks, n.min(DUP_ACK_THRESHOLD));
1197
1198            let expect_recovery = if n >= DUP_ACK_THRESHOLD {
1199                SackRecoveryState::InRecovery(expect_recovery.clone())
1200            } else {
1201                SackRecoveryState::NotInRecovery
1202            };
1203            assert_eq!(congestion_control.assert_sack_recovery().recovery, expect_recovery);
1204
1205            let (start, end) = sack.into_parts();
1206            // Don't increase by full MSS to prove that duplicate ACKs alone are
1207            // putting us in this state.
1208            sack = SackBlock::try_new(start, end + u32::from(mss) / 4).unwrap();
1209        }
1210
1211        let end = sack.right();
1212        let bytes_acked = NonZeroU32::new(u32::try_from(end - ack).unwrap()).unwrap();
1213        let ack = end;
1214        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1215
1216        let now = FakeInstant::default();
1217        let rtt = Some(Duration::from_millis(1));
1218
1219        // A cumulative ACK not covering the recovery point arrives.
1220        assert_eq!(congestion_control.on_ack(ack, bytes_acked, now, rtt), false);
1221        if dup_acks >= DUP_ACK_THRESHOLD {
1222            assert_eq!(
1223                congestion_control.assert_sack_recovery().recovery,
1224                SackRecoveryState::InRecovery(expect_recovery)
1225            );
1226        } else {
1227            assert_matches!(congestion_control.loss_recovery, None);
1228        }
1229
1230        // A cumulative ACK covering the recovery point arrives.
1231        let bytes_acked = NonZeroU32::new(u32::try_from(snd_nxt - ack).unwrap()).unwrap();
1232        let ack = snd_nxt;
1233        assert_eq!(
1234            congestion_control.on_ack(ack, bytes_acked, now, rtt),
1235            dup_acks >= DUP_ACK_THRESHOLD
1236        );
1237        assert_matches!(congestion_control.loss_recovery, None);
1238
1239        // A later cumulative ACK arrives.
1240        let snd_nxt = snd_nxt + 20;
1241        let ack = ack + 10;
1242        assert_eq!(congestion_control.preprocess_ack(ack, snd_nxt, &SackBlocks::EMPTY), None);
1243        assert_eq!(congestion_control.on_ack(ack, bytes_acked, now, rtt), false);
1244        assert_matches!(congestion_control.loss_recovery, None);
1245    }
1246
1247    #[test]
1248    fn sack_recovery_enter_loss_single_dupack() {
1249        let mut congestion_control =
1250            CongestionControl::<FakeInstant>::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
1251
1252        // SACK can enter recovery after a *single* duplicate ACK provided
1253        // enough information is in the scoreboard:
1254        let snd_nxt = SeqNum::new(100);
1255        let ack = SeqNum::new(0);
1256        assert_eq!(
1257            congestion_control.preprocess_ack(
1258                ack,
1259                snd_nxt,
1260                &testutil::sack_blocks([5..15, 25..35, 45..55])
1261            ),
1262            Some(true)
1263        );
1264        assert_eq!(
1265            congestion_control.on_dup_ack(ack, snd_nxt),
1266            Some(LossRecoveryMode::SackRecovery)
1267        );
1268        assert_eq!(
1269            congestion_control.assert_sack_recovery().recovery,
1270            SackRecoveryState::InRecovery(SackInRecoveryState {
1271                recovery_point: snd_nxt,
1272                high_rxt: ack,
1273                rescue_rxt: None
1274            })
1275        );
1276    }
1277
1278    #[test]
1279    fn sack_recovery_poll_send_not_recovery() {
1280        let mut scoreboard = SackScoreboard::default();
1281        let mut recovery = SackRecovery::new();
1282        let mss = DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
1283        let cwnd_mss = 10u32;
1284        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1285        let snd_una = SeqNum::new(1);
1286
1287        let in_flight = 5u32;
1288        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1289
1290        // When not in recovery, we delegate all the receiver window calculation
1291        // out. Prove that that's the case by telling SACK there's nothing to
1292        // send.
1293        let snd_wnd = WindowSize::ZERO;
1294        let available_bytes = 0;
1295
1296        let sack_block = SackBlock::try_from(nth_segment_from(snd_una, mss, 1)).unwrap();
1297        assert!(scoreboard.process_ack(
1298            snd_una,
1299            snd_nxt,
1300            None,
1301            &[sack_block].into_iter().collect(),
1302            mss
1303        ));
1304
1305        // With 1 SACK block this is how much window we expect to have available
1306        // in multiples of mss.
1307        let wnd_used = in_flight - 1;
1308        assert_eq!(scoreboard.pipe(), wnd_used * u32::from(mss));
1309        let wnd_available = cwnd_mss - wnd_used;
1310
1311        for i in 0..wnd_available {
1312            let snd_nxt = snd_nxt + i * u32::from(mss);
1313            assert_eq!(
1314                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1315                Some(CongestionControlSendOutcome {
1316                    next_seg: snd_nxt,
1317                    congestion_limit: mss.into(),
1318                    congestion_window: cwnd.cwnd(),
1319                    loss_recovery: LossRecoverySegment::No,
1320                })
1321            );
1322            scoreboard.increment_pipe(mss.into());
1323        }
1324
1325        // Used all of the window.
1326        assert_eq!(scoreboard.pipe(), cwnd.cwnd());
1327        // Poll send stops this round.
1328        assert_eq!(
1329            recovery.poll_send(
1330                cwnd,
1331                snd_una,
1332                snd_nxt + (wnd_used + 1) * u32::from(mss),
1333                snd_wnd,
1334                available_bytes,
1335                &scoreboard
1336            ),
1337            None
1338        );
1339    }
1340
1341    #[test_matrix(
1342        [MSS_1, MSS_2],
1343        [1, 3, 5],
1344        [StartingAck::One, StartingAck::Wraparound]
1345    )]
1346    fn sack_recovery_next_seg_rule_1(mss: Mss, lost_segments: u32, snd_una: StartingAck) {
1347        let mut scoreboard = SackScoreboard::default();
1348        let mut recovery = SackRecovery::new();
1349
1350        let snd_una = snd_una.into_seqnum(mss);
1351
1352        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1353        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1354        let in_flight = sacked_range.end + 5;
1355        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1356
1357        // Define a congestion window that will only let us fill part of the
1358        // lost segments with rule 1.
1359        let cwnd_mss = in_flight - sacked_segments - 1;
1360        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1361
1362        // Rule 1 should not care about available window size, since it's
1363        // retransmitting a lost segment.
1364        let snd_wnd = WindowSize::ZERO;
1365        let available_bytes = 0;
1366
1367        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1368        assert!(scoreboard.process_ack(
1369            snd_una,
1370            snd_nxt,
1371            None,
1372            &[sack_block].into_iter().collect(),
1373            mss
1374        ));
1375
1376        // Verify that our set up math here is correct, we want recovery to be
1377        // able to fill only a part of the hole.
1378        assert_eq!(cwnd.cwnd() - scoreboard.pipe(), (lost_segments - 1) * u32::from(mss));
1379        // Enter recovery.
1380        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1381
1382        for i in 0..(lost_segments - 1) {
1383            let next_seg = snd_una + i * u32::from(mss);
1384            assert_eq!(
1385                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1386                Some(CongestionControlSendOutcome {
1387                    next_seg,
1388                    congestion_limit: mss.into(),
1389                    congestion_window: cwnd.cwnd(),
1390                    loss_recovery: LossRecoverySegment::Yes {
1391                        rearm_retransmit: true,
1392                        mode: LossRecoveryMode::SackRecovery,
1393                    },
1394                })
1395            );
1396            scoreboard.increment_pipe(mss.into());
1397            assert_eq!(
1398                recovery.recovery,
1399                SackRecoveryState::InRecovery(SackInRecoveryState {
1400                    recovery_point: snd_nxt,
1401                    high_rxt: nth_segment_from(snd_una, mss, i).end,
1402                    // RescueRxt is always set to the first retransmitted
1403                    // segment.
1404                    rescue_rxt: Some(snd_una + u32::from(mss)),
1405                })
1406            );
1407        }
1408        // Ran out of CWND.
1409        assert_eq!(
1410            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1411            None
1412        );
1413    }
1414
1415    #[test_matrix(
1416        [MSS_1, MSS_2],
1417        [1, 3, 5],
1418        [StartingAck::One, StartingAck::Wraparound]
1419    )]
1420    fn sack_recovery_next_seg_rule_2(mss: Mss, expect_send: u32, snd_una: StartingAck) {
1421        let mut scoreboard = SackScoreboard::default();
1422        let mut recovery = SackRecovery::new();
1423
1424        let snd_una = snd_una.into_seqnum(mss);
1425
1426        let lost_segments = 1;
1427        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1428        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1429        let in_flight = sacked_range.end + 5;
1430        let mut snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1431
1432        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1433        assert!(scoreboard.process_ack(
1434            snd_una,
1435            snd_nxt,
1436            None,
1437            &[sack_block].into_iter().collect(),
1438            mss
1439        ));
1440
1441        // Define a congestion window that will allow us to send only the
1442        // desired segments.
1443        let cwnd = CongestionWindow::new(scoreboard.pipe() + expect_send * u32::from(mss), mss);
1444        // Enter recovery.
1445        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1446        // Force HighRxt to the end of the lost block to skip rules 1 and 3.
1447        let recovery_state = recovery.assert_in_recovery();
1448        recovery_state.high_rxt = nth_segment_from(snd_una, mss, lost_segments - 1).end;
1449        // Force RecoveryRxt to skip rule 4.
1450        recovery_state.rescue_rxt = Some(snd_nxt);
1451        let state_snapshot = recovery_state.clone();
1452
1453        // Available bytes is always counted from SND.UNA.
1454        let baseline = u32::try_from(snd_nxt - snd_una).unwrap();
1455        // If there is no window or nothing to send, return.
1456        for (snd_wnd, available_bytes) in [(0, 0), (1, 0), (0, 1)] {
1457            let snd_wnd = WindowSize::from_u32(baseline + snd_wnd).unwrap();
1458            let available_bytes = usize::try_from(baseline + available_bytes).unwrap();
1459            assert_eq!(
1460                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1461                None
1462            );
1463            assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1464        }
1465
1466        let baseline = baseline + (expect_send - 1) * u32::from(mss) + 1;
1467        let snd_wnd = WindowSize::from_u32(baseline).unwrap();
1468        let available_bytes = usize::try_from(baseline).unwrap();
1469        for _ in 0..expect_send {
1470            assert_eq!(
1471                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1472                Some(CongestionControlSendOutcome {
1473                    next_seg: snd_nxt,
1474                    congestion_limit: mss.into(),
1475                    congestion_window: cwnd.cwnd(),
1476                    loss_recovery: LossRecoverySegment::Yes {
1477                        rearm_retransmit: false,
1478                        mode: LossRecoveryMode::SackRecovery,
1479                    },
1480                })
1481            );
1482            assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1483            scoreboard.increment_pipe(mss.into());
1484            snd_nxt = snd_nxt + u32::from(mss);
1485        }
1486        // Ran out of CWND.
1487        let snd_wnd = WindowSize::MAX;
1488        let available_bytes = usize::MAX;
1489        assert_eq!(
1490            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1491            None
1492        );
1493        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(state_snapshot));
1494    }
1495
1496    #[test_matrix(
1497        [MSS_1, MSS_2],
1498        [1, 3, 5],
1499        [StartingAck::One, StartingAck::Wraparound]
1500    )]
1501    fn sack_recovery_next_seg_rule_3(mss: Mss, not_lost_segments: u32, snd_una: StartingAck) {
1502        let mut scoreboard = SackScoreboard::default();
1503        let mut recovery = SackRecovery::new();
1504
1505        let snd_una = snd_una.into_seqnum(mss);
1506
1507        let first_lost_block = 1;
1508        let first_sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1509        let first_sacked_range = first_lost_block..(first_lost_block + first_sacked_segments);
1510
1511        // "not_lost_segments" segments will not be considered lost by the
1512        // scoreboard, but they will not be sacked.
1513        let sacked_segments = 1;
1514        let sacked_range_start = first_sacked_range.end + not_lost_segments;
1515        let sacked_range = sacked_range_start..(sacked_range_start + sacked_segments);
1516
1517        let in_flight = sacked_range.end + 5;
1518        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1519
1520        let sack_block1 =
1521            SackBlock::try_from(nth_range(snd_una, mss, first_sacked_range.clone())).unwrap();
1522        let sack_block2 = SackBlock::try_from(nth_range(snd_una, mss, sacked_range)).unwrap();
1523        assert!(scoreboard.process_ack(
1524            snd_una,
1525            snd_nxt,
1526            None,
1527            &[sack_block1, sack_block2].into_iter().collect(),
1528            mss
1529        ));
1530
1531        // Define a congestion window that will only let us fill part of the
1532        // lost segments with rule 3.
1533        let expect_send = (not_lost_segments - 1).max(1);
1534        let cwnd_mss =
1535            in_flight - first_sacked_segments - sacked_segments - first_lost_block + expect_send;
1536        let cwnd = CongestionWindow::new(cwnd_mss * u32::from(mss), mss);
1537
1538        // Rule 3 is only hit if we don't have enough available data to send.
1539        let snd_wnd = WindowSize::ZERO;
1540        let available_bytes = 0;
1541
1542        // Verify that our set up math here is correct, we want recovery to be
1543        // able to fill only a part of the hole.
1544        assert_eq!(cwnd.cwnd() - scoreboard.pipe(), expect_send * u32::from(mss));
1545        // Enter recovery.
1546        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1547        // Poll while we expect to hit rule 1. Don't increment pipe here because
1548        // we set up our congestion window to stop only rule 3.
1549        for i in 0..first_lost_block {
1550            assert_eq!(
1551                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1552                Some(CongestionControlSendOutcome {
1553                    next_seg: nth_segment_from(snd_una, mss, i).start,
1554                    congestion_limit: mss.into(),
1555                    congestion_window: cwnd.cwnd(),
1556                    loss_recovery: LossRecoverySegment::Yes {
1557                        rearm_retransmit: true,
1558                        mode: LossRecoveryMode::SackRecovery,
1559                    },
1560                })
1561            );
1562        }
1563        let expect_recovery = SackInRecoveryState {
1564            recovery_point: snd_nxt,
1565            high_rxt: nth_segment_from(snd_una, mss, first_sacked_range.start).start,
1566            rescue_rxt: Some(nth_segment_from(snd_una, mss, 0).end),
1567        };
1568        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1569
1570        for i in 0..expect_send {
1571            let next_seg = snd_una + (first_sacked_range.end + i) * u32::from(mss);
1572            assert_eq!(
1573                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1574                Some(CongestionControlSendOutcome {
1575                    next_seg,
1576                    congestion_limit: mss.into(),
1577                    congestion_window: cwnd.cwnd(),
1578                    loss_recovery: LossRecoverySegment::Yes {
1579                        rearm_retransmit: true,
1580                        mode: LossRecoveryMode::SackRecovery,
1581                    },
1582                })
1583            );
1584            scoreboard.increment_pipe(mss.into());
1585            assert_eq!(
1586                recovery.recovery,
1587                SackRecoveryState::InRecovery(SackInRecoveryState {
1588                    high_rxt: next_seg + u32::from(mss),
1589                    ..expect_recovery
1590                })
1591            );
1592        }
1593        // Ran out of CWND.
1594        assert_eq!(
1595            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1596            None
1597        );
1598    }
1599
1600    #[test_matrix(
1601        [MSS_1, MSS_2],
1602        [0, 1, 3],
1603        [StartingAck::One, StartingAck::Wraparound]
1604    )]
1605    fn sack_recovery_next_seg_rule_4(mss: Mss, right_edge_segments: u32, snd_una: StartingAck) {
1606        let mut scoreboard = SackScoreboard::default();
1607        let mut recovery = SackRecovery::new();
1608
1609        let snd_una = snd_una.into_seqnum(mss);
1610
1611        let lost_segments = 1;
1612        let sacked_segments = u32::from(DUP_ACK_THRESHOLD);
1613        let sacked_range = lost_segments..(lost_segments + sacked_segments);
1614        let in_flight = sacked_range.end + right_edge_segments + 2;
1615        let snd_nxt = nth_segment_from(snd_una, mss, in_flight).start;
1616
1617        // Rule 4 should only be hit if we don't have available data to send.
1618        let snd_wnd = WindowSize::ZERO;
1619        let available_bytes = 0;
1620
1621        let sack_block =
1622            SackBlock::try_from(nth_range(snd_una, mss, sacked_range.clone())).unwrap();
1623        assert!(scoreboard.process_ack(
1624            snd_una,
1625            snd_nxt,
1626            None,
1627            &[sack_block].into_iter().collect(),
1628            mss
1629        ));
1630
1631        // Define a very large congestion window, given rule 4 should only
1632        // retransmit a single segment.
1633        let cwnd = CongestionWindow::new((in_flight + 500) * u32::from(mss), mss);
1634
1635        // Enter recovery.
1636        assert_eq!(recovery.on_dup_ack(snd_una, snd_nxt, &scoreboard), SackDupAckOutcome(true));
1637        // Send the segments that match rule 1. Don't increment pipe here, we
1638        // want to show that rule 4 stops even when cwnd is entirely open.
1639        for i in 0..lost_segments {
1640            let next_seg = snd_una + i * u32::from(mss);
1641            assert_eq!(
1642                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1643                Some(CongestionControlSendOutcome {
1644                    next_seg,
1645                    congestion_limit: mss.into(),
1646                    congestion_window: cwnd.cwnd(),
1647                    loss_recovery: LossRecoverySegment::Yes {
1648                        rearm_retransmit: true,
1649                        mode: LossRecoveryMode::SackRecovery,
1650                    },
1651                })
1652            );
1653        }
1654        let expect_recovery = SackInRecoveryState {
1655            recovery_point: snd_nxt,
1656            high_rxt: nth_segment_from(snd_una, mss, lost_segments).start,
1657            // RescueRxt is always set to the first retransmitted
1658            // segment.
1659            rescue_rxt: Some(nth_segment_from(snd_una, mss, 0).end),
1660        };
1661        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1662
1663        // Rule 4 should only hit after we receive an ACK past the first
1664        // RescueRxt value that was set.
1665        assert_eq!(
1666            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1667            None
1668        );
1669        // Acknowledge up to the sacked range, with one new sack block.
1670        let snd_una = nth_segment_from(snd_una, mss, sacked_range.end).start;
1671        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, 1..2)).unwrap();
1672        assert!(scoreboard.process_ack(
1673            snd_una,
1674            snd_nxt,
1675            Some(expect_recovery.high_rxt),
1676            &[sack_block].into_iter().collect(),
1677            mss
1678        ));
1679        assert_eq!(recovery.on_ack(snd_una), LossRecoveryOnAckOutcome::None);
1680        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1681        // Rule 3 will hit once here because we have a single not lost segment.
1682        assert_eq!(
1683            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1684            Some(CongestionControlSendOutcome {
1685                next_seg: snd_una,
1686                congestion_limit: mss.into(),
1687                congestion_window: cwnd.cwnd(),
1688                loss_recovery: LossRecoverySegment::Yes {
1689                    rearm_retransmit: true,
1690                    mode: LossRecoveryMode::SackRecovery,
1691                },
1692            })
1693        );
1694        let expect_recovery =
1695            SackInRecoveryState { high_rxt: snd_una + u32::from(mss), ..expect_recovery };
1696        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(expect_recovery));
1697
1698        // Now we should hit Rule 4, as long as we have unacknowledged data.
1699        if right_edge_segments > 0 {
1700            assert_eq!(
1701                recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1702                Some(CongestionControlSendOutcome {
1703                    next_seg: snd_nxt - u32::from(mss),
1704                    congestion_limit: mss.into(),
1705                    congestion_window: cwnd.cwnd(),
1706                    loss_recovery: LossRecoverySegment::Yes {
1707                        rearm_retransmit: true,
1708                        mode: LossRecoveryMode::SackRecovery,
1709                    },
1710                })
1711            );
1712            assert_eq!(
1713                recovery.recovery,
1714                SackRecoveryState::InRecovery(SackInRecoveryState {
1715                    rescue_rxt: Some(expect_recovery.recovery_point),
1716                    ..expect_recovery
1717                })
1718            );
1719        }
1720
1721        // Once we've done the rescue it can't happen again.
1722        assert_eq!(
1723            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1724            None
1725        );
1726    }
1727
1728    #[test_matrix(
1729        [MSS_1, MSS_2],
1730        [
1731            StartingAck::One,
1732            StartingAck::WraparoundAfter(1),
1733            StartingAck::WraparoundAfter(2),
1734            StartingAck::WraparoundAfter(3),
1735            StartingAck::WraparoundAfter(4)
1736        ]
1737    )]
1738    fn sack_recovery_all_rules(mss: Mss, snd_una: StartingAck) {
1739        let snd_una = snd_una.into_seqnum(mss);
1740
1741        // Set up the scoreboard so we have 1 hole considered lost, that is hit
1742        // by Rule 1, and another that is not lost, hit by Rule 3.
1743        let mut scoreboard = SackScoreboard::default();
1744        let first_sacked_range = 1..(u32::from(DUP_ACK_THRESHOLD) + 1);
1745        let first_sack_block =
1746            SackBlock::try_from(nth_range(snd_una, mss, first_sacked_range.clone())).unwrap();
1747
1748        let second_sacked_range = (first_sacked_range.end + 1)..(first_sacked_range.end + 2);
1749        let second_sack_block =
1750            SackBlock::try_from(nth_range(snd_una, mss, second_sacked_range.clone())).unwrap();
1751
1752        let snd_nxt = nth_segment_from(snd_una, mss, second_sacked_range.end + 1).start;
1753
1754        // To hit Rule 4 in one run, set up a recovery state that looks
1755        // like we've already tried to fill one hole with Rule 1 and
1756        // received an ack for it.
1757        let high_rxt = snd_una;
1758        let rescue_rxt = Some(snd_una);
1759
1760        assert!(scoreboard.process_ack(
1761            snd_una,
1762            snd_nxt,
1763            Some(high_rxt),
1764            &[first_sack_block, second_sack_block].into_iter().collect(),
1765            mss
1766        ));
1767
1768        // Create a situation where a single sequential round of calls to
1769        // poll_send will hit each rule.
1770        let recovery_state = SackInRecoveryState { recovery_point: snd_nxt, high_rxt, rescue_rxt };
1771        let mut recovery = SackRecovery {
1772            dup_acks: DUP_ACK_THRESHOLD,
1773            recovery: SackRecoveryState::InRecovery(recovery_state),
1774        };
1775
1776        // Define a congestion window that allows sending a single segment,
1777        // we'll not update the pipe variable at each call so we should never
1778        // hit the congestion limit.
1779        let cwnd = CongestionWindow::new(scoreboard.pipe() + u32::from(mss), mss);
1780
1781        // Make exactly one segment available in the receiver window and send
1782        // buffer so we hit Rule 2 exactly once.
1783        let available = u32::try_from(snd_nxt - snd_una).unwrap() + 1;
1784        let snd_wnd = WindowSize::from_u32(available).unwrap();
1785        let available_bytes = usize::try_from(available).unwrap();
1786
1787        // Hit Rule 1.
1788        assert_eq!(
1789            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1790            Some(CongestionControlSendOutcome {
1791                next_seg: snd_una,
1792                congestion_limit: u32::from(mss),
1793                congestion_window: cwnd.cwnd(),
1794                loss_recovery: LossRecoverySegment::Yes {
1795                    rearm_retransmit: true,
1796                    mode: LossRecoveryMode::SackRecovery,
1797                },
1798            })
1799        );
1800        let recovery_state =
1801            SackInRecoveryState { high_rxt: snd_una + u32::from(mss), ..recovery_state };
1802        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1803
1804        // Hit Rule 2.
1805        assert_eq!(
1806            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1807            Some(CongestionControlSendOutcome {
1808                next_seg: snd_nxt,
1809                congestion_limit: u32::from(mss),
1810                congestion_window: cwnd.cwnd(),
1811                loss_recovery: LossRecoverySegment::Yes {
1812                    rearm_retransmit: false,
1813                    mode: LossRecoveryMode::SackRecovery,
1814                },
1815            })
1816        );
1817        // snd_nxt should advance.
1818        let snd_nxt = snd_nxt + u32::from(mss);
1819        // No change to recovery state.
1820        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1821
1822        // Hit Rule 3.
1823        assert_eq!(
1824            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1825            Some(CongestionControlSendOutcome {
1826                next_seg: nth_segment_from(snd_una, mss, first_sacked_range.end).start,
1827                congestion_limit: u32::from(mss),
1828                congestion_window: cwnd.cwnd(),
1829                loss_recovery: LossRecoverySegment::Yes {
1830                    rearm_retransmit: true,
1831                    mode: LossRecoveryMode::SackRecovery,
1832                },
1833            })
1834        );
1835        let recovery_state = SackInRecoveryState {
1836            high_rxt: nth_segment_from(snd_una, mss, second_sacked_range.start).start,
1837            ..recovery_state
1838        };
1839        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1840
1841        // Hit Rule 4.
1842        assert_eq!(
1843            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1844            Some(CongestionControlSendOutcome {
1845                next_seg: snd_nxt - u32::from(mss),
1846                congestion_limit: u32::from(mss),
1847                congestion_window: cwnd.cwnd(),
1848                loss_recovery: LossRecoverySegment::Yes {
1849                    rearm_retransmit: true,
1850                    mode: LossRecoveryMode::SackRecovery,
1851                },
1852            })
1853        );
1854        let recovery_state = SackInRecoveryState {
1855            rescue_rxt: Some(recovery_state.recovery_point),
1856            ..recovery_state
1857        };
1858        assert_eq!(recovery.recovery, SackRecoveryState::InRecovery(recovery_state));
1859
1860        // Hit all the rules. Nothing to send even if we still have cwnd.
1861        assert_eq!(
1862            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1863            None
1864        );
1865        assert!(cwnd.cwnd() - scoreboard.pipe() >= u32::from(mss));
1866    }
1867
1868    #[test]
1869    fn sack_rto() {
1870        let mss = DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
1871        let mut congestion_control = CongestionControl::<FakeInstant>::cubic_with_mss(mss);
1872
1873        let rto_snd_nxt = SeqNum::new(50);
1874        // Set ourselves up not in recovery.
1875        congestion_control.loss_recovery = Some(LossRecovery::SackRecovery(SackRecovery {
1876            dup_acks: DUP_ACK_THRESHOLD - 1,
1877            recovery: SackRecoveryState::NotInRecovery,
1878        }));
1879        congestion_control.on_retransmission_timeout(rto_snd_nxt);
1880        assert_matches!(congestion_control.loss_recovery, None);
1881
1882        // Set ourselves up in loss recovery.
1883        congestion_control.loss_recovery = Some(LossRecovery::SackRecovery(SackRecovery {
1884            dup_acks: DUP_ACK_THRESHOLD,
1885            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
1886                recovery_point: SeqNum::new(10),
1887                high_rxt: SeqNum::new(0),
1888                rescue_rxt: None,
1889            }),
1890        }));
1891        congestion_control.on_retransmission_timeout(rto_snd_nxt);
1892        assert_eq!(
1893            congestion_control.assert_sack_recovery().recovery,
1894            SackRecoveryState::PostRto { recovery_point: rto_snd_nxt }
1895        );
1896
1897        let snd_una = SeqNum::new(0);
1898        let snd_nxt = SeqNum::new(10);
1899        // While in RTO held off state, we always send next data as if we were
1900        // not in recovery.
1901        assert_eq!(
1902            congestion_control.poll_send(snd_una, snd_nxt, WindowSize::ZERO, 0),
1903            Some(CongestionControlSendOutcome {
1904                next_seg: snd_nxt,
1905                congestion_limit: u32::from(mss),
1906                congestion_window: congestion_control.inspect_cwnd().cwnd(),
1907                loss_recovery: LossRecoverySegment::No,
1908            })
1909        );
1910        // Receiving duplicate acks does not enter recovery.
1911        for _ in 0..DUP_ACK_THRESHOLD {
1912            assert_eq!(congestion_control.on_dup_ack(snd_una, snd_nxt), None);
1913        }
1914
1915        let now = FakeInstant::default();
1916        let rtt = Some(Duration::from_millis(1));
1917
1918        // Receiving an ack before the RTO recovery point does not stop
1919        // recovery.
1920        let bytes_acked = NonZeroU32::new(u32::try_from(snd_nxt - snd_una).unwrap()).unwrap();
1921        let snd_una = snd_nxt;
1922        assert!(!congestion_control.on_ack(snd_una, bytes_acked, now, rtt));
1923        assert_eq!(
1924            congestion_control.assert_sack_recovery().recovery,
1925            SackRecoveryState::PostRto { recovery_point: rto_snd_nxt }
1926        );
1927
1928        // Covering the recovery point allows us to discard recovery state.
1929        let bytes_acked = NonZeroU32::new(u32::try_from(rto_snd_nxt - snd_una).unwrap()).unwrap();
1930        let snd_una = rto_snd_nxt;
1931
1932        // Not considered a recovery event since RTO is the thing that recovered
1933        // us.
1934        assert_eq!(congestion_control.on_ack(snd_una, bytes_acked, now, rtt), false);
1935        assert_matches!(congestion_control.loss_recovery, None);
1936    }
1937
1938    #[test]
1939    fn dont_rearm_rto_past_recovery_point() {
1940        let mut scoreboard = SackScoreboard::default();
1941        let mss = DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
1942        let snd_una = SeqNum::new(1);
1943
1944        let recovery_point = nth_segment_from(snd_una, mss, 100).start;
1945        let snd_nxt = recovery_point + 100 * u32::from(mss);
1946
1947        let mut recovery = SackRecovery {
1948            dup_acks: DUP_ACK_THRESHOLD,
1949            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
1950                recovery_point,
1951                high_rxt: recovery_point,
1952                rescue_rxt: Some(recovery_point),
1953            }),
1954        };
1955
1956        let block1 = nth_range(snd_una, mss, 101..110);
1957        let block2 = nth_range(snd_una, mss, 111..112);
1958        assert!(scoreboard.process_ack(
1959            snd_una,
1960            snd_nxt,
1961            recovery.high_rxt(),
1962            &[SackBlock::try_from(block1).unwrap(), SackBlock::try_from(block2).unwrap()]
1963                .into_iter()
1964                .collect(),
1965            mss,
1966        ));
1967
1968        let cwnd = CongestionWindow::new(u32::MAX, mss);
1969
1970        let snd_wnd = WindowSize::ZERO;
1971        let available_bytes = 0;
1972
1973        assert_eq!(
1974            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1975            Some(CongestionControlSendOutcome {
1976                next_seg: nth_segment_from(snd_una, mss, 100).start,
1977                congestion_limit: mss.into(),
1978                congestion_window: cwnd.cwnd(),
1979                loss_recovery: LossRecoverySegment::Yes {
1980                    rearm_retransmit: false,
1981                    mode: LossRecoveryMode::SackRecovery,
1982                }
1983            })
1984        );
1985        assert_eq!(
1986            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
1987            Some(CongestionControlSendOutcome {
1988                next_seg: nth_segment_from(snd_una, mss, 110).start,
1989                congestion_limit: mss.into(),
1990                congestion_window: cwnd.cwnd(),
1991                loss_recovery: LossRecoverySegment::Yes {
1992                    rearm_retransmit: false,
1993                    mode: LossRecoveryMode::SackRecovery,
1994                }
1995            })
1996        );
1997    }
1998
1999    // Parts of the state machine may end up rewinding SND.NXT to SND.UNA.
2000    // Ensure that NextSeg rule 4 implementation in SackRecovery (which is
2001    // sensitive to SND.NXT) gracefully handles that.
2002    #[test]
2003    fn sack_snd_nxt_rewind() {
2004        let mut scoreboard = SackScoreboard::default();
2005        let mss = DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
2006        let snd_una = SeqNum::new(1);
2007
2008        let recovery_point = nth_segment_from(snd_una, mss, 100).start;
2009        let snd_nxt = nth_segment_from(recovery_point, mss, 100).start;
2010
2011        let mut recovery = SackRecovery {
2012            dup_acks: DUP_ACK_THRESHOLD,
2013            recovery: SackRecoveryState::InRecovery(SackInRecoveryState {
2014                recovery_point,
2015                high_rxt: recovery_point,
2016                rescue_rxt: None,
2017            }),
2018        };
2019        let sack_block = SackBlock::try_from(nth_range(snd_una, mss, 1..5)).unwrap();
2020        assert!(scoreboard.process_ack(
2021            snd_una,
2022            snd_nxt,
2023            Some(recovery_point),
2024            &[sack_block].into_iter().collect(),
2025            mss,
2026        ));
2027        // Rewind.
2028        let snd_nxt = snd_una;
2029
2030        let cwnd = CongestionWindow::new(u32::MAX, mss);
2031        let snd_wnd = WindowSize::ZERO;
2032        let available_bytes = 0;
2033
2034        assert_eq!(
2035            recovery.poll_send(cwnd, snd_una, snd_nxt, snd_wnd, available_bytes, &scoreboard),
2036            None
2037        );
2038    }
2039}