1use crate::file::{AsyncGetSize, AsyncReadAt};
6use fidl_fuchsia_io as fio;
7use pin_project::pin_project;
8use std::cmp::min;
9use std::convert::TryInto as _;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13trait UsizeExt {
15 fn add(self, rhs: usize) -> Result<usize, std::io::Error>;
16}
17
18impl UsizeExt for usize {
19 fn add(self, rhs: usize) -> Result<usize, std::io::Error> {
20 self.checked_add(rhs).ok_or_else(|| std::io::Error::other("usize addition overflowed"))
21 }
22}
23
24fn u64_to_usize_safe(u: u64) -> usize {
25 let ret: usize = u.try_into().unwrap();
26 static_assertions::assert_eq_size_val!(u, ret);
27 ret
28}
29
30#[pin_project]
42pub struct BufferedAsyncReadAt<T> {
43 #[pin]
44 wrapped: T,
45 offset: usize,
47 len: usize,
49 cache: Option<Box<[u8; fio::MAX_TRANSFER_SIZE as usize]>>,
50}
51
52impl<T> BufferedAsyncReadAt<T> {
53 pub fn new(wrapped: T) -> Self {
54 Self { wrapped, offset: 0, len: 0, cache: None }
55 }
56}
57
58impl<T: AsyncReadAt> AsyncReadAt for BufferedAsyncReadAt<T> {
59 fn poll_read_at(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 offset_u64: u64,
63 buf: &mut [u8],
64 ) -> Poll<std::io::Result<usize>> {
65 let this = self.project();
66 let offset = u64_to_usize_safe(offset_u64);
67
68 let cache =
69 this.cache.get_or_insert_with(|| Box::new([0u8; fio::MAX_TRANSFER_SIZE as usize]));
70
71 if *this.offset <= offset && this.offset.add(*this.len)? > offset {
72 let start = offset - *this.offset;
73 let n = min(buf.len(), *this.len - start);
74 let () = buf[..n].copy_from_slice(&cache[start..start + n]);
75 return Poll::Ready(Ok(n));
76 }
77
78 match this.wrapped.poll_read_at(cx, offset_u64, &mut cache[..]) {
81 Poll::Pending => return Poll::Pending,
82 Poll::Ready(Ok(len)) => {
83 *this.offset = offset;
84 *this.len = len;
85 let n = min(len, buf.len());
86 let () = buf[..n].copy_from_slice(&cache[..n]);
87 return Poll::Ready(Ok(n));
88 }
89 p @ Poll::Ready(_) => {
90 return p;
91 }
92 }
93 }
94}
95
96impl<T: AsyncGetSize> AsyncGetSize for BufferedAsyncReadAt<T> {
97 fn poll_get_size(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
98 let this = self.project();
99 this.wrapped.poll_get_size(cx)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::file::{AsyncGetSizeExt as _, AsyncReadAtExt as _};
107 use assert_matches::assert_matches;
108 use std::cell::RefCell;
109 use std::convert::TryFrom as _;
110 use std::rc::Rc;
111
112 #[test]
113 fn max_transfer_size_fits_in_usize() {
114 assert_eq!(
115 fio::MAX_TRANSFER_SIZE,
116 u64::try_from(usize::try_from(fio::MAX_TRANSFER_SIZE).unwrap()).unwrap()
117 );
118 }
119
120 #[test]
121 fn usize_ext_add() {
122 assert_eq!(0usize.add(1).unwrap(), 1);
123 assert_matches!(usize::MAX.add(1), Err(_));
124 }
125
126 #[test]
127 fn u64_to_usize_safe() {
128 assert_eq!(super::u64_to_usize_safe(5u64), 5usize);
129 }
130
131 #[fuchsia_async::run_singlethreaded(test)]
132 async fn poll_get_size_forwards() {
133 struct Mock {
134 called: bool,
135 }
136 impl AsyncGetSize for Mock {
137 fn poll_get_size(
138 mut self: Pin<&mut Self>,
139 _: &mut Context<'_>,
140 ) -> Poll<std::io::Result<u64>> {
141 self.called = true;
142 Poll::Ready(Ok(3))
143 }
144 }
145
146 let mut reader = BufferedAsyncReadAt::new(Mock { called: false });
147
148 assert_matches!(reader.get_size().await, Ok(3));
149 assert!(reader.wrapped.called);
150 }
151
152 struct Mock {
153 recorded_offsets: Rc<RefCell<Vec<u64>>>,
154 content: Vec<u8>,
155 }
156 impl Mock {
157 fn new(content: Vec<u8>) -> (Self, Rc<RefCell<Vec<u64>>>) {
158 let recorded_offsets = Rc::new(RefCell::new(vec![]));
159 (Self { recorded_offsets: recorded_offsets.clone(), content }, recorded_offsets)
160 }
161 }
162 impl AsyncReadAt for Mock {
163 fn poll_read_at(
164 self: Pin<&mut Self>,
165 _cx: &mut Context<'_>,
166 offset: u64,
167 buf: &mut [u8],
168 ) -> Poll<std::io::Result<usize>> {
169 self.recorded_offsets.borrow_mut().push(offset);
170 let offset = super::u64_to_usize_safe(offset);
171 assert_eq!(buf.len(), usize::try_from(fio::MAX_TRANSFER_SIZE).unwrap());
172 let start = std::cmp::min(offset, self.content.len());
173 let n = std::cmp::min(buf.len(), self.content.len() - start);
174 let end = start + n;
175 buf[..n].copy_from_slice(&self.content[start..end]);
176 Poll::Ready(Ok(n))
177 }
178 }
179
180 #[fuchsia_async::run_singlethreaded(test)]
181 async fn poll_read_at_uses_cache() {
182 let (mock, offsets) = Mock::new(vec![0, 1, 2, 3, 4]);
183 let mut reader = BufferedAsyncReadAt::new(mock);
184
185 let mut buf = vec![5; 3];
187 let bytes_read = reader.read_at(1, buf.as_mut_slice()).await.unwrap();
188
189 assert_eq!(bytes_read, 3);
190 assert_eq!(buf, vec![1, 2, 3]);
191 assert_eq!(*offsets.borrow(), vec![1]);
192
193 offsets.borrow_mut().clear();
195
196 let mut buf = vec![5; 2];
197 let bytes_read = reader.read_at(1, buf.as_mut_slice()).await.unwrap();
198
199 assert_eq!(bytes_read, 2);
200 assert_eq!(buf, vec![1, 2]);
201 assert_eq!(*offsets.borrow(), Vec::<u64>::new());
202
203 let mut buf = vec![5; 2];
205 let bytes_read = reader.read_at(2, buf.as_mut_slice()).await.unwrap();
206
207 assert_eq!(bytes_read, 2);
208 assert_eq!(buf, vec![2, 3]);
209 assert_eq!(*offsets.borrow(), Vec::<u64>::new());
210
211 let mut buf = vec![5; 4];
215 let bytes_read = reader.read_at(1, buf.as_mut_slice()).await.unwrap();
216
217 assert_eq!(bytes_read, 4);
218 assert_eq!(buf, vec![1, 2, 3, 4]);
219 assert_eq!(*offsets.borrow(), Vec::<u64>::new());
220
221 let mut buf = vec![5; 1];
223 let bytes_read = reader.read_at(4, buf.as_mut_slice()).await.unwrap();
224
225 assert_eq!(bytes_read, 1);
226 assert_eq!(buf, vec![4]);
227 assert_eq!(*offsets.borrow(), Vec::<u64>::new());
228
229 let mut buf = vec![5; 3];
231 let bytes_read = reader.read_at(3, buf.as_mut_slice()).await.unwrap();
232
233 assert_eq!(bytes_read, 2);
234 assert_eq!(buf, vec![3, 4, 5]);
235 assert_eq!(*offsets.borrow(), Vec::<u64>::new());
236 }
237
238 #[fuchsia_async::run_singlethreaded(test)]
239 async fn poll_read_at_forwards() {
240 let content = (0u8..8)
241 .into_iter()
242 .cycle()
243 .take(fio::MAX_TRANSFER_SIZE.try_into().unwrap())
244 .chain([8])
245 .collect();
246 let (mock, offsets) = Mock::new(content);
247 let mut reader = BufferedAsyncReadAt::new(mock);
248
249 let mut buf = vec![9; 1];
251 let bytes_read = reader.read_at(1, buf.as_mut_slice()).await.unwrap();
252
253 assert_eq!(bytes_read, 1);
254 assert_eq!(buf, vec![1]);
255 assert_eq!(*offsets.borrow(), vec![1]);
256
257 offsets.borrow_mut().clear();
259
260 let mut buf = vec![9; 1];
261 let bytes_read = reader.read_at(0, buf.as_mut_slice()).await.unwrap();
262
263 assert_eq!(bytes_read, 1);
264 assert_eq!(buf, vec![0]);
265 assert_eq!(*offsets.borrow(), vec![0]);
266
267 offsets.borrow_mut().clear();
269
270 let mut buf = vec![9; 1];
271 let bytes_read = reader.read_at(fio::MAX_TRANSFER_SIZE, buf.as_mut_slice()).await.unwrap();
272
273 assert_eq!(bytes_read, 1);
274 assert_eq!(buf, vec![8]);
275 assert_eq!(*offsets.borrow(), vec![fio::MAX_TRANSFER_SIZE]);
276 }
277
278 #[fuchsia_async::run_singlethreaded(test)]
279 async fn poll_read_at_requested_range_ends_beyond_content() {
280 let (mock, offsets) = Mock::new(vec![0, 1, 2]);
281 let mut reader = BufferedAsyncReadAt::new(mock);
282
283 let mut buf = vec![3; 5];
284 let bytes_read = reader.read_at(0, buf.as_mut_slice()).await.unwrap();
285
286 assert_eq!(bytes_read, 3);
287 assert_eq!(buf, vec![0, 1, 2, 3, 3]);
288 assert_eq!(*offsets.borrow(), vec![0]);
289 }
290
291 #[fuchsia_async::run_singlethreaded(test)]
292 async fn poll_read_at_requested_range_starts_beyond_content() {
293 let (mock, offsets) = Mock::new(vec![0, 1, 2]);
294 let mut reader = BufferedAsyncReadAt::new(mock);
295
296 let mut buf = vec![3; 5];
297 let bytes_read = reader.read_at(3, buf.as_mut_slice()).await.unwrap();
298
299 assert_eq!(bytes_read, 0);
300 assert_eq!(buf, vec![3, 3, 3, 3, 3]);
301 assert_eq!(*offsets.borrow(), vec![3]);
302 }
303
304 #[fuchsia_async::run_singlethreaded(test)]
305 async fn poll_read_at_forwards_error() {
306 struct Mock;
307 impl AsyncReadAt for Mock {
308 fn poll_read_at(
309 self: Pin<&mut Self>,
310 _cx: &mut Context<'_>,
311 _offset: u64,
312 _buf: &mut [u8],
313 ) -> Poll<std::io::Result<usize>> {
314 Poll::Ready(Err(std::io::Error::other("BufferedAsyncReadAt forwarded the error")))
315 }
316 }
317
318 let mut reader = BufferedAsyncReadAt::new(Mock);
319
320 let mut buf = vec![0u8; 1];
321 let res = reader.read_at(0, buf.as_mut_slice()).await;
322
323 assert_matches!(res, Err(_));
324 assert_eq!(res.err().unwrap().to_string(), "BufferedAsyncReadAt forwarded the error");
325 }
326
327 #[fuchsia_async::run_singlethreaded(test)]
328 async fn poll_read_at_forwards_pending() {
329 struct Mock;
330 impl AsyncReadAt for Mock {
331 fn poll_read_at(
332 self: Pin<&mut Self>,
333 _cx: &mut Context<'_>,
334 _offset: u64,
335 _buf: &mut [u8],
336 ) -> Poll<std::io::Result<usize>> {
337 Poll::Pending
338 }
339 }
340
341 #[pin_project]
342 struct VerifyPending {
343 #[pin]
344 object_under_test: BufferedAsyncReadAt<Mock>,
345 }
346 impl futures::future::Future for VerifyPending {
347 type Output = ();
348 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
349 let this = self.project();
350 let res = this.object_under_test.poll_read_at(cx, 0, &mut [0]);
351 assert_matches!(res, Poll::Pending);
352 Poll::Ready(())
353 }
354 }
355
356 let reader = BufferedAsyncReadAt::new(Mock);
357 let verifier = VerifyPending { object_under_test: reader };
358
359 let () = verifier.await;
360 }
361}