usb_vsock/connection/
overflow_writer.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::io::WriteHalf;
6use futures::lock::{Mutex, OwnedMutexLockFuture};
7use futures::{AsyncWrite, FutureExt};
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{Error, ErrorKind, IoSlice};
11use std::sync::Weak;
12use std::task::{ready, Context, Poll, Waker};
13
14/// A wrapper around the write half of a socket that never fails due to the
15/// socket being full. Instead it puts an infinitely-growing buffer in front of
16/// the socket so it can always accept data.
17///
18/// When the buffer is non-empty, some async process must poll the
19/// [`OverflowWriter::poll_handle_overflow`] function to keep bytes moving from
20/// the buffer into the actual socket.
21pub struct OverflowWriter<S> {
22    inner: WriteHalf<S>,
23    overflow: VecDeque<u8>,
24}
25
26/// Status returned from [`OverflowWriter::write_all`].
27#[derive(Copy, Clone, PartialEq, Eq)]
28pub enum OverflowWriterStatus {
29    NoOverflow,
30    Overflow,
31}
32
33impl OverflowWriterStatus {
34    /// Whether the status indicates that the socket overflowed during the write.
35    pub fn overflowed(&self) -> bool {
36        matches!(self, OverflowWriterStatus::Overflow)
37    }
38}
39
40impl<S: AsyncWrite + Send + 'static> OverflowWriter<S> {
41    /// Make a new [`OverflowWriter`], wrapping the given socket write half.
42    pub fn new(inner: WriteHalf<S>) -> Self {
43        OverflowWriter { inner, overflow: VecDeque::new() }
44    }
45
46    /// Write all data into the writer. As much data as possible will drain into
47    /// the socket, but leftover data will be buffered and written later.
48    ///
49    /// If this write caused the socket to overflow, that is, if there was no
50    /// buffered data but this write made us *start* buffering data, we will
51    /// return [`OverflowWriterStatus::Overflow`].
52    pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
53        if !self.overflow.is_empty() {
54            self.overflow.extend(data.iter());
55            // We only report when the socket *starts* to overflow. Piling on
56            // more overflow data isn't signaled in the return value.
57            return Ok(OverflowWriterStatus::NoOverflow);
58        }
59
60        let mut cx = Context::from_waker(Waker::noop());
61        loop {
62            match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
63                Poll::Ready(res) => {
64                    let res = res?;
65
66                    if res == 0 {
67                        return Err(ErrorKind::WriteZero.into());
68                    }
69
70                    if res == data.len() {
71                        return Ok(OverflowWriterStatus::NoOverflow);
72                    }
73
74                    data = &data[res..];
75                }
76                Poll::Pending => {
77                    self.overflow.extend(data.iter());
78                    return Ok(OverflowWriterStatus::Overflow);
79                }
80            };
81        }
82    }
83}
84
85/// Future that drains any backlogged data from an OverflowWriter.
86pub struct OverflowHandleFut<S> {
87    writer: Weak<Mutex<OverflowWriter<S>>>,
88    guard_storage: Option<OwnedMutexLockFuture<OverflowWriter<S>>>,
89}
90
91impl<S> OverflowHandleFut<S> {
92    /// Create a new future to drain all backlogged data from the given writer.
93    pub fn new(writer: Weak<Mutex<OverflowWriter<S>>>) -> Self {
94        OverflowHandleFut { writer, guard_storage: None }
95    }
96}
97
98impl<S: AsyncWrite + Send + 'static> Future for OverflowHandleFut<S> {
99    type Output = Result<(), Error>;
100
101    fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
102        let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
103            lock_fut
104        } else {
105            let Some(payload_socket) = self.writer.upgrade() else {
106                return std::task::Poll::Ready(Ok(()));
107            };
108            self.guard_storage.insert(payload_socket.lock_owned())
109        };
110
111        let mut lock = ready!(lock_fut.poll_unpin(ctx));
112        let lock = &mut *lock;
113        self.guard_storage = None;
114
115        while !lock.overflow.is_empty() {
116            let (data_a, data_b) = lock.overflow.as_slices();
117            let res = ready!(std::pin::Pin::new(&mut lock.inner)
118                .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
119
120            if res == 0 {
121                return Poll::Ready(Err(ErrorKind::WriteZero.into()));
122            }
123
124            lock.overflow.drain(..res);
125        }
126
127        Poll::Ready(Ok(()))
128    }
129}