fuchsia_async/runtime/
task_group.rs

1// Copyright 2023 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::Future;
6
7use super::Scope;
8
9/// Allows the user to spawn multiple Tasks and await them as a unit.
10///
11/// Tasks can be added to this group using [`TaskGroup::add`].
12/// All pending tasks in the group can be awaited using [`TaskGroup::join`].
13///
14/// New code should prefer to use [`Scope`] instead.
15pub struct TaskGroup {
16    scope: Scope,
17}
18
19impl Default for TaskGroup {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl TaskGroup {
26    /// Creates a new TaskGroup.
27    ///
28    /// The TaskGroup can be used to await an arbitrary number of Tasks and may
29    /// consume an arbitrary amount of memory.
30    pub fn new() -> Self {
31        #[cfg(target_os = "fuchsia")]
32        return Self { scope: Scope::global().new_child() };
33        #[cfg(not(target_os = "fuchsia"))]
34        return Self { scope: Scope::new() };
35    }
36
37    /// Spawns a new task in this TaskGroup.
38    ///
39    /// To add a future that is not [`Send`] to this TaskGroup, use [`TaskGroup::local`].
40    ///
41    /// # Panics
42    ///
43    /// `spawn` may panic if not called in the context of an executor (e.g.
44    /// within a call to `run` or `run_singlethreaded`).
45    pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
46        self.scope.spawn(future);
47    }
48
49    /// Spawns a new task in this TaskGroup.
50    ///
51    /// # Panics
52    ///
53    /// `spawn` may panic if not called in the context of a single threaded executor
54    /// (e.g. within a call to `run_singlethreaded`).
55    pub fn local(&mut self, future: impl Future<Output = ()> + 'static) {
56        self.scope.spawn_local(future);
57    }
58
59    /// Waits for all Tasks in this TaskGroup to finish.
60    ///
61    /// Call this only after all Tasks have been added.
62    pub async fn join(self) {
63        self.scope.on_no_tasks().await;
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use crate::{SendExecutor, Task};
71    use futures::channel::mpsc;
72    use futures::StreamExt;
73    use std::sync::atomic::{AtomicU64, Ordering};
74    use std::sync::Arc;
75
76    // Notifies a channel when dropped, signifying completion of some operation.
77    #[derive(Clone)]
78    struct DoneSignaler {
79        done: mpsc::UnboundedSender<()>,
80    }
81    impl Drop for DoneSignaler {
82        fn drop(&mut self) {
83            self.done.unbounded_send(()).unwrap();
84            self.done.disconnect();
85        }
86    }
87
88    // Waits for a group of `impl Drop` to signal completion.
89    // Create as many `impl Drop` objects as needed with `WaitGroup::add_one` and
90    // call `wait` to wait for all of them to be dropped.
91    struct WaitGroup {
92        tx: mpsc::UnboundedSender<()>,
93        rx: mpsc::UnboundedReceiver<()>,
94    }
95
96    impl WaitGroup {
97        fn new() -> Self {
98            let (tx, rx) = mpsc::unbounded();
99            Self { tx, rx }
100        }
101
102        fn add_one(&self) -> impl Drop {
103            DoneSignaler { done: self.tx.clone() }
104        }
105
106        async fn wait(self) {
107            drop(self.tx);
108            self.rx.collect::<()>().await;
109        }
110    }
111
112    #[test]
113    fn test_task_group_join_waits_for_tasks() {
114        let task_count = 20;
115
116        SendExecutor::new(task_count).run(async move {
117            let mut task_group = TaskGroup::new();
118            let value = Arc::new(AtomicU64::new(0));
119
120            for _ in 0..task_count {
121                let value = value.clone();
122                task_group.spawn(async move {
123                    value.fetch_add(1, Ordering::Relaxed);
124                });
125            }
126
127            task_group.join().await;
128            assert_eq!(value.load(Ordering::Relaxed), task_count as u64);
129        });
130    }
131
132    #[test]
133    fn test_task_group_empty_join_completes() {
134        SendExecutor::new(1).run(async move {
135            TaskGroup::new().join().await;
136        });
137    }
138
139    #[test]
140    fn test_task_group_added_tasks_are_cancelled_on_drop() {
141        let wait_group = WaitGroup::new();
142        let task_count = 10;
143
144        SendExecutor::new(task_count).run(async move {
145            let mut task_group = TaskGroup::new();
146            for _ in 0..task_count {
147                let done_signaler = wait_group.add_one();
148
149                // Never completes but drops `done_signaler` when cancelled.
150                task_group.spawn(async move {
151                    // Take ownership of done_signaler.
152                    let _done_signaler = done_signaler;
153                    std::future::pending::<()>().await;
154                });
155            }
156
157            drop(task_group);
158            wait_group.wait().await;
159            // If we get here, all tasks were cancelled.
160        });
161    }
162
163    #[test]
164    fn test_task_group_spawn() {
165        let task_count = 3;
166        SendExecutor::new(task_count).run(async move {
167            let mut task_group = TaskGroup::new();
168
169            // We can spawn tasks from any Future<()> implementation, including...
170
171            // ... naked futures.
172            task_group.spawn(std::future::ready(()));
173
174            // ... futures returned from async blocks.
175            task_group.spawn(async move {
176                std::future::ready(()).await;
177            });
178
179            // ... and other tasks.
180            task_group.spawn(Task::spawn(std::future::ready(())));
181
182            task_group.join().await;
183        });
184    }
185}