Skip to main content

internet_checksum/
lib.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
13//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
14//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
15
16// Optimizations applied:
17//
18// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
19//    The sum of 16-bit integers can be computed in either byte order,
20//    so this actually saves us from the unnecessary byte swapping on
21//    an LE machine. As perfed on a gLinux workstation, that swapping
22//    can account for ~20% of the runtime.
23//
24// 1. Widen the accumulator: doing so enables us to process a bigger
25//    chunk of data once at a time, achieving some kind of poor man's
26//    SIMD. Currently a u128 counter is used on x86-64 and a u64 is
27//    used conservatively on other architectures.
28//
29// 2. Process more at a time: the old implementation uses a u32 accumulator
30//    but it only adds one u16 each time to implement deferred carry. In
31//    the current implementation we are processing a u128 once at a time
32//    on x86-64, which is 8 u16's. On other platforms, we are processing
33//    a u64 at a time, which is 4 u16's.
34//
35// 3. Induce the compiler to produce `adc` instruction: this is a very
36//    useful instruction to implement 1's complement addition and available
37//    on both x86 and ARM. The functions `adc_uXX` are for this use.
38//
39// 4. Eliminate branching as much as possible: the old implementation has
40//    if statements for detecting overflow of the u32 accumulator which
41//    is not needed when we can access the carry flag with `adc`. The old
42//    `normalize` function used to have a while loop to fold the u32,
43//    however, we can unroll that loop because we know ahead of time how
44//    much additions we need.
45//
46// 5. In the loop of `add_bytes`, the `adc_u64` is not used, instead,
47//    the `overflowing_add` is directly used. `adc_u64`'s carry flag
48//    comes from the current number being added while the slightly
49//    convoluted version in `add_bytes`, adding each number depends on
50//    the carry flag of the previous computation. I checked under release
51//    mode this issues 3 instructions instead of 4 for x86 and it should
52//    theoretically be beneficial, however, measurement showed me that it
53//    helps only a little. So this trick is not used for `update`.
54//
55// Results:
56//
57// Micro-benchmarks are run on an x86-64 gLinux workstation. In summary,
58// compared the baseline 0 which is prior to the byteorder independence
59// patch, there is a ~4x speedup.
60//
61// TODO: run this optimization on other platforms. I would expect
62// the situation on ARM a bit different because I am not sure
63// how much penalty there will be for misaligned read on ARM, or
64// whether it is even supported (On x86 there is generally no
65// penalty for misaligned read). If there will be penalties, we
66// should consider alignment as an optimization opportunity on ARM.
67
68// TODO(joshlf): Right-justify the columns above
69
70// TODO(joshlf):
71// - Investigate optimizations proposed in RFC 1071 Section 2. The most
72//   promising on modern hardware is probably (C) Parallel Summation, although
73//   that needs to be balanced against (1) Deferred Carries. Benchmarks will
74//   need to be performed to determine which is faster in practice, and under
75//   what scenarios.
76
77/// Compute the checksum of "bytes".
78///
79/// `checksum(bytes)` is shorthand for:
80///
81/// ```rust
82/// # use internet_checksum::Checksum;
83/// # let bytes = &[];
84/// # let _ = {
85/// let mut c = Checksum::new();
86/// c.add_bytes(bytes);
87/// c.checksum()
88/// # };
89/// ```
90#[inline]
91pub fn checksum(bytes: &[u8]) -> [u8; 2] {
92    let mut c = Checksum::new();
93    c.add_bytes(bytes);
94    c.checksum()
95}
96
97#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
98type Accumulator = u128;
99#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
100type Accumulator = u64;
101
102/// Updates bytes in an existing checksum.
103///
104/// `update` updates a checksum to reflect that the already-checksummed bytes
105/// `old` have been updated to contain the values in `new`. It implements the
106/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
107/// an even number offset in the original input. If an odd number offset byte
108/// needs to be updated, the caller should simply include the preceding byte as
109/// well. If an odd number of bytes is given, it is assumed that these are the
110/// last bytes of the input. If an odd number of bytes in the middle of the
111/// input needs to be updated, the preceding or following byte of the input
112/// should be added to make an even number of bytes.
113///
114/// # Panics
115///
116/// `update` panics if `old.len() != new.len()`.
117///
118/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
119#[inline]
120pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
121    assert_eq!(old.len(), new.len());
122    // We compute on the sum, not the one's complement of the sum. checksum
123    // is the one's complement of the sum, so we need to get back to the
124    // sum. Thus, we negate checksum.
125    // HC' = ~HC
126    let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
127
128    // Let's reuse `Checksum::add_bytes` to update our checksum
129    // so that we can get the speedup for free. Using
130    // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
131    let mut c1 = Checksum::new();
132    let mut c2 = Checksum::new();
133    c1.add_bytes(old);
134    c2.add_bytes(new);
135
136    // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
137    // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
138    // so we have to negate `c2.checksum_inner()` first to get m'.
139    // HC' += ~m, c1.checksum_inner() == ~m.
140    sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
141    // HC' += m', c2.checksum_inner() == ~m'.
142    sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
143    // HC' = ~HC.
144    (!normalize(sum)).to_ne_bytes()
145}
146
147/// RFC 1071 "internet checksum" computation.
148///
149/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
150/// updated in [RFC 1141] and [RFC 1624], which is used by many different
151/// protocols' packet formats. The checksum operates by computing the 1s
152/// complement of the 1s complement sum of successive 16-bit words of the input.
153///
154/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
155/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
156/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
157#[derive(Default)]
158pub struct Checksum {
159    sum: Accumulator,
160    // Since odd-length inputs are treated specially, we store the trailing byte
161    // for use in future calls to add_bytes(), and only treat it as a true
162    // trailing byte in checksum().
163    trailing_byte: Option<u8>,
164}
165
166impl Checksum {
167    /// Initialize a new checksum.
168    #[inline]
169    pub const fn new() -> Self {
170        Checksum { sum: 0, trailing_byte: None }
171    }
172
173    /// Add bytes to the checksum.
174    ///
175    /// If `bytes` does not contain an even number of bytes, a single zero byte
176    /// will be added to the end before updating the checksum.
177    ///
178    /// Note that `add_bytes` has some fixed overhead regardless of the size of
179    /// `bytes`. Where performance is a concern, prefer fewer calls to
180    /// `add_bytes` with larger input over more calls with smaller input.
181    #[inline]
182    pub fn add_bytes(&mut self, mut bytes: &[u8]) {
183        if bytes.is_empty() {
184            return;
185        }
186
187        let mut sum = self.sum;
188        let mut carry = false;
189
190        // We are not using `adc_uXX` functions here, instead, we manually track
191        // the carry flag. This is because in `adc_uXX` functions, the carry
192        // flag depends on addition itself. So the assembly for that function
193        // reads as follows:
194        //
195        // mov %rdi, %rcx
196        // mov %rsi, %rax
197        // add %rcx, %rsi -- waste! only used to generate CF.
198        // adc %rdi, $rax -- the real useful instruction.
199        //
200        // So we had better to make us depend on the CF generated by the
201        // addition of the previous 16-bit word. The ideal assembly should look
202        // like:
203        //
204        // add 0(%rdi), %rax
205        // adc 8(%rdi), %rax
206        // adc 16(%rdi), %rax
207        // .... and so on ...
208        //
209        // Sadly, there are too many instructions that can affect the carry
210        // flag, and LLVM is not that optimized to find out the pattern and let
211        // all these adc instructions not interleaved. However, doing so results
212        // in 3 instructions instead of the original 4 instructions (the two
213        // mov's are still there) and it makes a difference on input size like
214        // 1023.
215        macro_rules! update_sum_carry {
216            ($ty: ident, $chunk: expr) => {
217                let (s, c) = sum.overflowing_add($ty::from_ne_bytes($chunk) as Accumulator);
218                sum = s.wrapping_add(carry as Accumulator);
219                carry = c;
220            };
221        }
222
223        const ACCUMULATOR_BYTES: usize = (Accumulator::BITS / 8) as usize;
224        while let Some(chunk) = bytes.first_chunk::<ACCUMULATOR_BYTES>() {
225            update_sum_carry!(Accumulator, *chunk);
226            bytes = &bytes[ACCUMULATOR_BYTES..];
227        }
228
229        // Handle the tail.
230        if let Some(chunk) = bytes.first_chunk::<8>() {
231            update_sum_carry!(u64, *chunk);
232            bytes = &bytes[8..];
233        }
234        if let Some(chunk) = bytes.first_chunk::<4>() {
235            update_sum_carry!(u32, *chunk);
236            bytes = &bytes[4..];
237        }
238        if let Some(chunk) = bytes.first_chunk::<2>() {
239            update_sum_carry!(u16, *chunk);
240            bytes = &bytes[2..];
241        }
242        if bytes.len() == 1 {
243            if let Some(existing) = self.trailing_byte.take() {
244                // We already had a trailing byte. Deal with them both.
245                update_sum_carry!(u16, [existing, bytes[0]]);
246            } else {
247                // Otherwise, stash the trailing byte.
248                self.trailing_byte = Some(bytes[0])
249            }
250        }
251
252        self.sum = sum + (carry as Accumulator);
253    }
254
255    /// Computes the checksum, but in big endian byte order.
256    fn checksum_inner(&self) -> u16 {
257        let mut sum = self.sum;
258        if let Some(byte) = self.trailing_byte {
259            sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
260        }
261        !normalize(sum)
262    }
263
264    /// Computes the one's complement sum and returns the array representation.
265    ///
266    /// `partial_checksum` returns the one's complement sum of all data added
267    /// using `add_bytes` so far. Calling `partial_checksum` does *not* reset
268    /// the checksum. More bytes may be added after calling `partial_checksum`,
269    /// and they will be added to the checksum as expected.
270    ///
271    /// `partial_checksum` will return `None` if an odd number of bytes have
272    /// been added so far.
273    pub fn partial_checksum(&self) -> Option<[u8; 2]> {
274        if self.trailing_byte.is_some() {
275            return None;
276        }
277        Some(normalize(self.sum).to_ne_bytes())
278    }
279
280    /// Computes the checksum, and returns the array representation.
281    ///
282    /// `checksum` returns the checksum of all data added using `add_bytes` so
283    /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
284    /// added after calling `checksum`, and they will be added to the checksum
285    /// as expected.
286    ///
287    /// If an odd number of bytes have been added so far, the checksum will be
288    /// computed as though a single 0 byte had been added at the end in order to
289    /// even out the length of the input.
290    #[inline]
291    pub fn checksum(&self) -> [u8; 2] {
292        self.checksum_inner().to_ne_bytes()
293    }
294}
295
296macro_rules! impl_adc {
297    ($name: ident, $t: ty) => {
298        /// implements 1's complement addition for $t,
299        /// exploiting the carry flag on a 2's complement machine.
300        /// In practice, the adc instruction will be generated.
301        fn $name(a: $t, b: $t) -> $t {
302            let (s, c) = a.overflowing_add(b);
303            s + (c as $t)
304        }
305    };
306}
307
308impl_adc!(adc_u16, u16);
309impl_adc!(adc_u32, u32);
310#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
311impl_adc!(adc_u64, u64);
312impl_adc!(adc_accumulator, Accumulator);
313
314/// Normalizes the accumulator by mopping up the
315/// overflow until it fits in a `u16`.
316fn normalize(a: Accumulator) -> u16 {
317    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
318    return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
319    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
320    return normalize_64(a);
321}
322
323fn normalize_64(a: u64) -> u16 {
324    let t = adc_u32(a as u32, (a >> 32) as u32);
325    adc_u16(t as u16, (t >> 16) as u16)
326}
327
328#[cfg(test)]
329mod tests {
330    use rand::{Rng, SeedableRng};
331
332    use rand_xorshift::XorShiftRng;
333
334    use super::*;
335
336    /// Create a new deterministic RNG from a seed.
337    fn new_rng(mut seed: u128) -> XorShiftRng {
338        if seed == 0 {
339            // XorShiftRng can't take 0 seeds
340            seed = 1;
341        }
342        XorShiftRng::from_seed(seed.to_ne_bytes())
343    }
344
345    #[test]
346    fn test_checksum() {
347        for buf in IPV4_HEADERS {
348            // compute the checksum as normal
349            let mut c = Checksum::new();
350            c.add_bytes(&buf);
351            assert_eq!(c.checksum(), [0u8; 2]);
352            // compute the checksum one byte at a time to make sure our
353            // trailing_byte logic works
354            let mut c = Checksum::new();
355            for byte in *buf {
356                c.add_bytes(&[*byte]);
357            }
358            assert_eq!(c.checksum(), [0u8; 2]);
359
360            // Make sure that it works even if we overflow u32. Performing this
361            // loop 2 * 2^16 times is guaranteed to cause such an overflow
362            // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
363            // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
364            // making sure that, at least once, the sum gets smaller from one
365            // loop iteration to the next.
366            let mut c = Checksum::new();
367            c.add_bytes(&[0xFF, 0xFF]);
368            for _ in 0..((2 * (1 << 16)) - 1) {
369                c.add_bytes(&[0xFF, 0xFF]);
370            }
371            assert_eq!(c.checksum(), [0u8; 2]);
372        }
373    }
374
375    #[test]
376    fn test_partial_checksum() {
377        for buf in IPV4_HEADERS {
378            // Partial checksum should compute for even length slices.
379            for i in (0..buf.len()).step_by(2) {
380                let mut part = Checksum::new();
381                part.add_bytes(&buf[..i]);
382
383                let mut c = Checksum::new();
384                c.add_bytes(
385                    &part
386                        .partial_checksum()
387                        .expect("partial checksum should compute for even length slices"),
388                );
389                c.add_bytes(&buf[i..]);
390                assert_eq!(c.checksum(), [0u8; 2]);
391            }
392            // Partial checksum should not compute for odd length slices.
393            for i in (1..buf.len()).step_by(2) {
394                let mut part = Checksum::new();
395                part.add_bytes(&buf[..i]);
396                assert_eq!(part.partial_checksum(), None);
397            }
398            // Partial checksum should be the complement of the checksum.
399            let mut c = Checksum::new();
400            c.add_bytes(buf);
401            assert_eq!(c.partial_checksum(), Some([0xFF; 2]));
402        }
403    }
404
405    #[test]
406    fn test_update() {
407        for b in IPV4_HEADERS {
408            let mut buf = Vec::new();
409            buf.extend_from_slice(b);
410
411            let mut c = Checksum::new();
412            c.add_bytes(&buf);
413            assert_eq!(c.checksum(), [0u8; 2]);
414
415            // replace the destination IP with the loopback address
416            let old = [buf[16], buf[17], buf[18], buf[19]];
417            (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
418            let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
419            let from_scratch = {
420                let mut c = Checksum::new();
421                c.add_bytes(&buf);
422                c.checksum()
423            };
424            assert_eq!(updated, from_scratch);
425        }
426    }
427
428    #[test]
429    fn test_update_noop() {
430        for b in IPV4_HEADERS {
431            let mut buf = Vec::new();
432            buf.extend_from_slice(b);
433
434            let mut c = Checksum::new();
435            c.add_bytes(&buf);
436            assert_eq!(c.checksum(), [0u8; 2]);
437
438            // Replace the destination IP with the same address. I.e. this
439            // update should be a no-op.
440            let old = [buf[16], buf[17], buf[18], buf[19]];
441            let updated = update(c.checksum(), &old, &old);
442            let from_scratch = {
443                let mut c = Checksum::new();
444                c.add_bytes(&buf);
445                c.checksum()
446            };
447            assert_eq!(updated, from_scratch);
448        }
449    }
450
451    #[test]
452    fn test_smoke_update() {
453        let mut rng = new_rng(70_812_476_915_813);
454
455        for _ in 0..2048 {
456            // use an odd length so we test the odd length logic
457            const BUF_LEN: usize = 31;
458            let buf: [u8; BUF_LEN] = rng.random();
459            let mut c = Checksum::new();
460            c.add_bytes(&buf);
461
462            let (begin, end) = loop {
463                let begin = rng.random_range(0..BUF_LEN);
464                let end = begin + (rng.random_range(0..(BUF_LEN + 1 - begin)));
465                // update requires that begin is even and end is either even or
466                // the end of the input
467                if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
468                    break (begin, end);
469                }
470            };
471
472            let mut new_buf = buf;
473            for i in begin..end {
474                new_buf[i] = rng.random();
475            }
476            let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
477            let from_scratch = {
478                let mut c = Checksum::new();
479                c.add_bytes(&new_buf);
480                c.checksum()
481            };
482            assert_eq!(updated, from_scratch);
483        }
484    }
485
486    /// IPv4 headers.
487    ///
488    /// This data was obtained by capturing live network traffic.
489    const IPV4_HEADERS: &[&[u8]] = &[
490        &[
491            0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
492            0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
493        ],
494        &[
495            0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
496            0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
497        ],
498        &[
499            0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
500            0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
501        ],
502    ];
503
504    // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
505    // overflow.
506    #[test]
507    fn test_large_buffer_addition_overflow() {
508        let mut sum = Checksum { sum: 0, trailing_byte: None };
509        let bytes = [
510            0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
511            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
512            255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
513        ];
514        sum.add_bytes(&bytes[..]);
515    }
516}