netstack3_tcp/congestion.rs
1// Copyright 2022 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implements loss-based congestion control algorithms.
6//!
7//! The currently implemented algorithms are CUBIC from [RFC 8312] and RENO
8//! style fast retransmit and fast recovery from [RFC 5681].
9//!
10//! [RFC 8312]: https://www.rfc-editor.org/rfc/rfc8312
11//! [RFC 5681]: https://www.rfc-editor.org/rfc/rfc5681
12
13mod cubic;
14
15use core::cmp::Ordering;
16use core::num::{NonZeroU32, NonZeroU8};
17use core::time::Duration;
18
19use netstack3_base::{Instant, Mss, SackBlocks, SeqNum, WindowSize};
20
21use crate::internal::sack_scoreboard::SackScoreboard;
22
23// Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
24/// The fast retransmit algorithm uses the arrival of 3 duplicate ACKs (...)
25/// as an indication that a segment has been lost.
26pub(crate) const DUP_ACK_THRESHOLD: u8 = 3;
27
28/// Holds the parameters of congestion control that are common to algorithms.
29#[derive(Debug)]
30struct CongestionControlParams {
31 /// Slow start threshold.
32 ssthresh: u32,
33 /// Congestion control window size, in bytes.
34 cwnd: u32,
35 /// Sender MSS.
36 mss: Mss,
37}
38
39impl CongestionControlParams {
40 fn with_mss(mss: Mss) -> Self {
41 let mss_u32 = u32::from(mss);
42 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#page-5):
43 // IW, the initial value of cwnd, MUST be set using the following
44 // guidelines as an upper bound.
45 // If SMSS > 2190 bytes:
46 // IW = 2 * SMSS bytes and MUST NOT be more than 2 segments
47 // If (SMSS > 1095 bytes) and (SMSS <= 2190 bytes):
48 // IW = 3 * SMSS bytes and MUST NOT be more than 3 segments
49 // if SMSS <= 1095 bytes:
50 // IW = 4 * SMSS bytes and MUST NOT be more than 4 segments
51 let cwnd = if mss_u32 > 2190 {
52 mss_u32 * 2
53 } else if mss_u32 > 1095 {
54 mss_u32 * 3
55 } else {
56 mss_u32 * 4
57 };
58 Self { cwnd, ssthresh: u32::MAX, mss }
59 }
60
61 fn rounded_cwnd(&self) -> WindowSize {
62 let mss_u32 = u32::from(self.mss);
63 WindowSize::from_u32(self.cwnd / mss_u32 * mss_u32).unwrap_or(WindowSize::MAX)
64 }
65}
66
67/// Congestion control with five intertwined algorithms.
68///
69/// - Slow start
70/// - Congestion avoidance from a loss-based algorithm
71/// - Fast retransmit
72/// - Fast recovery: https://datatracker.ietf.org/doc/html/rfc5681#section-3
73/// - SACK recovery: https://datatracker.ietf.org/doc/html/rfc6675
74#[derive(Debug)]
75pub(crate) struct CongestionControl<I> {
76 params: CongestionControlParams,
77 sack_scoreboard: SackScoreboard,
78 algorithm: LossBasedAlgorithm<I>,
79 /// The connection is in fast recovery when this field is a [`Some`].
80 fast_recovery: Option<FastRecovery>,
81}
82
83/// Available congestion control algorithms.
84#[derive(Debug)]
85enum LossBasedAlgorithm<I> {
86 Cubic(cubic::Cubic<I, true /* FAST_CONVERGENCE */>),
87}
88
89impl<I: Instant> LossBasedAlgorithm<I> {
90 /// Called when there is a loss detected.
91 ///
92 /// Specifically, packet loss means
93 /// - either when the retransmission timer fired;
94 /// - or when we have received a certain amount of duplicate acks.
95 fn on_loss_detected(&mut self, params: &mut CongestionControlParams) {
96 match self {
97 LossBasedAlgorithm::Cubic(cubic) => cubic.on_loss_detected(params),
98 }
99 }
100
101 /// Called when we recovered from packet loss when receiving an ACK that
102 /// acknowledges new data.
103 fn on_loss_recovered(&mut self, params: &mut CongestionControlParams) {
104 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
105 // When the next ACK arrives that acknowledges previously
106 // unacknowledged data, a TCP MUST set cwnd to ssthresh (the value
107 // set in step 2). This is termed "deflating" the window.
108 params.cwnd = params.ssthresh;
109 }
110
111 fn on_ack(
112 &mut self,
113 params: &mut CongestionControlParams,
114 bytes_acked: NonZeroU32,
115 now: I,
116 rtt: Duration,
117 ) {
118 match self {
119 LossBasedAlgorithm::Cubic(cubic) => cubic.on_ack(params, bytes_acked, now, rtt),
120 }
121 }
122
123 fn on_retransmission_timeout(&mut self, params: &mut CongestionControlParams) {
124 match self {
125 LossBasedAlgorithm::Cubic(cubic) => cubic.on_retransmission_timeout(params),
126 }
127 }
128}
129
130impl<I: Instant> CongestionControl<I> {
131 /// Preprocesses an ACK that may contain selective ack blocks.
132 ///
133 /// Returns true if this should be considered a duplicate ACK according to
134 /// the rules in [RFC 6675 section 2].
135 ///
136 /// [RFC 6675 section 2]: https://datatracker.ietf.org/doc/html/rfc6675#section-2
137 pub(super) fn preprocess_ack(
138 &mut self,
139 seg_ack: SeqNum,
140 snd_nxt: SeqNum,
141 seg_sack_blocks: &SackBlocks,
142 ) -> bool {
143 let Self { params, algorithm: _, fast_recovery: _, sack_scoreboard } = self;
144 // TODO(https://fxbug.dev/42078221): Take HighRxt from loss recovery.
145 let high_rxt = None;
146 sack_scoreboard.process_ack(seg_ack, snd_nxt, high_rxt, seg_sack_blocks, params.mss)
147 }
148
149 /// Called when there are previously unacknowledged bytes being acked.
150 pub(super) fn on_ack(&mut self, bytes_acked: NonZeroU32, now: I, rtt: Duration) {
151 let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
152 // Exit fast recovery since there is an ACK that acknowledges new data.
153 if let Some(fast_recovery) = fast_recovery.take() {
154 if fast_recovery.dup_acks.get() >= DUP_ACK_THRESHOLD {
155 algorithm.on_loss_recovered(params);
156 }
157 };
158 algorithm.on_ack(params, bytes_acked, now, rtt);
159 }
160
161 /// Called when a duplicate ack is arrived.
162 ///
163 /// Returns `true` if fast recovery was initiated as a result of this ACK.
164 pub(super) fn on_dup_ack(&mut self, seg_ack: SeqNum) -> bool {
165 let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
166 match fast_recovery {
167 None => {
168 *fast_recovery = Some(FastRecovery::new());
169 true
170 }
171 Some(fast_recovery) => {
172 fast_recovery.on_dup_ack(params, algorithm, seg_ack);
173 false
174 }
175 }
176 }
177
178 /// Called upon a retransmission timeout.
179 pub(super) fn on_retransmission_timeout(&mut self) {
180 let Self { params, algorithm, fast_recovery, sack_scoreboard: _ } = self;
181 *fast_recovery = None;
182 algorithm.on_retransmission_timeout(params);
183 }
184
185 /// Gets the current congestion window size in bytes.
186 ///
187 /// This normally just returns whatever value the loss-based algorithm tells
188 /// us, with the exception that in limited transmit case, the cwnd is
189 /// inflated by dup_ack_cnt * mss, to allow unsent data packets to enter the
190 /// network and trigger more duplicate ACKs to enter fast retransmit. Note
191 /// that this still conforms to the RFC because we don't change the cwnd of
192 /// our algorithm, the algorithm is not aware of this "inflation".
193 pub(super) fn cwnd(&self) -> WindowSize {
194 let Self { params, algorithm: _, fast_recovery, sack_scoreboard: _ } = self;
195 let cwnd = params.rounded_cwnd();
196 if let Some(fast_recovery) = fast_recovery {
197 // Per RFC 3042 (https://www.rfc-editor.org/rfc/rfc3042#section-2):
198 // ... the Limited Transmit algorithm, which calls for a TCP
199 // sender to transmit new data upon the arrival of the first two
200 // consecutive duplicate ACKs ...
201 // The amount of outstanding data would remain less than or equal
202 // to the congestion window plus 2 segments. In other words, the
203 // sender can only send two segments beyond the congestion window
204 // (cwnd).
205 // Note: We don't directly change cwnd in the loss-based algorithm
206 // because the RFC says one MUST NOT do that. We follow the
207 // requirement here by not changing the cwnd of the algorithm - if
208 // a new ACK is received after the two dup acks, the loss-based
209 // algorithm will continue to operate the same way as if the 2 SMSS
210 // is never added to cwnd.
211 if fast_recovery.dup_acks.get() < DUP_ACK_THRESHOLD {
212 return cwnd.saturating_add(
213 u32::from(fast_recovery.dup_acks.get()) * u32::from(params.mss),
214 );
215 }
216 }
217 cwnd
218 }
219
220 pub(super) fn slow_start_threshold(&self) -> u32 {
221 self.params.ssthresh
222 }
223
224 /// Returns the starting sequence number of the segment that needs to be
225 /// retransmitted, if any.
226 pub(super) fn fast_retransmit(&mut self) -> Option<SeqNum> {
227 self.fast_recovery.as_mut().and_then(|r| r.fast_retransmit.take())
228 }
229
230 pub(super) fn cubic_with_mss(mss: Mss) -> Self {
231 Self {
232 params: CongestionControlParams::with_mss(mss),
233 algorithm: LossBasedAlgorithm::Cubic(Default::default()),
234 fast_recovery: None,
235 sack_scoreboard: SackScoreboard::default(),
236 }
237 }
238
239 pub(super) fn mss(&self) -> Mss {
240 self.params.mss
241 }
242
243 pub(super) fn update_mss(&mut self, mss: Mss) {
244 // From [RFC 5681 section 3.1]:
245 //
246 // When initial congestion windows of more than one segment are
247 // implemented along with Path MTU Discovery [RFC1191], and the MSS
248 // being used is found to be too large, the congestion window cwnd
249 // SHOULD be reduced to prevent large bursts of smaller segments.
250 // Specifically, cwnd SHOULD be reduced by the ratio of the old segment
251 // size to the new segment size.
252 //
253 // [RFC 5681 section 3.1]: https://datatracker.ietf.org/doc/html/rfc5681#section-3.1
254 if self.params.ssthresh == u32::MAX {
255 self.params.cwnd = self
256 .params
257 .cwnd
258 .saturating_div(u32::from(self.params.mss))
259 .saturating_mul(u32::from(mss));
260 }
261 self.params.mss = mss;
262 }
263
264 /// Returns true if this [`CongestionControl`] is in fast recovery.
265 pub(super) fn in_fast_recovery(&self) -> bool {
266 self.fast_recovery.is_some()
267 }
268
269 /// Returns true if this [`CongestionControl`] is in slow start.
270 pub(super) fn in_slow_start(&self) -> bool {
271 self.params.cwnd < self.params.ssthresh
272 }
273}
274
275/// Reno style Fast Recovery algorithm as described in
276/// [RFC 5681](https://tools.ietf.org/html/rfc5681).
277#[derive(Debug)]
278pub struct FastRecovery {
279 /// Holds the sequence number of the segment to fast retransmit, if any.
280 fast_retransmit: Option<SeqNum>,
281 /// The running count of consecutive duplicate ACKs we have received so far.
282 ///
283 /// Here we limit the maximum number of duplicate ACKS we track to 255, as
284 /// per a note in the RFC:
285 ///
286 /// Note: [SCWA99] discusses a receiver-based attack whereby many
287 /// bogus duplicate ACKs are sent to the data sender in order to
288 /// artificially inflate cwnd and cause a higher than appropriate
289 /// sending rate to be used. A TCP MAY therefore limit the number of
290 /// times cwnd is artificially inflated during loss recovery to the
291 /// number of outstanding segments (or, an approximation thereof).
292 ///
293 /// [SCWA99]: https://homes.cs.washington.edu/~tom/pubs/CCR99.pdf
294 dup_acks: NonZeroU8,
295}
296
297impl FastRecovery {
298 fn new() -> Self {
299 Self { dup_acks: NonZeroU8::new(1).unwrap(), fast_retransmit: None }
300 }
301
302 fn on_dup_ack<I: Instant>(
303 &mut self,
304 params: &mut CongestionControlParams,
305 loss_based: &mut LossBasedAlgorithm<I>,
306 seg_ack: SeqNum,
307 ) {
308 self.dup_acks = self.dup_acks.saturating_add(1);
309
310 match self.dup_acks.get().cmp(&DUP_ACK_THRESHOLD) {
311 Ordering::Less => {}
312 Ordering::Equal => {
313 loss_based.on_loss_detected(params);
314 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
315 // The lost segment starting at SND.UNA MUST be retransmitted
316 // and cwnd set to ssthresh plus 3*SMSS. This artificially
317 // "inflates" the congestion window by the number of segments
318 // (three) that have left the network and which the receiver
319 // has buffered.
320 self.fast_retransmit = Some(seg_ack);
321 params.cwnd =
322 params.ssthresh + u32::from(DUP_ACK_THRESHOLD) * u32::from(params.mss);
323 }
324 Ordering::Greater => {
325 // Per RFC 5681 (https://www.rfc-editor.org/rfc/rfc5681#section-3.2):
326 // For each additional duplicate ACK received (after the third),
327 // cwnd MUST be incremented by SMSS. This artificially inflates
328 // the congestion window in order to reflect the additional
329 // segment that has left the network.
330 params.cwnd = params.cwnd.saturating_add(u32::from(params.mss));
331 }
332 }
333 }
334}
335
336#[cfg(test)]
337mod test {
338 use netstack3_base::testutil::FakeInstant;
339
340 use super::*;
341 use crate::internal::base::testutil::DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE;
342
343 #[test]
344 fn no_recovery_before_reaching_threshold() {
345 let mut congestion_control =
346 CongestionControl::cubic_with_mss(DEFAULT_IPV4_MAXIMUM_SEGMENT_SIZE);
347 let old_cwnd = congestion_control.params.cwnd;
348 assert_eq!(congestion_control.params.ssthresh, u32::MAX);
349 assert!(congestion_control.on_dup_ack(SeqNum::new(0)));
350 congestion_control.on_ack(
351 NonZeroU32::new(1).unwrap(),
352 FakeInstant::from(Duration::from_secs(0)),
353 Duration::from_secs(1),
354 );
355 // We have only received one duplicate ack, receiving a new ACK should
356 // not mean "loss recovery" - we should not bump our cwnd to initial
357 // ssthresh (u32::MAX) and then overflow.
358 assert_eq!(old_cwnd + 1, congestion_control.params.cwnd);
359 }
360}