Skip to main content

fidl_next_codec/wire/vec/
optional.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::mem::{MaybeUninit, needs_drop};
6use core::{fmt, slice};
7
8use munge::munge;
9
10use super::raw::RawVector;
11use crate::{
12    Constrained, Decode, DecodeError, Decoder, DecoderExt as _, Encode, EncodeError, EncodeOption,
13    Encoder, EncoderExt as _, FromWire, FromWireOption, FromWireOptionRef, FromWireRef,
14    IntoNatural, Slot, ValidationError, Wire, wire,
15};
16
17/// An optional FIDL vector
18#[repr(transparent)]
19pub struct OptionalVector<'de, T> {
20    raw: RawVector<'de, T>,
21}
22
23// SAFETY: `OptionalVector` is `repr(transparent)` over `RawVector`, which implements `Wire`.
24// Lifetime erasure is safe since `OptionalVector` is covariant over its lifetime.
25unsafe impl<T: Wire> Wire for OptionalVector<'static, T> {
26    type Narrowed<'de> = OptionalVector<'de, T::Narrowed<'de>>;
27
28    #[inline]
29    fn zero_padding(out: &mut MaybeUninit<Self>) {
30        munge!(let Self { raw } = out);
31        RawVector::<T>::zero_padding(raw);
32    }
33}
34
35impl<T> Drop for OptionalVector<'_, T> {
36    fn drop(&mut self) {
37        if needs_drop::<T>() && self.is_some() {
38            // SAFETY: If the vector is present and `T` needs to be dropped, the pointer has
39            // been decoded and points to a valid slice of initialized `T` elements.
40            unsafe {
41                self.raw.as_slice_ptr().drop_in_place();
42            }
43        }
44    }
45}
46
47impl<'de, T> OptionalVector<'de, T> {
48    /// Encodes that a vector is present in a slot.
49    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
50        munge!(let Self { raw } = out);
51        RawVector::encode_present(raw, len);
52    }
53
54    /// Encodes that a vector is absent in a slot.
55    pub fn encode_absent(out: &mut MaybeUninit<Self>) {
56        munge!(let Self { raw } = out);
57        RawVector::encode_absent(raw);
58    }
59
60    /// Returns whether the vector is present.
61    pub fn is_some(&self) -> bool {
62        !self.raw.as_ptr().is_null()
63    }
64
65    /// Returns whether the vector is absent.
66    pub fn is_none(&self) -> bool {
67        !self.is_some()
68    }
69
70    /// Gets a reference to the vector, if any.
71    pub fn as_ref(&self) -> Option<&wire::Vector<'_, T>> {
72        if self.is_some() {
73            // SAFETY: `OptionalVector` and `Vector` have the same layout (`repr(transparent)`
74            // over `RawVector`). Since `self.is_some()` is true, the underlying pointer is
75            // non-null, which satisfies the invariant of `Vector`.
76            Some(unsafe { &*(self as *const Self).cast() })
77        } else {
78            None
79        }
80    }
81
82    /// Converts the optional wire vector to an `Option<WireVector>`.
83    pub fn to_option(self) -> Option<wire::Vector<'de, T>> {
84        if self.is_some() {
85            // SAFETY: `OptionalVector` and `Vector` have the same layout. Since `self.is_some()`
86            // is true, the underlying pointer is non-null, which satisfies the invariant of
87            // `Vector`.
88            Some(unsafe { core::mem::transmute::<Self, wire::Vector<'de, T>>(self) })
89        } else {
90            None
91        }
92    }
93
94    /// Decodes a wire vector which contains raw data.
95    ///
96    /// # Safety
97    ///
98    /// The elements of the wire vector must not need to be individually decoded, and must always be
99    /// valid.
100    pub unsafe fn decode_raw<D>(
101        mut slot: Slot<'_, Self>,
102        decoder: &mut D,
103        max_len: u64,
104    ) -> Result<(), DecodeError>
105    where
106        D: Decoder<'de> + ?Sized,
107        T: Decode<D>,
108    {
109        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
110
111        if wire::Pointer::is_encoded_present(ptr.as_mut())? {
112            if **len > max_len {
113                return Err(DecodeError::Validation(ValidationError::VectorTooLong {
114                    count: **len,
115                    limit: max_len,
116                }));
117            }
118
119            let slice = decoder.take_slice_slot::<T>(**len as usize)?;
120            wire::Pointer::set_decoded_slice(ptr, slice);
121        } else if *len != 0 {
122            return Err(DecodeError::InvalidOptionalSize(**len));
123        }
124
125        Ok(())
126    }
127
128    /// Validate that this vector's length falls within the limit.
129    pub(crate) fn validate_max_len(
130        slot: Slot<'_, Self>,
131        limit: u64,
132    ) -> Result<(), ValidationError> {
133        munge!(let Self { raw: RawVector { len, ptr } } = slot);
134        let count = **len;
135        let is_present = ptr.as_bytes() != [0; 8];
136        if is_present && count > limit {
137            Err(ValidationError::VectorTooLong { count, limit })
138        } else {
139            Ok(())
140        }
141    }
142}
143
144type VectorConstraint<T> = (u64, <T as Constrained>::Constraint);
145
146impl<T: Constrained> Constrained for OptionalVector<'_, T> {
147    type Constraint = VectorConstraint<T>;
148
149    fn validate(slot: Slot<'_, Self>, constraint: Self::Constraint) -> Result<(), ValidationError> {
150        let (limit, _member_constraint) = constraint;
151
152        Self::validate_max_len(slot, limit)
153    }
154}
155
156impl<T: fmt::Debug> fmt::Debug for OptionalVector<'_, T> {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        self.as_ref().fmt(f)
159    }
160}
161
162impl<T, U> PartialEq<Option<U>> for OptionalVector<'_, T>
163where
164    for<'de> wire::Vector<'de, T>: PartialEq<U>,
165{
166    fn eq(&self, other: &Option<U>) -> bool {
167        match (self.as_ref(), other.as_ref()) {
168            (Some(lhs), Some(rhs)) => lhs == rhs,
169            (None, None) => true,
170            _ => false,
171        }
172    }
173}
174
175// SAFETY: If `decode` returns `Ok`, the `OptionalVector` has been successfully decoded.
176// If present, the pointer is updated to point to a successfully decoded slice of `T`
177// allocated by the decoder. If absent, the pointer remains null and the length is 0.
178unsafe impl<'de, D, T> Decode<D> for OptionalVector<'de, T>
179where
180    D: Decoder<'de> + ?Sized,
181    T: Decode<D>,
182{
183    fn decode(
184        mut slot: Slot<'_, Self>,
185        decoder: &mut D,
186        constraint: Self::Constraint,
187    ) -> Result<(), DecodeError> {
188        munge!(let Self { raw: RawVector { len, mut ptr } } = slot.as_mut());
189
190        let (length_constraint, member_constraint) = constraint;
191
192        if wire::Pointer::is_encoded_present(ptr.as_mut())? {
193            if **len > length_constraint {
194                return Err(DecodeError::Validation(ValidationError::VectorTooLong {
195                    count: **len,
196                    limit: length_constraint,
197                }));
198            }
199
200            let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
201            for i in 0..**len as usize {
202                T::decode(slice.index(i), decoder, member_constraint)?;
203            }
204            wire::Pointer::set_decoded_slice(ptr, slice);
205        } else if *len != 0 {
206            return Err(DecodeError::InvalidOptionalSize(**len));
207        }
208
209        Ok(())
210    }
211}
212
213#[inline]
214fn encode_to_optional_vector<V, W, E, T>(
215    value: Option<V>,
216    encoder: &mut E,
217    out: &mut MaybeUninit<OptionalVector<'static, W>>,
218    constraint: VectorConstraint<W>,
219) -> Result<(), EncodeError>
220where
221    V: AsRef<[T]> + IntoIterator,
222    V::IntoIter: ExactSizeIterator,
223    V::Item: Encode<W, E>,
224    W: Wire,
225    E: Encoder + ?Sized,
226    T: Encode<W, E>,
227{
228    let (length_constraint, member_constraint) = constraint;
229
230    if let Some(value) = value {
231        let len = value.as_ref().len();
232
233        if len as u64 > length_constraint {
234            return Err(EncodeError::Validation(ValidationError::VectorTooLong {
235                count: len as u64,
236                limit: length_constraint,
237            }));
238        }
239
240        if T::COPY_OPTIMIZATION.is_enabled() {
241            let slice = value.as_ref();
242            // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit
243            // bytes and can be copied directly to the output instead of calling `encode`. This
244            // means that we may cast `&[T]` to `&[u8]` and write those bytes.
245            let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
246            encoder.write(bytes);
247        } else {
248            encoder.encode_next_iter_with_constraint(value.into_iter(), member_constraint)?;
249        }
250        OptionalVector::encode_present(out, len as u64);
251    } else {
252        OptionalVector::encode_absent(out);
253    }
254    Ok(())
255}
256
257// SAFETY: `encode_option` delegates to `encode_to_optional_vector`, which initializes the output.
258unsafe impl<W, E, T> EncodeOption<OptionalVector<'static, W>, E> for Vec<T>
259where
260    W: Wire,
261    E: Encoder + ?Sized,
262    T: Encode<W, E>,
263{
264    fn encode_option(
265        this: Option<Self>,
266        encoder: &mut E,
267        out: &mut MaybeUninit<OptionalVector<'static, W>>,
268        constraint: VectorConstraint<W>,
269    ) -> Result<(), EncodeError> {
270        encode_to_optional_vector(this, encoder, out, constraint)
271    }
272}
273
274// SAFETY: `encode_option` delegates to `encode_to_optional_vector`, which initializes the output.
275unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a Vec<T>
276where
277    W: Wire,
278    E: Encoder + ?Sized,
279    T: Encode<W, E>,
280    &'a T: Encode<W, E>,
281{
282    fn encode_option(
283        this: Option<Self>,
284        encoder: &mut E,
285        out: &mut MaybeUninit<OptionalVector<'static, W>>,
286        constraint: VectorConstraint<W>,
287    ) -> Result<(), EncodeError> {
288        encode_to_optional_vector(this, encoder, out, constraint)
289    }
290}
291
292// SAFETY: `encode_option` delegates to `encode_to_optional_vector`, which initializes the output.
293unsafe impl<W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for [T; N]
294where
295    W: Wire,
296    E: Encoder + ?Sized,
297    T: Encode<W, E>,
298{
299    fn encode_option(
300        this: Option<Self>,
301        encoder: &mut E,
302        out: &mut MaybeUninit<OptionalVector<'static, W>>,
303        constraint: VectorConstraint<W>,
304    ) -> Result<(), EncodeError> {
305        encode_to_optional_vector(this, encoder, out, constraint)
306    }
307}
308
309// SAFETY: `encode_option` delegates to `encode_to_optional_vector`, which initializes the output.
310unsafe impl<'a, W, E, T, const N: usize> EncodeOption<OptionalVector<'static, W>, E> for &'a [T; N]
311where
312    W: Wire,
313    E: Encoder + ?Sized,
314    T: Encode<W, E>,
315    &'a T: Encode<W, E>,
316{
317    fn encode_option(
318        this: Option<Self>,
319        encoder: &mut E,
320        out: &mut MaybeUninit<OptionalVector<'static, W>>,
321        constraint: VectorConstraint<W>,
322    ) -> Result<(), EncodeError> {
323        encode_to_optional_vector(this, encoder, out, constraint)
324    }
325}
326
327// SAFETY: `encode_option` delegates to `encode_to_optional_vector`, which initializes the output.
328unsafe impl<'a, W, E, T> EncodeOption<OptionalVector<'static, W>, E> for &'a [T]
329where
330    W: Wire,
331    E: Encoder + ?Sized,
332    T: Encode<W, E>,
333    &'a T: Encode<W, E>,
334{
335    fn encode_option(
336        this: Option<Self>,
337        encoder: &mut E,
338        out: &mut MaybeUninit<OptionalVector<'static, W>>,
339        constraint: VectorConstraint<W>,
340    ) -> Result<(), EncodeError> {
341        encode_to_optional_vector(this, encoder, out, constraint)
342    }
343}
344
345impl<T: FromWire<W>, W> FromWireOption<OptionalVector<'_, W>> for Vec<T> {
346    fn from_wire_option(wire: OptionalVector<'_, W>) -> Option<Self> {
347        wire.to_option().map(Vec::from_wire)
348    }
349}
350
351impl<T: IntoNatural> IntoNatural for OptionalVector<'_, T> {
352    type Natural = Option<Vec<T::Natural>>;
353}
354
355impl<T: FromWireRef<W>, W> FromWireOptionRef<OptionalVector<'_, W>> for Vec<T> {
356    fn from_wire_option_ref(wire: &OptionalVector<'_, W>) -> Option<Self> {
357        wire.as_ref().map(Vec::from_wire_ref)
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use crate::{DecoderExt as _, EncoderExt as _, chunks, wire};
364
365    #[test]
366    fn decode_optional_vec() {
367        assert_eq!(
368            chunks![
369                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
370                0x00, 0x00,
371            ]
372            .as_mut_slice()
373            .decode_with_constraint::<wire::OptionalVector<'_, wire::Uint32>>((1000, ()))
374            .unwrap()
375            .as_ref(),
376            None,
377        );
378    }
379
380    #[test]
381    fn encode_optional_vec() {
382        assert_eq!(
383            Vec::encode_with_constraint(None::<Vec<u32>>, (1000, ())).unwrap(),
384            chunks![
385                0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
386                0x00, 0x00,
387            ],
388        );
389    }
390}