turbo_tasks/
parallel.rs

1//! Parallel for each and map using tokio tasks.
2//!
3//! This avoid the problem of sleeping threads with mimalloc when using rayon in combination with
4//! tokio. It also avoid having multiple thread pools.
5//! see also https://pwy.io/posts/mimalloc-cigarette/
6
7use crate::{
8    scope::scope_and_block,
9    util::{Chunk, good_chunk_size, into_chunks},
10};
11
12struct Chunked {
13    chunk_size: usize,
14    chunk_count: usize,
15}
16
17fn get_chunked(len: usize) -> Option<Chunked> {
18    if len <= 1 {
19        return None;
20    }
21    let chunk_size = good_chunk_size(len);
22    let chunk_count = len.div_ceil(chunk_size);
23    if chunk_count <= 1 {
24        return None;
25    }
26    Some(Chunked {
27        chunk_size,
28        chunk_count,
29    })
30}
31
32pub fn for_each<'l, T, F>(items: &'l [T], f: F)
33where
34    T: Sync,
35    F: Fn(&'l T) + Send + Sync,
36{
37    let Some(Chunked {
38        chunk_size,
39        chunk_count,
40    }) = get_chunked(items.len())
41    else {
42        for item in items {
43            f(item);
44        }
45        return;
46    };
47    let f = &f;
48    let _results = scope_and_block(chunk_count, |scope| {
49        for chunk in items.chunks(chunk_size) {
50            scope.spawn(move || {
51                for item in chunk {
52                    f(item);
53                }
54            })
55        }
56    });
57}
58
59pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
60where
61    T: Send + Sync,
62{
63    let Some(Chunked {
64        chunk_size,
65        chunk_count,
66    }) = get_chunked(items.len())
67    else {
68        for item in items {
69            f(item);
70        }
71        return;
72    };
73    let f = &f;
74    let _results = scope_and_block(chunk_count, |scope| {
75        for chunk in into_chunks(items, chunk_size) {
76            scope.spawn(move || {
77                // SAFETY: Even when f() panics we drop all items in the chunk.
78                for item in chunk {
79                    f(item);
80                }
81            })
82        }
83    });
84}
85
86pub fn try_for_each<'l, T, E>(
87    items: &'l [T],
88    f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
89) -> Result<(), E>
90where
91    T: Sync,
92    E: Send + 'static,
93{
94    let Some(Chunked {
95        chunk_size,
96        chunk_count,
97    }) = get_chunked(items.len())
98    else {
99        for item in items {
100            f(item)?;
101        }
102        return Ok(());
103    };
104    let f = &f;
105    scope_and_block(chunk_count, |scope| {
106        for chunk in items.chunks(chunk_size) {
107            scope.spawn(move || {
108                for item in chunk {
109                    f(item)?;
110                }
111                Ok(())
112            })
113        }
114    })
115    .collect::<Result<(), E>>()
116}
117
118pub fn try_for_each_mut<'l, T, E>(
119    items: &'l mut [T],
120    f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
121) -> Result<(), E>
122where
123    T: Send + Sync,
124    E: Send + 'static,
125{
126    let Some(Chunked {
127        chunk_size,
128        chunk_count,
129    }) = get_chunked(items.len())
130    else {
131        for item in items {
132            f(item)?;
133        }
134        return Ok(());
135    };
136    let f = &f;
137    scope_and_block(chunk_count, |scope| {
138        for chunk in items.chunks_mut(chunk_size) {
139            scope.spawn(move || {
140                for item in chunk {
141                    f(item)?;
142                }
143                Ok(())
144            })
145        }
146    })
147    .collect::<Result<(), E>>()
148}
149
150pub fn try_for_each_owned<T, E>(
151    items: Vec<T>,
152    f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
153) -> Result<(), E>
154where
155    T: Send + Sync,
156    E: Send + 'static,
157{
158    let Some(Chunked {
159        chunk_size,
160        chunk_count,
161    }) = get_chunked(items.len())
162    else {
163        for item in items {
164            f(item)?;
165        }
166        return Ok(());
167    };
168    let f = &f;
169    scope_and_block(chunk_count, |scope| {
170        for chunk in into_chunks(items, chunk_size) {
171            scope.spawn(move || {
172                for item in chunk {
173                    f(item)?;
174                }
175                Ok(())
176            })
177        }
178    })
179    .collect::<Result<(), E>>()
180}
181
182pub fn map_collect<'l, Item, PerItemResult, Result>(
183    items: &'l [Item],
184    f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
185) -> Result
186where
187    Item: Sync,
188    PerItemResult: Send + Sync + 'l,
189    Result: FromIterator<PerItemResult>,
190{
191    let Some(Chunked {
192        chunk_size,
193        chunk_count,
194    }) = get_chunked(items.len())
195    else {
196        return Result::from_iter(items.iter().map(f));
197    };
198    let f = &f;
199    scope_and_block(chunk_count, |scope| {
200        for chunk in items.chunks(chunk_size) {
201            scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
202        }
203    })
204    .flatten()
205    .collect()
206}
207
208pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
209    items: Vec<Item>,
210    f: impl Fn(Item) -> PerItemResult + Send + Sync,
211) -> Result
212where
213    Item: Send + Sync,
214    PerItemResult: Send + Sync + 'l,
215    Result: FromIterator<PerItemResult>,
216{
217    let Some(Chunked {
218        chunk_size,
219        chunk_count,
220    }) = get_chunked(items.len())
221    else {
222        return Result::from_iter(items.into_iter().map(f));
223    };
224    let f = &f;
225    scope_and_block(chunk_count, |scope| {
226        for chunk in into_chunks(items, chunk_size) {
227            scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
228        }
229    })
230    .flatten()
231    .collect()
232}
233
234pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
235    items: Vec<Item>,
236    f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
237) -> Result
238where
239    Item: Send + Sync,
240    PerItemResult: Send + Sync + 'l,
241    Result: FromIterator<PerItemResult>,
242{
243    let Some(Chunked {
244        chunk_size,
245        chunk_count,
246    }) = get_chunked(items.len())
247    else {
248        let len = items.len();
249        return Result::from_iter(into_chunks(items, len).map(f));
250    };
251    let f = &f;
252    scope_and_block(chunk_count, |scope| {
253        for chunk in into_chunks(items, chunk_size) {
254            scope.spawn(move || f(chunk))
255        }
256    })
257    .collect()
258}
259
260#[cfg(test)]
261mod tests {
262    use std::{
263        panic::{AssertUnwindSafe, catch_unwind},
264        sync::atomic::{AtomicI32, Ordering},
265    };
266
267    use super::*;
268
269    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
270    async fn test_parallel_for_each() {
271        let input = vec![1, 2, 3, 4, 5];
272        let sum = AtomicI32::new(0);
273        for_each(&input, |&x| {
274            sum.fetch_add(x, Ordering::SeqCst);
275        });
276        assert_eq!(sum.load(Ordering::SeqCst), 15);
277    }
278
279    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
280    async fn test_parallel_try_for_each() {
281        let input = vec![1, 2, 3, 4, 5];
282        let result = try_for_each(&input, |&x| {
283            if x % 2 == 0 {
284                Ok(())
285            } else {
286                Err(format!("Odd number {x} encountered"))
287            }
288        });
289        assert!(result.is_err());
290        assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
291    }
292
293    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
294    async fn test_parallel_try_for_each_mut() {
295        let mut input = vec![1, 2, 3, 4, 5];
296        let result = try_for_each_mut(&mut input, |x| {
297            *x += 10;
298            if *x % 2 == 0 {
299                Ok(())
300            } else {
301                Err(format!("Odd number {} encountered", *x))
302            }
303        });
304        assert!(result.is_err());
305        assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
306        assert_eq!(input, vec![11, 12, 13, 14, 15]);
307    }
308
309    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
310    async fn test_parallel_for_each_owned() {
311        let input = vec![1, 2, 3, 4, 5];
312        let sum = AtomicI32::new(0);
313        for_each_owned(input, |x| {
314            sum.fetch_add(x, Ordering::SeqCst);
315        });
316        assert_eq!(sum.load(Ordering::SeqCst), 15);
317    }
318
319    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
320    async fn test_parallel_map_collect() {
321        let input = vec![1, 2, 3, 4, 5];
322        let result: Vec<_> = map_collect(&input, |&x| x * 2);
323        assert_eq!(result, vec![2, 4, 6, 8, 10]);
324    }
325
326    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
327    async fn test_parallel_map_collect_owned() {
328        let input = vec![1, 2, 3, 4, 5];
329        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
330        assert_eq!(result, vec![2, 4, 6, 8, 10]);
331    }
332
333    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
334    async fn test_parallel_map_collect_owned_many() {
335        let input = vec![1; 1000];
336        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
337        assert_eq!(result, vec![2; 1000]);
338    }
339
340    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341    async fn test_panic_in_scope() {
342        let result = catch_unwind(AssertUnwindSafe(|| {
343            let mut input = vec![1; 1000];
344            input[744] = 2;
345            for_each(&input, |x| {
346                if *x == 2 {
347                    panic!("Intentional panic");
348                }
349            });
350            panic!("Should not get here")
351        }));
352        assert!(result.is_err());
353        assert_eq!(
354            result.unwrap_err().downcast_ref::<&str>(),
355            Some(&"Intentional panic")
356        );
357    }
358}