Skip to main content

turbo_tasks/
parallel.rs

1//! Parallel for each and map using tokio tasks.
2//!
3//! This avoid the problem of sleeping threads with mimalloc when using rayon in combination with
4//! tokio. It also avoid having multiple thread pools.
5//!
6//! See also: <https://pwy.io/posts/mimalloc-cigarette/>
7
8use crate::{
9    scope::scope_and_block,
10    util::{Chunk, good_chunk_size, into_chunks},
11};
12
13struct Chunked {
14    chunk_size: usize,
15    chunk_count: usize,
16}
17
18fn get_chunked(len: usize) -> Option<Chunked> {
19    if len <= 1 {
20        return None;
21    }
22    let chunk_size = good_chunk_size(len);
23    let chunk_count = len.div_ceil(chunk_size);
24    if chunk_count <= 1 {
25        return None;
26    }
27    Some(Chunked {
28        chunk_size,
29        chunk_count,
30    })
31}
32
33pub fn for_each<'l, T, F>(items: &'l [T], f: F)
34where
35    T: Sync,
36    F: Fn(&'l T) + Send + Sync,
37{
38    let Some(Chunked {
39        chunk_size,
40        chunk_count,
41    }) = get_chunked(items.len())
42    else {
43        for item in items {
44            f(item);
45        }
46        return;
47    };
48    let f = &f;
49    let _results = scope_and_block(chunk_count, |scope| {
50        for chunk in items.chunks(chunk_size) {
51            scope.spawn(move || {
52                for item in chunk {
53                    f(item);
54                }
55            })
56        }
57    });
58}
59
60pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
61where
62    T: Send + Sync,
63{
64    let Some(Chunked {
65        chunk_size,
66        chunk_count,
67    }) = get_chunked(items.len())
68    else {
69        for item in items {
70            f(item);
71        }
72        return;
73    };
74    let f = &f;
75    let _results = scope_and_block(chunk_count, |scope| {
76        for chunk in into_chunks(items, chunk_size) {
77            scope.spawn(move || {
78                // SAFETY: Even when f() panics we drop all items in the chunk.
79                for item in chunk {
80                    f(item);
81                }
82            })
83        }
84    });
85}
86
87pub fn try_for_each<'l, T, E>(
88    items: &'l [T],
89    f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
90) -> Result<(), E>
91where
92    T: Sync,
93    E: Send + 'static,
94{
95    let Some(Chunked {
96        chunk_size,
97        chunk_count,
98    }) = get_chunked(items.len())
99    else {
100        for item in items {
101            f(item)?;
102        }
103        return Ok(());
104    };
105    let f = &f;
106    scope_and_block(chunk_count, |scope| {
107        for chunk in items.chunks(chunk_size) {
108            scope.spawn(move || {
109                for item in chunk {
110                    f(item)?;
111                }
112                Ok(())
113            })
114        }
115    })
116    .collect::<Result<(), E>>()
117}
118
119pub fn try_for_each_mut<'l, T, E>(
120    items: &'l mut [T],
121    f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
122) -> Result<(), E>
123where
124    T: Send + Sync,
125    E: Send + 'static,
126{
127    let Some(Chunked {
128        chunk_size,
129        chunk_count,
130    }) = get_chunked(items.len())
131    else {
132        for item in items {
133            f(item)?;
134        }
135        return Ok(());
136    };
137    let f = &f;
138    scope_and_block(chunk_count, |scope| {
139        for chunk in items.chunks_mut(chunk_size) {
140            scope.spawn(move || {
141                for item in chunk {
142                    f(item)?;
143                }
144                Ok(())
145            })
146        }
147    })
148    .collect::<Result<(), E>>()
149}
150
151pub fn try_for_each_owned<T, E>(
152    items: Vec<T>,
153    f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
154) -> Result<(), E>
155where
156    T: Send + Sync,
157    E: Send + 'static,
158{
159    let Some(Chunked {
160        chunk_size,
161        chunk_count,
162    }) = get_chunked(items.len())
163    else {
164        for item in items {
165            f(item)?;
166        }
167        return Ok(());
168    };
169    let f = &f;
170    scope_and_block(chunk_count, |scope| {
171        for chunk in into_chunks(items, chunk_size) {
172            scope.spawn(move || {
173                for item in chunk {
174                    f(item)?;
175                }
176                Ok(())
177            })
178        }
179    })
180    .collect::<Result<(), E>>()
181}
182
183pub fn map_collect<'l, Item, PerItemResult, Result>(
184    items: &'l [Item],
185    f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
186) -> Result
187where
188    Item: Sync,
189    PerItemResult: Send + Sync + 'l,
190    Result: FromIterator<PerItemResult>,
191{
192    let Some(Chunked {
193        chunk_size,
194        chunk_count,
195    }) = get_chunked(items.len())
196    else {
197        return Result::from_iter(items.iter().map(f));
198    };
199    let f = &f;
200    scope_and_block(chunk_count, |scope| {
201        for chunk in items.chunks(chunk_size) {
202            scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
203        }
204    })
205    .flatten()
206    .collect()
207}
208
209pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
210    items: Vec<Item>,
211    f: impl Fn(Item) -> PerItemResult + Send + Sync,
212) -> Result
213where
214    Item: Send + Sync,
215    PerItemResult: Send + Sync + 'l,
216    Result: FromIterator<PerItemResult>,
217{
218    let Some(Chunked {
219        chunk_size,
220        chunk_count,
221    }) = get_chunked(items.len())
222    else {
223        return Result::from_iter(items.into_iter().map(f));
224    };
225    let f = &f;
226    scope_and_block(chunk_count, |scope| {
227        for chunk in into_chunks(items, chunk_size) {
228            scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
229        }
230    })
231    .flatten()
232    .collect()
233}
234
235pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
236    items: Vec<Item>,
237    f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
238) -> Result
239where
240    Item: Send + Sync,
241    PerItemResult: Send + Sync + 'l,
242    Result: FromIterator<PerItemResult>,
243{
244    let Some(Chunked {
245        chunk_size,
246        chunk_count,
247    }) = get_chunked(items.len())
248    else {
249        let len = items.len();
250        return Result::from_iter(into_chunks(items, len).map(f));
251    };
252    let f = &f;
253    scope_and_block(chunk_count, |scope| {
254        for chunk in into_chunks(items, chunk_size) {
255            scope.spawn(move || f(chunk))
256        }
257    })
258    .collect()
259}
260
261#[cfg(test)]
262mod tests {
263    use std::{
264        panic::{AssertUnwindSafe, catch_unwind},
265        sync::atomic::{AtomicI32, Ordering},
266    };
267
268    use super::*;
269
270    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
271    async fn test_parallel_for_each() {
272        let input = vec![1, 2, 3, 4, 5];
273        let sum = AtomicI32::new(0);
274        for_each(&input, |&x| {
275            sum.fetch_add(x, Ordering::SeqCst);
276        });
277        assert_eq!(sum.load(Ordering::SeqCst), 15);
278    }
279
280    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
281    async fn test_parallel_try_for_each() {
282        let input = vec![1, 2, 3, 4, 5];
283        let result = try_for_each(&input, |&x| {
284            if x % 2 == 0 {
285                Ok(())
286            } else {
287                Err(format!("Odd number {x} encountered"))
288            }
289        });
290        assert!(result.is_err());
291        assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
292    }
293
294    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
295    async fn test_parallel_try_for_each_mut() {
296        let mut input = vec![1, 2, 3, 4, 5];
297        let result = try_for_each_mut(&mut input, |x| {
298            *x += 10;
299            if *x % 2 == 0 {
300                Ok(())
301            } else {
302                Err(format!("Odd number {} encountered", *x))
303            }
304        });
305        assert!(result.is_err());
306        assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
307        assert_eq!(input, vec![11, 12, 13, 14, 15]);
308    }
309
310    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
311    async fn test_parallel_for_each_owned() {
312        let input = vec![1, 2, 3, 4, 5];
313        let sum = AtomicI32::new(0);
314        for_each_owned(input, |x| {
315            sum.fetch_add(x, Ordering::SeqCst);
316        });
317        assert_eq!(sum.load(Ordering::SeqCst), 15);
318    }
319
320    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
321    async fn test_parallel_map_collect() {
322        let input = vec![1, 2, 3, 4, 5];
323        let result: Vec<_> = map_collect(&input, |&x| x * 2);
324        assert_eq!(result, vec![2, 4, 6, 8, 10]);
325    }
326
327    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
328    async fn test_parallel_map_collect_owned() {
329        let input = vec![1, 2, 3, 4, 5];
330        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
331        assert_eq!(result, vec![2, 4, 6, 8, 10]);
332    }
333
334    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
335    async fn test_parallel_map_collect_owned_many() {
336        let input = vec![1; 1000];
337        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
338        assert_eq!(result, vec![2; 1000]);
339    }
340
341    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
342    async fn test_panic_in_scope() {
343        let result = catch_unwind(AssertUnwindSafe(|| {
344            let mut input = vec![1; 1000];
345            input[744] = 2;
346            for_each(&input, |x| {
347                if *x == 2 {
348                    panic!("Intentional panic");
349                }
350            });
351            panic!("Should not get here")
352        }));
353        assert!(result.is_err());
354        assert_eq!(
355            result.unwrap_err().downcast_ref::<&str>(),
356            Some(&"Intentional panic")
357        );
358    }
359}