cm_util/
abortable_scope.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_sync::Mutex;
6use futures::channel::oneshot::{self, Canceled};
7use futures::future::{FutureExt, Shared};
8use futures::task::Poll;
9use futures::Future;
10use pin_project::pin_project;
11use std::fmt::Debug;
12use std::pin::Pin;
13use std::sync::Arc;
14
15/// [`AbortableScope`] transforms futures into abortable futures.
16///
17/// The wrapped futures can then be aborted via the [`AbortHandle`].
18/// When [`AbortHandle::abort`] is called, the wrapped future will
19/// complete immediately without making further progress.
20#[derive(Debug)]
21pub struct AbortableScope {
22    rx: Shared<oneshot::Receiver<()>>,
23}
24
25/// The error returned when a future wrapped with `abortable` is aborted.
26#[derive(Debug, Clone, Copy, Eq, PartialEq)]
27pub struct AbortError;
28
29/// [`AbortHandle`] allows aborting a future.
30#[derive(Debug, Clone)]
31pub struct AbortHandle {
32    tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
33}
34
35impl AbortHandle {
36    /// Interrupt the future as soon as possible.
37    pub fn abort(&self) {
38        let _ = self.tx.lock().take().map(|tx| tx.send(()));
39    }
40}
41
42impl AbortableScope {
43    /// Creates a scope and a handle to abort futures running in the scope.
44    pub fn new() -> (AbortableScope, AbortHandle) {
45        let (tx, rx) = oneshot::channel();
46        (AbortableScope { rx: rx.shared() }, AbortHandle { tx: Arc::new(Mutex::new(Some(tx))) })
47    }
48
49    /// Runs a future in this scope.
50    ///
51    /// Returns an abort error if [`AbortHandle::abort`] is called.
52    pub async fn run<T, Fut>(&self, future: Fut) -> Result<T, AbortError>
53    where
54        Fut: Future<Output = T> + Send,
55    {
56        AbortableFuture { future, abort_rx: self.rx.clone(), tx_dropped: false }.await
57    }
58}
59
60pub trait AbortFutureExt<T>: Future<Output = T> + Send {
61    /// Causes the future to complete with an [`InterruptError`] if the scope is aborted.
62    ///
63    /// Syntax sugar for `scope.run(future)`:
64    ///
65    /// ```
66    /// let (scope, handle) = AbortableScope::new();
67    /// handle.abort();
68    /// some_future.with(&scope).await;   // Result<T, InterruptError>
69    /// ```
70    fn with(self, scope: &AbortableScope) -> impl Future<Output = Result<T, AbortError>>;
71}
72
73impl<Fut, T> AbortFutureExt<T> for Fut
74where
75    Fut: Future<Output = T> + Send,
76{
77    fn with(self, scope: &AbortableScope) -> impl Future<Output = Result<T, AbortError>> {
78        scope.run(self)
79    }
80}
81
82#[pin_project]
83struct AbortableFuture<T, Fut: Future<Output = T> + Send, InterruptFut> {
84    #[pin]
85    future: Fut,
86    #[pin]
87    abort_rx: InterruptFut,
88    tx_dropped: bool,
89}
90
91impl<T, Fut: Future<Output = T> + Send, InterruptFut: Future<Output = Result<(), Canceled>>> Future
92    for AbortableFuture<T, Fut, InterruptFut>
93{
94    type Output = Result<T, AbortError>;
95
96    fn poll(
97        self: Pin<&mut Self>,
98        cx: &mut std::task::Context<'_>,
99    ) -> std::task::Poll<Self::Output> {
100        let this = self.project();
101
102        if !*this.tx_dropped {
103            match this.abort_rx.poll(cx) {
104                Poll::Ready(Ok(())) => return Poll::Ready(Err(AbortError)),
105                Poll::Ready(Err(Canceled)) => {
106                    *this.tx_dropped = true;
107                }
108                Poll::Pending => {}
109            }
110        }
111
112        match this.future.poll(cx) {
113            Poll::Ready(output) => Poll::Ready(Ok(output)),
114            Poll::Pending => Poll::Pending,
115        }
116    }
117}
118
119#[cfg(test)]
120pub mod tests {
121    use super::*;
122    use fuchsia_async as fasync;
123
124    #[test]
125    fn abort_a_future_pending() {
126        let mut ex = fasync::TestExecutor::new();
127        let forever = std::future::pending::<()>();
128        let (scope, handle) = AbortableScope::new();
129        let fut = scope.run(forever);
130        let mut fut = std::pin::pin!(fut);
131
132        assert!(ex.run_until_stalled(&mut fut).is_pending());
133        handle.abort();
134        assert_eq!(ex.run_until_stalled(&mut fut), Poll::Ready(Err(AbortError)));
135    }
136
137    #[test]
138    fn abort_a_future_ready() {
139        let mut ex = fasync::TestExecutor::new();
140        let now = std::future::ready(());
141        let (scope, handle) = AbortableScope::new();
142        let fut = scope.run(now);
143        let mut fut = std::pin::pin!(fut);
144
145        handle.abort();
146        assert_eq!(ex.run_until_stalled(&mut fut), Poll::Ready(Err(AbortError)));
147    }
148
149    #[test]
150    fn abort_many_futures() {
151        let mut ex = fasync::TestExecutor::new();
152        let (scope, handle) = AbortableScope::new();
153
154        let now = std::future::ready(());
155        let fut1 = scope.run(now);
156        let mut fut1 = std::pin::pin!(fut1);
157
158        let now = std::future::ready(());
159        let fut2 = scope.run(now);
160        let mut fut2 = std::pin::pin!(fut2);
161
162        handle.abort();
163
164        assert_eq!(ex.run_until_stalled(&mut fut1), Poll::Ready(Err(AbortError)));
165        assert_eq!(ex.run_until_stalled(&mut fut2), Poll::Ready(Err(AbortError)));
166
167        let now = std::future::ready(());
168        let fut3 = scope.run(now);
169        let mut fut3 = std::pin::pin!(fut3);
170        assert_eq!(ex.run_until_stalled(&mut fut3), Poll::Ready(Err(AbortError)));
171    }
172
173    #[test]
174    fn abort_a_future_handle_dropped() {
175        let mut ex = fasync::TestExecutor::new();
176        let forever = std::future::pending::<()>();
177        let (scope, handle) = AbortableScope::new();
178        let fut = scope.run(forever);
179        let mut fut = std::pin::pin!(fut);
180
181        assert!(ex.run_until_stalled(&mut fut).is_pending());
182        drop(handle);
183        assert!(ex.run_until_stalled(&mut fut).is_pending());
184        assert!(ex.run_until_stalled(&mut fut).is_pending());
185    }
186
187    #[test]
188    fn not_aborting_a_future() {
189        let mut ex = fasync::TestExecutor::new();
190        let now = std::future::ready(());
191        let (scope, _handle) = AbortableScope::new();
192        let fut = scope.run(now);
193        let mut fut = std::pin::pin!(fut);
194
195        assert_eq!(ex.run_until_stalled(&mut fut), Poll::Ready(Ok(())));
196    }
197}