Skip to main content

singleflight/
singleflight.rs

1/// Deduplicate concurrent async function calls by key.
2///
3/// A [`Group`] ensures that only one execution is in-flight for a given key
4/// at a time. Duplicate callers wait for the original to finish and receive
5/// the same result. This is useful for cache stampede prevention, expensive
6/// lookups, or any scenario where concurrent identical work should be
7/// coalesced into a single operation.
8///
9/// # Example
10///
11/// ```rust
12/// use singleflight::Group;
13///
14/// # async fn example() {
15/// let group = Group::<String, String>::new();
16///
17/// let (tx, mut rx) = tokio::sync::mpsc::channel(16);
18///
19/// for _ in 0..10 {
20///     let group = group.clone();
21///     let mut tx = tx.clone();
22///     tokio::spawn(async move {
23///         let call = group
24///             .call_async("key".to_string(), || async {
25///                 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
26///                 "computed".to_string()
27///             })
28///             .await;
29///         tx.send(call.shared).await.unwrap();
30///     });
31/// }
32/// drop(tx);
33///
34/// // Exactly one caller was the original (shared = false), the rest are duplicates.
35/// let mut originals = 0;
36/// let mut duplicates = 0;
37/// while let Some(shared) = rx.recv().await {
38///     if shared { duplicates += 1; } else { originals += 1; }
39/// }
40/// assert_eq!(originals, 1);
41/// assert_eq!(duplicates, 9);
42/// # }
43/// ```
44use std::collections::HashMap;
45use std::{
46    future::Future,
47    hash::Hash,
48    sync::{
49        Arc, Mutex,
50        atomic::{AtomicUsize, Ordering},
51    },
52};
53
54use tokio::sync::Notify;
55
56/// The result of a singleflight call.
57///
58/// `val` is the value produced by the executed function.
59/// `shared` is `true` when this caller received a result that was computed
60/// by another caller rather than executing the function itself.
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Call<V> {
63    pub val: V,
64    pub shared: bool,
65}
66
67struct Entry<K, V> {
68    key: K,
69    val: Mutex<Option<V>>,
70    notify: Notify,
71    dups: AtomicUsize,
72}
73
74/// A `Group` represents a namespace in which work can be deduplicated by key.
75///
76/// Only one caller per key executes the function; concurrent callers for the
77/// same key receive the same result once it is ready.
78///
79/// # Panics
80///
81/// The function passed to [`call_async`][Group::call_async] **must not panic**.
82/// If it panics, all waiting callers will hang forever because the value will
83/// never be stored and they will never be notified.
84pub struct Group<K, V> {
85    inner: Arc<Mutex<HashMap<K, Arc<Entry<K, V>>>>>,
86}
87
88// --- inherent impls independent of Eq/Hash ----------------------------------
89
90impl<K, V> Group<K, V> {
91    /// Creates a new empty `Group`.
92    pub fn new() -> Self {
93        Group {
94            inner: Arc::new(Mutex::new(HashMap::new())),
95        }
96    }
97}
98
99impl<K, V> Default for Group<K, V> {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105// --- Clone ------------------------------------------------------------------
106
107impl<K, V> Clone for Group<K, V> {
108    fn clone(&self) -> Self {
109        Group {
110            inner: self.inner.clone(),
111        }
112    }
113}
114
115// --- API requiring Eq + Hash on K -------------------------------------------
116
117impl<K, V> Group<K, V>
118where
119    K: Eq + Hash,
120{
121    /// Tells the `Group` to forget about a key. Future calls to [`call_async`]
122    /// for this key will execute the function rather than waiting for an
123    /// earlier call to complete.
124    ///
125    /// If no call is in-flight for the given key, this is a no-op.
126    pub fn forget(&self, key: &K) {
127        let _ = self.inner.lock().unwrap().remove(key);
128    }
129}
130
131// --- call_async -------------------------------------------------------------
132
133impl<K, V> Group<K, V>
134where
135    K: Eq + Hash + Clone + Send + Sync,
136    V: Clone + Send,
137{
138    /// Executes `f` for the given `key`, coalescing duplicate callers.
139    ///
140    /// If no other call is in-flight for `key`, this call executes `f` and
141    /// returns its result with `shared = false`.
142    ///
143    /// If another call for the same key is already in-flight, this call
144    /// awaits that call's completion and returns the same result with
145    /// `shared = true`.
146    ///
147    /// # Panics
148    ///
149    /// `f` **must not panic**. See the [type-level docs][Group#panics].
150    pub async fn call_async<F, Fut>(&self, key: K, f: F) -> Call<V>
151    where
152        F: FnOnce() -> Fut + Send,
153        Fut: Future<Output = V> + Send,
154    {
155        // Scoped map access so the MutexGuard is dropped before any await.
156        let existing = {
157            let map = self.inner.lock().unwrap();
158            map.get(&key).map(|e| e.clone())
159        };
160
161        if let Some(entry) = existing {
162            entry.dups.fetch_add(1, Ordering::Relaxed);
163            loop {
164                let notified = entry.notify.notified();
165                if let Some(val) = entry.val.lock().unwrap().clone() {
166                    return Call {
167                        val,
168                        shared: true,
169                    };
170                }
171                notified.await;
172            }
173        }
174
175        // First caller – store key inside the entry so we don't hold `key`
176        // across the await point (avoids requiring `'static` on K).
177        let entry = Arc::new(Entry {
178            key, // moved into the entry
179            val: Mutex::new(None),
180            notify: Notify::new(),
181            dups: AtomicUsize::new(0),
182        });
183
184        // Insert into the map (scoped).
185        {
186            let mut map = self.inner.lock().unwrap();
187            map.insert(entry.key.clone(), entry.clone());
188        }
189
190        // **Must not panic.** If this panics, waiters block forever.
191        let result = f().await;
192
193        // Store the result and wake every waiter.
194        *entry.val.lock().unwrap() = Some(result.clone());
195        entry.notify.notify_waiters();
196
197        // Remove the entry from the map (scoped).
198        {
199            let mut map = self.inner.lock().unwrap();
200            map.remove(&entry.key);
201        }
202
203        Call {
204            val: result,
205            shared: false,
206        }
207    }
208}
209
210// ----------------------------------------------------------------------------
211// Tests
212// ----------------------------------------------------------------------------
213
214#[cfg(test)]
215mod test {
216    use super::*;
217
218    #[tokio::test]
219    async fn single_caller_is_not_shared() {
220        let group = Group::<&str, u32>::new();
221        let call = group.call_async("a", || async { 42 }).await;
222        assert_eq!(call.val, 42);
223        assert!(!call.shared);
224    }
225
226    #[tokio::test]
227    async fn duplicate_callers_share_result() {
228        let group = Group::<String, String>::new();
229
230        let barrier = Arc::new(tokio::sync::Barrier::new(10));
231        let started = Arc::new(AtomicUsize::new(0));
232
233        let mut handles = Vec::new();
234        for _ in 0..10 {
235            let group = group.clone();
236            let barrier = barrier.clone();
237            let started = started.clone();
238            handles.push(tokio::spawn(async move {
239                barrier.wait().await;
240                let call = group
241                    .call_async("key".to_string(), || async {
242                        started.fetch_add(1, Ordering::SeqCst);
243                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
244                        "hello".to_string()
245                    })
246                    .await;
247                call
248            }));
249        }
250
251        let mut originals = 0;
252        let mut duplicates = 0;
253        for h in handles {
254            let call = h.await.unwrap();
255            assert_eq!(call.val, "hello");
256            if call.shared {
257                duplicates += 1;
258            } else {
259                originals += 1;
260            }
261        }
262
263        assert_eq!(originals, 1);
264        assert_eq!(duplicates, 9);
265        // The function should have only executed once.
266        assert_eq!(started.load(Ordering::SeqCst), 1);
267    }
268
269    #[tokio::test]
270    async fn different_keys_dont_interfere() {
271        let group = Group::<String, u32>::new();
272
273        let (tx, mut rx) = tokio::sync::mpsc::channel(8);
274
275        for k in ["a", "b"] {
276            let group = group.clone();
277            let tx = tx.clone();
278            tokio::spawn(async move {
279                let call = group
280                    .call_async(k.to_string(), || async {
281                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
282                        1
283                    })
284                    .await;
285                tx.send((k.to_string(), call.shared)).await.unwrap();
286            });
287        }
288        drop(tx);
289
290        let mut results = HashMap::new();
291        while let Some((k, shared)) = rx.recv().await {
292            results.entry(k).or_insert(Vec::new()).push(shared);
293        }
294
295        assert_eq!(results.len(), 2);
296        for (_, shareds) in &results {
297            assert_eq!(shareds.len(), 1);
298            assert!(!shareds[0]); // each was the original for its key
299        }
300    }
301
302    #[tokio::test]
303    async fn forget_stops_dedup() {
304        let group = Group::<&str, u32>::new();
305
306        let first = group
307            .call_async("k", || async {
308                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
309                1
310            })
311            .await;
312        assert_eq!(first.val, 1);
313        assert!(!first.shared);
314
315        group.forget(&"k");
316
317        let second = group
318            .call_async("k", || async {
319                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
320                2
321            })
322            .await;
323        assert_eq!(second.val, 2);
324        assert!(!second.shared);
325    }
326
327    #[tokio::test]
328    async fn forget_unknown_key_is_noop() {
329        let group = Group::<&str, u32>::new();
330        group.forget(&"nonexistent");
331
332        let call = group.call_async("k", || async { 7 }).await;
333        assert_eq!(call.val, 7);
334    }
335
336    #[tokio::test]
337    async fn call_after_completion_starts_fresh() {
338        let group = Group::<&str, u32>::new();
339
340        let counter = Arc::new(AtomicUsize::new(0));
341
342        let c = counter.clone();
343        let first = group
344            .call_async("k", || async {
345                c.fetch_add(1, Ordering::SeqCst);
346                10
347            })
348            .await;
349        assert_eq!(first.val, 10);
350
351        // Wait a tiny bit so the completion has propagated.
352        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
353        // At this point the entry has been removed from the map.
354
355        let c = counter.clone();
356        let second = group
357            .call_async("k", || async {
358                c.fetch_add(1, Ordering::SeqCst);
359                20
360            })
361            .await;
362        assert_eq!(second.val, 20);
363        assert!(!second.shared);
364
365        assert_eq!(counter.load(Ordering::SeqCst), 2);
366    }
367
368    #[tokio::test]
369    async fn high_concurrency_stress() {
370        let group = Group::<String, usize>::new();
371        let n = 200_usize;
372        let started = Arc::new(AtomicUsize::new(0));
373
374        let mut handles = Vec::new();
375        for i in 0..n {
376            let group = group.clone();
377            let started = started.clone();
378            handles.push(tokio::spawn(async move {
379                let key = (i % 10).to_string();
380                let call = group
381                    .call_async(key, || async {
382                        started.fetch_add(1, Ordering::SeqCst);
383                        tokio::time::sleep(std::time::Duration::from_micros(100)).await;
384                        i
385                    })
386                    .await;
387                call
388            }));
389        }
390
391        for h in handles {
392            // val is the `i` of whichever caller won the race for that key.
393            // All 20 callers for the same key-mod-10 share the same result.
394            let _call = h.await.unwrap();
395        }
396
397        // At most 10 executions (one per key), but could be fewer if some
398        // hadn't started when the first for that key finished.
399        assert!(started.load(Ordering::SeqCst) <= 10);
400    }
401
402    /// Ensures that forgetting a key while a call is in-flight doesn't
403    /// hang the original caller (they hold an `Arc`), and a subsequent
404    /// call for the same key starts a fresh execution.
405    #[tokio::test]
406    async fn forget_during_flight_lets_new_callers_start_fresh() {
407        let group = Group::<&str, u32>::new();
408        let run_count = Arc::new(AtomicUsize::new(0));
409
410        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
411        let r1 = run_count.clone();
412
413        let group_for_spawn = group.clone();
414        let h1 = tokio::spawn(async move {
415            let call = group_for_spawn
416                .call_async("k", || async {
417                    r1.fetch_add(1, Ordering::SeqCst);
418                    let _ = tx.send(());
419                    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
420                    42
421                })
422                .await;
423            call
424        });
425
426        // Wait until the first caller has started but not finished.
427        rx.await.unwrap();
428        tokio::time::sleep(std::time::Duration::from_millis(5)).await;
429        assert_eq!(run_count.load(Ordering::SeqCst), 1);
430
431        // Forget the key while the call is in-flight.
432        group.forget(&"k");
433
434        // Start a second caller for the same key after forget.
435        // It should execute its own function rather than waiting.
436        let r2 = run_count.clone();
437        let group_for_spawn2 = group.clone();
438        let h2 = tokio::spawn(async move {
439            let call = group_for_spawn2
440                .call_async("k", || async {
441                    r2.fetch_add(1, Ordering::SeqCst);
442                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
443                    99
444                })
445                .await;
446            call
447        });
448
449        let c1 = h1.await.unwrap();
450        let c2 = h2.await.unwrap();
451
452        // Both functions executed independently.
453        assert_eq!(run_count.load(Ordering::SeqCst), 2);
454
455        assert_eq!(c1.val, 42);
456        assert!(!c1.shared);
457
458        assert_eq!(c2.val, 99);
459        assert!(!c2.shared);
460    }
461}