test_diagnostics/
zstd_compress.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::mpsc;
6use futures::SinkExt;
7use std::cell::RefCell;
8use thiserror::Error;
9use zstd::stream::raw::Operation;
10
11const BUFFER_SIZE: usize = 1024 * 1024 * 4; // 4 MB
12const CHANNEL_SIZE: usize = 10; // 1 MB
13
14thread_local! {
15    static BUFFER: RefCell<Vec<u8>> = RefCell::new(vec![0; BUFFER_SIZE]);
16}
17
18/// Error encountered during compression or Decompression.
19#[derive(Debug, Error)]
20pub enum Error {
21    /// Error Decompressing bytes.
22    #[error("Error Decompressing bytes:  pos: {1}, len: {2}, error: {0:?}")]
23    Decompress(#[source] std::io::Error, usize, usize),
24
25    /// Error compressing bytes.
26    #[error("Error compressing bytes:  pos: {1}, len: {2}, error: {0:?}")]
27    Compress(#[source] std::io::Error, usize, usize),
28
29    /// Error decompressing while flushing the decoder.
30    #[error("Error Decompressing while flushing:  {0:?}")]
31    DecompressFinish(#[source] std::io::Error),
32
33    /// Error compressing while flushing the encoder.
34    #[error("Error compressing while flushing:  {0:?}")]
35    CompressFinish(#[source] std::io::Error),
36
37    /// Error while sending data on the mpsc channel.
38    #[error("Error while sending on mpsc channel:  {0:?}")]
39    Send(#[source] mpsc::SendError),
40}
41
42/// A decoder that decompresses data using the Zstandard algorithm.
43pub struct Decoder<'a> {
44    sender: mpsc::Sender<Vec<u8>>,
45    decoder: zstd::stream::raw::Decoder<'a>,
46}
47
48impl Decoder<'static> {
49    /// Creates a new `Decoder` and returns a receiver for the decompressed data.
50    ///
51    /// The `Decoder` will decompress data in chunks and send the decompressed chunks
52    /// over the returned `mpsc::Receiver`.
53    pub fn new() -> (Self, mpsc::Receiver<Vec<u8>>) {
54        let (sender, receiver) = mpsc::channel(CHANNEL_SIZE);
55        let decoder = Self { sender: sender, decoder: zstd::stream::raw::Decoder::new().unwrap() };
56        (decoder, receiver)
57    }
58
59    /// Decompresses the given bytes and sends the decompressed data over the channel.
60    ///
61    /// This method decompresses the input data in chunks, using a thread-local buffer
62    /// to store the decompressed data. The decompressed chunks are then sent over
63    /// the channel to the receiver.
64    pub async fn decompress(&mut self, bytes: &[u8]) -> Result<(), Error> {
65        let len = bytes.len();
66        let mut pos = 0;
67        while pos != len {
68            let decoded_bytes = BUFFER.with_borrow_mut(|buf| {
69                let status = self
70                    .decoder
71                    .run_on_buffers(&bytes[pos..], buf.as_mut_slice())
72                    .map_err(|e| Error::Decompress(e, pos, len))?;
73                pos += status.bytes_read;
74                Ok::<Vec<u8>, Error>(buf[..status.bytes_written].to_vec())
75            })?;
76            self.sender.send(decoded_bytes).await.map_err(Error::Send)?;
77        }
78        Ok(())
79    }
80
81    /// Flushes the decoder and sends any remaining decompressed data over the channel.
82    ///
83    /// This method should always be called after all input data has been decompressed to ensure
84    /// that all decompressed data is sent to the receiver.
85    pub async fn finish(mut self) -> Result<(), Error> {
86        loop {
87            let (remaining_bytes, decoded_bytes) = BUFFER.with_borrow_mut(|buf| {
88                let mut out_buffer = zstd::stream::raw::OutBuffer::around(buf.as_mut_slice());
89                let remaining_bytes =
90                    self.decoder.flush(&mut out_buffer).map_err(Error::DecompressFinish)?;
91                Ok::<(usize, Vec<u8>), Error>((remaining_bytes, out_buffer.as_slice().to_vec()))
92            })?;
93            if !decoded_bytes.is_empty() {
94                self.sender.send(decoded_bytes).await.map_err(Error::Send)?;
95            }
96            if remaining_bytes == 0 {
97                break;
98            }
99        }
100        Ok(())
101    }
102}
103
104/// An encoder that compresses data using the Zstandard algorithm.
105pub struct Encoder<'a> {
106    sender: mpsc::Sender<Vec<u8>>,
107    encoder: zstd::stream::raw::Encoder<'a>,
108}
109
110impl Encoder<'static> {
111    /// Creates a new `Encoder` with the given compression level and returns a receiver
112    /// for the compressed data.
113    ///
114    /// The `Encoder` will compress data in chunks and send the compressed chunks
115    /// over the returned `mpsc::Receiver`.
116    pub fn new(level: i32) -> (Self, mpsc::Receiver<Vec<u8>>) {
117        let (sender, receiver) = mpsc::channel(CHANNEL_SIZE);
118        let decoder =
119            Self { sender: sender, encoder: zstd::stream::raw::Encoder::new(level).unwrap() };
120        (decoder, receiver)
121    }
122
123    /// Compresses the given bytes and sends the compressed data over the channel.
124    ///
125    /// This method compresses the input data in chunks, using a thread-local buffer
126    /// to store the compressed data. The compressed chunks are then sent over
127    /// the channel to the receiver.
128    pub async fn compress(&mut self, bytes: &[u8]) -> Result<(), Error> {
129        let len = bytes.len();
130        let mut pos = 0;
131        while pos != len {
132            let encoded_bytes = BUFFER.with_borrow_mut(|buf| {
133                let status = self
134                    .encoder
135                    .run_on_buffers(&bytes[pos..], buf.as_mut_slice())
136                    .map_err(|e| Error::Compress(e, pos, len))?;
137                pos += status.bytes_read;
138                Ok::<Vec<u8>, Error>(buf[..status.bytes_written].to_vec())
139            })?;
140            self.sender.send(encoded_bytes).await.map_err(Error::Send)?;
141        }
142        Ok(())
143    }
144
145    /// Flushes the encoder and sends any remaining compressed data over the channel.
146    ///
147    /// This method should be called after all input data has been compressed to ensure
148    /// that all compressed data is sent to the receiver.
149    pub async fn finish(mut self) -> Result<(), Error> {
150        loop {
151            let (remaining_bytes, encoded_bytes) = BUFFER.with_borrow_mut(|buf| {
152                let mut out_buffer = zstd::stream::raw::OutBuffer::around(buf.as_mut_slice());
153                let remaining_bytes =
154                    self.encoder.finish(&mut out_buffer, true).map_err(Error::CompressFinish)?;
155                Ok::<(usize, Vec<u8>), Error>((remaining_bytes, out_buffer.as_slice().to_vec()))
156            })?;
157            if !encoded_bytes.is_empty() {
158                self.sender.send(encoded_bytes).await.map_err(Error::Send)?;
159            }
160            if remaining_bytes == 0 {
161                break;
162            }
163        }
164        Ok(())
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use assert_matches::assert_matches;
172    use futures::StreamExt;
173    use rand::RngCore;
174    use test_case::test_case;
175
176    #[test_case(Vec::from(b"This is a test string"); "normal test string")]
177    #[test_case(Vec::from(b""); "empty string")]
178    #[fuchsia::test]
179    async fn test_compress_decompress(original_data: Vec<u8>) {
180        let (mut encoder, mut rx) = Encoder::new(0);
181        let (mut decoder, mut drx) = Decoder::new();
182
183        // Compress the data
184        encoder.compress(original_data.as_slice()).await.unwrap();
185        encoder.finish().await.unwrap();
186
187        // Receive the compressed data
188        let mut compressed_data = Vec::new();
189        while let Some(chunk) = rx.next().await {
190            compressed_data.extend_from_slice(&chunk);
191        }
192
193        assert_ne!(compressed_data.len(), original_data.len());
194
195        // Decompress the data
196        decoder.decompress(&compressed_data).await.unwrap();
197        decoder.finish().await.unwrap();
198
199        // Receive the decompressed data
200        let mut decompressed_data = Vec::new();
201        while let Some(chunk) = drx.next().await {
202            decompressed_data.extend_from_slice(&chunk);
203        }
204
205        // Assert that the decompressed data matches the original data
206        assert_eq!(original_data.as_slice(), &decompressed_data[..]);
207    }
208
209    #[fuchsia::test]
210    async fn test_compress_decompress_large_chunked() {
211        let (mut encoder, mut rx) = Encoder::new(0);
212        let (mut decoder, mut drx) = Decoder::new();
213
214        let original_data = vec![b'a'; BUFFER_SIZE * 10 + 100];
215        let chunk_size = 2 * 1024 * 1024; // 2 MB
216
217        // Compress the data in chunks
218        let compress_fut = async {
219            for i in (0..original_data.len()).step_by(chunk_size) {
220                encoder
221                    .compress(&original_data[i..i + chunk_size.min(original_data.len() - i)])
222                    .await
223                    .unwrap();
224            }
225            encoder.finish().await.unwrap();
226        };
227        let mut compressed_len = 0;
228        let decompress_fut = async {
229            while let Some(compressed_chunk) = rx.next().await {
230                compressed_len += compressed_chunk.len();
231                decoder.decompress(&compressed_chunk).await.unwrap();
232            }
233            decoder.finish().await.unwrap();
234        };
235
236        let mut decompressed_data = Vec::new();
237        let collect_final_data = async {
238            // Receive decompressed chunks
239            while let Some(chunk) = drx.next().await {
240                decompressed_data.extend_from_slice(&chunk);
241            }
242        };
243
244        futures::join!(compress_fut, decompress_fut, collect_final_data);
245
246        assert!(compressed_len < original_data.len());
247        assert_eq!(original_data, decompressed_data);
248    }
249
250    #[fuchsia::test]
251    async fn test_compress_decompress_random_chunked() {
252        let (mut encoder, mut rx) = Encoder::new(0);
253        let (mut decoder, mut drx) = Decoder::new();
254
255        let mut original_data = vec![0u8; BUFFER_SIZE * 5 + 100];
256        rand::thread_rng().fill_bytes(&mut original_data); // Fill with random data
257        let chunk_size = 2 * 1024 * 1024; // 2 MB
258
259        // Compress the data in chunks
260        let compress_fut = async {
261            for i in (0..original_data.len()).step_by(chunk_size) {
262                encoder
263                    .compress(&original_data[i..i + chunk_size.min(original_data.len() - i)])
264                    .await
265                    .unwrap();
266            }
267            encoder.finish().await.unwrap();
268        };
269        let mut compressed_len = 0;
270        let decompress_fut = async {
271            while let Some(compressed_chunk) = rx.next().await {
272                compressed_len += compressed_chunk.len();
273                decoder.decompress(&compressed_chunk).await.unwrap();
274            }
275            decoder.finish().await.unwrap();
276        };
277
278        let mut decompressed_data = Vec::new();
279        let collect_final_data = async {
280            // Receive decompressed chunks
281            while let Some(chunk) = drx.next().await {
282                decompressed_data.extend_from_slice(&chunk);
283            }
284        };
285
286        futures::join!(compress_fut, decompress_fut, collect_final_data);
287
288        assert_ne!(compressed_len, original_data.len());
289        assert_eq!(original_data, decompressed_data);
290    }
291
292    #[fuchsia::test]
293    async fn test_invalid_input() {
294        let (mut decoder, _drx) = Decoder::new();
295
296        let invalid_data = vec![0xff; 1024];
297
298        let result = decoder.decompress(&invalid_data).await;
299
300        assert_matches!(result, Err(Error::Decompress(..)));
301    }
302
303    #[fuchsia::test]
304    async fn test_send_error() {
305        let (mut encoder, rx) = Encoder::new(0);
306
307        let data = b"some_text";
308        drop(rx);
309
310        let result = encoder.compress(data).await;
311
312        assert_matches!(result, Err(Error::Send(..)));
313
314        let (mut encoder, mut rx) = Encoder::new(0);
315        let (mut decoder, drx) = Decoder::new();
316        encoder.compress(data).await.unwrap();
317        encoder.finish().await.unwrap();
318        drop(drx);
319
320        let mut compressed_data = Vec::new();
321        while let Some(chunk) = rx.next().await {
322            compressed_data.extend_from_slice(&chunk);
323        }
324
325        let result = decoder.decompress(&compressed_data).await;
326        assert_matches!(result, Err(Error::Send(..)));
327    }
328}