errgroup/errgroup.rs
1/// Async task group with error propagation and optional concurrency limiting.
2///
3/// A [`Group`] spawns multiple concurrent tasks. If any task returns an error,
4/// the group remembers the first error. When all tasks finish, [`wait`] returns
5/// that error (or `Ok(())` if all succeeded).
6///
7/// Optionally, [`set_limit`] restricts how many tasks run concurrently. Tasks
8/// beyond the limit block inside the spawned task until a slot opens.
9///
10/// # Panics
11///
12/// Tasks **must not panic**. If a task panics, the panic is silently caught and
13/// ignored (the overall result is unaffected, but the panic is lost). This
14/// matches [`JoinSet`]'s behaviour.
15///
16/// # Example
17///
18/// ```rust
19/// use errgroup::Group;
20///
21/// # async fn example() -> Result<(), String> {
22/// let group = Group::new();
23///
24/// for url in ["http://example.com/a", "http://example.com/b"] {
25/// group.spawn(move || async move {
26/// // simulate a fetch
27/// if url.contains("b") {
28/// Err("failed".to_string())
29/// } else {
30/// Ok(())
31/// }
32/// });
33/// }
34///
35/// let result = group.wait().await;
36/// assert!(result.is_err());
37/// # Ok(())
38/// # }
39/// ```
40///
41/// [`wait`]: Group::wait
42/// [`set_limit`]: Group::set_limit
43use std::{
44 future::Future,
45 sync::{Arc, Mutex},
46};
47
48use tokio::{sync::Semaphore, task::JoinSet};
49
50/// A group of concurrent tasks that propagates the first error.
51///
52/// Create a new group with [`Group::new`], spawn tasks with [`spawn`] or
53/// [`try_spawn`], and collect results with [`wait`].
54///
55/// [`spawn`]: Group::spawn
56/// [`try_spawn`]: Group::try_spawn
57/// [`wait`]: Group::wait
58pub struct Group<E> {
59 tasks: Mutex<JoinSet<Result<(), E>>>,
60 sem: Option<Arc<Semaphore>>,
61}
62
63// ---------------------------------------------------------------------------
64// Construction
65// ---------------------------------------------------------------------------
66
67impl<E> Group<E>
68where
69 E: Send + 'static,
70{
71 /// Creates a new `Group` with no concurrency limit.
72 pub fn new() -> Self {
73 Group {
74 tasks: Mutex::new(JoinSet::new()),
75 sem: None,
76 }
77 }
78
79 /// Limits the number of tasks that may run concurrently to at most `n`.
80 ///
81 /// A limit of `0` prevents any new task from running. A negative value
82 /// would indicate no limit, but since this is `usize`, no limit is the
83 /// default (created by [`new`][Group::new]).
84 ///
85 /// # Panics
86 ///
87 /// Must not be called while any tasks are active.
88 pub fn set_limit(&mut self, n: usize) {
89 self.sem = Some(Arc::new(Semaphore::new(n)));
90 }
91}
92
93impl<E> Default for Group<E>
94where
95 E: Send + 'static,
96{
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102// ---------------------------------------------------------------------------
103// Spawning
104// ---------------------------------------------------------------------------
105
106impl<E> Group<E>
107where
108 E: Send + 'static,
109{
110 /// Spawns a new task.
111 ///
112 /// If a concurrency limit was set with [`set_limit`] and the maximum
113 /// number of tasks is already running, the newly spawned task will block
114 /// **inside** itself (not the caller) until a slot opens up.
115 ///
116 /// This method returns immediately and never blocks the caller.
117 ///
118 /// # Panics
119 ///
120 /// The task **must not panic**. See the [crate-level docs][crate#panics].
121 ///
122 /// [`set_limit`]: Group::set_limit
123 pub fn spawn<F, Fut>(&self, f: F)
124 where
125 F: FnOnce() -> Fut + Send + 'static,
126 Fut: Future<Output = Result<(), E>> + Send + 'static,
127 {
128 let sem = self.sem.clone();
129 self.tasks.lock().unwrap().spawn(async move {
130 let _permit = match sem {
131 Some(ref sem) => Some(sem.clone().acquire_owned().await.expect("semaphore closed")),
132 None => None,
133 };
134 f().await
135 });
136 }
137
138 /// Tries to spawn a new task without waiting for a concurrency slot.
139 ///
140 /// If a limit is set and the maximum number of tasks is running, returns
141 /// `false` without spawning the task. Otherwise spawns the task and
142 /// returns `true`.
143 ///
144 /// If no limit was set, this is equivalent to [`spawn`][Group::spawn]
145 /// and always returns `true`.
146 pub fn try_spawn<F, Fut>(&self, f: F) -> bool
147 where
148 F: FnOnce() -> Fut + Send + 'static,
149 Fut: Future<Output = Result<(), E>> + Send + 'static,
150 {
151 let permit = match self.sem.as_ref() {
152 Some(sem) => match sem.clone().try_acquire_owned() {
153 Ok(p) => Some(p),
154 Err(_) => return false,
155 },
156 None => None,
157 };
158
159 self.tasks.lock().unwrap().spawn(async move {
160 let _permit = permit;
161 f().await
162 });
163 true
164 }
165}
166
167// ---------------------------------------------------------------------------
168// Wait
169// ---------------------------------------------------------------------------
170
171impl<E> Group<E>
172where
173 E: Send + 'static,
174{
175 /// Waits for all spawned tasks to complete.
176 ///
177 /// Returns the first error encountered, or `Ok(())` if all tasks succeeded.
178 ///
179 /// Consumes the group — no more tasks can be spawned after calling this.
180 pub async fn wait(self) -> Result<(), E> {
181 let mut tasks = self.tasks.into_inner().unwrap();
182 let mut first_error = None;
183
184 while let Some(result) = tasks.join_next().await {
185 match result {
186 Ok(Ok(())) => { /* success – nothing to do */ }
187 Ok(Err(e)) => {
188 if first_error.is_none() {
189 first_error = Some(e);
190 }
191 }
192 Err(_) => {
193 // Task panicked – silently ignored (documented).
194 }
195 }
196 }
197
198 match first_error {
199 Some(e) => Err(e),
200 None => Ok(()),
201 }
202 }
203}
204
205// ---------------------------------------------------------------------------
206// Tests
207// ---------------------------------------------------------------------------
208
209#[cfg(test)]
210mod test {
211 use std::sync::atomic::{AtomicUsize, Ordering};
212
213 use super::*;
214
215 #[tokio::test]
216 async fn no_tasks_returns_ok() {
217 let group = Group::<&str>::new();
218 assert_eq!(group.wait().await, Ok(()));
219 }
220
221 #[tokio::test]
222 async fn single_task_ok() {
223 let group = Group::<&str>::new();
224 group.spawn(|| async { Ok(()) });
225 assert_eq!(group.wait().await, Ok(()));
226 }
227
228 #[tokio::test]
229 async fn single_task_error() {
230 let group = Group::new();
231 group.spawn(|| async { Err("oops") });
232 assert_eq!(group.wait().await, Err("oops"));
233 }
234
235 #[tokio::test]
236 async fn multiple_tasks_all_ok() {
237 let group = Group::<&str>::new();
238 for _ in 0..10 {
239 group.spawn(|| async { Ok(()) });
240 }
241 assert_eq!(group.wait().await, Ok(()));
242 }
243
244 #[tokio::test]
245 async fn first_error_wins() {
246 let group = Group::<&str>::new();
247
248 let barrier = Arc::new(tokio::sync::Barrier::new(10));
249
250 for i in 0..10 {
251 let b = barrier.clone();
252 group.spawn(move || async move {
253 b.wait().await;
254 if i == 5 {
255 Err("task 5 failed")
256 } else {
257 // Sleep tasks long enough so the error doesn't race
258 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
259 Ok(())
260 }
261 });
262 }
263
264 let result = group.wait().await;
265 assert_eq!(result, Err("task 5 failed"));
266 }
267
268 #[tokio::test]
269 async fn concurrency_limit_is_enforced() {
270 let mut group = Group::<String>::new();
271 group.set_limit(3);
272
273 let running = Arc::new(AtomicUsize::new(0));
274 let max_running = Arc::new(AtomicUsize::new(0));
275
276 for _ in 0..20 {
277 let running = running.clone();
278 let max_running = max_running.clone();
279 group.spawn(move || async move {
280 let prev = running.fetch_add(1, Ordering::SeqCst);
281 max_running.fetch_max(prev + 1, Ordering::SeqCst);
282
283 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
284
285 running.fetch_sub(1, Ordering::SeqCst);
286 Ok(())
287 });
288 }
289
290 group.wait().await.unwrap();
291 assert!(
292 max_running.load(Ordering::SeqCst) <= 3,
293 "max was {}",
294 max_running.load(Ordering::SeqCst)
295 );
296 }
297
298 #[tokio::test]
299 async fn try_spawn_rejected_at_limit() {
300 let mut group = Group::<&str>::new();
301 group.set_limit(2);
302
303 // First two should succeed
304 assert!(group.try_spawn(|| async {
305 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
306 Ok(())
307 }));
308
309 assert!(group.try_spawn(|| async {
310 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
311 Ok(())
312 }));
313
314 // Third should be rejected
315 assert!(!group.try_spawn(|| async { Ok(()) }));
316
317 group.wait().await.unwrap();
318 }
319
320 #[tokio::test]
321 async fn try_spawn_without_limit_always_succeeds() {
322 let group = Group::<&str>::new();
323 assert!(group.try_spawn(|| async { Ok(()) }));
324 assert!(group.try_spawn(|| async { Ok(()) }));
325 group.wait().await.unwrap();
326 }
327
328 #[tokio::test]
329 async fn high_concurrency_stress() {
330 let mut group = Group::<String>::new();
331 group.set_limit(10);
332
333 let n = 500;
334 let counter = Arc::new(AtomicUsize::new(0));
335
336 for _i in 0..n {
337 let c = counter.clone();
338 group.spawn(move || async move {
339 tokio::time::sleep(std::time::Duration::from_micros(10)).await;
340 c.fetch_add(1, Ordering::SeqCst);
341 Ok::<(), String>(())
342 });
343 }
344
345 group.wait().await.unwrap();
346 assert_eq!(counter.load(Ordering::SeqCst), n);
347 }
348
349 /// When a task panics, the group should not propagate the panic
350 /// and should still be able to wait for remaining tasks.
351 #[tokio::test]
352 async fn panicking_task_is_ignored() {
353 let group = Group::new();
354
355 group.spawn(|| async {
356 panic!("this should be caught");
357 });
358
359 group.spawn(|| async {
360 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361 Ok::<(), &str>(())
362 });
363
364 // Should not propagate the panic.
365 let result = group.wait().await;
366 assert_eq!(result, Ok(()));
367 }
368}