fidl_next_codec/wire/vec/
required.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::marker::PhantomData;
6use core::mem::{forget, needs_drop, MaybeUninit};
7use core::ops::Deref;
8use core::ptr::{copy_nonoverlapping, NonNull};
9use core::{fmt, slice};
10
11use munge::munge;
12
13use super::raw::RawWireVector;
14use crate::{
15    Chunk, Decode, DecodeError, Decoder, DecoderExt as _, Encodable, Encode, EncodeError,
16    EncodeRef, Encoder, EncoderExt as _, FromWire, FromWireRef, Slot, Wire, WirePointer,
17};
18
19/// A FIDL vector
20#[repr(transparent)]
21pub struct WireVector<'de, T> {
22    pub(crate) raw: RawWireVector<'de, T>,
23}
24
25unsafe impl<T: Wire> Wire for WireVector<'static, T> {
26    type Decoded<'de> = WireVector<'de, T::Decoded<'de>>;
27
28    #[inline]
29    fn zero_padding(out: &mut MaybeUninit<Self>) {
30        munge!(let Self { raw } = out);
31        RawWireVector::<T>::zero_padding(raw);
32    }
33}
34
35impl<T> Drop for WireVector<'_, T> {
36    fn drop(&mut self) {
37        if needs_drop::<T>() {
38            unsafe {
39                self.raw.as_slice_ptr().drop_in_place();
40            }
41        }
42    }
43}
44
45impl<T> WireVector<'_, T> {
46    /// Encodes that a vector is present in a slot.
47    pub fn encode_present(out: &mut MaybeUninit<Self>, len: u64) {
48        munge!(let Self { raw } = out);
49        RawWireVector::encode_present(raw, len);
50    }
51
52    /// Returns the length of the vector in elements.
53    pub fn len(&self) -> usize {
54        self.raw.len() as usize
55    }
56
57    /// Returns whether the vector is empty.
58    pub fn is_empty(&self) -> bool {
59        self.len() == 0
60    }
61
62    /// Returns a pointer to the elements of the vector.
63    fn as_slice_ptr(&self) -> NonNull<[T]> {
64        unsafe { NonNull::new_unchecked(self.raw.as_slice_ptr()) }
65    }
66
67    /// Returns a slice of the elements of the vector.
68    pub fn as_slice(&self) -> &[T] {
69        unsafe { self.as_slice_ptr().as_ref() }
70    }
71
72    /// Decodes a wire vector which contains raw data.
73    ///
74    /// # Safety
75    ///
76    /// The elements of the wire vector must not need to be individually decoded, and must always be
77    /// valid.
78    pub unsafe fn decode_raw<D>(
79        mut slot: Slot<'_, Self>,
80        mut decoder: &mut D,
81    ) -> Result<(), DecodeError>
82    where
83        D: Decoder + ?Sized,
84        T: Decode<D>,
85    {
86        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
87
88        if !WirePointer::is_encoded_present(ptr.as_mut())? {
89            return Err(DecodeError::RequiredValueAbsent);
90        }
91
92        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
93        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
94
95        Ok(())
96    }
97}
98
99/// An iterator over the items of a `WireVector`.
100pub struct IntoIter<'de, T> {
101    current: *mut T,
102    remaining: usize,
103    _phantom: PhantomData<&'de mut [Chunk]>,
104}
105
106impl<T> Drop for IntoIter<'_, T> {
107    fn drop(&mut self) {
108        for i in 0..self.remaining {
109            unsafe {
110                self.current.add(i).drop_in_place();
111            }
112        }
113    }
114}
115
116impl<T> Iterator for IntoIter<'_, T> {
117    type Item = T;
118
119    fn next(&mut self) -> Option<Self::Item> {
120        if self.remaining == 0 {
121            None
122        } else {
123            let result = unsafe { self.current.read() };
124            self.current = unsafe { self.current.add(1) };
125            self.remaining -= 1;
126            Some(result)
127        }
128    }
129}
130
131impl<'de, T> IntoIterator for WireVector<'de, T> {
132    type IntoIter = IntoIter<'de, T>;
133    type Item = T;
134
135    fn into_iter(self) -> Self::IntoIter {
136        let current = self.raw.as_ptr();
137        let remaining = self.len();
138        forget(self);
139
140        IntoIter { current, remaining, _phantom: PhantomData }
141    }
142}
143
144impl<T> Deref for WireVector<'_, T> {
145    type Target = [T];
146
147    fn deref(&self) -> &Self::Target {
148        self.as_slice()
149    }
150}
151
152impl<T: fmt::Debug> fmt::Debug for WireVector<'_, T> {
153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154        self.as_slice().fmt(f)
155    }
156}
157
158unsafe impl<D: Decoder + ?Sized, T: Decode<D>> Decode<D> for WireVector<'static, T> {
159    fn decode(mut slot: Slot<'_, Self>, mut decoder: &mut D) -> Result<(), DecodeError> {
160        munge!(let Self { raw: RawWireVector { len, mut ptr } } = slot.as_mut());
161
162        if !WirePointer::is_encoded_present(ptr.as_mut())? {
163            return Err(DecodeError::RequiredValueAbsent);
164        }
165
166        let mut slice = decoder.take_slice_slot::<T>(**len as usize)?;
167        for i in 0..**len as usize {
168            T::decode(slice.index(i), decoder)?;
169        }
170        WirePointer::set_decoded(ptr, slice.as_mut_ptr().cast());
171
172        Ok(())
173    }
174}
175
176#[inline]
177fn encode_to_vector<V, E, T>(
178    value: V,
179    encoder: &mut E,
180    out: &mut MaybeUninit<WireVector<'_, T::Encoded>>,
181) -> Result<(), EncodeError>
182where
183    V: AsRef<[T]> + IntoIterator,
184    V::IntoIter: ExactSizeIterator,
185    V::Item: Encode<E, Encoded = T::Encoded>,
186    E: Encoder + ?Sized,
187    T: Encode<E>,
188{
189    let len = value.as_ref().len();
190    if T::COPY_OPTIMIZATION.is_enabled() {
191        let slice = value.as_ref();
192        // SAFETY: `T` has copy optimization enabled, which guarantees that it has no uninit bytes
193        // and can be copied directly to the output instead of calling `encode`. This means that we
194        // may cast `&[T]` to `&[u8]` and write those bytes.
195        let bytes = unsafe { slice::from_raw_parts(slice.as_ptr().cast(), size_of_val(slice)) };
196        encoder.write(bytes);
197    } else {
198        encoder.encode_next_iter(value.into_iter())?;
199    }
200    WireVector::encode_present(out, len as u64);
201    Ok(())
202}
203
204impl<T: Encodable> Encodable for Vec<T> {
205    type Encoded = WireVector<'static, T::Encoded>;
206}
207
208unsafe impl<E, T> Encode<E> for Vec<T>
209where
210    E: Encoder + ?Sized,
211    T: Encode<E>,
212{
213    fn encode(
214        self,
215        encoder: &mut E,
216        out: &mut MaybeUninit<Self::Encoded>,
217    ) -> Result<(), EncodeError> {
218        encode_to_vector(self, encoder, out)
219    }
220}
221
222unsafe impl<E, T> EncodeRef<E> for Vec<T>
223where
224    E: Encoder + ?Sized,
225    T: EncodeRef<E>,
226{
227    fn encode_ref(
228        &self,
229        encoder: &mut E,
230        out: &mut MaybeUninit<Self::Encoded>,
231    ) -> Result<(), EncodeError> {
232        encode_to_vector(self, encoder, out)
233    }
234}
235
236impl<T: Encodable> Encodable for &[T] {
237    type Encoded = WireVector<'static, T::Encoded>;
238}
239
240unsafe impl<E, T> Encode<E> for &[T]
241where
242    E: Encoder + ?Sized,
243    T: EncodeRef<E>,
244{
245    fn encode(
246        self,
247        encoder: &mut E,
248        out: &mut MaybeUninit<Self::Encoded>,
249    ) -> Result<(), EncodeError> {
250        encode_to_vector(self, encoder, out)
251    }
252}
253
254impl<T: FromWire<W>, W> FromWire<WireVector<'_, W>> for Vec<T> {
255    fn from_wire(wire: WireVector<'_, W>) -> Self {
256        let mut result = Vec::<T>::with_capacity(wire.len());
257        if T::COPY_OPTIMIZATION.is_enabled() {
258            unsafe {
259                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
260            }
261            unsafe {
262                result.set_len(wire.len());
263            }
264            forget(wire);
265        } else {
266            for item in wire.into_iter() {
267                result.push(T::from_wire(item));
268            }
269        }
270        result
271    }
272}
273
274impl<T: FromWireRef<W>, W> FromWireRef<WireVector<'_, W>> for Vec<T> {
275    fn from_wire_ref(wire: &WireVector<'_, W>) -> Self {
276        let mut result = Vec::<T>::with_capacity(wire.len());
277        if T::COPY_OPTIMIZATION.is_enabled() {
278            unsafe {
279                copy_nonoverlapping(wire.as_ptr().cast(), result.as_mut_ptr(), wire.len());
280            }
281            unsafe {
282                result.set_len(wire.len());
283            }
284        } else {
285            for item in wire.iter() {
286                result.push(T::from_wire_ref(item));
287            }
288        }
289        result
290    }
291}