1use std::cmp;
2use std::io;
3use std::io::prelude::*;
4
5#[cfg(feature = "tokio")]
6use futures::Poll;
7#[cfg(feature = "tokio")]
8use tokio_io::{AsyncRead, AsyncWrite};
9
10use super::bufread::{corrupt, read_gz_header};
11use super::{GzBuilder, GzHeader};
12use crate::crc::{Crc, CrcWriter};
13use crate::zio;
14use crate::{Compress, Compression, Decompress, Status};
15
16#[derive(Debug)]
39pub struct GzEncoder<W: Write> {
40 inner: zio::Writer<W, Compress>,
41 crc: Crc,
42 crc_bytes_written: usize,
43 header: Vec<u8>,
44}
45
46pub fn gz_encoder<W: Write>(header: Vec<u8>, w: W, lvl: Compression) -> GzEncoder<W> {
47 GzEncoder {
48 inner: zio::Writer::new(w, Compress::new(lvl, false)),
49 crc: Crc::new(),
50 header: header,
51 crc_bytes_written: 0,
52 }
53}
54
55impl<W: Write> GzEncoder<W> {
56 pub fn new(w: W, level: Compression) -> GzEncoder<W> {
64 GzBuilder::new().write(w, level)
65 }
66
67 pub fn get_ref(&self) -> &W {
69 self.inner.get_ref()
70 }
71
72 pub fn get_mut(&mut self) -> &mut W {
77 self.inner.get_mut()
78 }
79
80 pub fn try_finish(&mut self) -> io::Result<()> {
96 self.write_header()?;
97 self.inner.finish()?;
98
99 while self.crc_bytes_written < 8 {
100 let (sum, amt) = (self.crc.sum() as u32, self.crc.amount());
101 let buf = [
102 (sum >> 0) as u8,
103 (sum >> 8) as u8,
104 (sum >> 16) as u8,
105 (sum >> 24) as u8,
106 (amt >> 0) as u8,
107 (amt >> 8) as u8,
108 (amt >> 16) as u8,
109 (amt >> 24) as u8,
110 ];
111 let inner = self.inner.get_mut();
112 let n = inner.write(&buf[self.crc_bytes_written..])?;
113 self.crc_bytes_written += n;
114 }
115 Ok(())
116 }
117
118 pub fn finish(mut self) -> io::Result<W> {
132 self.try_finish()?;
133 Ok(self.inner.take_inner())
134 }
135
136 fn write_header(&mut self) -> io::Result<()> {
137 while self.header.len() > 0 {
138 let n = self.inner.get_mut().write(&self.header)?;
139 self.header.drain(..n);
140 }
141 Ok(())
142 }
143}
144
145impl<W: Write> Write for GzEncoder<W> {
146 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
147 assert_eq!(self.crc_bytes_written, 0);
148 self.write_header()?;
149 let n = self.inner.write(buf)?;
150 self.crc.update(&buf[..n]);
151 Ok(n)
152 }
153
154 fn flush(&mut self) -> io::Result<()> {
155 assert_eq!(self.crc_bytes_written, 0);
156 self.write_header()?;
157 self.inner.flush()
158 }
159}
160
161#[cfg(feature = "tokio")]
162impl<W: AsyncWrite> AsyncWrite for GzEncoder<W> {
163 fn shutdown(&mut self) -> Poll<(), io::Error> {
164 self.try_finish()?;
165 self.get_mut().shutdown()
166 }
167}
168
169impl<R: Read + Write> Read for GzEncoder<R> {
170 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
171 self.get_mut().read(buf)
172 }
173}
174
175#[cfg(feature = "tokio")]
176impl<R: AsyncRead + AsyncWrite> AsyncRead for GzEncoder<R> {}
177
178impl<W: Write> Drop for GzEncoder<W> {
179 fn drop(&mut self) {
180 if self.inner.is_present() {
181 let _ = self.try_finish();
182 }
183 }
184}
185
186#[derive(Debug)]
219pub struct GzDecoder<W: Write> {
220 inner: zio::Writer<CrcWriter<W>, Decompress>,
221 crc_bytes: Vec<u8>,
222 header: Option<GzHeader>,
223 header_buf: Vec<u8>,
224}
225
226const CRC_BYTES_LEN: usize = 8;
227
228impl<W: Write> GzDecoder<W> {
229 pub fn new(w: W) -> GzDecoder<W> {
234 GzDecoder {
235 inner: zio::Writer::new(CrcWriter::new(w), Decompress::new(false)),
236 crc_bytes: Vec::with_capacity(CRC_BYTES_LEN),
237 header: None,
238 header_buf: Vec::new(),
239 }
240 }
241
242 pub fn header(&self) -> Option<&GzHeader> {
244 self.header.as_ref()
245 }
246
247 pub fn get_ref(&self) -> &W {
249 self.inner.get_ref().get_ref()
250 }
251
252 pub fn get_mut(&mut self) -> &mut W {
257 self.inner.get_mut().get_mut()
258 }
259
260 pub fn try_finish(&mut self) -> io::Result<()> {
276 self.finish_and_check_crc()?;
277 Ok(())
278 }
279
280 pub fn finish(mut self) -> io::Result<W> {
296 self.finish_and_check_crc()?;
297 Ok(self.inner.take_inner().into_inner())
298 }
299
300 fn finish_and_check_crc(&mut self) -> io::Result<()> {
301 self.inner.finish()?;
302
303 if self.crc_bytes.len() != 8 {
304 return Err(corrupt());
305 }
306
307 let crc = ((self.crc_bytes[0] as u32) << 0)
308 | ((self.crc_bytes[1] as u32) << 8)
309 | ((self.crc_bytes[2] as u32) << 16)
310 | ((self.crc_bytes[3] as u32) << 24);
311 let amt = ((self.crc_bytes[4] as u32) << 0)
312 | ((self.crc_bytes[5] as u32) << 8)
313 | ((self.crc_bytes[6] as u32) << 16)
314 | ((self.crc_bytes[7] as u32) << 24);
315 if crc != self.inner.get_ref().crc().sum() as u32 {
316 return Err(corrupt());
317 }
318 if amt != self.inner.get_ref().crc().amount() {
319 return Err(corrupt());
320 }
321 Ok(())
322 }
323}
324
325struct Counter<T: Read> {
326 inner: T,
327 pos: usize,
328}
329
330impl<T: Read> Read for Counter<T> {
331 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
332 let pos = self.inner.read(buf)?;
333 self.pos += pos;
334 Ok(pos)
335 }
336}
337
338impl<W: Write> Write for GzDecoder<W> {
339 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
340 if self.header.is_none() {
341 let (res, pos) = {
343 let mut counter = Counter {
344 inner: self.header_buf.chain(buf),
345 pos: 0,
346 };
347 let res = read_gz_header(&mut counter);
348 (res, counter.pos)
349 };
350
351 match res {
352 Err(err) => {
353 if err.kind() == io::ErrorKind::UnexpectedEof {
354 self.header_buf.extend(buf);
356 Ok(buf.len())
357 } else {
358 Err(err)
359 }
360 }
361 Ok(header) => {
362 self.header = Some(header);
363 let pos = pos - self.header_buf.len();
364 self.header_buf.truncate(0);
365 Ok(pos)
366 }
367 }
368 } else {
369 let (n, status) = self.inner.write_with_status(buf)?;
370
371 if status == Status::StreamEnd {
372 if n < buf.len() && self.crc_bytes.len() < 8 {
373 let remaining = buf.len() - n;
374 let crc_bytes = cmp::min(remaining, CRC_BYTES_LEN - self.crc_bytes.len());
375 self.crc_bytes.extend(&buf[n..n + crc_bytes]);
376 return Ok(n + crc_bytes);
377 }
378 }
379 Ok(n)
380 }
381 }
382
383 fn flush(&mut self) -> io::Result<()> {
384 self.inner.flush()
385 }
386}
387
388#[cfg(feature = "tokio")]
389impl<W: AsyncWrite> AsyncWrite for GzDecoder<W> {
390 fn shutdown(&mut self) -> Poll<(), io::Error> {
391 self.try_finish()?;
392 self.inner.get_mut().get_mut().shutdown()
393 }
394}
395
396impl<W: Read + Write> Read for GzDecoder<W> {
397 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
398 self.inner.get_mut().get_mut().read(buf)
399 }
400}
401
402#[cfg(feature = "tokio")]
403impl<W: AsyncRead + AsyncWrite> AsyncRead for GzDecoder<W> {}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 const STR: &'static str = "Hello World Hello World Hello World Hello World Hello World \
410 Hello World Hello World Hello World Hello World Hello World \
411 Hello World Hello World Hello World Hello World Hello World \
412 Hello World Hello World Hello World Hello World Hello World \
413 Hello World Hello World Hello World Hello World Hello World";
414
415 #[test]
416 fn decode_writer_one_chunk() {
417 let mut e = GzEncoder::new(Vec::new(), Compression::default());
418 e.write(STR.as_ref()).unwrap();
419 let bytes = e.finish().unwrap();
420
421 let mut writer = Vec::new();
422 let mut decoder = GzDecoder::new(writer);
423 let n = decoder.write(&bytes[..]).unwrap();
424 decoder.write(&bytes[n..]).unwrap();
425 decoder.try_finish().unwrap();
426 writer = decoder.finish().unwrap();
427 let return_string = String::from_utf8(writer).expect("String parsing error");
428 assert_eq!(return_string, STR);
429 }
430
431 #[test]
432 fn decode_writer_partial_header() {
433 let mut e = GzEncoder::new(Vec::new(), Compression::default());
434 e.write(STR.as_ref()).unwrap();
435 let bytes = e.finish().unwrap();
436
437 let mut writer = Vec::new();
438 let mut decoder = GzDecoder::new(writer);
439 assert_eq!(decoder.write(&bytes[..5]).unwrap(), 5);
440 let n = decoder.write(&bytes[5..]).unwrap();
441 if n < bytes.len() - 5 {
442 decoder.write(&bytes[n + 5..]).unwrap();
443 }
444 writer = decoder.finish().unwrap();
445 let return_string = String::from_utf8(writer).expect("String parsing error");
446 assert_eq!(return_string, STR);
447 }
448
449 #[test]
450 fn decode_writer_exact_header() {
451 let mut e = GzEncoder::new(Vec::new(), Compression::default());
452 e.write(STR.as_ref()).unwrap();
453 let bytes = e.finish().unwrap();
454
455 let mut writer = Vec::new();
456 let mut decoder = GzDecoder::new(writer);
457 assert_eq!(decoder.write(&bytes[..10]).unwrap(), 10);
458 decoder.write(&bytes[10..]).unwrap();
459 writer = decoder.finish().unwrap();
460 let return_string = String::from_utf8(writer).expect("String parsing error");
461 assert_eq!(return_string, STR);
462 }
463
464 #[test]
465 fn decode_writer_partial_crc() {
466 let mut e = GzEncoder::new(Vec::new(), Compression::default());
467 e.write(STR.as_ref()).unwrap();
468 let bytes = e.finish().unwrap();
469
470 let mut writer = Vec::new();
471 let mut decoder = GzDecoder::new(writer);
472 let l = bytes.len() - 5;
473 let n = decoder.write(&bytes[..l]).unwrap();
474 decoder.write(&bytes[n..]).unwrap();
475 writer = decoder.finish().unwrap();
476 let return_string = String::from_utf8(writer).expect("String parsing error");
477 assert_eq!(return_string, STR);
478 }
479}