usb_vsock/connection/
pause_state.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::future::Future;
6use std::pin::pin;
7use std::sync::{Arc, Mutex};
8use std::task::{Poll, Waker};
9
10/// Maintains whether a connection is paused. A paused connection should not
11/// send any more data to the peer.
12pub struct PauseState(Mutex<PauseStateInner>);
13
14/// Mutex-protected interior of [`PauseState`]
15struct PauseStateInner {
16    paused: bool,
17    wakers: Vec<Waker>,
18}
19
20impl PauseState {
21    /// Create a new [`PauseState`]. The initial state is un-paused.
22    pub fn new() -> Arc<Self> {
23        Arc::new(PauseState(Mutex::new(PauseStateInner { paused: false, wakers: Vec::new() })))
24    }
25
26    /// Polls the given future, but pauses polling when we are in the paused
27    /// state.
28    pub async fn while_unpaused<T>(&self, f: impl Future<Output = T>) -> T {
29        let mut f = pin!(f);
30        futures::future::poll_fn(move |ctx| {
31            {
32                let mut this = self.0.lock().unwrap();
33
34                if this.wakers.iter().all(|x| !x.will_wake(ctx.waker())) {
35                    this.wakers.push(ctx.waker().clone());
36                }
37
38                if this.paused {
39                    return Poll::Pending;
40                }
41            }
42
43            f.as_mut().poll(ctx)
44        })
45        .await
46    }
47
48    /// Set the paused state.
49    pub fn set_paused(&self, paused: bool) {
50        let mut this = self.0.lock().unwrap();
51
52        this.paused = paused;
53        this.wakers.drain(..).for_each(Waker::wake);
54    }
55}
56
57#[cfg(test)]
58mod test {
59    use super::*;
60    use futures::{Stream, StreamExt};
61    use std::task::Context;
62
63    #[fuchsia::test]
64    async fn test_pause() {
65        let pause_state = PauseState::new();
66        let pause_state_clone = Arc::clone(&pause_state);
67        let stream = futures::stream::iter(1..)
68            .then(|x| pause_state_clone.while_unpaused(futures::future::ready(x)));
69        let mut stream = pin!(stream);
70        let mut ctx = Context::from_waker(&Waker::noop());
71
72        assert_eq!(Poll::Ready(Some(1)), stream.as_mut().poll_next(&mut ctx));
73        assert_eq!(Poll::Ready(Some(2)), stream.as_mut().poll_next(&mut ctx));
74        assert_eq!(Poll::Ready(Some(3)), stream.as_mut().poll_next(&mut ctx));
75        assert_eq!(Poll::Ready(Some(4)), stream.as_mut().poll_next(&mut ctx));
76        assert_eq!(Poll::Ready(Some(5)), stream.as_mut().poll_next(&mut ctx));
77
78        pause_state.set_paused(true);
79
80        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
81        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
82        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
83        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
84        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
85
86        pause_state.set_paused(true);
87
88        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
89        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
90        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
91        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
92        assert_eq!(Poll::Pending, stream.as_mut().poll_next(&mut ctx));
93
94        pause_state.set_paused(false);
95
96        assert_eq!(Poll::Ready(Some(6)), stream.as_mut().poll_next(&mut ctx));
97        assert_eq!(Poll::Ready(Some(7)), stream.as_mut().poll_next(&mut ctx));
98        assert_eq!(Poll::Ready(Some(8)), stream.as_mut().poll_next(&mut ctx));
99        assert_eq!(Poll::Ready(Some(9)), stream.as_mut().poll_next(&mut ctx));
100        assert_eq!(Poll::Ready(Some(10)), stream.as_mut().poll_next(&mut ctx));
101    }
102}