fidl_next_codec/wire/vec/
required.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::mem::{MaybeUninit, forget, needs_drop};
7use core::ops::Deref;
8use core::ptr::{NonNull, copy_nonoverlapping};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawWireVector;
14use crate::{
15    Chunk, Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encodable, Encode,
16    EncodeError, EncodeRef, Encoder, EncoderExt as _, FromWire, FromWireRef, IntoNatural, Slot,
17    ValidationError, Wire, WirePointer,
18};
19
20/// A FIDL vector
21#[repr(transparent)]
22pub struct WireVector<'de, T> {
23    raw: RawWireVector<'de, T>,
24}
25
26unsafe impl<T: Wire> Wire for WireVector<'static, T> {
27    type Decoded<'de> = WireVector<'de, T::Decoded<'de>>;
28
29    #[inline]
30    fn zero_padding(out: &mut MaybeUninit<Self>) {
31        munge!(let Self { raw } = out);
32        RawWireVector::<T>::zero_padding(raw);
33    }
34}
35
36impl<T> Drop for WireVector<'_, T> {
37    fn drop(&mut self) {
38        if needs_drop::<T>() {
39            unsafe {
40                self.raw.as_slice_ptr().drop_in_place();
41            }
42        }
43    }
44}
45
46impl<T> WireVector<'_, T> {
47    /// Encodes that a vector is present in a slot.
48    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
49        munge!(let Self { raw } = out);
50        RawWireVector::encode_present(raw, len);
51    }
52
53    /// Returns the length of the vector in elements.
54    pub fn len(&self) -> usize {
55        self.raw.len() as usize
56    }
57
58    /// Returns whether the vector is empty.
59    pub fn is_empty(&self) -> bool {
60        self.len() == 0
61    }
62
63    /// Returns a pointer to the elements of the vector.
64    fn as_slice_ptr(&self) -> NonNull<[T]> {
65        unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
66    }
67
68    /// Returns a slice of the elements of the vector.
69    pub fn as_slice(&self) -> &[T] {
70        unsafe { self.as_slice_ptr().as_ref() }
71    }
72
73    /// Decodes a wire vector which contains raw data.
74    ///
75    /// # Safety
76    ///
77    /// The elements of the wire vector must not need to be individually decoded, and must always be
78    /// valid.
79    pub unsafe fn decode_raw<D>(
80        mut slot: Slot<'_, Self>,
81        mut decoder: &mut D,
82        max_len: u64,
83    ) -> Result<(), DecodeError>
84    where
85        D: Decoder + ?Sized,
86        T: Decode<D>,
87    {
88        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
89
90        if !WirePointer::is_encoded_present(ptr.as_mut())? {
91            return Err(DecodeError::RequiredValueAbsent);
92        }
93
94        if **len > max_len {
95            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
96                count: **len,
97                limit: max_len,
98            }));
99        }
100
101        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
102        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
103
104        Ok(())
105    }
106
107    /// Validate that this vector's length falls within the limit.
108    pub(crate) fn validate_max_len(
109        slot: Slot<'_, Self>,
110        limit: u64,
111    ) -> Result<(), crate::ValidationError> {
112        munge!(let Self { raw: RawWireVector { len, ptr:_ } } = slot);
113        let count: u64 = **len;
114        if count > limit { Err(ValidationError::VectorTooLong { count, limit }) } else { Ok(()) }
115    }
116}
117
118type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
119
120impl<T: Constrained> Constrained for WireVector<'_, T> {
121    type Constraint = VectorConstraint<T>;
122
123    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
124        let (limit, _) = constraint;
125
126        munge!(let Self { raw: RawWireVector { len, ptr:_ } } = slot);
127        let count = **len;
128        if count > limit {
129            return Err(ValidationError::VectorTooLong { count, limit });
130        }
131
132        Ok(())
133    }
134}
135
136/// An iterator over the items of a `WireVector`.
137pub struct IntoIter<'de, T> {
138    current: *mut T,
139    remaining: usize,
140    _phantom: PhantomData<&'de mut [Chunk]>,
141}
142
143impl<T> Drop for IntoIter<'_, T> {
144    fn drop(&mut self) {
145        for i in 0..self.remaining {
146            unsafe {
147                self.current.add(i).drop_in_place();
148            }
149        }
150    }
151}
152
153impl<T> Iterator for IntoIter<'_, T> {
154    type Item = T;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.remaining == 0 {
158            None
159        } else {
160            let result = unsafe { self.current.read() };
161            self.current = unsafe { self.current.add(1) };
162            self.remaining -= 1;
163            Some(result)
164        }
165    }
166}
167
168impl<'de, T> IntoIterator for WireVector<'de, T> {
169    type IntoIter = IntoIter<'de, T>;
170    type Item = T;
171
172    fn into_iter(self) -> Self::IntoIter {
173        let current = self.raw.as_ptr();
174        let remaining = self.len();
175        forget(self);
176
177        IntoIter { current, remaining, _phantom: PhantomData }
178    }
179}
180
181impl<T> Deref for WireVector<'_, T> {
182    type Target = [T];
183
184    fn deref(&self) -> &Self::Target {
185        self.as_slice()
186    }
187}
188
189impl<T: fmt::Debug> fmt::Debug for WireVector<'_, T> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        self.as_slice().fmt(f)
192    }
193}
194
195unsafe impl<D: Decoder + ?Sized, T: Decode<D>> Decode<D> for WireVector<'static, T> {
196    fn decode(
197        mut slot: Slot<'_, Self>,
198        mut decoder: &mut D,
199        constraint: <Self as Constrained>::Constraint,
200    ) -> Result<(), DecodeError> {
201        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
202
203        let (length_constraint, member_constraint) = constraint;
204
205        if **len > length_constraint {
206            return Err(DecodeError::Validation(ValidationError::VectorTooLong {
207                count: **len,
208                limit: length_constraint,
209            }));
210        }
211
212        if !WirePointer::is_encoded_present(ptr.as_mut())? {
213            return Err(DecodeError::RequiredValueAbsent);
214        }
215
216        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
217        for i in 0..**len as usize {
218            T::decode(slice.index(i), decoder, member_constraint)?;
219        }
220        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
221
222        Ok(())
223    }
224}
225
226#[inline]
227fn encode_to_vector<V, E, T>(
228    value: V,
229    encoder: &mut E,
230    out: &mut MaybeUninit<WireVector<'_, T::Encoded>>,
231    constraint: VectorConstraint<T::Encoded>,
232) -> Result<(), EncodeError>
233where
234    V: AsRef<[T]> + IntoIterator,
235    V::IntoIter: ExactSizeIterator,
236    V::Item: Encode<E, Encoded = T::Encoded>,
237    E: Encoder + ?Sized,
238    T: Encodable,
239{
240    let len = value.as_ref().len();
241    let (_length_constraint, member_constraint) = constraint;
242    if T::COPY_OPTIMIZATION.is_enabled() {
243        let slice = value.as_ref();
244        // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit bytes
245        // and can be copied directly to the output instead of calling `encode`. This means that we
246        // may cast `&[T]` to `&[u8]` and write those bytes.
247        let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
248        encoder.write(bytes);
249    } else {
250        encoder.encode_next_iter(value.into_iter(), member_constraint)?;
251    }
252    WireVector::encode_present(out, len as u64);
253    Ok(())
254}
255
256impl<T: Encodable> Encodable for Vec<T> {
257    type Encoded = WireVector<'static, T::Encoded>;
258}
259
260unsafe impl<E, T> Encode<E> for Vec<T>
261where
262    E: Encoder + ?Sized,
263    T: Encode<E>,
264{
265    fn encode(
266        self,
267        encoder: &mut E,
268        out: &mut MaybeUninit<Self::Encoded>,
269        constraint: <Self::Encoded as Constrained>::Constraint,
270    ) -> Result<(), EncodeError> {
271        encode_to_vector(self, encoder, out, constraint)?;
272
273        munge! (let Self::Encoded { raw } = out);
274
275        let raw_ptr = unsafe { &*raw.as_ptr() };
276
277        for _member in unsafe { raw_ptr.as_slice_ptr().as_ref() }.unwrap() {
278            // member.validate_in_line()
279        }
280
281        Ok(())
282    }
283}
284
285unsafe impl<E, T> EncodeRef<E> for Vec<T>
286where
287    E: Encoder + ?Sized,
288    T: EncodeRef<E>,
289{
290    fn encode_ref(
291        &self,
292        encoder: &mut E,
293        out: &mut MaybeUninit<Self::Encoded>,
294        constraint: <Self::Encoded as Constrained>::Constraint,
295    ) -> Result<(), EncodeError> {
296        encode_to_vector(self, encoder, out, constraint)
297    }
298}
299
300impl<T: Encodable> Encodable for [T] {
301    type Encoded = WireVector<'static, T::Encoded>;
302}
303
304unsafe impl<E, T> EncodeRef<E> for [T]
305where
306    E: Encoder + ?Sized,
307    T: EncodeRef<E>,
308{
309    fn encode_ref(
310        &self,
311        encoder: &mut E,
312        out: &mut MaybeUninit<Self::Encoded>,
313        constraint: <Self::Encoded as Constrained>::Constraint,
314    ) -> Result<(), EncodeError> {
315        encode_to_vector(self, encoder, out, constraint)
316    }
317}
318
319impl<T: FromWire<W>, W> FromWire<WireVector<'_, W>> for Vec<T> {
320    fn from_wire(wire: WireVector<'_, W>) -> Self {
321        let mut result = Vec::<T>::with_capacity(wire.len());
322        if T::COPY_OPTIMIZATION.is_enabled() {
323            unsafe {
324                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
325            }
326            unsafe {
327                result.set_len(wire.len());
328            }
329            forget(wire);
330        } else {
331            for item in wire.into_iter() {
332                result.push(T::from_wire(item));
333            }
334        }
335        result
336    }
337}
338
339impl<T: IntoNatural> IntoNatural for WireVector<'_, T> {
340    type Natural = Vec<T::Natural>;
341}
342
343impl<T: FromWireRef<W>, W> FromWireRef<WireVector<'_, W>> for Vec<T> {
344    fn from_wire_ref(wire: &WireVector<'_, W>) -> Self {
345        let mut result = Vec::<T>::with_capacity(wire.len());
346        if T::COPY_OPTIMIZATION.is_enabled() {
347            unsafe {
348                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
349            }
350            unsafe {
351                result.set_len(wire.len());
352            }
353        } else {
354            for item in wire.iter() {
355                result.push(T::from_wire_ref(item));
356            }
357        }
358        result
359    }
360}