1use std::collections::HashMap;
45use std::{
46 future::Future,
47 hash::Hash,
48 sync::{
49 Arc, Mutex,
50 atomic::{AtomicUsize, Ordering},
51 },
52};
53
54use tokio::sync::Notify;
55
56#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct Call<V> {
63 pub val: V,
64 pub shared: bool,
65}
66
67struct Entry<K, V> {
68 key: K,
69 val: Mutex<Option<V>>,
70 notify: Notify,
71 dups: AtomicUsize,
72}
73
74pub struct Group<K, V> {
85 inner: Arc<Mutex<HashMap<K, Arc<Entry<K, V>>>>>,
86}
87
88impl<K, V> Group<K, V> {
91 pub fn new() -> Self {
93 Group {
94 inner: Arc::new(Mutex::new(HashMap::new())),
95 }
96 }
97}
98
99impl<K, V> Default for Group<K, V> {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl<K, V> Clone for Group<K, V> {
108 fn clone(&self) -> Self {
109 Group {
110 inner: self.inner.clone(),
111 }
112 }
113}
114
115impl<K, V> Group<K, V>
118where
119 K: Eq + Hash,
120{
121 pub fn forget(&self, key: &K) {
127 let _ = self.inner.lock().unwrap().remove(key);
128 }
129}
130
131impl<K, V> Group<K, V>
134where
135 K: Eq + Hash + Clone + Send + Sync,
136 V: Clone + Send,
137{
138 pub async fn call_async<F, Fut>(&self, key: K, f: F) -> Call<V>
151 where
152 F: FnOnce() -> Fut + Send,
153 Fut: Future<Output = V> + Send,
154 {
155 let existing = {
157 let map = self.inner.lock().unwrap();
158 map.get(&key).map(|e| e.clone())
159 };
160
161 if let Some(entry) = existing {
162 entry.dups.fetch_add(1, Ordering::Relaxed);
163 loop {
164 let notified = entry.notify.notified();
165 if let Some(val) = entry.val.lock().unwrap().clone() {
166 return Call {
167 val,
168 shared: true,
169 };
170 }
171 notified.await;
172 }
173 }
174
175 let entry = Arc::new(Entry {
178 key, val: Mutex::new(None),
180 notify: Notify::new(),
181 dups: AtomicUsize::new(0),
182 });
183
184 {
186 let mut map = self.inner.lock().unwrap();
187 map.insert(entry.key.clone(), entry.clone());
188 }
189
190 let result = f().await;
192
193 *entry.val.lock().unwrap() = Some(result.clone());
195 entry.notify.notify_waiters();
196
197 {
199 let mut map = self.inner.lock().unwrap();
200 map.remove(&entry.key);
201 }
202
203 Call {
204 val: result,
205 shared: false,
206 }
207 }
208}
209
210#[cfg(test)]
215mod test {
216 use super::*;
217
218 #[tokio::test]
219 async fn single_caller_is_not_shared() {
220 let group = Group::<&str, u32>::new();
221 let call = group.call_async("a", || async { 42 }).await;
222 assert_eq!(call.val, 42);
223 assert!(!call.shared);
224 }
225
226 #[tokio::test]
227 async fn duplicate_callers_share_result() {
228 let group = Group::<String, String>::new();
229
230 let barrier = Arc::new(tokio::sync::Barrier::new(10));
231 let started = Arc::new(AtomicUsize::new(0));
232
233 let mut handles = Vec::new();
234 for _ in 0..10 {
235 let group = group.clone();
236 let barrier = barrier.clone();
237 let started = started.clone();
238 handles.push(tokio::spawn(async move {
239 barrier.wait().await;
240 let call = group
241 .call_async("key".to_string(), || async {
242 started.fetch_add(1, Ordering::SeqCst);
243 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
244 "hello".to_string()
245 })
246 .await;
247 call
248 }));
249 }
250
251 let mut originals = 0;
252 let mut duplicates = 0;
253 for h in handles {
254 let call = h.await.unwrap();
255 assert_eq!(call.val, "hello");
256 if call.shared {
257 duplicates += 1;
258 } else {
259 originals += 1;
260 }
261 }
262
263 assert_eq!(originals, 1);
264 assert_eq!(duplicates, 9);
265 assert_eq!(started.load(Ordering::SeqCst), 1);
267 }
268
269 #[tokio::test]
270 async fn different_keys_dont_interfere() {
271 let group = Group::<String, u32>::new();
272
273 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
274
275 for k in ["a", "b"] {
276 let group = group.clone();
277 let tx = tx.clone();
278 tokio::spawn(async move {
279 let call = group
280 .call_async(k.to_string(), || async {
281 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
282 1
283 })
284 .await;
285 tx.send((k.to_string(), call.shared)).await.unwrap();
286 });
287 }
288 drop(tx);
289
290 let mut results = HashMap::new();
291 while let Some((k, shared)) = rx.recv().await {
292 results.entry(k).or_insert(Vec::new()).push(shared);
293 }
294
295 assert_eq!(results.len(), 2);
296 for (_, shareds) in &results {
297 assert_eq!(shareds.len(), 1);
298 assert!(!shareds[0]); }
300 }
301
302 #[tokio::test]
303 async fn forget_stops_dedup() {
304 let group = Group::<&str, u32>::new();
305
306 let first = group
307 .call_async("k", || async {
308 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
309 1
310 })
311 .await;
312 assert_eq!(first.val, 1);
313 assert!(!first.shared);
314
315 group.forget(&"k");
316
317 let second = group
318 .call_async("k", || async {
319 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
320 2
321 })
322 .await;
323 assert_eq!(second.val, 2);
324 assert!(!second.shared);
325 }
326
327 #[tokio::test]
328 async fn forget_unknown_key_is_noop() {
329 let group = Group::<&str, u32>::new();
330 group.forget(&"nonexistent");
331
332 let call = group.call_async("k", || async { 7 }).await;
333 assert_eq!(call.val, 7);
334 }
335
336 #[tokio::test]
337 async fn call_after_completion_starts_fresh() {
338 let group = Group::<&str, u32>::new();
339
340 let counter = Arc::new(AtomicUsize::new(0));
341
342 let c = counter.clone();
343 let first = group
344 .call_async("k", || async {
345 c.fetch_add(1, Ordering::SeqCst);
346 10
347 })
348 .await;
349 assert_eq!(first.val, 10);
350
351 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
353 let c = counter.clone();
356 let second = group
357 .call_async("k", || async {
358 c.fetch_add(1, Ordering::SeqCst);
359 20
360 })
361 .await;
362 assert_eq!(second.val, 20);
363 assert!(!second.shared);
364
365 assert_eq!(counter.load(Ordering::SeqCst), 2);
366 }
367
368 #[tokio::test]
369 async fn high_concurrency_stress() {
370 let group = Group::<String, usize>::new();
371 let n = 200_usize;
372 let started = Arc::new(AtomicUsize::new(0));
373
374 let mut handles = Vec::new();
375 for i in 0..n {
376 let group = group.clone();
377 let started = started.clone();
378 handles.push(tokio::spawn(async move {
379 let key = (i % 10).to_string();
380 let call = group
381 .call_async(key, || async {
382 started.fetch_add(1, Ordering::SeqCst);
383 tokio::time::sleep(std::time::Duration::from_micros(100)).await;
384 i
385 })
386 .await;
387 call
388 }));
389 }
390
391 for h in handles {
392 let _call = h.await.unwrap();
395 }
396
397 assert!(started.load(Ordering::SeqCst) <= 10);
400 }
401
402 #[tokio::test]
406 async fn forget_during_flight_lets_new_callers_start_fresh() {
407 let group = Group::<&str, u32>::new();
408 let run_count = Arc::new(AtomicUsize::new(0));
409
410 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
411 let r1 = run_count.clone();
412
413 let group_for_spawn = group.clone();
414 let h1 = tokio::spawn(async move {
415 let call = group_for_spawn
416 .call_async("k", || async {
417 r1.fetch_add(1, Ordering::SeqCst);
418 let _ = tx.send(());
419 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
420 42
421 })
422 .await;
423 call
424 });
425
426 rx.await.unwrap();
428 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
429 assert_eq!(run_count.load(Ordering::SeqCst), 1);
430
431 group.forget(&"k");
433
434 let r2 = run_count.clone();
437 let group_for_spawn2 = group.clone();
438 let h2 = tokio::spawn(async move {
439 let call = group_for_spawn2
440 .call_async("k", || async {
441 r2.fetch_add(1, Ordering::SeqCst);
442 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
443 99
444 })
445 .await;
446 call
447 });
448
449 let c1 = h1.await.unwrap();
450 let c2 = h2.await.unwrap();
451
452 assert_eq!(run_count.load(Ordering::SeqCst), 2);
454
455 assert_eq!(c1.val, 42);
456 assert!(!c1.shared);
457
458 assert_eq!(c2.val, 99);
459 assert!(!c2.shared);
460 }
461}