internet_checksum/
lib.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! # Benchmarks
13//!
14//! ## [`Checksum::add_bytes`]
15//!
16//! The following microbenchmarks were performed on a 2018 Google Pixelbook. Each benchmark
17//! constructs a [`Checksum`] object, calls [`Checksum::add_bytes`] with an input of the given
18//! number of bytes, and then calls [`Checksum::checksum`] to finalize. Average values were
19//! calculated over 3 trials.
20//!
21//! Bytes |    Time    |    Rate
22//! ----- | ---------- | ----------
23//!    20 |   2,649 ns |  7.55 MB/s
24//!    31 |   3,826 ns |  8.10 MB/s
25//!    32 |   3,871 ns |  8.27 MB/s
26//!    64 |   1,433 ns |  44.7 MB/s
27//!   128 |   2,225 ns |  57.5 MB/s
28//!   256 |   3,829 ns |  66.9 MB/s
29//!  1023 |  13,802 ns |  74.1 MB/s
30//!  1024 |  13,535 ns |  75.7 MB/s
31//!
32//! ## [`Checksum::add_bytes_small`]
33//!
34//! The following microbenchmarks were performed on a 2018 Google Pixelbook. Each benchmark
35//! constructs a [`Checksum`] object, calls [`Checksum::add_bytes_small`] with an input of the
36//! given number of bytes, and then calls [`Checksum::checksum`] to finalize. Average values
37//! were calculated over 3 trials.
38//!
39//! Bytes |    Time    |    Rate
40//! ----- | ---------- | ----------
41//!    20 |   2,639 ns |  7.57 MB/s
42//!    31 |   3,806 ns |  8.15 MB/s
43//!
44//! ## [`update`]
45//!
46//! The following microbenchmarks were performed on a 2018 Google Pixelbook. Each benchmark
47//! calls [`update`] with an original 2 byte checksum, and byteslices of specified lengths
48//! to be updated. Average values were calculated over 3 trials.
49//!
50//! Bytes |    Time    |    Rate
51//! ----- | ---------- | ----------
52//!     2 |   1,550 ns |  1.29 MB/s
53//!     4 |   1,972 ns |  2.03 MB/s
54//!     8 |   2,892 ns |  2.77 MB/s
55//!
56//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
57//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
58//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
59
60// Optimizations applied:
61//
62// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
63//    The sum of 16-bit integers can be computed in either byte order,
64//    so this actually saves us from the unnecessary byte swapping on
65//    an LE machine. As perfed on a gLinux workstation, that swapping
66//    can account for ~20% of the runtime.
67//
68// 1. Widen the accumulator: doing so enables us to process a bigger
69//    chunk of data once at a time, achieving some kind of poor man's
70//    SIMD. Currently a u128 counter is used on x86-64 and a u64 is
71//    used conservatively on other architectures.
72//
73// 2. Process more at a time: the old implementation uses a u32 accumulator
74//    but it only adds one u16 each time to implement deferred carry. In
75//    the current implementation we are processing a u128 once at a time
76//    on x86-64, which is 8 u16's. On other platforms, we are processing
77//    a u64 at a time, which is 4 u16's.
78//
79// 3. Induce the compiler to produce `adc` instruction: this is a very
80//    useful instruction to implement 1's complement addition and available
81//    on both x86 and ARM. The functions `adc_uXX` are for this use.
82//
83// 4. Eliminate branching as much as possible: the old implementation has
84//    if statements for detecting overflow of the u32 accumulator which
85//    is not needed when we can access the carry flag with `adc`. The old
86//    `normalize` function used to have a while loop to fold the u32,
87//    however, we can unroll that loop because we know ahead of time how
88//    much additions we need.
89//
90// 5. In the loop of `add_bytes`, the `adc_u64` is not used, instead,
91//    the `overflowing_add` is directly used. `adc_u64`'s carry flag
92//    comes from the current number being added while the slightly
93//    convoluted version in `add_bytes`, adding each number depends on
94//    the carry flag of the previous computation. I checked under release
95//    mode this issues 3 instructions instead of 4 for x86 and it should
96//    theoretically be beneficial, however, measurement showed me that it
97//    helps only a little. So this trick is not used for `update`.
98//
99// 6. When the input is small, fallback to deferred carry method. Deferred
100//    carry turns out to be very efficient when dealing with small buffers:
101//    If the input is small, the cost to deal with the tail may already
102//    outweigh the benefit of the unrolling itself. Some measurement
103//    confirms this theory.
104//
105// Results:
106//
107// Micro-benchmarks are run on an x86-64 gLinux workstation. In summary,
108// compared the baseline 0 which is prior to the byteorder independence
109// patch, there is a ~4x speedup.
110//
111// TODO: run this optimization on other platforms. I would expect
112// the situation on ARM a bit different because I am not sure
113// how much penalty there will be for misaligned read on ARM, or
114// whether it is even supported (On x86 there is generally no
115// penalty for misaligned read). If there will be penalties, we
116// should consider alignment as an optimization opportunity on ARM.
117
118// TODO(joshlf): Right-justify the columns above
119
120#![cfg_attr(feature = "benchmark", feature(test))]
121
122#[cfg(all(test, feature = "benchmark"))]
123extern crate test;
124
125// TODO(joshlf):
126// - Investigate optimizations proposed in RFC 1071 Section 2. The most
127//   promising on modern hardware is probably (C) Parallel Summation, although
128//   that needs to be balanced against (1) Deferred Carries. Benchmarks will
129//   need to be performed to determine which is faster in practice, and under
130//   what scenarios.
131
132/// Compute the checksum of "bytes".
133///
134/// `checksum(bytes)` is shorthand for:
135///
136/// ```rust
137/// # use internet_checksum::Checksum;
138/// # let bytes = &[];
139/// # let _ = {
140/// let mut c = Checksum::new();
141/// c.add_bytes(bytes);
142/// c.checksum()
143/// # };
144/// ```
145#[inline]
146pub fn checksum(bytes: &[u8]) -> [u8; 2] {
147    let mut c = Checksum::new();
148    c.add_bytes(bytes);
149    c.checksum()
150}
151
152#[cfg(target_arch = "x86_64")]
153type Accumulator = u128;
154#[cfg(not(target_arch = "x86_64"))]
155type Accumulator = u64;
156
157/// The threshold for small buffers, if the buffer is too small,
158/// fall back to the normal deferred carry method where a wide
159/// accumulator is used but one `u16` is added once at a time.
160// TODO: `64` works fine on x86_64, but this value may be different
161// on other platforms.
162const SMALL_BUF_THRESHOLD: usize = 64;
163
164/// The following macro unrolls operations on u16's to wider integers.
165///
166/// # Arguments
167///
168/// * `$arr`  - The byte slice being processed.
169/// * `$body` - The operation to operate on the wider integer. It should
170///             be a macro because functions are not options here.
171///
172///
173/// This macro will choose the "wide integer" for you, on x86-64,
174/// it will choose u128 as the "wide integer" and u64 anywhere else.
175macro_rules! loop_unroll {
176    (@inner $arr: ident, 16, $body:ident) => {
177        while $arr.len() >= 16 {
178            $body!(16, u128);
179        }
180        unroll_tail!($arr, 16, $body);
181    };
182
183    (@inner $arr: ident, 8, $body:ident) => {
184        while $arr.len() >= 8 {
185            $body!(8, u64);
186        }
187        unroll_tail!($arr, 8, $body);
188    };
189
190    ($arr: ident, $body: ident) => {
191        #[cfg(target_arch = "x86_64")]
192        loop_unroll!(@inner $arr, 16, $body);
193        #[cfg(not(target_arch = "x86_64"))]
194        loop_unroll!(@inner $arr, 8, $body);
195    };
196}
197
198/// At the the end of loop unrolling, we have to take care of bytes
199/// that are left over. For example, `unroll_tail!(bytes, 4, body)`
200/// expands to
201/// ```
202/// if bytes.len & 2 != 0 {
203///   body!(2, u16);
204/// }
205/// ```
206macro_rules! unroll_tail {
207    ($arr: ident, $n: literal, $read: ident, $body: ident) => {
208        if $arr.len() & $n != 0 {
209            $body!($n, $read);
210        }
211    };
212
213    ($arr: ident, 4, $body: ident) => {
214        unroll_tail!($arr, 2, u16, $body);
215    };
216
217    ($arr: ident, 8, $body: ident) => {
218        unroll_tail!($arr, 4, u32, $body);
219        unroll_tail!($arr, 4, $body);
220    };
221
222    ($arr: ident, 16, $body: ident) => {
223        unroll_tail!($arr, 8, u64, $body);
224        unroll_tail!($arr, 8, $body);
225    };
226}
227
228/// Updates bytes in an existing checksum.
229///
230/// `update` updates a checksum to reflect that the already-checksummed bytes
231/// `old` have been updated to contain the values in `new`. It implements the
232/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
233/// an even number offset in the original input. If an odd number offset byte
234/// needs to be updated, the caller should simply include the preceding byte as
235/// well. If an odd number of bytes is given, it is assumed that these are the
236/// last bytes of the input. If an odd number of bytes in the middle of the
237/// input needs to be updated, the preceding or following byte of the input
238/// should be added to make an even number of bytes.
239///
240/// # Panics
241///
242/// `update` panics if `old.len() != new.len()`.
243///
244/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
245#[inline]
246pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
247    assert_eq!(old.len(), new.len());
248    // We compute on the sum, not the one's complement of the sum. checksum
249    // is the one's complement of the sum, so we need to get back to the
250    // sum. Thus, we negate checksum.
251    // HC' = ~HC
252    let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
253
254    // Let's reuse `Checksum::add_bytes` to update our checksum
255    // so that we can get the speedup for free. Using
256    // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
257    let mut c1 = Checksum::new();
258    let mut c2 = Checksum::new();
259    c1.add_bytes(old);
260    c2.add_bytes(new);
261
262    // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
263    // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
264    // so we have to negate `c2.checksum_inner()` first to get m'.
265    // HC' += ~m, c1.checksum_inner() == ~m.
266    sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
267    // HC' += m', c2.checksum_inner() == ~m'.
268    sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
269    // HC' = ~HC.
270    (!normalize(sum)).to_ne_bytes()
271}
272
273/// RFC 1071 "internet checksum" computation.
274///
275/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
276/// updated in [RFC 1141] and [RFC 1624], which is used by many different
277/// protocols' packet formats. The checksum operates by computing the 1s
278/// complement of the 1s complement sum of successive 16-bit words of the input.
279///
280/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
281/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
282/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
283#[derive(Default)]
284pub struct Checksum {
285    sum: Accumulator,
286    // Since odd-length inputs are treated specially, we store the trailing byte
287    // for use in future calls to add_bytes(), and only treat it as a true
288    // trailing byte in checksum().
289    trailing_byte: Option<u8>,
290}
291
292impl Checksum {
293    /// Initialize a new checksum.
294    #[inline]
295    pub const fn new() -> Self {
296        Checksum { sum: 0, trailing_byte: None }
297    }
298
299    /// Add bytes to the checksum.
300    ///
301    /// If `bytes` does not contain an even number of bytes, a single zero byte
302    /// will be added to the end before updating the checksum.
303    ///
304    /// Note that `add_bytes` has some fixed overhead regardless of the size of
305    /// `bytes`. Where performance is a concern, prefer fewer calls to
306    /// `add_bytes` with larger input over more calls with smaller input.
307    #[inline]
308    pub fn add_bytes(&mut self, mut bytes: &[u8]) {
309        if bytes.len() < SMALL_BUF_THRESHOLD {
310            self.add_bytes_small(bytes);
311            return;
312        }
313
314        let mut sum = self.sum;
315        let mut carry = false;
316
317        // We are not using `adc_uXX` functions here, instead, we manually track
318        // the carry flag. This is because in `adc_uXX` functions, the carry
319        // flag depends on addition itself. So the assembly for that function
320        // reads as follows:
321        //
322        // mov %rdi, %rcx
323        // mov %rsi, %rax
324        // add %rcx, %rsi -- waste! only used to generate CF.
325        // adc %rdi, $rax -- the real useful instruction.
326        //
327        // So we had better to make us depend on the CF generated by the
328        // addition of the previous 16-bit word. The ideal assembly should look
329        // like:
330        //
331        // add 0(%rdi), %rax
332        // adc 8(%rdi), %rax
333        // adc 16(%rdi), %rax
334        // .... and so on ...
335        //
336        // Sadly, there are too many instructions that can affect the carry
337        // flag, and LLVM is not that optimized to find out the pattern and let
338        // all these adc instructions not interleaved. However, doing so results
339        // in 3 instructions instead of the original 4 instructions (the two
340        // mov's are still there) and it makes a difference on input size like
341        // 1023.
342
343        // The following macro is used as a `body` when invoking a `loop_unroll`
344        // macro. `$step` means how many bytes to handle at once; `$read` is
345        // supposed to be `u16`, `u32` and so on, it is used to get an unsigned
346        // integer of `$step` width from a byte slice; `$bytes` is the byte
347        // slice mentioned before, if omitted, it defaults to be `bytes`, which
348        // is the argument of the surrounding function.
349        macro_rules! update_sum_carry {
350            ($step: literal, $ty: ident, $bytes: expr) => {
351                let (s, c) = sum
352                    .overflowing_add($ty::from_ne_bytes($bytes.try_into().unwrap()) as Accumulator);
353                sum = s.wrapping_add(carry as Accumulator);
354                carry = c;
355                bytes = &bytes[$step..];
356            };
357            ($step: literal, $ty: ident) => {
358                update_sum_carry!($step, $ty, bytes[..$step]);
359            };
360        }
361
362        // if there's a trailing byte, consume it first
363        if let Some(byte) = self.trailing_byte {
364            update_sum_carry!(1, u16, [byte, bytes[0]]);
365            self.trailing_byte = None;
366        }
367
368        loop_unroll!(bytes, update_sum_carry);
369
370        if bytes.len() == 1 {
371            self.trailing_byte = Some(bytes[0]);
372        }
373
374        self.sum = sum + (carry as Accumulator);
375    }
376
377    /// The efficient fallback when the buffer is small.
378    ///
379    /// In this implementation, one `u16` is added once a
380    /// time, so we don't waste time on dealing with the
381    /// tail of the buffer. Besides, given that the accumulator
382    /// is large enough, when inputs are small, there should
383    /// hardly be overflows, so for any modern architecture,
384    /// there is little chance in misprediction.
385    // The inline attribute is needed here, micro benchmarks showed
386    // that it speeds up things.
387    #[inline(always)]
388    fn add_bytes_small(&mut self, mut bytes: &[u8]) {
389        if bytes.is_empty() {
390            return;
391        }
392
393        let mut sum = self.sum;
394        fn update_sum(acc: Accumulator, rhs: u16) -> Accumulator {
395            if let Some(updated) = acc.checked_add(rhs as Accumulator) {
396                updated
397            } else {
398                (normalize(acc) + rhs) as Accumulator
399            }
400        }
401
402        if let Some(byte) = self.trailing_byte {
403            sum = update_sum(sum, u16::from_ne_bytes([byte, bytes[0]]));
404            bytes = &bytes[1..];
405            self.trailing_byte = None;
406        }
407
408        bytes.chunks(2).for_each(|chunk| match chunk {
409            [byte] => self.trailing_byte = Some(*byte),
410            [first, second] => {
411                sum = update_sum(sum, u16::from_ne_bytes([*first, *second]));
412            }
413            bytes => unreachable!("{:?}", bytes),
414        });
415
416        self.sum = sum;
417    }
418
419    /// Computes the checksum, but in big endian byte order.
420    fn checksum_inner(&self) -> u16 {
421        let mut sum = self.sum;
422        if let Some(byte) = self.trailing_byte {
423            sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
424        }
425        !normalize(sum)
426    }
427
428    /// Computes the checksum, and returns the array representation.
429    ///
430    /// `checksum` returns the checksum of all data added using `add_bytes` so
431    /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
432    /// added after calling `checksum`, and they will be added to the checksum
433    /// as expected.
434    ///
435    /// If an odd number of bytes have been added so far, the checksum will be
436    /// computed as though a single 0 byte had been added at the end in order to
437    /// even out the length of the input.
438    #[inline]
439    pub fn checksum(&self) -> [u8; 2] {
440        self.checksum_inner().to_ne_bytes()
441    }
442}
443
444macro_rules! impl_adc {
445    ($name: ident, $t: ty) => {
446        /// implements 1's complement addition for $t,
447        /// exploiting the carry flag on a 2's complement machine.
448        /// In practice, the adc instruction will be generated.
449        fn $name(a: $t, b: $t) -> $t {
450            let (s, c) = a.overflowing_add(b);
451            s + (c as $t)
452        }
453    };
454}
455
456impl_adc!(adc_u16, u16);
457impl_adc!(adc_u32, u32);
458#[cfg(target_arch = "x86_64")]
459impl_adc!(adc_u64, u64);
460impl_adc!(adc_accumulator, Accumulator);
461
462/// Normalizes the accumulator by mopping up the
463/// overflow until it fits in a `u16`.
464fn normalize(a: Accumulator) -> u16 {
465    #[cfg(target_arch = "x86_64")]
466    return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
467    #[cfg(not(target_arch = "x86_64"))]
468    return normalize_64(a);
469}
470
471fn normalize_64(a: u64) -> u16 {
472    let t = adc_u32(a as u32, (a >> 32) as u32);
473    adc_u16(t as u16, (t >> 16) as u16)
474}
475
476#[cfg(all(test, feature = "benchmark"))]
477mod benchmarks {
478    extern crate test;
479    use super::*;
480
481    /// Benchmark time to calculate checksum with a single call to `add_bytes`
482    /// with 31 bytes.
483    #[bench]
484    fn bench_checksum_31(b: &mut test::Bencher) {
485        b.iter(|| {
486            let buf = test::black_box([0xFF; 31]);
487            let mut c = Checksum::new();
488            c.add_bytes(&buf);
489            test::black_box(c.checksum());
490        });
491    }
492
493    /// Benchmark time to calculate checksum with a single call to `add_bytes`
494    /// with 32 bytes.
495    #[bench]
496    fn bench_checksum_32(b: &mut test::Bencher) {
497        b.iter(|| {
498            let buf = test::black_box([0xFF; 32]);
499            let mut c = Checksum::new();
500            c.add_bytes(&buf);
501            test::black_box(c.checksum());
502        });
503    }
504
505    /// Benchmark time to calculate checksum with a single call to `add_bytes`
506    /// with 64 bytes.
507    #[bench]
508    fn bench_checksum_64(b: &mut test::Bencher) {
509        b.iter(|| {
510            let buf = test::black_box([0xFF; 64]);
511            let mut c = Checksum::new();
512            c.add_bytes(&buf);
513            test::black_box(c.checksum());
514        });
515    }
516
517    /// Benchmark time to calculate checksum with a single call to `add_bytes`
518    /// with 128 bytes.
519    #[bench]
520    fn bench_checksum_128(b: &mut test::Bencher) {
521        b.iter(|| {
522            let buf = test::black_box([0xFF; 128]);
523            let mut c = Checksum::new();
524            c.add_bytes(&buf);
525            test::black_box(c.checksum());
526        });
527    }
528
529    /// Benchmark time to calculate checksum with a single call to `add_bytes`
530    /// with 256 bytes.
531    #[bench]
532    fn bench_checksum_256(b: &mut test::Bencher) {
533        b.iter(|| {
534            let buf = test::black_box([0xFF; 256]);
535            let mut c = Checksum::new();
536            c.add_bytes(&buf);
537            test::black_box(c.checksum());
538        });
539    }
540
541    /// Benchmark time to calculate checksum with a single call to `add_bytes`
542    /// with 1024 bytes.
543    #[bench]
544    fn bench_checksum_1024(b: &mut test::Bencher) {
545        b.iter(|| {
546            let buf = test::black_box([0xFF; 1024]);
547            let mut c = Checksum::new();
548            c.add_bytes(&buf);
549            test::black_box(c.checksum());
550        });
551    }
552
553    /// Benchmark time to calculate checksum with a single call to `add_bytes`
554    /// with 1023 bytes.
555    #[bench]
556    fn bench_checksum_1023(b: &mut test::Bencher) {
557        b.iter(|| {
558            let buf = test::black_box([0xFF; 1023]);
559            let mut c = Checksum::new();
560            c.add_bytes(&buf);
561            test::black_box(c.checksum());
562        });
563    }
564
565    #[bench]
566    fn bench_checksum_20(b: &mut test::Bencher) {
567        b.iter(|| {
568            let buf = test::black_box([0xFF; 20]);
569            let mut c = Checksum::new();
570            c.add_bytes(&buf);
571            test::black_box(c.checksum());
572        });
573    }
574
575    #[bench]
576    fn bench_checksum_small_20(b: &mut test::Bencher) {
577        b.iter(|| {
578            let buf = test::black_box([0xFF; 20]);
579            let mut c = Checksum::new();
580            c.add_bytes_small(&buf);
581            test::black_box(c.checksum());
582        });
583    }
584
585    #[bench]
586    fn bench_checksum_small_31(b: &mut test::Bencher) {
587        b.iter(|| {
588            let buf = test::black_box([0xFF; 31]);
589            let mut c = Checksum::new();
590            c.add_bytes_small(&buf);
591            test::black_box(c.checksum());
592        });
593    }
594
595    #[bench]
596    fn bench_update_2(b: &mut test::Bencher) {
597        b.iter(|| {
598            let old = test::black_box([0x42; 2]);
599            let new = test::black_box([0xa0; 2]);
600            test::black_box(update([42; 2], &old[..], &new[..]));
601        });
602    }
603
604    #[bench]
605    fn bench_update_4(b: &mut test::Bencher) {
606        b.iter(|| {
607            let old = test::black_box([0x42; 4]);
608            let new = test::black_box([0xa0; 4]);
609            test::black_box(update([42; 2], &old[..], &new[..]));
610        });
611    }
612
613    #[bench]
614    fn bench_update_8(b: &mut test::Bencher) {
615        b.iter(|| {
616            let old = test::black_box([0x42; 8]);
617            let new = test::black_box([0xa0; 8]);
618            test::black_box(update([42; 2], &old[..], &new[..]));
619        });
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use rand::{Rng, SeedableRng};
626
627    use rand_xorshift::XorShiftRng;
628
629    use super::*;
630
631    /// Create a new deterministic RNG from a seed.
632    fn new_rng(mut seed: u128) -> XorShiftRng {
633        if seed == 0 {
634            // XorShiftRng can't take 0 seeds
635            seed = 1;
636        }
637        XorShiftRng::from_seed(seed.to_ne_bytes())
638    }
639
640    #[test]
641    fn test_checksum() {
642        for buf in IPV4_HEADERS {
643            // compute the checksum as normal
644            let mut c = Checksum::new();
645            c.add_bytes(&buf);
646            assert_eq!(c.checksum(), [0u8; 2]);
647            // compute the checksum one byte at a time to make sure our
648            // trailing_byte logic works
649            let mut c = Checksum::new();
650            for byte in *buf {
651                c.add_bytes(&[*byte]);
652            }
653            assert_eq!(c.checksum(), [0u8; 2]);
654
655            // Make sure that it works even if we overflow u32. Performing this
656            // loop 2 * 2^16 times is guaranteed to cause such an overflow
657            // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
658            // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
659            // making sure that, at least once, the sum gets smaller from one
660            // loop iteration to the next.
661            let mut c = Checksum::new();
662            c.add_bytes(&[0xFF, 0xFF]);
663            for _ in 0..((2 * (1 << 16)) - 1) {
664                c.add_bytes(&[0xFF, 0xFF]);
665            }
666            assert_eq!(c.checksum(), [0u8; 2]);
667        }
668    }
669
670    #[test]
671    fn test_update() {
672        for b in IPV4_HEADERS {
673            let mut buf = Vec::new();
674            buf.extend_from_slice(b);
675
676            let mut c = Checksum::new();
677            c.add_bytes(&buf);
678            assert_eq!(c.checksum(), [0u8; 2]);
679
680            // replace the destination IP with the loopback address
681            let old = [buf[16], buf[17], buf[18], buf[19]];
682            (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
683            let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
684            let from_scratch = {
685                let mut c = Checksum::new();
686                c.add_bytes(&buf);
687                c.checksum()
688            };
689            assert_eq!(updated, from_scratch);
690        }
691    }
692
693    #[test]
694    fn test_update_noop() {
695        for b in IPV4_HEADERS {
696            let mut buf = Vec::new();
697            buf.extend_from_slice(b);
698
699            let mut c = Checksum::new();
700            c.add_bytes(&buf);
701            assert_eq!(c.checksum(), [0u8; 2]);
702
703            // Replace the destination IP with the same address. I.e. this
704            // update should be a no-op.
705            let old = [buf[16], buf[17], buf[18], buf[19]];
706            let updated = update(c.checksum(), &old, &old);
707            let from_scratch = {
708                let mut c = Checksum::new();
709                c.add_bytes(&buf);
710                c.checksum()
711            };
712            assert_eq!(updated, from_scratch);
713        }
714    }
715
716    #[test]
717    fn test_smoke_update() {
718        let mut rng = new_rng(70_812_476_915_813);
719
720        for _ in 0..2048 {
721            // use an odd length so we test the odd length logic
722            const BUF_LEN: usize = 31;
723            let buf: [u8; BUF_LEN] = rng.gen();
724            let mut c = Checksum::new();
725            c.add_bytes(&buf);
726
727            let (begin, end) = loop {
728                let begin = rng.gen::<usize>() % BUF_LEN;
729                let end = begin + (rng.gen::<usize>() % (BUF_LEN + 1 - begin));
730                // update requires that begin is even and end is either even or
731                // the end of the input
732                if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
733                    break (begin, end);
734                }
735            };
736
737            let mut new_buf = buf;
738            for i in begin..end {
739                new_buf[i] = rng.gen();
740            }
741            let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
742            let from_scratch = {
743                let mut c = Checksum::new();
744                c.add_bytes(&new_buf);
745                c.checksum()
746            };
747            assert_eq!(updated, from_scratch);
748        }
749    }
750
751    #[test]
752    fn test_add_bytes_small_prop_test() {
753        // Since we have two independent implementations
754        // Now it is time for us to write a property test
755        // to ensure the checksum algorithm(s) are indeed correct.
756
757        let mut rng = new_rng(123478012483);
758        let mut c1 = Checksum::new();
759        let mut c2 = Checksum::new();
760        for len in 64..1_025 {
761            for _ in 0..4 {
762                let mut buf = vec![];
763                for _ in 0..len {
764                    buf.push(rng.gen());
765                }
766                c1.add_bytes(&buf[..]);
767                c2.add_bytes_small(&buf[..]);
768                assert_eq!(c1.checksum(), c2.checksum());
769                let n1 = c1.checksum_inner();
770                let n2 = c2.checksum_inner();
771                assert_eq!(n1, n2);
772                let mut t1 = Checksum::new();
773                let mut t2 = Checksum::new();
774                let mut t3 = Checksum::new();
775                t3.add_bytes(&buf[..]);
776                if buf.len() % 2 == 1 {
777                    buf.push(0);
778                }
779                assert_eq!(buf.len() % 2, 0);
780                buf.extend_from_slice(&t3.checksum());
781                t1.add_bytes(&buf[..]);
782                t2.add_bytes_small(&buf[..]);
783                assert_eq!(t1.checksum(), [0, 0]);
784                assert_eq!(t2.checksum(), [0, 0]);
785            }
786        }
787    }
788
789    /// IPv4 headers.
790    ///
791    /// This data was obtained by capturing live network traffic.
792    const IPV4_HEADERS: &[&[u8]] = &[
793        &[
794            0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
795            0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
796        ],
797        &[
798            0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
799            0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
800        ],
801        &[
802            0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
803            0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
804        ],
805    ];
806
807    // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
808    // overflow.
809    #[test]
810    fn test_large_buffer_addition_overflow() {
811        let mut sum = Checksum { sum: 0, trailing_byte: None };
812        let bytes = [
813            0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
814            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
815            255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
816        ];
817        sum.add_bytes(&bytes[..]);
818    }
819}