Skip to main content

errgroup/
errgroup.rs

1/// Async task group with error propagation and optional concurrency limiting.
2///
3/// A [`Group`] spawns multiple concurrent tasks. If any task returns an error,
4/// the group remembers the first error. When all tasks finish, [`wait`] returns
5/// that error (or `Ok(())` if all succeeded).
6///
7/// Optionally, [`set_limit`] restricts how many tasks run concurrently. Tasks
8/// beyond the limit block inside the spawned task until a slot opens.
9///
10/// # Panics
11///
12/// Tasks **must not panic**. If a task panics, the panic is silently caught and
13/// ignored (the overall result is unaffected, but the panic is lost). This
14/// matches [`JoinSet`]'s behaviour.
15///
16/// # Example
17///
18/// ```rust
19/// use errgroup::Group;
20///
21/// # async fn example() -> Result<(), String> {
22/// let group = Group::new();
23///
24/// for url in ["http://example.com/a", "http://example.com/b"] {
25///     group.spawn(move || async move {
26///         // simulate a fetch
27///         if url.contains("b") {
28///             Err("failed".to_string())
29///         } else {
30///             Ok(())
31///         }
32///     });
33/// }
34///
35/// let result = group.wait().await;
36/// assert!(result.is_err());
37/// # Ok(())
38/// # }
39/// ```
40///
41/// [`wait`]: Group::wait
42/// [`set_limit`]: Group::set_limit
43use std::{
44    future::Future,
45    sync::{Arc, Mutex},
46};
47
48use tokio::{sync::Semaphore, task::JoinSet};
49
50/// A group of concurrent tasks that propagates the first error.
51///
52/// Create a new group with [`Group::new`], spawn tasks with [`spawn`] or
53/// [`try_spawn`], and collect results with [`wait`].
54///
55/// [`spawn`]: Group::spawn
56/// [`try_spawn`]: Group::try_spawn
57/// [`wait`]: Group::wait
58pub struct Group<E> {
59    tasks: Mutex<JoinSet<Result<(), E>>>,
60    sem: Option<Arc<Semaphore>>,
61}
62
63// ---------------------------------------------------------------------------
64// Construction
65// ---------------------------------------------------------------------------
66
67impl<E> Group<E>
68where
69    E: Send + 'static,
70{
71    /// Creates a new `Group` with no concurrency limit.
72    pub fn new() -> Self {
73        Group {
74            tasks: Mutex::new(JoinSet::new()),
75            sem: None,
76        }
77    }
78
79    /// Limits the number of tasks that may run concurrently to at most `n`.
80    ///
81    /// A limit of `0` prevents any new task from running. A negative value
82    /// would indicate no limit, but since this is `usize`, no limit is the
83    /// default (created by [`new`][Group::new]).
84    ///
85    /// # Panics
86    ///
87    /// Must not be called while any tasks are active.
88    pub fn set_limit(&mut self, n: usize) {
89        self.sem = Some(Arc::new(Semaphore::new(n)));
90    }
91}
92
93impl<E> Default for Group<E>
94where
95    E: Send + 'static,
96{
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102// ---------------------------------------------------------------------------
103// Spawning
104// ---------------------------------------------------------------------------
105
106impl<E> Group<E>
107where
108    E: Send + 'static,
109{
110    /// Spawns a new task.
111    ///
112    /// If a concurrency limit was set with [`set_limit`] and the maximum
113    /// number of tasks is already running, the newly spawned task will block
114    /// **inside** itself (not the caller) until a slot opens up.
115    ///
116    /// This method returns immediately and never blocks the caller.
117    ///
118    /// # Panics
119    ///
120    /// The task **must not panic**. See the [crate-level docs][crate#panics].
121    ///
122    /// [`set_limit`]: Group::set_limit
123    pub fn spawn<F, Fut>(&self, f: F)
124    where
125        F: FnOnce() -> Fut + Send + 'static,
126        Fut: Future<Output = Result<(), E>> + Send + 'static,
127    {
128        let sem = self.sem.clone();
129        self.tasks.lock().unwrap().spawn(async move {
130            let _permit = match sem {
131                Some(ref sem) => Some(sem.clone().acquire_owned().await.expect("semaphore closed")),
132                None => None,
133            };
134            f().await
135        });
136    }
137
138    /// Tries to spawn a new task without waiting for a concurrency slot.
139    ///
140    /// If a limit is set and the maximum number of tasks is running, returns
141    /// `false` without spawning the task. Otherwise spawns the task and
142    /// returns `true`.
143    ///
144    /// If no limit was set, this is equivalent to [`spawn`][Group::spawn]
145    /// and always returns `true`.
146    pub fn try_spawn<F, Fut>(&self, f: F) -> bool
147    where
148        F: FnOnce() -> Fut + Send + 'static,
149        Fut: Future<Output = Result<(), E>> + Send + 'static,
150    {
151        let permit = match self.sem.as_ref() {
152            Some(sem) => match sem.clone().try_acquire_owned() {
153                Ok(p) => Some(p),
154                Err(_) => return false,
155            },
156            None => None,
157        };
158
159        self.tasks.lock().unwrap().spawn(async move {
160            let _permit = permit;
161            f().await
162        });
163        true
164    }
165}
166
167// ---------------------------------------------------------------------------
168// Wait
169// ---------------------------------------------------------------------------
170
171impl<E> Group<E>
172where
173    E: Send + 'static,
174{
175    /// Waits for all spawned tasks to complete.
176    ///
177    /// Returns the first error encountered, or `Ok(())` if all tasks succeeded.
178    ///
179    /// Consumes the group — no more tasks can be spawned after calling this.
180    pub async fn wait(self) -> Result<(), E> {
181        let mut tasks = self.tasks.into_inner().unwrap();
182        let mut first_error = None;
183
184        while let Some(result) = tasks.join_next().await {
185            match result {
186                Ok(Ok(())) => { /* success – nothing to do */ }
187                Ok(Err(e)) => {
188                    if first_error.is_none() {
189                        first_error = Some(e);
190                    }
191                }
192                Err(_) => {
193                    // Task panicked – silently ignored (documented).
194                }
195            }
196        }
197
198        match first_error {
199            Some(e) => Err(e),
200            None => Ok(()),
201        }
202    }
203}
204
205// ---------------------------------------------------------------------------
206// Tests
207// ---------------------------------------------------------------------------
208
209#[cfg(test)]
210mod test {
211    use std::sync::atomic::{AtomicUsize, Ordering};
212
213    use super::*;
214
215    #[tokio::test]
216    async fn no_tasks_returns_ok() {
217        let group = Group::<&str>::new();
218        assert_eq!(group.wait().await, Ok(()));
219    }
220
221    #[tokio::test]
222    async fn single_task_ok() {
223        let group = Group::<&str>::new();
224        group.spawn(|| async { Ok(()) });
225        assert_eq!(group.wait().await, Ok(()));
226    }
227
228    #[tokio::test]
229    async fn single_task_error() {
230        let group = Group::new();
231        group.spawn(|| async { Err("oops") });
232        assert_eq!(group.wait().await, Err("oops"));
233    }
234
235    #[tokio::test]
236    async fn multiple_tasks_all_ok() {
237        let group = Group::<&str>::new();
238        for _ in 0..10 {
239            group.spawn(|| async { Ok(()) });
240        }
241        assert_eq!(group.wait().await, Ok(()));
242    }
243
244    #[tokio::test]
245    async fn first_error_wins() {
246        let group = Group::<&str>::new();
247
248        let barrier = Arc::new(tokio::sync::Barrier::new(10));
249
250        for i in 0..10 {
251            let b = barrier.clone();
252            group.spawn(move || async move {
253                b.wait().await;
254                if i == 5 {
255                    Err("task 5 failed")
256                } else {
257                    // Sleep tasks long enough so the error doesn't race
258                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
259                    Ok(())
260                }
261            });
262        }
263
264        let result = group.wait().await;
265        assert_eq!(result, Err("task 5 failed"));
266    }
267
268    #[tokio::test]
269    async fn concurrency_limit_is_enforced() {
270        let mut group = Group::<String>::new();
271        group.set_limit(3);
272
273        let running = Arc::new(AtomicUsize::new(0));
274        let max_running = Arc::new(AtomicUsize::new(0));
275
276        for _ in 0..20 {
277            let running = running.clone();
278            let max_running = max_running.clone();
279            group.spawn(move || async move {
280                let prev = running.fetch_add(1, Ordering::SeqCst);
281                max_running.fetch_max(prev + 1, Ordering::SeqCst);
282
283                tokio::time::sleep(std::time::Duration::from_millis(30)).await;
284
285                running.fetch_sub(1, Ordering::SeqCst);
286                Ok(())
287            });
288        }
289
290        group.wait().await.unwrap();
291        assert!(
292            max_running.load(Ordering::SeqCst) <= 3,
293            "max was {}",
294            max_running.load(Ordering::SeqCst)
295        );
296    }
297
298    #[tokio::test]
299    async fn try_spawn_rejected_at_limit() {
300        let mut group = Group::<&str>::new();
301        group.set_limit(2);
302
303        // First two should succeed
304        assert!(group.try_spawn(|| async {
305            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
306            Ok(())
307        }));
308
309        assert!(group.try_spawn(|| async {
310            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
311            Ok(())
312        }));
313
314        // Third should be rejected
315        assert!(!group.try_spawn(|| async { Ok(()) }));
316
317        group.wait().await.unwrap();
318    }
319
320    #[tokio::test]
321    async fn try_spawn_without_limit_always_succeeds() {
322        let group = Group::<&str>::new();
323        assert!(group.try_spawn(|| async { Ok(()) }));
324        assert!(group.try_spawn(|| async { Ok(()) }));
325        group.wait().await.unwrap();
326    }
327
328    #[tokio::test]
329    async fn high_concurrency_stress() {
330        let mut group = Group::<String>::new();
331        group.set_limit(10);
332
333        let n = 500;
334        let counter = Arc::new(AtomicUsize::new(0));
335
336        for _i in 0..n {
337            let c = counter.clone();
338            group.spawn(move || async move {
339                tokio::time::sleep(std::time::Duration::from_micros(10)).await;
340                c.fetch_add(1, Ordering::SeqCst);
341                Ok::<(), String>(())
342            });
343        }
344
345        group.wait().await.unwrap();
346        assert_eq!(counter.load(Ordering::SeqCst), n);
347    }
348
349    /// When a task panics, the group should not propagate the panic
350    /// and should still be able to wait for remaining tasks.
351    #[tokio::test]
352    async fn panicking_task_is_ignored() {
353        let group = Group::new();
354
355        group.spawn(|| async {
356            panic!("this should be caught");
357        });
358
359        group.spawn(|| async {
360            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361            Ok::<(), &str>(())
362        });
363
364        // Should not propagate the panic.
365        let result = group.wait().await;
366        assert_eq!(result, Ok(()));
367    }
368}