turbo_tasks/
parallel.rs

1//! Parallel for each and map using tokio tasks.
2//!
3//! This avoid the problem of sleeping threads with mimalloc when using rayon in combination with
4//! tokio. It also avoid having multiple thread pools.
5//! see also https://pwy.io/posts/mimalloc-cigarette/
6
7use crate::{
8    scope::scope_and_block,
9    util::{good_chunk_size, into_chunks},
10};
11
12pub fn for_each<'l, T, F>(items: &'l [T], f: F)
13where
14    T: Sync,
15    F: Fn(&'l T) + Send + Sync,
16{
17    let len = items.len();
18    if len <= 1 {
19        for item in items {
20            f(item);
21        }
22        return;
23    }
24    let chunk_size = good_chunk_size(len);
25    let f = &f;
26    let _results = scope_and_block(len.div_ceil(chunk_size), |scope| {
27        for chunk in items.chunks(chunk_size) {
28            scope.spawn(async move {
29                for item in chunk {
30                    f(item);
31                }
32            })
33        }
34    });
35}
36
37pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
38where
39    T: Send + Sync,
40{
41    let len = items.len();
42    if len <= 1 {
43        for item in items {
44            f(item);
45        }
46        return;
47    }
48    let chunk_size = good_chunk_size(len);
49    let f = &f;
50    let _results = scope_and_block(len.div_ceil(chunk_size), |scope| {
51        for chunk in into_chunks(items, chunk_size) {
52            scope.spawn(async move {
53                // SAFETY: Even when f() panics we drop all items in the chunk.
54                for item in chunk {
55                    f(item);
56                }
57            })
58        }
59    });
60}
61
62pub fn try_for_each<'l, T, E>(
63    items: &'l [T],
64    f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
65) -> Result<(), E>
66where
67    T: Sync,
68    E: Send + 'static,
69{
70    let len = items.len();
71    if len <= 1 {
72        for item in items {
73            f(item)?;
74        }
75        return Ok(());
76    }
77    let chunk_size = good_chunk_size(len);
78    let f = &f;
79    scope_and_block(len.div_ceil(chunk_size), |scope| {
80        for chunk in items.chunks(chunk_size) {
81            scope.spawn(async move {
82                for item in chunk {
83                    f(item)?;
84                }
85                Ok(())
86            })
87        }
88    })
89    .collect::<Result<(), E>>()
90}
91
92pub fn try_for_each_mut<'l, T, E>(
93    items: &'l mut [T],
94    f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
95) -> Result<(), E>
96where
97    T: Send + Sync,
98    E: Send + 'static,
99{
100    let len = items.len();
101    if len <= 1 {
102        for item in items {
103            f(item)?;
104        }
105        return Ok(());
106    }
107    let chunk_size = good_chunk_size(len);
108    let f = &f;
109    scope_and_block(len.div_ceil(chunk_size), |scope| {
110        for chunk in items.chunks_mut(chunk_size) {
111            scope.spawn(async move {
112                for item in chunk {
113                    f(item)?;
114                }
115                Ok(())
116            })
117        }
118    })
119    .collect::<Result<(), E>>()
120}
121
122pub fn try_for_each_owned<T, E>(
123    items: Vec<T>,
124    f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
125) -> Result<(), E>
126where
127    T: Send + Sync,
128    E: Send + 'static,
129{
130    let len = items.len();
131    if len <= 1 {
132        for item in items {
133            f(item)?;
134        }
135        return Ok(());
136    }
137    let chunk_size = good_chunk_size(len);
138    let f = &f;
139    scope_and_block(len.div_ceil(chunk_size), |scope| {
140        for chunk in into_chunks(items, chunk_size) {
141            scope.spawn(async move {
142                for item in chunk {
143                    f(item)?;
144                }
145                Ok(())
146            })
147        }
148    })
149    .collect::<Result<(), E>>()
150}
151
152pub fn map_collect<'l, Item, PerItemResult, Result>(
153    items: &'l [Item],
154    f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
155) -> Result
156where
157    Item: Sync,
158    PerItemResult: Send + Sync + 'l,
159    Result: FromIterator<PerItemResult>,
160{
161    let len = items.len();
162    if len == 0 {
163        return Result::from_iter(std::iter::empty()); // No items to process, return empty
164        // collection
165    }
166    let chunk_size = good_chunk_size(len);
167    let f = &f;
168    scope_and_block(len.div_ceil(chunk_size), |scope| {
169        for chunk in items.chunks(chunk_size) {
170            scope.spawn(async move { chunk.iter().map(f).collect::<Vec<_>>() })
171        }
172    })
173    .flatten()
174    .collect()
175}
176
177pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
178    items: Vec<Item>,
179    f: impl Fn(Item) -> PerItemResult + Send + Sync,
180) -> Result
181where
182    Item: Send + Sync,
183    PerItemResult: Send + Sync + 'l,
184    Result: FromIterator<PerItemResult>,
185{
186    let len = items.len();
187    if len == 0 {
188        return Result::from_iter(std::iter::empty()); // No items to process, return empty
189        // collection;
190    }
191    let chunk_size = good_chunk_size(len);
192    let f = &f;
193    scope_and_block(len.div_ceil(chunk_size), |scope| {
194        for chunk in into_chunks(items, chunk_size) {
195            scope.spawn(async move { chunk.map(f).collect::<Vec<_>>() })
196        }
197    })
198    .flatten()
199    .collect()
200}
201
202#[cfg(test)]
203mod tests {
204    use std::{
205        panic::{AssertUnwindSafe, catch_unwind},
206        sync::atomic::{AtomicI32, Ordering},
207    };
208
209    use super::*;
210
211    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
212    async fn test_parallel_for_each() {
213        let input = vec![1, 2, 3, 4, 5];
214        let sum = AtomicI32::new(0);
215        for_each(&input, |&x| {
216            sum.fetch_add(x, Ordering::SeqCst);
217        });
218        assert_eq!(sum.load(Ordering::SeqCst), 15);
219    }
220
221    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
222    async fn test_parallel_try_for_each() {
223        let input = vec![1, 2, 3, 4, 5];
224        let result = try_for_each(&input, |&x| {
225            if x % 2 == 0 {
226                Ok(())
227            } else {
228                Err(format!("Odd number {x} encountered"))
229            }
230        });
231        assert!(result.is_err());
232        assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
233    }
234
235    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
236    async fn test_parallel_try_for_each_mut() {
237        let mut input = vec![1, 2, 3, 4, 5];
238        let result = try_for_each_mut(&mut input, |x| {
239            *x += 10;
240            if *x % 2 == 0 {
241                Ok(())
242            } else {
243                Err(format!("Odd number {} encountered", *x))
244            }
245        });
246        assert!(result.is_err());
247        assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
248        assert_eq!(input, vec![11, 12, 13, 14, 15]);
249    }
250
251    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
252    async fn test_parallel_for_each_owned() {
253        let input = vec![1, 2, 3, 4, 5];
254        let sum = AtomicI32::new(0);
255        for_each_owned(input, |x| {
256            sum.fetch_add(x, Ordering::SeqCst);
257        });
258        assert_eq!(sum.load(Ordering::SeqCst), 15);
259    }
260
261    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
262    async fn test_parallel_map_collect() {
263        let input = vec![1, 2, 3, 4, 5];
264        let result: Vec<_> = map_collect(&input, |&x| x * 2);
265        assert_eq!(result, vec![2, 4, 6, 8, 10]);
266    }
267
268    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
269    async fn test_parallel_map_collect_owned() {
270        let input = vec![1, 2, 3, 4, 5];
271        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
272        assert_eq!(result, vec![2, 4, 6, 8, 10]);
273    }
274
275    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
276    async fn test_parallel_map_collect_owned_many() {
277        let input = vec![1; 1000];
278        let result: Vec<_> = map_collect_owned(input, |x| x * 2);
279        assert_eq!(result, vec![2; 1000]);
280    }
281
282    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
283    async fn test_panic_in_scope() {
284        let result = catch_unwind(AssertUnwindSafe(|| {
285            let mut input = vec![1; 1000];
286            input[744] = 2;
287            for_each(&input, |x| {
288                if *x == 2 {
289                    panic!("Intentional panic");
290                }
291            });
292            panic!("Should not get here")
293        }));
294        assert!(result.is_err());
295        assert_eq!(
296            result.unwrap_err().downcast_ref::<&str>(),
297            Some(&"Intentional panic")
298        );
299    }
300}