usb_vsock/connection/
overflow_writer.rs1use futures::io::WriteHalf;
6use futures::lock::Mutex;
7use futures::{AsyncWrite, FutureExt};
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{Error, ErrorKind, IoSlice};
11use std::sync::Weak;
12use std::task::{ready, Context, Poll, Waker};
13
14use crate::connection::{ConnectionStateWriter, ConnectionStateWriterFut};
15
16pub struct OverflowWriter<S> {
24 inner: WriteHalf<S>,
25 overflow: VecDeque<u8>,
26}
27
28#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum OverflowWriterStatus {
31 NoOverflow,
32 Overflow,
33}
34
35impl OverflowWriterStatus {
36 pub fn overflowed(&self) -> bool {
38 matches!(self, OverflowWriterStatus::Overflow)
39 }
40}
41
42impl<S: AsyncWrite + Send + 'static> OverflowWriter<S> {
43 pub fn new(inner: WriteHalf<S>) -> Self {
45 OverflowWriter { inner, overflow: VecDeque::new() }
46 }
47
48 pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
55 if !self.overflow.is_empty() {
56 self.overflow.extend(data.iter());
57 return Ok(OverflowWriterStatus::NoOverflow);
60 }
61
62 let mut cx = Context::from_waker(Waker::noop());
63 loop {
64 match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
65 Poll::Ready(res) => {
66 let res = res?;
67
68 if res == 0 {
69 return Err(ErrorKind::WriteZero.into());
70 }
71
72 if res == data.len() {
73 return Ok(OverflowWriterStatus::NoOverflow);
74 }
75
76 data = &data[res..];
77 }
78 Poll::Pending => {
79 self.overflow.extend(data.iter());
80 return Ok(OverflowWriterStatus::Overflow);
81 }
82 };
83 }
84 }
85}
86
87pub struct OverflowHandleFut<S> {
89 writer: Weak<Mutex<ConnectionStateWriter<S>>>,
90 guard_storage: Option<ConnectionStateWriterFut<S>>,
91}
92
93impl<S> OverflowHandleFut<S> {
94 pub fn new(writer: Weak<Mutex<ConnectionStateWriter<S>>>) -> Self {
96 OverflowHandleFut { writer, guard_storage: None }
97 }
98}
99
100impl<S: AsyncWrite + Send + 'static> Future for OverflowHandleFut<S> {
101 type Output = Result<(), Error>;
102
103 fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
104 let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
105 lock_fut
106 } else {
107 let Some(payload_socket) = self.writer.upgrade() else {
108 return std::task::Poll::Ready(Ok(()));
109 };
110 self.guard_storage.insert(ConnectionStateWriter::wait_available(payload_socket))
111 };
112
113 let mut lock = ready!(lock_fut.poll_unpin(ctx));
114 let lock = &mut *lock;
115 self.guard_storage = None;
116 let ConnectionStateWriter::Available(lock) = &mut *lock else {
117 unreachable!("ConnectionStateWriterFut didn't wait until writer was available!");
118 };
119
120 while !lock.overflow.is_empty() {
121 let (data_a, data_b) = lock.overflow.as_slices();
122 let res = ready!(std::pin::Pin::new(&mut lock.inner)
123 .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
124
125 if res == 0 {
126 return Poll::Ready(Err(ErrorKind::WriteZero.into()));
127 }
128
129 lock.overflow.drain(..res);
130 }
131
132 Poll::Ready(Ok(()))
133 }
134}