fuchsia_async/runtime/
task_group.rs1use crate::Task;
6
7use futures::channel::mpsc;
8use futures::Future;
9
10use super::Scope;
11
12#[derive(Debug, thiserror::Error)]
14enum Error {
15 #[error("Failed to add Task: {0}")]
17 GroupDropped(#[from] mpsc::TrySendError<Task<()>>),
18}
19
20pub struct TaskGroup {
27 scope: Scope,
28}
29
30impl Default for TaskGroup {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl TaskGroup {
37 pub fn new() -> Self {
42 #[cfg(target_os = "fuchsia")]
43 return Self { scope: Scope::global().new_child() };
44 #[cfg(not(target_os = "fuchsia"))]
45 return Self { scope: Scope::new() };
46 }
47
48 pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
57 self.scope.spawn(future);
58 }
59
60 pub fn local(&mut self, future: impl Future<Output = ()> + 'static) {
67 self.scope.spawn_local(future);
68 }
69
70 pub async fn join(self) {
74 self.scope.on_no_tasks().await;
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::SendExecutor;
82 use futures::StreamExt;
83 use std::sync::atomic::{AtomicU64, Ordering};
84 use std::sync::Arc;
85
86 #[derive(Clone)]
88 struct DoneSignaler {
89 done: mpsc::UnboundedSender<()>,
90 }
91 impl Drop for DoneSignaler {
92 fn drop(&mut self) {
93 self.done.unbounded_send(()).unwrap();
94 self.done.disconnect();
95 }
96 }
97
98 struct WaitGroup {
102 tx: mpsc::UnboundedSender<()>,
103 rx: mpsc::UnboundedReceiver<()>,
104 }
105
106 impl WaitGroup {
107 fn new() -> Self {
108 let (tx, rx) = mpsc::unbounded();
109 Self { tx, rx }
110 }
111
112 fn add_one(&self) -> impl Drop {
113 DoneSignaler { done: self.tx.clone() }
114 }
115
116 async fn wait(self) {
117 drop(self.tx);
118 self.rx.collect::<()>().await;
119 }
120 }
121
122 #[test]
123 fn test_task_group_join_waits_for_tasks() {
124 let task_count = 20;
125
126 SendExecutor::new(task_count).run(async move {
127 let mut task_group = TaskGroup::new();
128 let value = Arc::new(AtomicU64::new(0));
129
130 for _ in 0..task_count {
131 let value = value.clone();
132 task_group.spawn(async move {
133 value.fetch_add(1, Ordering::Relaxed);
134 });
135 }
136
137 task_group.join().await;
138 assert_eq!(value.load(Ordering::Relaxed), task_count as u64);
139 });
140 }
141
142 #[test]
143 fn test_task_group_empty_join_completes() {
144 SendExecutor::new(1).run(async move {
145 TaskGroup::new().join().await;
146 });
147 }
148
149 #[test]
150 fn test_task_group_added_tasks_are_cancelled_on_drop() {
151 let wait_group = WaitGroup::new();
152 let task_count = 10;
153
154 SendExecutor::new(task_count).run(async move {
155 let mut task_group = TaskGroup::new();
156 for _ in 0..task_count {
157 let done_signaler = wait_group.add_one();
158
159 task_group.spawn(async move {
161 let _done_signaler = done_signaler;
163 std::future::pending::<()>().await;
164 });
165 }
166
167 drop(task_group);
168 wait_group.wait().await;
169 });
171 }
172
173 #[test]
174 fn test_task_group_spawn() {
175 let task_count = 3;
176 SendExecutor::new(task_count).run(async move {
177 let mut task_group = TaskGroup::new();
178
179 task_group.spawn(std::future::ready(()));
183
184 task_group.spawn(async move {
186 std::future::ready(()).await;
187 });
188
189 task_group.spawn(Task::spawn(std::future::ready(())));
191
192 task_group.join().await;
193 });
194 }
195}