usb_vsock/connection/
overflow_writer.rs1use futures::io::WriteHalf;
6use futures::lock::{Mutex, OwnedMutexLockFuture};
7use futures::{AsyncWrite, FutureExt};
8use std::collections::VecDeque;
9use std::future::Future;
10use std::io::{Error, ErrorKind, IoSlice};
11use std::sync::Weak;
12use std::task::{ready, Context, Poll, Waker};
13
14pub struct OverflowWriter<S> {
22 inner: WriteHalf<S>,
23 overflow: VecDeque<u8>,
24}
25
26#[derive(Copy, Clone, PartialEq, Eq)]
28pub enum OverflowWriterStatus {
29 NoOverflow,
30 Overflow,
31}
32
33impl OverflowWriterStatus {
34 pub fn overflowed(&self) -> bool {
36 matches!(self, OverflowWriterStatus::Overflow)
37 }
38}
39
40impl<S: AsyncWrite + Send + 'static> OverflowWriter<S> {
41 pub fn new(inner: WriteHalf<S>) -> Self {
43 OverflowWriter { inner, overflow: VecDeque::new() }
44 }
45
46 pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
53 if !self.overflow.is_empty() {
54 self.overflow.extend(data.iter());
55 return Ok(OverflowWriterStatus::NoOverflow);
58 }
59
60 let mut cx = Context::from_waker(Waker::noop());
61 loop {
62 match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
63 Poll::Ready(res) => {
64 let res = res?;
65
66 if res == 0 {
67 return Err(ErrorKind::WriteZero.into());
68 }
69
70 if res == data.len() {
71 return Ok(OverflowWriterStatus::NoOverflow);
72 }
73
74 data = &data[res..];
75 }
76 Poll::Pending => {
77 self.overflow.extend(data.iter());
78 return Ok(OverflowWriterStatus::Overflow);
79 }
80 };
81 }
82 }
83}
84
85pub struct OverflowHandleFut<S> {
87 writer: Weak<Mutex<OverflowWriter<S>>>,
88 guard_storage: Option<OwnedMutexLockFuture<OverflowWriter<S>>>,
89}
90
91impl<S> OverflowHandleFut<S> {
92 pub fn new(writer: Weak<Mutex<OverflowWriter<S>>>) -> Self {
94 OverflowHandleFut { writer, guard_storage: None }
95 }
96}
97
98impl<S: AsyncWrite + Send + 'static> Future for OverflowHandleFut<S> {
99 type Output = Result<(), Error>;
100
101 fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
102 let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
103 lock_fut
104 } else {
105 let Some(payload_socket) = self.writer.upgrade() else {
106 return std::task::Poll::Ready(Ok(()));
107 };
108 self.guard_storage.insert(payload_socket.lock_owned())
109 };
110
111 let mut lock = ready!(lock_fut.poll_unpin(ctx));
112 let lock = &mut *lock;
113 self.guard_storage = None;
114
115 while !lock.overflow.is_empty() {
116 let (data_a, data_b) = lock.overflow.as_slices();
117 let res = ready!(std::pin::Pin::new(&mut lock.inner)
118 .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
119
120 if res == 0 {
121 return Poll::Ready(Err(ErrorKind::WriteZero.into()));
122 }
123
124 lock.overflow.drain(..res);
125 }
126
127 Poll::Ready(Ok(()))
128 }
129}