internet_checksum/lib.rs
1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! RFC 1071 "internet checksum" computation.
6//!
7//! This crate implements the "internet checksum" defined in [RFC 1071] and
8//! updated in [RFC 1141] and [RFC 1624], which is used by many different
9//! protocols' packet formats. The checksum operates by computing the 1s
10//! complement of the 1s complement sum of successive 16-bit words of the input.
11//!
12//! [RFC 1071]: https://tools.ietf.org/html/rfc1071
13//! [RFC 1141]: https://tools.ietf.org/html/rfc1141
14//! [RFC 1624]: https://tools.ietf.org/html/rfc1624
15
16// Optimizations applied:
17//
18// 0. Byteorder independence: as described in RFC 1071 section 2.(B)
19// The sum of 16-bit integers can be computed in either byte order,
20// so this actually saves us from the unnecessary byte swapping on
21// an LE machine. As perfed on a gLinux workstation, that swapping
22// can account for ~20% of the runtime.
23//
24// 1. Widen the accumulator: doing so enables us to process a bigger
25// chunk of data once at a time, achieving some kind of poor man's
26// SIMD. Currently a u128 counter is used on x86-64 and a u64 is
27// used conservatively on other architectures.
28//
29// 2. Process more at a time: the old implementation uses a u32 accumulator
30// but it only adds one u16 each time to implement deferred carry. In
31// the current implementation we are processing a u128 once at a time
32// on x86-64, which is 8 u16's. On other platforms, we are processing
33// a u64 at a time, which is 4 u16's.
34//
35// 3. Induce the compiler to produce `adc` instruction: this is a very
36// useful instruction to implement 1's complement addition and available
37// on both x86 and ARM. The functions `adc_uXX` are for this use.
38//
39// 4. Eliminate branching as much as possible: the old implementation has
40// if statements for detecting overflow of the u32 accumulator which
41// is not needed when we can access the carry flag with `adc`. The old
42// `normalize` function used to have a while loop to fold the u32,
43// however, we can unroll that loop because we know ahead of time how
44// much additions we need.
45//
46// 5. In the loop of `add_bytes`, the `adc_u64` is not used, instead,
47// the `overflowing_add` is directly used. `adc_u64`'s carry flag
48// comes from the current number being added while the slightly
49// convoluted version in `add_bytes`, adding each number depends on
50// the carry flag of the previous computation. I checked under release
51// mode this issues 3 instructions instead of 4 for x86 and it should
52// theoretically be beneficial, however, measurement showed me that it
53// helps only a little. So this trick is not used for `update`.
54//
55// Results:
56//
57// Micro-benchmarks are run on an x86-64 gLinux workstation. In summary,
58// compared the baseline 0 which is prior to the byteorder independence
59// patch, there is a ~4x speedup.
60//
61// TODO: run this optimization on other platforms. I would expect
62// the situation on ARM a bit different because I am not sure
63// how much penalty there will be for misaligned read on ARM, or
64// whether it is even supported (On x86 there is generally no
65// penalty for misaligned read). If there will be penalties, we
66// should consider alignment as an optimization opportunity on ARM.
67
68// TODO(joshlf): Right-justify the columns above
69
70// TODO(joshlf):
71// - Investigate optimizations proposed in RFC 1071 Section 2. The most
72// promising on modern hardware is probably (C) Parallel Summation, although
73// that needs to be balanced against (1) Deferred Carries. Benchmarks will
74// need to be performed to determine which is faster in practice, and under
75// what scenarios.
76
77/// Compute the checksum of "bytes".
78///
79/// `checksum(bytes)` is shorthand for:
80///
81/// ```rust
82/// # use internet_checksum::Checksum;
83/// # let bytes = &[];
84/// # let _ = {
85/// let mut c = Checksum::new();
86/// c.add_bytes(bytes);
87/// c.checksum()
88/// # };
89/// ```
90#[inline]
91pub fn checksum(bytes: &[u8]) -> [u8; 2] {
92 let mut c = Checksum::new();
93 c.add_bytes(bytes);
94 c.checksum()
95}
96
97#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
98type Accumulator = u128;
99#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
100type Accumulator = u64;
101
102/// Updates bytes in an existing checksum.
103///
104/// `update` updates a checksum to reflect that the already-checksummed bytes
105/// `old` have been updated to contain the values in `new`. It implements the
106/// algorithm described in Equation 3 in [RFC 1624]. The first byte must be at
107/// an even number offset in the original input. If an odd number offset byte
108/// needs to be updated, the caller should simply include the preceding byte as
109/// well. If an odd number of bytes is given, it is assumed that these are the
110/// last bytes of the input. If an odd number of bytes in the middle of the
111/// input needs to be updated, the preceding or following byte of the input
112/// should be added to make an even number of bytes.
113///
114/// # Panics
115///
116/// `update` panics if `old.len() != new.len()`.
117///
118/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
119#[inline]
120pub fn update(checksum: [u8; 2], old: &[u8], new: &[u8]) -> [u8; 2] {
121 assert_eq!(old.len(), new.len());
122 // We compute on the sum, not the one's complement of the sum. checksum
123 // is the one's complement of the sum, so we need to get back to the
124 // sum. Thus, we negate checksum.
125 // HC' = ~HC
126 let mut sum = !u16::from_ne_bytes(checksum) as Accumulator;
127
128 // Let's reuse `Checksum::add_bytes` to update our checksum
129 // so that we can get the speedup for free. Using
130 // [RFC 1071 Eqn. 3], we can efficiently update our new checksum.
131 let mut c1 = Checksum::new();
132 let mut c2 = Checksum::new();
133 c1.add_bytes(old);
134 c2.add_bytes(new);
135
136 // Note, `c1.checksum_inner()` is actually ~m in [Eqn. 3]
137 // `c2.checksum_inner()` is actually ~m' in [Eqn. 3]
138 // so we have to negate `c2.checksum_inner()` first to get m'.
139 // HC' += ~m, c1.checksum_inner() == ~m.
140 sum = adc_accumulator(sum, c1.checksum_inner() as Accumulator);
141 // HC' += m', c2.checksum_inner() == ~m'.
142 sum = adc_accumulator(sum, !c2.checksum_inner() as Accumulator);
143 // HC' = ~HC.
144 (!normalize(sum)).to_ne_bytes()
145}
146
147/// RFC 1071 "internet checksum" computation.
148///
149/// `Checksum` implements the "internet checksum" defined in [RFC 1071] and
150/// updated in [RFC 1141] and [RFC 1624], which is used by many different
151/// protocols' packet formats. The checksum operates by computing the 1s
152/// complement of the 1s complement sum of successive 16-bit words of the input.
153///
154/// [RFC 1071]: https://tools.ietf.org/html/rfc1071
155/// [RFC 1141]: https://tools.ietf.org/html/rfc1141
156/// [RFC 1624]: https://tools.ietf.org/html/rfc1624
157#[derive(Default)]
158pub struct Checksum {
159 sum: Accumulator,
160 // Since odd-length inputs are treated specially, we store the trailing byte
161 // for use in future calls to add_bytes(), and only treat it as a true
162 // trailing byte in checksum().
163 trailing_byte: Option<u8>,
164}
165
166impl Checksum {
167 /// Initialize a new checksum.
168 #[inline]
169 pub const fn new() -> Self {
170 Checksum { sum: 0, trailing_byte: None }
171 }
172
173 /// Add bytes to the checksum.
174 ///
175 /// If `bytes` does not contain an even number of bytes, a single zero byte
176 /// will be added to the end before updating the checksum.
177 ///
178 /// Note that `add_bytes` has some fixed overhead regardless of the size of
179 /// `bytes`. Where performance is a concern, prefer fewer calls to
180 /// `add_bytes` with larger input over more calls with smaller input.
181 #[inline]
182 pub fn add_bytes(&mut self, mut bytes: &[u8]) {
183 if bytes.is_empty() {
184 return;
185 }
186
187 let mut sum = self.sum;
188 let mut carry = false;
189
190 // We are not using `adc_uXX` functions here, instead, we manually track
191 // the carry flag. This is because in `adc_uXX` functions, the carry
192 // flag depends on addition itself. So the assembly for that function
193 // reads as follows:
194 //
195 // mov %rdi, %rcx
196 // mov %rsi, %rax
197 // add %rcx, %rsi -- waste! only used to generate CF.
198 // adc %rdi, $rax -- the real useful instruction.
199 //
200 // So we had better to make us depend on the CF generated by the
201 // addition of the previous 16-bit word. The ideal assembly should look
202 // like:
203 //
204 // add 0(%rdi), %rax
205 // adc 8(%rdi), %rax
206 // adc 16(%rdi), %rax
207 // .... and so on ...
208 //
209 // Sadly, there are too many instructions that can affect the carry
210 // flag, and LLVM is not that optimized to find out the pattern and let
211 // all these adc instructions not interleaved. However, doing so results
212 // in 3 instructions instead of the original 4 instructions (the two
213 // mov's are still there) and it makes a difference on input size like
214 // 1023.
215 macro_rules! update_sum_carry {
216 ($ty: ident, $chunk: expr) => {
217 let (s, c) = sum.overflowing_add($ty::from_ne_bytes($chunk) as Accumulator);
218 sum = s.wrapping_add(carry as Accumulator);
219 carry = c;
220 };
221 }
222
223 const ACCUMULATOR_BYTES: usize = (Accumulator::BITS / 8) as usize;
224 while let Some(chunk) = bytes.first_chunk::<ACCUMULATOR_BYTES>() {
225 update_sum_carry!(Accumulator, *chunk);
226 bytes = &bytes[ACCUMULATOR_BYTES..];
227 }
228
229 // Handle the tail.
230 if let Some(chunk) = bytes.first_chunk::<8>() {
231 update_sum_carry!(u64, *chunk);
232 bytes = &bytes[8..];
233 }
234 if let Some(chunk) = bytes.first_chunk::<4>() {
235 update_sum_carry!(u32, *chunk);
236 bytes = &bytes[4..];
237 }
238 if let Some(chunk) = bytes.first_chunk::<2>() {
239 update_sum_carry!(u16, *chunk);
240 bytes = &bytes[2..];
241 }
242 if bytes.len() == 1 {
243 if let Some(existing) = self.trailing_byte.take() {
244 // We already had a trailing byte. Deal with them both.
245 update_sum_carry!(u16, [existing, bytes[0]]);
246 } else {
247 // Otherwise, stash the trailing byte.
248 self.trailing_byte = Some(bytes[0])
249 }
250 }
251
252 self.sum = sum + (carry as Accumulator);
253 }
254
255 /// Computes the checksum, but in big endian byte order.
256 fn checksum_inner(&self) -> u16 {
257 let mut sum = self.sum;
258 if let Some(byte) = self.trailing_byte {
259 sum = adc_accumulator(sum, u16::from_ne_bytes([byte, 0]) as Accumulator);
260 }
261 !normalize(sum)
262 }
263
264 /// Computes the one's complement sum and returns the array representation.
265 ///
266 /// `partial_checksum` returns the one's complement sum of all data added
267 /// using `add_bytes` so far. Calling `partial_checksum` does *not* reset
268 /// the checksum. More bytes may be added after calling `partial_checksum`,
269 /// and they will be added to the checksum as expected.
270 ///
271 /// `partial_checksum` will return `None` if an odd number of bytes have
272 /// been added so far.
273 pub fn partial_checksum(&self) -> Option<[u8; 2]> {
274 if self.trailing_byte.is_some() {
275 return None;
276 }
277 Some(normalize(self.sum).to_ne_bytes())
278 }
279
280 /// Computes the checksum, and returns the array representation.
281 ///
282 /// `checksum` returns the checksum of all data added using `add_bytes` so
283 /// far. Calling `checksum` does *not* reset the checksum. More bytes may be
284 /// added after calling `checksum`, and they will be added to the checksum
285 /// as expected.
286 ///
287 /// If an odd number of bytes have been added so far, the checksum will be
288 /// computed as though a single 0 byte had been added at the end in order to
289 /// even out the length of the input.
290 #[inline]
291 pub fn checksum(&self) -> [u8; 2] {
292 self.checksum_inner().to_ne_bytes()
293 }
294}
295
296macro_rules! impl_adc {
297 ($name: ident, $t: ty) => {
298 /// implements 1's complement addition for $t,
299 /// exploiting the carry flag on a 2's complement machine.
300 /// In practice, the adc instruction will be generated.
301 fn $name(a: $t, b: $t) -> $t {
302 let (s, c) = a.overflowing_add(b);
303 s + (c as $t)
304 }
305 };
306}
307
308impl_adc!(adc_u16, u16);
309impl_adc!(adc_u32, u32);
310#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
311impl_adc!(adc_u64, u64);
312impl_adc!(adc_accumulator, Accumulator);
313
314/// Normalizes the accumulator by mopping up the
315/// overflow until it fits in a `u16`.
316fn normalize(a: Accumulator) -> u16 {
317 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
318 return normalize_64(adc_u64(a as u64, (a >> 64) as u64));
319 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
320 return normalize_64(a);
321}
322
323fn normalize_64(a: u64) -> u16 {
324 let t = adc_u32(a as u32, (a >> 32) as u32);
325 adc_u16(t as u16, (t >> 16) as u16)
326}
327
328#[cfg(test)]
329mod tests {
330 use rand::{Rng, SeedableRng};
331
332 use rand_xorshift::XorShiftRng;
333
334 use super::*;
335
336 /// Create a new deterministic RNG from a seed.
337 fn new_rng(mut seed: u128) -> XorShiftRng {
338 if seed == 0 {
339 // XorShiftRng can't take 0 seeds
340 seed = 1;
341 }
342 XorShiftRng::from_seed(seed.to_ne_bytes())
343 }
344
345 #[test]
346 fn test_checksum() {
347 for buf in IPV4_HEADERS {
348 // compute the checksum as normal
349 let mut c = Checksum::new();
350 c.add_bytes(&buf);
351 assert_eq!(c.checksum(), [0u8; 2]);
352 // compute the checksum one byte at a time to make sure our
353 // trailing_byte logic works
354 let mut c = Checksum::new();
355 for byte in *buf {
356 c.add_bytes(&[*byte]);
357 }
358 assert_eq!(c.checksum(), [0u8; 2]);
359
360 // Make sure that it works even if we overflow u32. Performing this
361 // loop 2 * 2^16 times is guaranteed to cause such an overflow
362 // because 0xFFFF + 0xFFFF > 2^16, and we're effectively adding
363 // (0xFFFF + 0xFFFF) 2^16 times. We verify the overflow as well by
364 // making sure that, at least once, the sum gets smaller from one
365 // loop iteration to the next.
366 let mut c = Checksum::new();
367 c.add_bytes(&[0xFF, 0xFF]);
368 for _ in 0..((2 * (1 << 16)) - 1) {
369 c.add_bytes(&[0xFF, 0xFF]);
370 }
371 assert_eq!(c.checksum(), [0u8; 2]);
372 }
373 }
374
375 #[test]
376 fn test_partial_checksum() {
377 for buf in IPV4_HEADERS {
378 // Partial checksum should compute for even length slices.
379 for i in (0..buf.len()).step_by(2) {
380 let mut part = Checksum::new();
381 part.add_bytes(&buf[..i]);
382
383 let mut c = Checksum::new();
384 c.add_bytes(
385 &part
386 .partial_checksum()
387 .expect("partial checksum should compute for even length slices"),
388 );
389 c.add_bytes(&buf[i..]);
390 assert_eq!(c.checksum(), [0u8; 2]);
391 }
392 // Partial checksum should not compute for odd length slices.
393 for i in (1..buf.len()).step_by(2) {
394 let mut part = Checksum::new();
395 part.add_bytes(&buf[..i]);
396 assert_eq!(part.partial_checksum(), None);
397 }
398 // Partial checksum should be the complement of the checksum.
399 let mut c = Checksum::new();
400 c.add_bytes(buf);
401 assert_eq!(c.partial_checksum(), Some([0xFF; 2]));
402 }
403 }
404
405 #[test]
406 fn test_update() {
407 for b in IPV4_HEADERS {
408 let mut buf = Vec::new();
409 buf.extend_from_slice(b);
410
411 let mut c = Checksum::new();
412 c.add_bytes(&buf);
413 assert_eq!(c.checksum(), [0u8; 2]);
414
415 // replace the destination IP with the loopback address
416 let old = [buf[16], buf[17], buf[18], buf[19]];
417 (&mut buf[16..20]).copy_from_slice(&[127, 0, 0, 1]);
418 let updated = update(c.checksum(), &old, &[127, 0, 0, 1]);
419 let from_scratch = {
420 let mut c = Checksum::new();
421 c.add_bytes(&buf);
422 c.checksum()
423 };
424 assert_eq!(updated, from_scratch);
425 }
426 }
427
428 #[test]
429 fn test_update_noop() {
430 for b in IPV4_HEADERS {
431 let mut buf = Vec::new();
432 buf.extend_from_slice(b);
433
434 let mut c = Checksum::new();
435 c.add_bytes(&buf);
436 assert_eq!(c.checksum(), [0u8; 2]);
437
438 // Replace the destination IP with the same address. I.e. this
439 // update should be a no-op.
440 let old = [buf[16], buf[17], buf[18], buf[19]];
441 let updated = update(c.checksum(), &old, &old);
442 let from_scratch = {
443 let mut c = Checksum::new();
444 c.add_bytes(&buf);
445 c.checksum()
446 };
447 assert_eq!(updated, from_scratch);
448 }
449 }
450
451 #[test]
452 fn test_smoke_update() {
453 let mut rng = new_rng(70_812_476_915_813);
454
455 for _ in 0..2048 {
456 // use an odd length so we test the odd length logic
457 const BUF_LEN: usize = 31;
458 let buf: [u8; BUF_LEN] = rng.random();
459 let mut c = Checksum::new();
460 c.add_bytes(&buf);
461
462 let (begin, end) = loop {
463 let begin = rng.random_range(0..BUF_LEN);
464 let end = begin + (rng.random_range(0..(BUF_LEN + 1 - begin)));
465 // update requires that begin is even and end is either even or
466 // the end of the input
467 if begin % 2 == 0 && (end % 2 == 0 || end == BUF_LEN) {
468 break (begin, end);
469 }
470 };
471
472 let mut new_buf = buf;
473 for i in begin..end {
474 new_buf[i] = rng.random();
475 }
476 let updated = update(c.checksum(), &buf[begin..end], &new_buf[begin..end]);
477 let from_scratch = {
478 let mut c = Checksum::new();
479 c.add_bytes(&new_buf);
480 c.checksum()
481 };
482 assert_eq!(updated, from_scratch);
483 }
484 }
485
486 /// IPv4 headers.
487 ///
488 /// This data was obtained by capturing live network traffic.
489 const IPV4_HEADERS: &[&[u8]] = &[
490 &[
491 0x45, 0x00, 0x00, 0x34, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0xae, 0xea, 0xc0, 0xa8,
492 0x01, 0x0f, 0xc0, 0xb8, 0x09, 0x6a,
493 ],
494 &[
495 0x45, 0x20, 0x00, 0x74, 0x5b, 0x6e, 0x40, 0x00, 0x37, 0x06, 0x5c, 0x1c, 0xc0, 0xb8,
496 0x09, 0x6a, 0xc0, 0xa8, 0x01, 0x0f,
497 ],
498 &[
499 0x45, 0x20, 0x02, 0x8f, 0x00, 0x00, 0x40, 0x00, 0x3b, 0x11, 0xc9, 0x3f, 0xac, 0xd9,
500 0x05, 0x6e, 0xc0, 0xa8, 0x01, 0x0f,
501 ],
502 ];
503
504 // This test checks that an input, found by a fuzzer, no longer causes a crash due to addition
505 // overflow.
506 #[test]
507 fn test_large_buffer_addition_overflow() {
508 let mut sum = Checksum { sum: 0, trailing_byte: None };
509 let bytes = [
510 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
511 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
512 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
513 ];
514 sum.add_bytes(&bytes[..]);
515 }
516}