usb_vsock/connection/
overflow_writer.rs1use fuchsia_async::Socket;
6use futures::io::WriteHalf;
7use futures::lock::{Mutex, OwnedMutexLockFuture};
8use futures::{AsyncWrite, FutureExt};
9use std::collections::VecDeque;
10use std::future::Future;
11use std::io::{Error, ErrorKind, IoSlice};
12use std::sync::Weak;
13use std::task::{ready, Context, Poll, Waker};
14
15pub struct OverflowWriter {
23 inner: WriteHalf<Socket>,
24 overflow: VecDeque<u8>,
25}
26
27#[derive(Copy, Clone, PartialEq, Eq)]
29pub enum OverflowWriterStatus {
30 NoOverflow,
31 Overflow,
32}
33
34impl OverflowWriterStatus {
35 pub fn overflowed(&self) -> bool {
37 matches!(self, OverflowWriterStatus::Overflow)
38 }
39}
40
41impl OverflowWriter {
42 pub fn new(inner: WriteHalf<Socket>) -> Self {
44 OverflowWriter { inner, overflow: VecDeque::new() }
45 }
46
47 pub fn write_all(&mut self, mut data: &[u8]) -> Result<OverflowWriterStatus, Error> {
54 if !self.overflow.is_empty() {
55 self.overflow.extend(data.iter());
56 return Ok(OverflowWriterStatus::NoOverflow);
59 }
60
61 let mut cx = Context::from_waker(Waker::noop());
62 loop {
63 match std::pin::Pin::new(&mut self.inner).poll_write(&mut cx, data) {
64 Poll::Ready(res) => {
65 let res = res?;
66
67 if res == 0 {
68 return Err(ErrorKind::WriteZero.into());
69 }
70
71 if res == data.len() {
72 return Ok(OverflowWriterStatus::NoOverflow);
73 }
74
75 data = &data[res..];
76 }
77 Poll::Pending => {
78 self.overflow.extend(data.iter());
79 return Ok(OverflowWriterStatus::Overflow);
80 }
81 };
82 }
83 }
84}
85
86pub struct OverflowHandleFut {
88 writer: Weak<Mutex<OverflowWriter>>,
89 guard_storage: Option<OwnedMutexLockFuture<OverflowWriter>>,
90}
91
92impl OverflowHandleFut {
93 pub fn new(writer: Weak<Mutex<OverflowWriter>>) -> Self {
95 OverflowHandleFut { writer, guard_storage: None }
96 }
97}
98
99impl Future for OverflowHandleFut {
100 type Output = Result<(), Error>;
101
102 fn poll(mut self: std::pin::Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
103 let lock_fut = if let Some(lock_fut) = &mut self.guard_storage {
104 lock_fut
105 } else {
106 let Some(payload_socket) = self.writer.upgrade() else {
107 return std::task::Poll::Ready(Ok(()));
108 };
109 self.guard_storage.insert(payload_socket.lock_owned())
110 };
111
112 let mut lock = ready!(lock_fut.poll_unpin(ctx));
113 let lock = &mut *lock;
114 self.guard_storage = None;
115
116 while !lock.overflow.is_empty() {
117 let (data_a, data_b) = lock.overflow.as_slices();
118 let res = ready!(std::pin::Pin::new(&mut lock.inner)
119 .poll_write_vectored(ctx, &[IoSlice::new(data_a), IoSlice::new(data_b)]))?;
120
121 if res == 0 {
122 return Poll::Ready(Err(ErrorKind::WriteZero.into()));
123 }
124
125 lock.overflow.drain(..res);
126 }
127
128 Poll::Ready(Ok(()))
129 }
130}