Skip to main content

turbo_tasks/
parallel.rs

1//! Parallel for each and map using tokio tasks.
2//!
3//! This avoid the problem of sleeping threads with mimalloc when using rayon in combination with
4//! tokio. It also avoid having multiple thread pools.
5//!
6//! See also: <https://pwy.io/posts/mimalloc-cigarette/>
7
8use std::{
9    env, io,
10    num::NonZeroUsize,
11    sync::{Arc, OnceLock},
12    thread,
13};
14
15use crate::{
16    scope::scope_and_block,
17    util::{Chunk, good_chunk_size, into_chunks},
18};
19
20/// Returns the recommended amount of parallelism for the current process. Typically the number of
21/// available CPU cores.
22///
23/// This wraps [`std::thread::available_parallelism`] with a couple extras:
24///
25/// - If the `TURBO_TASKS_AVAILABLE_PARALLELISM` env var is set, overrides the value. Panics if this
26///   env var fails to parse.
27/// - The resolved value is cached in a [`OnceLock`]
28pub fn available_parallelism() -> Result<NonZeroUsize, Arc<io::Error>> {
29    static CACHED: OnceLock<Result<NonZeroUsize, Arc<io::Error>>> = OnceLock::new();
30    CACHED
31        .get_or_init(|| {
32            if let Ok(raw) = env::var("TURBO_TASKS_AVAILABLE_PARALLELISM") {
33                Ok(raw.parse::<NonZeroUsize>().unwrap_or_else(|err| {
34                    panic!("Invalid TURBO_TASKS_AVAILABLE_PARALLELISM={raw:?}: {err}")
35                }))
36            } else {
37                thread::available_parallelism().map_err(Arc::new)
38            }
39        })
40        .clone()
41}
42
43struct Chunked {
44    chunk_size: usize,
45    chunk_count: usize,
46}
47
48fn get_chunked(len: usize) -> Option<Chunked> {
49    if len <= 1 {
50        return None;
51    }
52    let chunk_size = good_chunk_size(len);
53    let chunk_count = len.div_ceil(chunk_size);
54    if chunk_count <= 1 {
55        return None;
56    }
57    Some(Chunked {
58        chunk_size,
59        chunk_count,
60    })
61}
62
63pub fn for_each<'l, T, F>(items: &'l [T], f: F)
64where
65    T: Sync,
66    F: Fn(&'l T) + Send + Sync,
67{
68    let Some(Chunked {
69        chunk_size,
70        chunk_count,
71    }) = get_chunked(items.len())
72    else {
73        for item in items {
74            f(item);
75        }
76        return;
77    };
78    let f = &f;
79    let _results = scope_and_block(chunk_count, |scope| {
80        for chunk in items.chunks(chunk_size) {
81            scope.spawn(move || {
82                for item in chunk {
83                    f(item);
84                }
85            })
86        }
87    });
88}
89
90pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
91where
92    T: Send + Sync,
93{
94    let Some(Chunked {
95        chunk_size,
96        chunk_count,
97    }) = get_chunked(items.len())
98    else {
99        for item in items {
100            f(item);
101        }
102        return;
103    };
104    let f = &f;
105    let _results = scope_and_block(chunk_count, |scope| {
106        for chunk in into_chunks(items, chunk_size) {
107            scope.spawn(move || {
108                // SAFETY: Even when f() panics we drop all items in the chunk.
109                for item in chunk {
110                    f(item);
111                }
112            })
113        }
114    });
115}
116
117pub fn try_for_each<'l, T, E>(
118    items: &'l [T],
119    f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
120) -> Result<(), E>
121where
122    T: Sync,
123    E: Send + 'static,
124{
125    let Some(Chunked {
126        chunk_size,
127        chunk_count,
128    }) = get_chunked(items.len())
129    else {
130        for item in items {
131            f(item)?;
132        }
133        return Ok(());
134    };
135    let f = &f;
136    scope_and_block(chunk_count, |scope| {
137        for chunk in items.chunks(chunk_size) {
138            scope.spawn(move || {
139                for item in chunk {
140                    f(item)?;
141                }
142                Ok(())
143            })
144        }
145    })
146    .collect::<Result<(), E>>()
147}
148
149pub fn try_for_each_mut<'l, T, E>(
150    items: &'l mut [T],
151    f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
152) -> Result<(), E>
153where
154    T: Send + Sync,
155    E: Send + 'static,
156{
157    let Some(Chunked {
158        chunk_size,
159        chunk_count,
160    }) = get_chunked(items.len())
161    else {
162        for item in items {
163            f(item)?;
164        }
165        return Ok(());
166    };
167    let f = &f;
168    scope_and_block(chunk_count, |scope| {
169        for chunk in items.chunks_mut(chunk_size) {
170            scope.spawn(move || {
171                for item in chunk {
172                    f(item)?;
173                }
174                Ok(())
175            })
176        }
177    })
178    .collect::<Result<(), E>>()
179}
180
181pub fn try_for_each_owned<T, E>(
182    items: Vec<T>,
183    f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
184) -> Result<(), E>
185where
186    T: Send + Sync,
187    E: Send + 'static,
188{
189    let Some(Chunked {
190        chunk_size,
191        chunk_count,
192    }) = get_chunked(items.len())
193    else {
194        for item in items {
195            f(item)?;
196        }
197        return Ok(());
198    };
199    let f = &f;
200    scope_and_block(chunk_count, |scope| {
201        for chunk in into_chunks(items, chunk_size) {
202            scope.spawn(move || {
203                for item in chunk {
204                    f(item)?;
205                }
206                Ok(())
207            })
208        }
209    })
210    .collect::<Result<(), E>>()
211}
212
213pub fn map_collect<'l, Item, PerItemResult, Result>(
214    items: &'l [Item],
215    f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
216) -> Result
217where
218    Item: Sync,
219    PerItemResult: Send + Sync + 'l,
220    Result: FromIterator<PerItemResult>,
221{
222    let Some(Chunked {
223        chunk_size,
224        chunk_count,
225    }) = get_chunked(items.len())
226    else {
227        return Result::from_iter(items.iter().map(f));
228    };
229    let f = &f;
230    scope_and_block(chunk_count, |scope| {
231        for chunk in items.chunks(chunk_size) {
232            scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
233        }
234    })
235    .flatten()
236    .collect()
237}
238
239pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
240    items: Vec<Item>,
241    f: impl Fn(Item) -> PerItemResult + Send + Sync,
242) -> Result
243where
244    Item: Send + Sync,
245    PerItemResult: Send + Sync + 'l,
246    Result: FromIterator<PerItemResult>,
247{
248    let Some(Chunked {
249        chunk_size,
250        chunk_count,
251    }) = get_chunked(items.len())
252    else {
253        return Result::from_iter(items.into_iter().map(f));
254    };
255    let f = &f;
256    scope_and_block(chunk_count, |scope| {
257        for chunk in into_chunks(items, chunk_size) {
258            scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
259        }
260    })
261    .flatten()
262    .collect()
263}
264
265pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
266    items: Vec<Item>,
267    f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
268) -> Result
269where
270    Item: Send + Sync,
271    PerItemResult: Send + Sync + 'l,
272    Result: FromIterator<PerItemResult>,
273{
274    let Some(Chunked {
275        chunk_size,
276        chunk_count,
277    }) = get_chunked(items.len())
278    else {
279        let len = items.len();
280        return Result::from_iter(into_chunks(items, len).map(f));
281    };
282    let f = &f;
283    scope_and_block(chunk_count, |scope| {
284        for chunk in into_chunks(items, chunk_size) {
285            scope.spawn(move || f(chunk))
286        }
287    })
288    .collect()
289}
290
291#[cfg(test)]
292mod tests {
293    use std::{
294        panic::{AssertUnwindSafe, catch_unwind},
295        sync::atomic::{AtomicI32, Ordering},
296    };
297
298    use super::*;
299
300    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
301    async fn test_parallel_for_each() {
302        let input = vec![1, 2, 3, 4, 5];
303        let sum = AtomicI32::new(0);
304        for_each(&input, |&x| {
305            sum.fetch_add(x, Ordering::SeqCst);
306        });
307        assert_eq!(sum.load(Ordering::SeqCst), 15);
308    }
309
310    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
311    async fn test_parallel_try_for_each() {
312        let input = vec![1, 2, 3, 4, 5];
313        let result = try_for_each(&input, |&x| {
314            if x % 2 == 0 {
315                Ok(())
316            } else {
317                Err(format!("Odd number {x} encountered"))
318            }
319        });
320        assert!(result.is_err());
321        assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
322    }
323
324    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
325    async fn test_parallel_try_for_each_mut() {
326        let mut input = vec![1, 2, 3, 4, 5];
327        let result = try_for_each_mut(&mut input, |x| {
328            *x += 10;
329            if *x % 2 == 0 {
330                Ok(())
331            } else {
332                Err(format!("Odd number {} encountered", *x))
333            }
334        });
335        assert!(result.is_err());
336        assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
337        assert_eq!(input, vec![11, 12, 13, 14, 15]);
338    }
339
340    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341    async fn test_parallel_for_each_owned() {
342        let input = vec![1, 2, 3, 4, 5];
343        let sum = AtomicI32::new(0);
344        for_each_owned(input, |x| {
345            sum.fetch_add(x, Ordering::SeqCst);
346        });
347        assert_eq!(sum.load(Ordering::SeqCst), 15);
348    }
349
350    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
351    async fn test_parallel_map_collect() {
352        let input = vec![1, 2, 3, 4, 5];
353        let result: Vec<_> = map_collect(&input, |&x| x * 2);
354        assert_eq!(result, vec![2, 4, 6, 8, 10]);
355    }
356
357    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
358    async fn test_parallel_map_collect_owned() {
359        let input = vec![1, 2, 3, 4, 5];
360        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
361        assert_eq!(result, vec![2, 4, 6, 8, 10]);
362    }
363
364    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
365    async fn test_parallel_map_collect_owned_many() {
366        let input = vec![1; 1000];
367        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
368        assert_eq!(result, vec![2; 1000]);
369    }
370
371    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
372    async fn test_panic_in_scope() {
373        let result = catch_unwind(AssertUnwindSafe(|| {
374            let mut input = vec![1; 1000];
375            input[744] = 2;
376            for_each(&input, |x| {
377                if *x == 2 {
378                    panic!("Intentional panic");
379                }
380            });
381            panic!("Should not get here")
382        }));
383        assert!(result.is_err());
384        assert_eq!(
385            result.unwrap_err().downcast_ref::<&str>(),
386            Some(&"Intentional panic")
387        );
388    }
389}