stream_processor_test/
output_validator.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use crate::FatalError;
6use anyhow::Error;
7use async_trait::async_trait;
8use fidl_fuchsia_media::{FormatDetails, StreamOutputFormat};
9use fidl_table_validation::*;
10use fuchsia_stream_processors::*;
11use hex::{decode, encode};
12use log::info;
13use mundane::hash::{Digest, Hasher, Sha256};
14use num_traits::PrimInt;
15use std::fmt;
16use std::io::Write;
17use std::rc::Rc;
18
19#[derive(ValidFidlTable, Debug, PartialEq)]
20#[fidl_table_src(StreamOutputFormat)]
21pub struct ValidStreamOutputFormat {
22    pub stream_lifetime_ordinal: u64,
23    pub format_details: FormatDetails,
24}
25
26/// An output packet from the stream.
27#[derive(Debug, PartialEq)]
28pub struct OutputPacket {
29    pub data: Vec<u8>,
30    pub format: Rc<ValidStreamOutputFormat>,
31    pub packet: ValidPacket,
32}
33
34/// Returns all the packets in the output with preserved order.
35pub fn output_packets(output: &[Output]) -> impl Iterator<Item = &OutputPacket> {
36    output.iter().filter_map(|output| match output {
37        Output::Packet(packet) => Some(packet),
38        _ => None,
39    })
40}
41
42/// Output represents any output from a stream we might want to validate programmatically.
43///
44/// This may extend to contain not just explicit events but certain stream control behaviors or
45/// even errors.
46#[derive(Debug, PartialEq)]
47pub enum Output {
48    Packet(OutputPacket),
49    Eos { stream_lifetime_ordinal: u64 },
50    CodecChannelClose,
51}
52
53/// Checks all output packets, which are provided to the validator in the order in which they
54/// were received from the stream processor.
55///
56/// Failure should be indicated by returning an error, not by panic, so that the full context of
57/// the error will be available in the failure output.
58#[async_trait(?Send)]
59pub trait OutputValidator {
60    async fn validate(&self, output: &[Output]) -> Result<(), Error>;
61}
62
63/// Validates that the output contains the expected number of packets.
64pub struct OutputPacketCountValidator {
65    pub expected_output_packet_count: usize,
66}
67
68#[async_trait(?Send)]
69impl OutputValidator for OutputPacketCountValidator {
70    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
71        let actual_output_packet_count: usize = output
72            .iter()
73            .filter(|output| match output {
74                Output::Packet(_) => true,
75                _ => false,
76            })
77            .count();
78
79        if actual_output_packet_count != self.expected_output_packet_count {
80            return Err(FatalError(format!(
81                "actual output packet count: {}; expected output packet count: {}",
82                actual_output_packet_count, self.expected_output_packet_count
83            ))
84            .into());
85        }
86
87        Ok(())
88    }
89}
90
91/// Validates that the output contains the expected number of bytes.
92pub struct OutputDataSizeValidator {
93    pub expected_output_data_size: usize,
94}
95
96#[async_trait(?Send)]
97impl OutputValidator for OutputDataSizeValidator {
98    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
99        let actual_output_data_size: usize = output
100            .iter()
101            .map(|output| match output {
102                Output::Packet(p) => p.data.len(),
103                _ => 0,
104            })
105            .sum();
106
107        if actual_output_data_size != self.expected_output_data_size {
108            return Err(FatalError(format!(
109                "actual output data size: {}; expected output data size: {}",
110                actual_output_data_size, self.expected_output_data_size
111            ))
112            .into());
113        }
114
115        Ok(())
116    }
117}
118
119/// Validates that a stream terminates with Eos.
120pub struct TerminatesWithValidator {
121    pub expected_terminal_output: Output,
122}
123
124#[async_trait(?Send)]
125impl OutputValidator for TerminatesWithValidator {
126    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
127        let actual_terminal_output = output.last().ok_or(FatalError(format!(
128            "In terminal output: expected {:?}; found: None",
129            Some(&self.expected_terminal_output)
130        )))?;
131
132        if *actual_terminal_output == self.expected_terminal_output {
133            Ok(())
134        } else {
135            Err(FatalError(format!(
136                "In terminal output: expected {:?}; found: {:?}",
137                Some(&self.expected_terminal_output),
138                actual_terminal_output
139            ))
140            .into())
141        }
142    }
143}
144
145/// Validates that an output's format matches expected
146pub struct FormatValidator {
147    pub expected_format: FormatDetails,
148}
149
150#[async_trait(?Send)]
151impl OutputValidator for FormatValidator {
152    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
153        let packets: Vec<&OutputPacket> = output_packets(output).collect();
154        let format = &packets
155            .first()
156            .ok_or(FatalError(String::from("No packets in output")))?
157            .format
158            .format_details;
159
160        if self.expected_format != *format {
161            return Err(FatalError(format!(
162                "Expected {:?}; got {:?}",
163                self.expected_format, format
164            ))
165            .into());
166        }
167
168        Ok(())
169    }
170}
171
172/// Validates that an output's data exactly matches an expected hash, including oob_bytes
173pub struct BytesValidator {
174    pub output_file: Option<&'static str>,
175    pub expected_digests: Vec<ExpectedDigest>,
176}
177
178impl BytesValidator {
179    fn write_and_hash(
180        &self,
181        mut writer: impl Write,
182        oob: &[u8],
183        packets: &[&OutputPacket],
184    ) -> Result<(), Error> {
185        let mut hasher = Sha256::default();
186
187        hasher.update(oob);
188
189        for packet in packets {
190            writer.write_all(&packet.data)?;
191            hasher.update(&packet.data);
192        }
193        writer.flush()?;
194
195        let digest = hasher.finish().bytes();
196
197        if let None = self.expected_digests.iter().find(|e| e.bytes == digest) {
198            return Err(FatalError(format!(
199                "Expected one of {:?}; got {}",
200                self.expected_digests,
201                encode(digest)
202            ))
203            .into());
204        }
205
206        Ok(())
207    }
208}
209
210fn output_writer(output_file: Option<&'static str>) -> Result<impl Write, Error> {
211    Ok(if let Some(file) = output_file {
212        Box::new(std::fs::File::create(file)?) as Box<dyn Write>
213    } else {
214        Box::new(std::io::sink()) as Box<dyn Write>
215    })
216}
217
218#[async_trait(?Send)]
219impl OutputValidator for BytesValidator {
220    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
221        let packets: Vec<&OutputPacket> = output_packets(output).collect();
222        let oob = packets
223            .first()
224            .ok_or(FatalError(String::from("No packets in output")))?
225            .format
226            .format_details
227            .oob_bytes
228            .clone()
229            .unwrap_or(vec![]);
230
231        self.write_and_hash(output_writer(self.output_file)?, oob.as_slice(), &packets)
232    }
233}
234
235#[derive(Clone)]
236pub struct ExpectedDigest {
237    pub label: &'static str,
238    pub bytes: <<Sha256 as Hasher>::Digest as Digest>::Bytes,
239    pub per_frame_bytes: Option<Vec<<<Sha256 as Hasher>::Digest as Digest>::Bytes>>,
240}
241
242impl ExpectedDigest {
243    pub fn new(label: &'static str, hex: impl AsRef<[u8]>) -> Self {
244        Self {
245            label,
246            bytes: decode(hex)
247                .expect("Decoding static compile-time test hash as valid hex")
248                .as_slice()
249                .try_into()
250                .expect("Taking 32 bytes from compile-time test hash"),
251            per_frame_bytes: None,
252        }
253    }
254    pub fn new_with_per_frame_digest(
255        label: &'static str,
256        hex: impl AsRef<[u8]>,
257        per_frame_hexen: Vec<impl AsRef<[u8]>>,
258    ) -> Self {
259        Self {
260            per_frame_bytes: Some(
261                per_frame_hexen
262                    .into_iter()
263                    .map(|per_frame_hex| {
264                        decode(per_frame_hex)
265                            .expect("Decoding static compile-time test hash as valid hex")
266                            .as_slice()
267                            .try_into()
268                            .expect("Taking 32 bytes from compile-time test hash")
269                    })
270                    .collect(),
271            ),
272            ..Self::new(label, hex)
273        }
274    }
275
276    pub fn new_from_raw(label: &'static str, raw_data: Vec<u8>) -> Self {
277        Self {
278            label,
279            bytes: <Sha256 as Hasher>::hash(raw_data.as_slice()).bytes(),
280            per_frame_bytes: None,
281        }
282    }
283}
284
285impl fmt::Display for ExpectedDigest {
286    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
287        write!(w, "{:?}", self)
288    }
289}
290
291impl fmt::Debug for ExpectedDigest {
292    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
293        write!(w, "ExpectedDigest {{\n")?;
294        write!(w, "\tlabel: {}", self.label)?;
295        write!(w, "\tbytes: {}", encode(self.bytes))?;
296        write!(w, "}}")
297    }
298}
299
300/// Validates that the RMSE of output data and the expected data
301/// falls within an acceptable range.
302#[allow(unused)]
303pub struct RmseValidator<T> {
304    pub output_file: Option<&'static str>,
305    pub expected_data: Vec<T>,
306    pub expected_rmse: f64,
307    // By how much percentage should we allow the calculated RMSE value to
308    // differ from the expected RMSE.
309    pub rmse_diff_tolerance: f64,
310    pub data_len_diff_tolerance: u32,
311    pub output_converter: fn(Vec<u8>) -> Vec<T>,
312}
313
314pub fn calculate_rmse<T: PrimInt + std::fmt::Debug>(
315    expected_data: &[T],
316    actual_data: &[T],
317    acceptable_len_diff: u32,
318) -> Result<f64, Error> {
319    // There could be a slight difference to the length of the expected data
320    // and the actual data due to the way some codecs deal with left over data
321    // at the end of the stream. This can be caused by minimum block size and
322    // how some codecs may choose to pad out the last block or insert a silence
323    // data at the start. Ensure the difference in length between expected and
324    // actual data is not too much.
325    let compare_len = std::cmp::min(expected_data.len(), actual_data.len());
326    if std::cmp::max(expected_data.len(), actual_data.len()) - compare_len
327        > acceptable_len_diff.try_into().unwrap()
328    {
329        return Err(FatalError(format!(
330            "Expected data (len {}) and the actual data (len {}) have significant length difference and cannot be compared.",
331            expected_data.len(), actual_data.len(),
332        )).into());
333    }
334    let expected_data = &expected_data[..compare_len];
335    let actual_data = &actual_data[..compare_len];
336
337    let mut rmse = 0.0;
338    let mut n = 0;
339    for data in std::iter::zip(actual_data.iter(), expected_data.iter()) {
340        let b1: f64 = num_traits::cast::cast(*data.0).unwrap();
341        let b2: f64 = num_traits::cast::cast(*data.1).unwrap();
342        rmse += (b1 - b2).powi(2);
343        n += 1;
344    }
345    Ok((rmse / n as f64).sqrt())
346}
347
348impl<T: PrimInt + std::fmt::Debug> RmseValidator<T> {
349    fn write_and_calc_rsme(
350        &self,
351        mut writer: impl Write,
352        packets: &[&OutputPacket],
353    ) -> Result<(), Error> {
354        let mut output_data: Vec<u8> = Vec::new();
355        for packet in packets {
356            writer.write_all(&packet.data)?;
357            packet.data.iter().for_each(|item| output_data.push(*item));
358        }
359
360        let actual_data = (self.output_converter)(output_data);
361
362        let rmse = calculate_rmse(
363            self.expected_data.as_slice(),
364            actual_data.as_slice(),
365            self.data_len_diff_tolerance,
366        )?;
367        info!("RMSE is {}", rmse);
368        if (rmse - self.expected_rmse).abs() > self.rmse_diff_tolerance {
369            return Err(FatalError(format!(
370                "expected rmse: {}; actual rmse: {}; rmse diff tolerance {}",
371                self.expected_rmse, rmse, self.rmse_diff_tolerance,
372            ))
373            .into());
374        }
375        Ok(())
376    }
377}
378
379#[async_trait(?Send)]
380impl<T: PrimInt + std::fmt::Debug> OutputValidator for RmseValidator<T> {
381    async fn validate(&self, output: &[Output]) -> Result<(), Error> {
382        let packets: Vec<&OutputPacket> = output_packets(output).collect();
383        self.write_and_calc_rsme(output_writer(self.output_file)?, &packets)
384    }
385}