usb_vsock/connection/
overflow_writer.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::io::WriteHalf;
6use futures::lock::Mutex;
7use futures::{AsyncWrite, FutureExt};
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{Error, ErrorKind, IoSlice};
11use std::sync::Weak;
12use std::task::{ready, Context, Poll, Waker};
13
14use crate::connection::{ConnectionStateWriter, ConnectionStateWriterFut};
15
16/// A wrapper around the write half of a socket that never fails due to the
17/// socket being full. Instead it puts an infinitely-growing buffer in front of
18/// the socket so it can always accept data.
19///
20/// When the buffer is non-empty, some async process must poll the
21/// [`OverflowWriter::poll_handle_overflow`] function to keep bytes moving from
22/// the buffer into the actual socket.
23pub struct OverflowWriter<S> {
24    inner: WriteHalf<S>,
25    overflow: VecDeque<u8>,
26}
27
28/// Status returned from [`OverflowWriter::write_all`].
29#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum OverflowWriterStatus {
31    NoOverflow,
32    Overflow,
33}
34
35impl OverflowWriterStatus {
36    /// Whether the status indicates that the socket overflowed during the write.
37    pub fn overflowed(&self) -> bool {
38        matches!(self, OverflowWriterStatus::Overflow)
39    }
40}
41
42impl<S: AsyncWrite + Send + 'static> OverflowWriter<S> {
43    /// Make a new [`OverflowWriter`], wrapping the given socket write half.
44    pub fn new(inner: WriteHalf<S>) -> Self {
45        OverflowWriter { inner, overflow: VecDeque::new() }
46    }
47
48    /// Write all data into the writer. As much data as possible will drain into
49    /// the socket, but leftover data will be buffered and written later.
50    ///
51    /// If this write caused the socket to overflow, that is, if there was no
52    /// buffered data but this write made us *start* buffering data, we will
53    /// return [`OverflowWriterStatus::Overflow`].
54    pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
55        if !self.overflow.is_empty() {
56            self.overflow.extend(data.iter());
57            // We only report when the socket *starts* to overflow. Piling on
58            // more overflow data isn't signaled in the return value.
59            return Ok(OverflowWriterStatus::NoOverflow);
60        }
61
62        let mut cx = Context::from_waker(Waker::noop());
63        loop {
64            match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
65                Poll::Ready(res) => {
66                    let res = res?;
67
68                    if res == 0 {
69                        return Err(ErrorKind::WriteZero.into());
70                    }
71
72                    if res == data.len() {
73                        return Ok(OverflowWriterStatus::NoOverflow);
74                    }
75
76                    data = &data[res..];
77                }
78                Poll::Pending => {
79                    self.overflow.extend(data.iter());
80                    return Ok(OverflowWriterStatus::Overflow);
81                }
82            };
83        }
84    }
85}
86
87/// Future that drains any backlogged data from an OverflowWriter.
88pub struct OverflowHandleFut<S> {
89    writer: Weak<Mutex<ConnectionStateWriter<S>>>,
90    guard_storage: Option<ConnectionStateWriterFut<S>>,
91}
92
93impl<S> OverflowHandleFut<S> {
94    /// Create a new future to drain all backlogged data from the given writer.
95    pub fn new(writer: Weak<Mutex<ConnectionStateWriter<S>>>) -> Self {
96        OverflowHandleFut { writer, guard_storage: None }
97    }
98}
99
100impl<S: AsyncWrite + Send + 'static> Future for OverflowHandleFut<S> {
101    type Output = Result<(), Error>;
102
103    fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
104        let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
105            lock_fut
106        } else {
107            let Some(payload_socket) = self.writer.upgrade() else {
108                return std::task::Poll::Ready(Ok(()));
109            };
110            self.guard_storage.insert(ConnectionStateWriter::wait_available(payload_socket))
111        };
112
113        let mut lock = ready!(lock_fut.poll_unpin(ctx));
114        let lock = &mut *lock;
115        self.guard_storage = None;
116        let ConnectionStateWriter::Available(lock) = &mut *lock else {
117            unreachable!("ConnectionStateWriterFut didn't wait until writer was available!");
118        };
119
120        while !lock.overflow.is_empty() {
121            let (data_a, data_b) = lock.overflow.as_slices();
122            let res = ready!(std::pin::Pin::new(&mut lock.inner)
123                .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
124
125            if res == 0 {
126                return Poll::Ready(Err(ErrorKind::WriteZero.into()));
127            }
128
129            lock.overflow.drain(..res);
130        }
131
132        Poll::Ready(Ok(()))
133    }
134}