delivery_blob/
compression.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Implementation of chunked-compression library in Rust. Archives can be created by making a new
6//! [`ChunkedArchive`] and serializing/writing it. An archive's header can be verified and seek
7//! table decoded using [`decode_archive`].
8
9use crc::Hasher32;
10use itertools::Itertools;
11use rayon::prelude::*;
12use std::ops::Range;
13use thiserror::Error;
14use zerocopy::byteorder::{LE, U16, U32, U64};
15use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Ref, Unaligned};
16
17/// Validated chunk information from an archive. Compressed ranges are relative to the start of
18/// compressed data (i.e. they start after the header and seek table).
19// *NOTE*: Use caution when using the `#[source]` attribute or naming fields `source`. Some callers
20// attempt to downcast library errors into the concrete type of the root cause.
21// See https://docs.rs/thiserror/latest/thiserror/ for more information.
22#[derive(Debug, Error)]
23pub enum ChunkedArchiveError {
24    #[error("Invalid or unsupported archive version.")]
25    InvalidVersion,
26
27    #[error("Archive header has incorrect magic.")]
28    BadMagic,
29
30    #[error("Integrity checks failed (e.g. incorrect CRC, inconsistent header fields).")]
31    IntegrityError,
32
33    #[error("Value is out of range or cannot be represented in specified type.")]
34    OutOfRange,
35
36    #[error("Error invoking Zstd function: `{0:?}`.")]
37    ZstdError(std::io::Error),
38
39    #[error("Error decompressing chunk {index}: `{error}`.")]
40    DecompressionError { index: usize, error: std::io::Error },
41
42    #[error("Error compressing chunk {index}: `{error}`.")]
43    CompressionError { index: usize, error: std::io::Error },
44}
45
46/// Validated chunk information from an archive. Compressed ranges are relative to the start of
47/// compressed data (i.e. they start after the header and seek table).
48#[derive(Clone, Debug, Eq, PartialEq)]
49pub struct ChunkInfo {
50    pub decompressed_range: Range<usize>,
51    pub compressed_range: Range<usize>,
52}
53
54/// Decode a chunked archive header. Returns validated seek table and start of chunk data. Ranges
55/// in resulting chunks are relative to start of returned slice. Returns `Ok(None)` if `data` is not
56/// large enough to decode the archive header & seek table.
57pub fn decode_archive(
58    data: &[u8],
59    archive_length: usize,
60) -> Result<Option<(Vec<ChunkInfo>, /*archive_data*/ &[u8])>, ChunkedArchiveError> {
61    match Ref::<_, ChunkedArchiveHeader>::from_prefix(data).map_err(Into::into) {
62        Ok((header, data)) => header.decode_seek_table(data, archive_length as u64),
63        Err(zerocopy::SizeError { .. }) => Ok(None), // Not enough data.
64    }
65}
66
67impl ChunkInfo {
68    fn from_entry(
69        entry: &SeekTableEntry,
70        header_length: usize,
71    ) -> Result<Self, ChunkedArchiveError> {
72        let decompressed_start = entry.decompressed_offset.get() as usize;
73        let decompressed_size = entry.decompressed_size.get() as usize;
74        let decompressed_range = decompressed_start
75            ..decompressed_start
76                .checked_add(decompressed_size)
77                .ok_or(ChunkedArchiveError::OutOfRange)?;
78
79        let compressed_offset = entry.compressed_offset.get() as usize;
80        let compressed_start = compressed_offset
81            .checked_sub(header_length)
82            .ok_or(ChunkedArchiveError::IntegrityError)?;
83        let compressed_size = entry.compressed_size.get() as usize;
84        let compressed_range = compressed_start
85            ..compressed_start
86                .checked_add(compressed_size)
87                .ok_or(ChunkedArchiveError::OutOfRange)?;
88
89        Ok(Self { decompressed_range, compressed_range })
90    }
91}
92
93/// Chunked archive header.
94#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
95#[repr(C)]
96struct ChunkedArchiveHeader {
97    magic: [u8; 8],
98    version: U16<LE>,
99    reserved_0: U16<LE>,
100    num_entries: U32<LE>,
101    checksum: U32<LE>,
102    reserved_1: U32<LE>,
103    reserved_2: U64<LE>,
104}
105
106/// Chunked archive seek table entry.
107#[derive(IntoBytes, KnownLayout, FromBytes, Immutable, Unaligned, Clone, Copy, Debug)]
108#[repr(C)]
109struct SeekTableEntry {
110    decompressed_offset: U64<LE>,
111    decompressed_size: U64<LE>,
112    compressed_offset: U64<LE>,
113    compressed_size: U64<LE>,
114}
115
116impl ChunkedArchiveHeader {
117    const CHUNKED_ARCHIVE_MAGIC: [u8; 8] = [0x46, 0x9b, 0x78, 0xef, 0x0f, 0xd0, 0xb2, 0x03];
118    const CHUNKED_ARCHIVE_VERSION: u16 = 2;
119    const CHUNKED_ARCHIVE_MAX_FRAMES: usize = 1023;
120    const CHUNKED_ARCHIVE_CHECKSUM_OFFSET: usize = 16;
121
122    fn new(seek_table: &[SeekTableEntry]) -> Result<Self, ChunkedArchiveError> {
123        let header: ChunkedArchiveHeader = Self {
124            magic: Self::CHUNKED_ARCHIVE_MAGIC,
125            version: Self::CHUNKED_ARCHIVE_VERSION.into(),
126            reserved_0: 0.into(),
127            num_entries: TryInto::<u32>::try_into(seek_table.len())
128                .or(Err(ChunkedArchiveError::OutOfRange))?
129                .into(),
130            checksum: 0.into(), // `checksum` is calculated below.
131            reserved_1: 0.into(),
132            reserved_2: 0.into(),
133        };
134        Ok(Self { checksum: header.checksum(seek_table).into(), ..header })
135    }
136
137    /// Calculate the checksum of the header + all seek table entries.
138    fn checksum(&self, entries: &[SeekTableEntry]) -> u32 {
139        let mut first_crc = crc::crc32::Digest::new(crc::crc32::IEEE);
140        first_crc.write(&self.as_bytes()[..Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET]);
141        let mut crc = crc::crc32::Digest::new_with_initial(crc::crc32::IEEE, first_crc.sum32());
142        crc.write(
143            &self.as_bytes()
144                [Self::CHUNKED_ARCHIVE_CHECKSUM_OFFSET + self.checksum.as_bytes().len()..],
145        );
146        crc.write(entries.as_bytes());
147        crc.sum32()
148    }
149
150    /// Calculate the total header length of an archive *including* all seek table entries.
151    fn header_length(num_entries: usize) -> usize {
152        std::mem::size_of::<ChunkedArchiveHeader>()
153            + (std::mem::size_of::<SeekTableEntry>() * num_entries)
154    }
155
156    /// Decode seek table for this archive. Returns validated seek table and start of chunk data.
157    /// `data` must point to the start of the seek table. Returns `Ok(None)` if `data` is not large
158    /// enough to decode all seek table entries.
159    fn decode_seek_table(
160        self,
161        data: &[u8],
162        archive_length: u64,
163    ) -> Result<Option<(Vec<ChunkInfo>, /*chunk_data*/ &[u8])>, ChunkedArchiveError> {
164        // Deserialize seek table.
165        let num_entries = self.num_entries.get() as usize;
166        let Ok((entries, chunk_data)) =
167            Ref::<_, [SeekTableEntry]>::from_prefix_with_elems(data, num_entries)
168        else {
169            return Ok(None);
170        };
171        let entries: &[SeekTableEntry] = Ref::into_ref(entries);
172
173        // Validate archive header.
174        if self.magic != Self::CHUNKED_ARCHIVE_MAGIC {
175            return Err(ChunkedArchiveError::BadMagic);
176        }
177        if self.version.get() != Self::CHUNKED_ARCHIVE_VERSION {
178            return Err(ChunkedArchiveError::InvalidVersion);
179        }
180        if self.checksum.get() != self.checksum(entries) {
181            return Err(ChunkedArchiveError::IntegrityError);
182        }
183        if entries.len() > Self::CHUNKED_ARCHIVE_MAX_FRAMES {
184            return Err(ChunkedArchiveError::IntegrityError);
185        }
186
187        // Validate seek table using invariants I0 through I5.
188
189        // I0: The first seek table entry, if any, must have decompressed offset 0.
190        if !entries.is_empty() && entries[0].decompressed_offset.get() != 0 {
191            return Err(ChunkedArchiveError::IntegrityError);
192        }
193
194        // I1: The compressed offsets of all seek table entries must not overlap with the header.
195        let header_length = Self::header_length(entries.len());
196        if entries.iter().any(|entry| entry.compressed_offset.get() < header_length as u64) {
197            return Err(ChunkedArchiveError::IntegrityError);
198        }
199
200        // I2: Each entry's decompressed offset must be equal to the end of the previous frame
201        //     (i.e. to the previous frame's decompressed offset + length).
202        for (prev, curr) in entries.iter().tuple_windows() {
203            if (prev.decompressed_offset.get() + prev.decompressed_size.get())
204                != curr.decompressed_offset.get()
205            {
206                return Err(ChunkedArchiveError::IntegrityError);
207            }
208        }
209
210        // I3: Each entry's compressed offset must be greater than or equal to the end of the
211        //     previous frame (i.e. to the previous frame's compressed offset + length).
212        for (prev, curr) in entries.iter().tuple_windows() {
213            if (prev.compressed_offset.get() + prev.compressed_size.get())
214                > curr.compressed_offset.get()
215            {
216                return Err(ChunkedArchiveError::IntegrityError);
217            }
218        }
219
220        // I4: Each entry must have a non-zero decompressed and compressed length.
221        for entry in entries.iter() {
222            if entry.decompressed_size.get() == 0 || entry.compressed_size.get() == 0 {
223                return Err(ChunkedArchiveError::IntegrityError);
224            }
225        }
226
227        // I5: Data referenced by each entry must fit within the specified file size.
228        for entry in entries.iter() {
229            let compressed_end = entry.compressed_offset.get() + entry.compressed_size.get();
230            if compressed_end > archive_length {
231                return Err(ChunkedArchiveError::IntegrityError);
232            }
233        }
234
235        let seek_table = entries
236            .into_iter()
237            .map(|entry| ChunkInfo::from_entry(entry, header_length))
238            .try_collect()?;
239        Ok(Some((seek_table, chunk_data)))
240    }
241}
242
243/// In-memory representation of a compressed chunk.
244pub struct CompressedChunk {
245    /// Compressed data for this chunk.
246    pub compressed_data: Vec<u8>,
247    /// Size of this chunk when decompressed.
248    pub decompressed_size: usize,
249}
250
251/// In-memory representation of a compressed chunked archive.
252pub struct ChunkedArchive {
253    /// Chunks this archive contains, in order. Right now we only allow creating archives with
254    /// contiguous compressed and decompressed space.
255    chunks: Vec<CompressedChunk>,
256    /// Size used to chunk input when creating this archive. Last chunk may be smaller than this amount.
257    chunk_size: usize,
258}
259
260impl ChunkedArchive {
261    const MAX_CHUNKS: usize = ChunkedArchiveHeader::CHUNKED_ARCHIVE_MAX_FRAMES;
262    const TARGET_CHUNK_SIZE: usize = 32 * 1024;
263    const COMPRESSION_LEVEL: i32 = 14;
264
265    /// Create a ChunkedArchive for `data` compressing each chunk in parallel. This function uses
266    /// the `rayon` crate for parallelism. By default compression happens in the global thread pool,
267    /// but this function can also be executed within a locally scoped pool.
268    pub fn new(data: &[u8], chunk_alignment: usize) -> Result<Self, ChunkedArchiveError> {
269        let chunk_size = ChunkedArchive::chunk_size_for(data.len(), chunk_alignment);
270        let mut chunks: Vec<Result<CompressedChunk, ChunkedArchiveError>> = vec![];
271        data.par_chunks(chunk_size)
272            .enumerate()
273            .map(|(index, chunk)| {
274                // Creating and destroying zstd::bulk::Compressor objects is expensive. A single
275                // `Compressor` is created for each `rayon` thread and is reused across chunks.
276                thread_local! {
277                    static COMPRESSOR: std::cell::RefCell<zstd::bulk::Compressor<'static>> =
278                        std::cell::RefCell::new({
279                            let mut compressor =
280                                zstd::bulk::Compressor::new(ChunkedArchive::COMPRESSION_LEVEL)
281                                    .unwrap();
282                            compressor
283                                .set_parameter(zstd::zstd_safe::CParameter::ChecksumFlag(true))
284                                .unwrap();
285                            compressor
286                        });
287                }
288                let compressed_data = COMPRESSOR.with(|compressor| {
289                    let mut compressor = compressor.borrow_mut();
290                    compressor
291                        .compress(chunk)
292                        .map_err(|error| ChunkedArchiveError::CompressionError { index, error })
293                })?;
294                Ok(CompressedChunk { compressed_data, decompressed_size: chunk.len() })
295            })
296            .collect_into_vec(&mut chunks);
297        let chunks: Vec<_> = chunks.into_iter().try_collect()?;
298        Ok(ChunkedArchive { chunks, chunk_size })
299    }
300
301    /// Accessor for compressed chunk data.
302    pub fn chunks(&self) -> &Vec<CompressedChunk> {
303        &self.chunks
304    }
305
306    /// The chunk size calculated for this archive during compression. Represents how input data
307    /// was chunked for compression. Note that the final chunk may be smaller than this amount
308    /// when decompressed.
309    pub fn chunk_size(&self) -> usize {
310        self.chunk_size
311    }
312
313    /// Sum of sizes of all compressed chunks.
314    pub fn compressed_data_size(&self) -> usize {
315        self.chunks.iter().map(|chunk| chunk.compressed_data.len()).sum()
316    }
317
318    /// Total size of the archive in bytes.
319    pub fn serialized_size(&self) -> usize {
320        ChunkedArchiveHeader::header_length(self.chunks.len()) + self.compressed_data_size()
321    }
322
323    /// Write the archive to `writer`.
324    pub fn write(self, mut writer: impl std::io::Write) -> Result<(), std::io::Error> {
325        let seek_table = self.make_seek_table();
326        let header = ChunkedArchiveHeader::new(&seek_table).unwrap();
327        writer.write_all(header.as_bytes())?;
328        writer.write_all(seek_table.as_slice().as_bytes())?;
329        for chunk in self.chunks {
330            writer.write_all(&chunk.compressed_data)?;
331        }
332        Ok(())
333    }
334
335    /// Calculate how large chunks must be for a given uncompressed buffer.
336    fn chunk_size_for(uncompressed_length: usize, chunk_alignment: usize) -> usize {
337        if uncompressed_length <= (Self::MAX_CHUNKS * Self::TARGET_CHUNK_SIZE) {
338            return Self::TARGET_CHUNK_SIZE;
339        }
340        // TODO(https://github.com/rust-lang/rust/issues/88581): Replace with
341        // `{integer}::div_ceil()` when `int_roundings` is available.
342        let chunk_size =
343            round_up(uncompressed_length, ChunkedArchive::MAX_CHUNKS) / ChunkedArchive::MAX_CHUNKS;
344        return round_up(chunk_size, chunk_alignment);
345    }
346
347    /// Create the seek table for this archive.
348    fn make_seek_table(&self) -> Vec<SeekTableEntry> {
349        let header_length = ChunkedArchiveHeader::header_length(self.chunks.len());
350        let mut seek_table = vec![];
351        seek_table.reserve(self.chunks.len());
352        let mut compressed_size: usize = 0;
353        let mut decompressed_offset: usize = 0;
354        for chunk in &self.chunks {
355            seek_table.push(SeekTableEntry {
356                decompressed_offset: (decompressed_offset as u64).into(),
357                decompressed_size: (chunk.decompressed_size as u64).into(),
358                compressed_offset: ((header_length + compressed_size) as u64).into(),
359                compressed_size: (chunk.compressed_data.len() as u64).into(),
360            });
361            compressed_size += chunk.compressed_data.len();
362            decompressed_offset += chunk.decompressed_size;
363        }
364        seek_table
365    }
366}
367
368/// Streaming decompressor for chunked archives. Example:
369/// ```
370/// // Create a chunked archive:
371/// let data: Vec<u8> = vec![3; 1024];
372/// let compressed = ChunkedArchive::new(&data, /*block_size*/ 8192).serialize().unwrap();
373/// // Verify the header + decode the seek table:
374/// let (seek_table, archive_data) = decode_archive(&compressed, compressed.len())?.unwrap();
375/// let mut decompressed: Vec<u8> = vec![];
376/// let mut on_chunk = |data: &[u8]| { decompressed.extend_from_slice(data); };
377/// let mut decompressor = ChunkedDecompressor(seek_table);
378/// // `on_chunk` is invoked as each slice is made available. Archive can be provided as chunks.
379/// decompressor.update(archive_data, &mut on_chunk);
380/// assert_eq!(data.as_slice(), decompressed.as_slice());
381/// ```
382pub struct ChunkedDecompressor {
383    seek_table: Vec<ChunkInfo>,
384    buffer: Vec<u8>,
385    data_written: usize,
386    curr_chunk: usize,
387    total_compressed_size: usize,
388    decompressor: zstd::bulk::Decompressor<'static>,
389    decompressed_buffer: Vec<u8>,
390    error_handler: Option<ErrorHandler>,
391}
392
393type ErrorHandler = Box<dyn Fn(usize, ChunkInfo, &[u8]) -> () + Send + 'static>;
394
395impl ChunkedDecompressor {
396    /// Create a new decompressor to decode an archive from a validated seek table.
397    pub fn new(seek_table: Vec<ChunkInfo>) -> Result<Self, ChunkedArchiveError> {
398        let total_compressed_size =
399            seek_table.last().map(|last_chunk| last_chunk.compressed_range.end).unwrap_or(0);
400        let decompressed_buffer =
401            vec![0u8; seek_table.first().map(|c| c.decompressed_range.len()).unwrap_or(0)];
402        let decompressor =
403            zstd::bulk::Decompressor::new().map_err(ChunkedArchiveError::ZstdError)?;
404        Ok(Self {
405            seek_table,
406            buffer: vec![],
407            data_written: 0,
408            curr_chunk: 0,
409            total_compressed_size,
410            decompressor,
411            decompressed_buffer,
412            error_handler: None,
413        })
414    }
415
416    /// Creates a new decompressor with an additional error handler invoked when a chunk fails to be
417    /// decompressed.
418    pub fn new_with_error_handler(
419        seek_table: Vec<ChunkInfo>,
420        error_handler: ErrorHandler,
421    ) -> Result<Self, ChunkedArchiveError> {
422        Ok(Self { error_handler: Some(error_handler), ..Self::new(seek_table)? })
423    }
424
425    pub fn seek_table(&self) -> &Vec<ChunkInfo> {
426        &self.seek_table
427    }
428
429    fn finish_chunk(
430        &mut self,
431        data: &[u8],
432        chunk_callback: &mut impl FnMut(&[u8]) -> (),
433    ) -> Result<(), ChunkedArchiveError> {
434        debug_assert_eq!(data.len(), self.seek_table[self.curr_chunk].compressed_range.len());
435        let chunk = &self.seek_table[self.curr_chunk];
436        let decompressed_size = self
437            .decompressor
438            .decompress_to_buffer(data, self.decompressed_buffer.as_mut_slice())
439            .map_err(|error| {
440                if let Some(ref error_handler) = self.error_handler {
441                    error_handler(self.curr_chunk, chunk.clone(), data.as_bytes());
442                }
443                ChunkedArchiveError::DecompressionError { index: self.curr_chunk, error }
444            })?;
445        if decompressed_size != chunk.decompressed_range.len() {
446            return Err(ChunkedArchiveError::IntegrityError);
447        }
448        chunk_callback(&self.decompressed_buffer[..decompressed_size]);
449        self.curr_chunk += 1;
450        Ok(())
451    }
452
453    /// Update the decompressor with more data.
454    pub fn update(
455        &mut self,
456        mut data: &[u8],
457        chunk_callback: &mut impl FnMut(&[u8]) -> (),
458    ) -> Result<(), ChunkedArchiveError> {
459        // Caller must not provide too much data.
460        if self.data_written + data.len() > self.total_compressed_size {
461            return Err(ChunkedArchiveError::OutOfRange);
462        }
463        self.data_written += data.len();
464
465        // If we had leftover data from a previous read, append until we've filled a chunk.
466        if !self.buffer.is_empty() {
467            let to_read = std::cmp::min(
468                data.len(),
469                self.seek_table[self.curr_chunk]
470                    .compressed_range
471                    .len()
472                    .checked_sub(self.buffer.len())
473                    .unwrap(),
474            );
475            self.buffer.extend_from_slice(&data[..to_read]);
476            if self.buffer.len() == self.seek_table[self.curr_chunk].compressed_range.len() {
477                // Take self.buffer temporarily (so we don't have to split borrows).
478                // That way we don't have to re-commit the pages we've already used in the buffer
479                // for next time.
480                let full_chunk = std::mem::take(&mut self.buffer);
481                self.finish_chunk(&full_chunk[..], chunk_callback)?;
482                self.buffer = full_chunk;
483                // Draining the buffer will set the length to 0 but keep the capacity the same.
484                self.buffer.drain(..);
485            }
486            data = &data[to_read..];
487        }
488
489        // Decode as many full chunks as we can.
490        while !data.is_empty()
491            && self.curr_chunk < self.seek_table.len()
492            && self.seek_table[self.curr_chunk].compressed_range.len() <= data.len()
493        {
494            let len = self.seek_table[self.curr_chunk].compressed_range.len();
495            self.finish_chunk(&data[..len], chunk_callback)?;
496            data = &data[len..];
497        }
498
499        // Buffer the rest for the next call.
500        if !data.is_empty() {
501            debug_assert!(self.curr_chunk < self.seek_table.len());
502            debug_assert!(self.data_written < self.total_compressed_size);
503            self.buffer.extend_from_slice(data);
504        }
505
506        debug_assert!(
507            self.data_written < self.total_compressed_size
508                || self.curr_chunk == self.seek_table.len()
509        );
510
511        Ok(())
512    }
513}
514
515/// TODO(https://github.com/rust-lang/rust/issues/88581): Replace with
516/// `{integer}::checked_next_multiple_of()` when `int_roundings` is available.
517fn round_up(value: usize, multiple: usize) -> usize {
518    let remainder = value % multiple;
519    if remainder > 0 {
520        value.checked_add(multiple - remainder).unwrap()
521    } else {
522        value
523    }
524}
525
526#[cfg(test)]
527mod tests {
528
529    use super::*;
530    use rand::Rng;
531    use std::matches;
532
533    const BLOCK_SIZE: usize = 8192;
534
535    /// Create a compressed archive and ensure we can decode it as a valid archive that passes all
536    /// required integrity checks.
537    #[test]
538    fn compress_simple() {
539        let data: Vec<u8> = vec![0; 32 * 1024 * 16];
540        let archive = ChunkedArchive::new(&data, BLOCK_SIZE).unwrap();
541        // This data is highly compressible, so the result should be smaller than the original.
542        let mut compressed: Vec<u8> = vec![];
543        archive.write(&mut compressed).unwrap();
544        assert!(compressed.len() <= data.len());
545        // We should be able to decode and verify the archive's integrity in-place.
546        assert!(decode_archive(&compressed, compressed.len()).unwrap().is_some());
547    }
548
549    /// Generate a header + seek table for verifying invariants/integrity checks.
550    fn generate_archive(
551        num_entries: usize,
552    ) -> (ChunkedArchiveHeader, Vec<SeekTableEntry>, /*archive_length*/ u64) {
553        let mut seek_table = vec![];
554        seek_table.reserve(num_entries);
555        let header_length = ChunkedArchiveHeader::header_length(num_entries) as u64;
556        const COMPRESSED_CHUNK_SIZE: u64 = 1024;
557        const DECOMPRESSED_CHUNK_SIZE: u64 = 2048;
558        for n in 0..(num_entries as u64) {
559            seek_table.push(SeekTableEntry {
560                compressed_offset: (header_length + (n * COMPRESSED_CHUNK_SIZE)).into(),
561                compressed_size: COMPRESSED_CHUNK_SIZE.into(),
562                decompressed_offset: (n * DECOMPRESSED_CHUNK_SIZE).into(),
563                decompressed_size: DECOMPRESSED_CHUNK_SIZE.into(),
564            });
565        }
566        let header = ChunkedArchiveHeader::new(&seek_table).unwrap();
567        let archive_length: u64 = header_length + (num_entries as u64 * COMPRESSED_CHUNK_SIZE);
568        (header, seek_table, archive_length)
569    }
570
571    #[test]
572    fn should_validate_self() {
573        let (header, seek_table, archive_length) = generate_archive(4);
574        let serialized_table = seek_table.as_slice().as_bytes();
575        assert!(header.decode_seek_table(serialized_table, archive_length).unwrap().is_some());
576    }
577
578    #[test]
579    fn should_validate_empty() {
580        let (header, _, archive_length) = generate_archive(0);
581        assert!(header.decode_seek_table(&[], archive_length).unwrap().is_some());
582    }
583
584    #[test]
585    fn should_detect_bad_magic() {
586        let (header, seek_table, archive_length) = generate_archive(4);
587        let mut corrupt_magic = ChunkedArchiveHeader::CHUNKED_ARCHIVE_MAGIC;
588        corrupt_magic[0] = !corrupt_magic[0];
589        let bad_magic = ChunkedArchiveHeader { magic: corrupt_magic, ..header };
590        let serialized_table = seek_table.as_slice().as_bytes();
591        assert!(matches!(
592            bad_magic.decode_seek_table(serialized_table, archive_length).unwrap_err(),
593            ChunkedArchiveError::BadMagic
594        ));
595    }
596    #[test]
597    fn should_detect_wrong_version() {
598        let (header, seek_table, archive_length) = generate_archive(4);
599        let wrong_version = ChunkedArchiveHeader {
600            version: (ChunkedArchiveHeader::CHUNKED_ARCHIVE_VERSION + 1).into(),
601            ..header
602        };
603        let serialized_table = seek_table.as_slice().as_bytes();
604        assert!(matches!(
605            wrong_version.decode_seek_table(serialized_table, archive_length).unwrap_err(),
606            ChunkedArchiveError::InvalidVersion
607        ));
608    }
609
610    #[test]
611    fn should_detect_corrupt_checksum() {
612        let (header, seek_table, archive_length) = generate_archive(4);
613        let corrupt_checksum =
614            ChunkedArchiveHeader { checksum: (!header.checksum.get()).into(), ..header };
615        let serialized_table = seek_table.as_slice().as_bytes();
616        assert!(matches!(
617            corrupt_checksum.decode_seek_table(serialized_table, archive_length).unwrap_err(),
618            ChunkedArchiveError::IntegrityError
619        ));
620    }
621
622    #[test]
623    fn should_reject_too_many_entries() {
624        let (too_many_entries, seek_table, archive_length) =
625            generate_archive(ChunkedArchiveHeader::CHUNKED_ARCHIVE_MAX_FRAMES + 1);
626
627        let serialized_table = seek_table.as_slice().as_bytes();
628        assert!(matches!(
629            too_many_entries.decode_seek_table(serialized_table, archive_length).unwrap_err(),
630            ChunkedArchiveError::IntegrityError
631        ));
632    }
633
634    #[test]
635    fn invariant_i0_first_entry_zero() {
636        let (header, mut seek_table, archive_length) = generate_archive(4);
637        assert_eq!(seek_table[0].decompressed_offset.get(), 0);
638        seek_table[0].decompressed_offset = 1.into();
639
640        let serialized_table = seek_table.as_slice().as_bytes();
641        assert!(matches!(
642            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
643            ChunkedArchiveError::IntegrityError
644        ));
645    }
646
647    #[test]
648    fn invariant_i1_no_header_overlap() {
649        let (header, mut seek_table, archive_length) = generate_archive(4);
650        let header_end = ChunkedArchiveHeader::header_length(seek_table.len()) as u64;
651        assert!(seek_table[0].compressed_offset.get() >= header_end);
652        seek_table[0].compressed_offset = (header_end - 1).into();
653        let serialized_table = seek_table.as_slice().as_bytes();
654        assert!(matches!(
655            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
656            ChunkedArchiveError::IntegrityError
657        ));
658    }
659
660    #[test]
661    fn invariant_i2_decompressed_monotonic() {
662        let (header, mut seek_table, archive_length) = generate_archive(4);
663        assert_eq!(
664            seek_table[0].decompressed_offset.get() + seek_table[0].decompressed_size.get(),
665            seek_table[1].decompressed_offset.get()
666        );
667        seek_table[1].decompressed_offset = (seek_table[1].decompressed_offset.get() - 1).into();
668        let serialized_table = seek_table.as_slice().as_bytes();
669        assert!(matches!(
670            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
671            ChunkedArchiveError::IntegrityError
672        ));
673    }
674
675    #[test]
676    fn invariant_i3_compressed_monotonic() {
677        let (header, mut seek_table, archive_length) = generate_archive(4);
678        assert!(
679            (seek_table[0].compressed_offset.get() + seek_table[0].compressed_size.get())
680                <= seek_table[1].compressed_offset.get()
681        );
682        seek_table[1].compressed_offset = (seek_table[1].compressed_offset.get() - 1).into();
683        let serialized_table = seek_table.as_slice().as_bytes();
684        assert!(matches!(
685            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
686            ChunkedArchiveError::IntegrityError
687        ));
688    }
689
690    #[test]
691    fn invariant_i4_nonzero_compressed_size() {
692        let (header, mut seek_table, archive_length) = generate_archive(4);
693        assert!(seek_table[0].compressed_size.get() > 0);
694        seek_table[0].compressed_size = 0.into();
695        let serialized_table = seek_table.as_slice().as_bytes();
696        assert!(matches!(
697            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
698            ChunkedArchiveError::IntegrityError
699        ));
700    }
701
702    #[test]
703    fn invariant_i4_nonzero_decompressed_size() {
704        let (header, mut seek_table, archive_length) = generate_archive(4);
705        assert!(seek_table[0].decompressed_size.get() > 0);
706        seek_table[0].decompressed_size = 0.into();
707        let serialized_table = seek_table.as_slice().as_bytes();
708        assert!(matches!(
709            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
710            ChunkedArchiveError::IntegrityError
711        ));
712    }
713
714    #[test]
715    fn invariant_i5_within_archive() {
716        let (header, mut seek_table, archive_length) = generate_archive(4);
717        let last_entry = seek_table.last_mut().unwrap();
718        assert!(
719            (last_entry.compressed_offset.get() + last_entry.compressed_size.get())
720                <= archive_length
721        );
722        last_entry.compressed_offset = (archive_length + 1).into();
723        let serialized_table = seek_table.as_slice().as_bytes();
724        assert!(matches!(
725            header.decode_seek_table(serialized_table, archive_length).unwrap_err(),
726            ChunkedArchiveError::IntegrityError
727        ));
728    }
729
730    #[test]
731    fn max_chunks() {
732        assert_eq!(
733            ChunkedArchive::chunk_size_for(
734                ChunkedArchive::MAX_CHUNKS * ChunkedArchive::TARGET_CHUNK_SIZE,
735                BLOCK_SIZE,
736            ),
737            ChunkedArchive::TARGET_CHUNK_SIZE
738        );
739        assert_eq!(
740            ChunkedArchive::chunk_size_for(
741                ChunkedArchive::MAX_CHUNKS * ChunkedArchive::TARGET_CHUNK_SIZE + 1,
742                BLOCK_SIZE,
743            ),
744            ChunkedArchive::TARGET_CHUNK_SIZE + BLOCK_SIZE
745        );
746    }
747
748    #[test]
749    fn test_decompressor_empty_archive() {
750        let mut compressed: Vec<u8> = vec![];
751        ChunkedArchive::new(&[], BLOCK_SIZE)
752            .expect("compress")
753            .write(&mut compressed)
754            .expect("write archive");
755        let (seek_table, chunk_data) =
756            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
757        assert!(seek_table.is_empty());
758        let mut decompressor = ChunkedDecompressor::new(seek_table).unwrap();
759        let mut chunk_callback = |_chunk: &[u8]| panic!("Archive doesn't have any chunks.");
760        // Stream data into the decompressor in small chunks to exhaust more edge cases.
761        chunk_data
762            .chunks(4)
763            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
764    }
765
766    #[test]
767    fn test_decompressor() {
768        const UNCOMPRESSED_LENGTH: usize = 3_000_000;
769        let data: Vec<u8> = {
770            let range = rand::distributions::Uniform::<u8>::new_inclusive(0, 255);
771            rand::thread_rng().sample_iter(&range).take(UNCOMPRESSED_LENGTH).collect()
772        };
773        let mut compressed: Vec<u8> = vec![];
774        ChunkedArchive::new(&data, BLOCK_SIZE)
775            .expect("compress")
776            .write(&mut compressed)
777            .expect("write archive");
778        let (seek_table, chunk_data) =
779            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
780
781        // Make sure we have multiple chunks for this test.
782        let num_chunks = seek_table.len();
783        assert!(num_chunks > 1);
784
785        let mut decompressor = ChunkedDecompressor::new(seek_table).unwrap();
786
787        let mut decoded_chunks: usize = 0;
788        let mut decompressed_offset: usize = 0;
789        let mut chunk_callback = |decompressed_chunk: &[u8]| {
790            assert!(
791                decompressed_chunk
792                    == &data[decompressed_offset..decompressed_offset + decompressed_chunk.len()]
793            );
794            decompressed_offset += decompressed_chunk.len();
795            decoded_chunks += 1;
796        };
797
798        // Stream data into the decompressor in small chunks to exhaust more edge cases.
799        chunk_data
800            .chunks(4)
801            .for_each(|data| decompressor.update(data, &mut chunk_callback).unwrap());
802        assert_eq!(decoded_chunks, num_chunks);
803    }
804
805    #[test]
806    fn test_decompressor_corrupt_decompressed_size() {
807        let data = vec![0; 3_000_000];
808        let mut compressed: Vec<u8> = vec![];
809        ChunkedArchive::new(&data, BLOCK_SIZE)
810            .expect("compress")
811            .write(&mut compressed)
812            .expect("write archive");
813        let (mut seek_table, chunk_data) =
814            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
815
816        // Corrupt the decompressed size of the chunk.
817        seek_table[0].decompressed_range =
818            seek_table[0].decompressed_range.start..seek_table[0].decompressed_range.end + 1;
819
820        let mut decompressor = ChunkedDecompressor::new(seek_table).unwrap();
821        assert!(matches!(
822            decompressor.update(&chunk_data, &mut |_chunk| {}),
823            Err(ChunkedArchiveError::IntegrityError)
824        ));
825    }
826
827    #[test]
828    fn test_decompressor_corrupt_compressed_size() {
829        let data = vec![0; 3_000_000];
830        let mut compressed: Vec<u8> = vec![];
831        ChunkedArchive::new(&data, BLOCK_SIZE)
832            .expect("compress")
833            .write(&mut compressed)
834            .expect("write archive");
835        let (mut seek_table, chunk_data) =
836            decode_archive(&compressed, compressed.len()).unwrap().unwrap();
837
838        // Corrupt the compressed size of the chunk.
839        seek_table[0].compressed_range =
840            seek_table[0].compressed_range.start..seek_table[0].compressed_range.end - 1;
841        let first_chunk_info = seek_table[0].clone();
842        let error_handler = move |chunk_index: usize, chunk_info: ChunkInfo, chunk_data: &[u8]| {
843            assert_eq!(chunk_index, 0);
844            assert_eq!(chunk_info, first_chunk_info);
845            assert_eq!(chunk_data.len(), chunk_info.compressed_range.len());
846        };
847
848        let mut decompressor =
849            ChunkedDecompressor::new_with_error_handler(seek_table, Box::new(error_handler))
850                .unwrap();
851        assert!(matches!(
852            decompressor.update(&chunk_data, &mut |_chunk| {}),
853            Err(ChunkedArchiveError::DecompressionError { index: 0, .. })
854        ));
855    }
856}