netstack3_tcp/
congestion.rs

1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU32, NonZeroU8};
17use core::time::Duration;
18
19use netstack3_base::{Instant, Mss, SackBlocks, SeqNum, WindowSize};
20
21use crate::internal::sack_scoreboard::SackScoreboard;
22
23// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
24///   The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
25///   as an indication that a segment has been lost.
26pub(crate) const DUP_ACK_THRESHOLD: u8 = 3;
27
28/// Holds the parameters of congestion control that are common to algorithms.
29#[derive(Debug)]
30struct CongestionControlParams {
31    /// Slow start threshold.
32    ssthresh: u32,
33    /// Congestion control window size, in bytes.
34    cwnd: u32,
35    /// Sender MSS.
36    mss: Mss,
37}
38
39impl CongestionControlParams {
40    fn with_mss(mss: Mss) -> Self {
41        let mss_u32 = u32::from(mss);
42        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
43        //   IW, the initial value of cwnd, MUST be set using the following
44        //   guidelines as an upper bound.
45        //   If SMSS > 2190 bytes:
46        //       IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
47        //   If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
48        //       IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
49        //   if SMSS <= 1095 bytes:
50        //       IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
51        let cwnd = if mss_u32 > 2190 {
52            mss_u32 * 2
53        } else if mss_u32 > 1095 {
54            mss_u32 * 3
55        } else {
56            mss_u32 * 4
57        };
58        Self { cwnd, ssthresh: u32::MAX, mss }
59    }
60
61    fn rounded_cwnd(&self) -> WindowSize {
62        let mss_u32 = u32::from(self.mss);
63        WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
64    }
65}
66
67/// Congestion control with five intertwined algorithms.
68///
69/// - Slow start
70/// - Congestion avoidance from a loss-based algorithm
71/// - Fast retransmit
72/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
73/// - SACK recovery: https://datatracker.ietf.org/doc/html/rfc6675
74#[derive(Debug)]
75pub(crate) struct CongestionControl<I> {
76    params: CongestionControlParams,
77    sack_scoreboard: SackScoreboard,
78    algorithm: LossBasedAlgorithm<I>,
79    /// The connection is in fast recovery when this field is a [`Some`].
80    fast_recovery: Option<FastRecovery>,
81}
82
83/// Available congestion control algorithms.
84#[derive(Debug)]
85enum LossBasedAlgorithm<I> {
86    Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
87}
88
89impl<I: Instant> LossBasedAlgorithm<I> {
90    /// Called when there is a loss detected.
91    ///
92    /// Specifically, packet loss means
93    /// - either when the retransmission timer fired;
94    /// - or when we have received a certain amount of duplicate acks.
95    fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
96        match self {
97            LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
98        }
99    }
100
101    /// Called when we recovered from packet loss when receiving an ACK that
102    /// acknowledges new data.
103    fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
104        // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
105        //   When the next ACK arrives that acknowledges previously
106        //   unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
107        //   set in step 2).  This is termed "deflating" the window.
108        params.cwnd = params.ssthresh;
109    }
110
111    fn on_ack(
112        &mut self,
113        params: &mut CongestionControlParams,
114        bytes_acked: NonZeroU32,
115        now: I,
116        rtt: Duration,
117    ) {
118        match self {
119            LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
120        }
121    }
122
123    fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
124        match self {
125            LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
126        }
127    }
128}
129
130impl<I: Instant> CongestionControl<I> {
131    /// Preprocesses an ACK that may contain selective ack blocks.
132    ///
133    /// Returns true if this should be considered a duplicate ACK according to
134    /// the rules in [RFC 6675 section 2].
135    ///
136    /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
137    pub(super) fn preprocess_ack(
138        &mut self,
139        seg_ack: SeqNum,
140        snd_nxt: SeqNum,
141        seg_sack_blocks: &SackBlocks,
142    ) -> bool {
143        let Self { params, algorithm: _, fast_recovery: _, sack_scoreboard } = self;
144        // TODO(https://fxbug.dev/42078221): Take HighRxt from loss recovery.
145        let high_rxt = None;
146        sack_scoreboard.process_ack(seg_ack, snd_nxt, high_rxt, seg_sack_blocks, params.mss)
147    }
148
149    /// Called when there are previously unacknowledged bytes being acked.
150    pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
151        let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
152        // Exit fast recovery since there is an ACK that acknowledges new data.
153        if let Some(fast_recovery) = fast_recovery.take() {
154            if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
155                algorithm.on_loss_recovered(params);
156            }
157        };
158        algorithm.on_ack(params, bytes_acked, now, rtt);
159    }
160
161    /// Called when a duplicate ack is arrived.
162    ///
163    /// Returns `true` if fast recovery was initiated as a result of this ACK.
164    pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
165        let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
166        match fast_recovery {
167            None => {
168                *fast_recovery = Some(FastRecovery::new());
169                true
170            }
171            Some(fast_recovery) => {
172                fast_recovery.on_dup_ack(params, algorithm, seg_ack);
173                false
174            }
175        }
176    }
177
178    /// Called upon a retransmission timeout.
179    pub(super) fn on_retransmission_timeout(&mut self) {
180        let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
181        *fast_recovery = None;
182        algorithm.on_retransmission_timeout(params);
183    }
184
185    /// Gets the current congestion window size in bytes.
186    ///
187    /// This normally just returns whatever value the loss-based algorithm tells
188    /// us, with the exception that in limited transmit case, the cwnd is
189    /// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
190    /// network and trigger more duplicate ACKs to enter fast retransmit. Note
191    /// that this still conforms to the RFC because we don't change the cwnd of
192    /// our algorithm, the algorithm is not aware of this "inflation".
193    pub(super) fn cwnd(&self) -> WindowSize {
194        let Self { params, algorithm: _, fast_recovery, sack_scoreboard: _ } = self;
195        let cwnd = params.rounded_cwnd();
196        if let Some(fast_recovery) = fast_recovery {
197            // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
198            //   ... the Limited Transmit algorithm, which calls for a TCP
199            //   sender to transmit new data upon the arrival of the first two
200            //   consecutive duplicate ACKs ...
201            //   The amount of outstanding data would remain less than or equal
202            //   to the congestion window plus 2 segments.  In other words, the
203            //   sender can only send two segments beyond the congestion window
204            //   (cwnd).
205            // Note: We don't directly change cwnd in the loss-based algorithm
206            // because the RFC says one MUST NOT do that. We follow the
207            // requirement here by not changing the cwnd of the algorithm - if
208            // a new ACK is received after the two dup acks, the loss-based
209            // algorithm will continue to operate the same way as if the 2 SMSS
210            // is never added to cwnd.
211            if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
212                return cwnd.saturating_add(
213                    u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
214                );
215            }
216        }
217        cwnd
218    }
219
220    pub(super) fn slow_start_threshold(&self) -> u32 {
221        self.params.ssthresh
222    }
223
224    /// Returns the starting sequence number of the segment that needs to be
225    /// retransmitted, if any.
226    pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
227        self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
228    }
229
230    pub(super) fn cubic_with_mss(mss: Mss) -> Self {
231        Self {
232            params: CongestionControlParams::with_mss(mss),
233            algorithm: LossBasedAlgorithm::Cubic(Default::default()),
234            fast_recovery: None,
235            sack_scoreboard: SackScoreboard::default(),
236        }
237    }
238
239    pub(super) fn mss(&self) -> Mss {
240        self.params.mss
241    }
242
243    pub(super) fn update_mss(&mut self, mss: Mss) {
244        // From [RFC 5681 section 3.1]:
245        //
246        //    When initial congestion windows of more than one segment are
247        //    implemented along with Path MTU Discovery [RFC1191], and the MSS
248        //    being used is found to be too large, the congestion window cwnd
249        //    SHOULD be reduced to prevent large bursts of smaller segments.
250        //    Specifically, cwnd SHOULD be reduced by the ratio of the old segment
251        //    size to the new segment size.
252        //
253        // [RFC 5681 section 3.1]: https://datatracker.ietf.org/doc/html/rfc5681#section-3.1
254        if self.params.ssthresh == u32::MAX {
255            self.params.cwnd = self
256                .params
257                .cwnd
258                .saturating_div(u32::from(self.params.mss))
259                .saturating_mul(u32::from(mss));
260        }
261        self.params.mss = mss;
262    }
263
264    /// Returns true if this [`CongestionControl`] is in fast recovery.
265    pub(super) fn in_fast_recovery(&self) -> bool {
266        self.fast_recovery.is_some()
267    }
268
269    /// Returns true if this [`CongestionControl`] is in slow start.
270    pub(super) fn in_slow_start(&self) -> bool {
271        self.params.cwnd < self.params.ssthresh
272    }
273}
274
275/// Reno style Fast Recovery algorithm as described in
276/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
277#[derive(Debug)]
278pub struct FastRecovery {
279    /// Holds the sequence number of the segment to fast retransmit, if any.
280    fast_retransmit: Option<SeqNum>,
281    /// The running count of consecutive duplicate ACKs we have received so far.
282    ///
283    /// Here we limit the maximum number of duplicate ACKS we track to 255, as
284    /// per a note in the RFC:
285    ///
286    /// Note: [SCWA99] discusses a receiver-based attack whereby many
287    /// bogus duplicate ACKs are sent to the data sender in order to
288    /// artificially inflate cwnd and cause a higher than appropriate
289    /// sending rate to be used.  A TCP MAY therefore limit the number of
290    /// times cwnd is artificially inflated during loss recovery to the
291    /// number of outstanding segments (or, an approximation thereof).
292    ///
293    /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
294    dup_acks: NonZeroU8,
295}
296
297impl FastRecovery {
298    fn new() -> Self {
299        Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
300    }
301
302    fn on_dup_ack<I: Instant>(
303        &mut self,
304        params: &mut CongestionControlParams,
305        loss_based: &mut LossBasedAlgorithm<I>,
306        seg_ack: SeqNum,
307    ) {
308        self.dup_acks = self.dup_acks.saturating_add(1);
309
310        match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
311            Ordering::Less => {}
312            Ordering::Equal => {
313                loss_based.on_loss_detected(params);
314                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
315                //   The lost segment starting at SND.UNA MUST be retransmitted
316                //   and cwnd set to ssthresh plus 3*SMSS.  This artificially
317                //   "inflates" the congestion window by the number of segments
318                //   (three) that have left the network and which the receiver
319                //   has buffered.
320                self.fast_retransmit = Some(seg_ack);
321                params.cwnd =
322                    params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
323            }
324            Ordering::Greater => {
325                // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
326                //   For each additional duplicate ACK received (after the third),
327                //   cwnd MUST be incremented by SMSS. This artificially inflates
328                //   the congestion window in order to reflect the additional
329                //   segment that has left the network.
330                params.cwnd = params.cwnd.saturating_add(u32::from(params.mss));
331            }
332        }
333    }
334}
335
336#[cfg(test)]
337mod test {
338    use netstack3_base::testutil::FakeInstant;
339
340    use super::*;
341    use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
342
343    #[test]
344    fn no_recovery_before_reaching_threshold() {
345        let mut congestion_control =
346            CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
347        let old_cwnd = congestion_control.params.cwnd;
348        assert_eq!(congestion_control.params.ssthresh, u32::MAX);
349        assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
350        congestion_control.on_ack(
351            NonZeroU32::new(1).unwrap(),
352            FakeInstant::from(Duration::from_secs(0)),
353            Duration::from_secs(1),
354        );
355        // We have only received one duplicate ack, receiving a new ACK should
356        // not mean "loss recovery" - we should not bump our cwnd to initial
357        // ssthresh (u32::MAX) and then overflow.
358        assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
359    }
360}