usb_vsock/connection/
overflow_writer.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_async::Socket;
6use futures::io::WriteHalf;
7use futures::lock::{Mutex, OwnedMutexLockFuture};
8use futures::{AsyncWrite, FutureExt};
9use std::collections::VecDeque;
10use std::future::Future;
11use std::io::{Error, ErrorKind, IoSlice};
12use std::sync::Weak;
13use std::task::{ready, Context, Poll, Waker};
14
15/// A wrapper around the write half of a socket that never fails due to the
16/// socket being full. Instead it puts an infinitely-growing buffer in front of
17/// the socket so it can always accept data.
18///
19/// When the buffer is non-empty, some async process must poll the
20/// [`OverflowWriter::poll_handle_overflow`] function to keep bytes moving from
21/// the buffer into the actual socket.
22pub struct OverflowWriter {
23    inner: WriteHalf<Socket>,
24    overflow: VecDeque<u8>,
25}
26
27/// Status returned from [`OverflowWriter::write_all`].
28#[derive(Copy, Clone, PartialEq, Eq)]
29pub enum OverflowWriterStatus {
30    NoOverflow,
31    Overflow,
32}
33
34impl OverflowWriterStatus {
35    /// Whether the status indicates that the socket overflowed during the write.
36    pub fn overflowed(&self) -> bool {
37        matches!(self, OverflowWriterStatus::Overflow)
38    }
39}
40
41impl OverflowWriter {
42    /// Make a new [`OverflowWriter`], wrapping the given socket write half.
43    pub fn new(inner: WriteHalf<Socket>) -> Self {
44        OverflowWriter { inner, overflow: VecDeque::new() }
45    }
46
47    /// Write all data into the writer. As much data as possible will drain into
48    /// the socket, but leftover data will be buffered and written later.
49    ///
50    /// If this write caused the socket to overflow, that is, if there was no
51    /// buffered data but this write made us *start* buffering data, we will
52    /// return [`OverflowWriterStatus::Overflow`].
53    pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
54        if !self.overflow.is_empty() {
55            self.overflow.extend(data.iter());
56            // We only report when the socket *starts* to overflow. Piling on
57            // more overflow data isn't signaled in the return value.
58            return Ok(OverflowWriterStatus::NoOverflow);
59        }
60
61        let mut cx = Context::from_waker(Waker::noop());
62        loop {
63            match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
64                Poll::Ready(res) => {
65                    let res = res?;
66
67                    if res == 0 {
68                        return Err(ErrorKind::WriteZero.into());
69                    }
70
71                    if res == data.len() {
72                        return Ok(OverflowWriterStatus::NoOverflow);
73                    }
74
75                    data = &data[res..];
76                }
77                Poll::Pending => {
78                    self.overflow.extend(data.iter());
79                    return Ok(OverflowWriterStatus::Overflow);
80                }
81            };
82        }
83    }
84}
85
86/// Future that drains any backlogged data from an OverflowWriter.
87pub struct OverflowHandleFut {
88    writer: Weak<Mutex<OverflowWriter>>,
89    guard_storage: Option<OwnedMutexLockFuture<OverflowWriter>>,
90}
91
92impl OverflowHandleFut {
93    /// Create a new future to drain all backlogged data from the given writer.
94    pub fn new(writer: Weak<Mutex<OverflowWriter>>) -> Self {
95        OverflowHandleFut { writer, guard_storage: None }
96    }
97}
98
99impl Future for OverflowHandleFut {
100    type Output = Result<(), Error>;
101
102    fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
103        let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
104            lock_fut
105        } else {
106            let Some(payload_socket) = self.writer.upgrade() else {
107                return std::task::Poll::Ready(Ok(()));
108            };
109            self.guard_storage.insert(payload_socket.lock_owned())
110        };
111
112        let mut lock = ready!(lock_fut.poll_unpin(ctx));
113        let lock = &mut *lock;
114        self.guard_storage = None;
115
116        while !lock.overflow.is_empty() {
117            let (data_a, data_b) = lock.overflow.as_slices();
118            let res = ready!(std::pin::Pin::new(&mut lock.inner)
119                .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
120
121            if res == 0 {
122                return Poll::Ready(Err(ErrorKind::WriteZero.into()));
123            }
124
125            lock.overflow.drain(..res);
126        }
127
128        Poll::Ready(Ok(()))
129    }
130}