1use super::PolicyValidationContext;
6use super::parser::{PolicyCursor, PolicyData, PolicyOffset};
7use crate::policy::{Counted, Parse, Validate};
8use hashbrown::hash_table::HashTable;
9use std::fmt::Debug;
10use std::hash::{DefaultHasher, Hash, Hasher};
11use std::marker::PhantomData;
12use std::ops::Deref;
13use zerocopy::FromBytes;
14
15pub trait HasMetadata {
21    type Metadata: FromBytes + Sized;
23}
24
25pub trait Walk {
29    fn walk(policy_data: &PolicyData, offset: PolicyOffset) -> PolicyOffset;
33}
34
35#[derive(Debug, Clone, Copy)]
40pub struct View<T> {
41    phantom: PhantomData<T>,
42
43    start: PolicyOffset,
45
46    end: PolicyOffset,
48}
49
50impl<T> View<T> {
51    pub fn new(start: PolicyOffset, end: PolicyOffset) -> Self {
53        Self { phantom: PhantomData, start, end }
54    }
55}
56
57impl<T: Sized> View<T> {
58    pub fn at(start: PolicyOffset) -> Self {
62        let end = start + std::mem::size_of::<T>() as u32;
63        Self::new(start, end)
64    }
65}
66
67impl<T: FromBytes + Sized> View<T> {
68    pub fn read(&self, policy_data: &PolicyData) -> T {
75        debug_assert_eq!(self.end - self.start, std::mem::size_of::<T>() as u32);
76        let start = self.start as usize;
77        let end = self.end as usize;
78        T::read_from_bytes(&policy_data[start..end]).unwrap()
79    }
80}
81
82impl<T: HasMetadata> View<T> {
83    pub fn metadata(&self) -> View<T::Metadata> {
87        View::<T::Metadata>::at(self.start)
88    }
89
90    pub fn read_metadata(&self, policy_data: &PolicyData) -> T::Metadata {
92        self.metadata().read(policy_data)
93    }
94}
95
96impl<T: Parse> View<T> {
97    pub fn parse(&self, policy_data: &PolicyData) -> T {
103        let cursor = PolicyCursor::new_at(policy_data.clone(), self.start);
104        let (object, _) =
105            T::parse(cursor).map_err(Into::<anyhow::Error>::into).expect("policy should be valid");
106        object
107    }
108}
109
110impl<T: Validate + Parse> Validate for View<T> {
111    type Error = anyhow::Error;
112
113    fn validate(&self, context: &mut PolicyValidationContext) -> Result<(), Self::Error> {
114        let object = self.parse(&context.data);
115        object.validate(context).map_err(Into::<anyhow::Error>::into)
116    }
117}
118
119#[derive(Debug, Clone, Copy)]
124pub struct ArrayDataView<D> {
125    phantom: PhantomData<D>,
126    start: PolicyOffset,
127    count: u32,
128}
129
130impl<D> ArrayDataView<D> {
131    pub fn new(start: PolicyOffset, count: u32) -> Self {
133        Self { phantom: PhantomData, start, count }
134    }
135
136    pub fn iter(self, policy_data: &PolicyData) -> ArrayDataViewIter<D> {
143        ArrayDataViewIter::new(policy_data.clone(), self.start, self.count)
144    }
145}
146
147pub struct ArrayDataViewIter<D> {
152    phantom: PhantomData<D>,
153    policy_data: PolicyData,
154    offset: PolicyOffset,
155    remaining: u32,
156}
157
158impl<T> ArrayDataViewIter<T> {
159    pub(crate) fn new(policy_data: PolicyData, offset: PolicyOffset, remaining: u32) -> Self {
161        Self { phantom: PhantomData, policy_data, offset, remaining }
162    }
163}
164
165impl<D: Walk> std::iter::Iterator for ArrayDataViewIter<D> {
166    type Item = View<D>;
167
168    fn next(&mut self) -> Option<Self::Item> {
169        if self.remaining > 0 {
170            let start = self.offset;
171            self.offset = D::walk(&self.policy_data, start);
172            self.remaining -= 1;
173            Some(View::new(start, self.offset))
174        } else {
175            None
176        }
177    }
178}
179
180#[derive(Debug, Clone, Copy)]
185pub(crate) struct ArrayView<M, D> {
186    phantom: PhantomData<(M, D)>,
187    start: PolicyOffset,
188    count: u32,
189}
190
191impl<M, D> ArrayView<M, D> {
192    pub fn new(start: PolicyOffset, count: u32) -> Self {
194        Self { phantom: PhantomData, start, count }
195    }
196}
197
198impl<M: Sized, D> ArrayView<M, D> {
199    pub fn metadata(&self) -> View<M> {
201        View::<M>::at(self.start)
202    }
203
204    pub fn data(&self) -> ArrayDataView<D> {
206        ArrayDataView::new(self.metadata().end, self.count)
207    }
208}
209
210fn parse_array_data<D: Parse>(
211    cursor: PolicyCursor,
212    count: u32,
213) -> Result<PolicyCursor, anyhow::Error> {
214    let mut tail = cursor;
215    for _ in 0..count {
216        let (_, next) = D::parse(tail).map_err(Into::<anyhow::Error>::into)?;
217        tail = next;
218    }
219    Ok(tail)
220}
221
222impl<M: Counted + Parse + Sized, D: Parse> Parse for ArrayView<M, D> {
223    type Error = anyhow::Error;
226
227    fn parse(cursor: PolicyCursor) -> Result<(Self, PolicyCursor), Self::Error> {
228        let start = cursor.offset();
229        let (metadata, cursor) = M::parse(cursor).map_err(Into::<anyhow::Error>::into)?;
230        let count = metadata.count();
231        let cursor = parse_array_data::<D>(cursor, count)?;
232        Ok((Self::new(start, count), cursor))
233    }
234}
235
236#[derive(Debug, Clone)]
241pub(crate) struct HashedArrayView<M, D: HasMetadata> {
242    array_view: ArrayView<M, D>,
243    index: HashTable<PolicyOffset>,
244}
245
246impl<M, D: HasMetadata> Deref for HashedArrayView<M, D> {
247    type Target = ArrayView<M, D>;
248
249    fn deref(&self) -> &Self::Target {
250        &self.array_view
251    }
252}
253
254impl<D: HasMetadata, M> HashedArrayView<M, D>
255where
256    D::Metadata: Hash,
257{
258    fn metadata_hash(metadata: &D::Metadata) -> u64 {
259        let mut hasher = DefaultHasher::new();
260        metadata.hash(&mut hasher);
261        hasher.finish()
262    }
263}
264
265impl<D: Parse + HasMetadata, M> HashedArrayView<M, D>
266where
267    D::Metadata: Eq + PartialEq + Hash + Debug,
268{
269    pub fn find(&self, key: D::Metadata, policy_data: &PolicyData) -> Option<D> {
271        let key_hash = Self::metadata_hash(&key);
272        let offset = self.index.find(key_hash, |&offset| {
273            let element = View::<D>::at(offset);
274            key == element.read_metadata(policy_data)
275        })?;
276        let element = View::<D>::at(*offset);
277        Some(element.parse(policy_data))
278    }
279}
280
281impl<M: Counted + Parse + Sized, D: Parse + HasMetadata> Parse for HashedArrayView<M, D>
282where
283    D::Metadata: Eq + Debug + PartialEq + Parse + Hash,
284{
285    type Error = anyhow::Error;
288
289    fn parse(cursor: PolicyCursor) -> Result<(Self, PolicyCursor), Self::Error> {
290        let (array_view, _) =
292            ArrayView::<M, D>::parse(cursor.clone()).map_err(Into::<anyhow::Error>::into)?;
293
294        let mut index = HashTable::with_capacity(array_view.count as usize);
296
297        let (_, mut cursor) = M::parse(cursor).map_err(Into::<anyhow::Error>::into)?;
299        for _ in 0..array_view.count {
300            let (metadata, _) =
301                D::Metadata::parse(cursor.clone()).map_err(Into::<anyhow::Error>::into)?;
302            let (_, next) = D::parse(cursor.clone()).map_err(Into::<anyhow::Error>::into)?;
303
304            index.insert_unique(Self::metadata_hash(&metadata), cursor.offset(), |&offset| {
305                let element = View::<D>::at(offset);
306                Self::metadata_hash(&element.read_metadata(cursor.data()))
307            });
308
309            cursor = next;
310        }
311
312        Ok((Self { array_view, index }, cursor))
313    }
314}