1use std::{
9 env, io,
10 num::NonZeroUsize,
11 sync::{Arc, OnceLock},
12 thread,
13};
14
15use crate::{
16 scope::scope_and_block,
17 util::{Chunk, good_chunk_size, into_chunks},
18};
19
20pub fn available_parallelism() -> Result<NonZeroUsize, Arc<io::Error>> {
29 static CACHED: OnceLock<Result<NonZeroUsize, Arc<io::Error>>> = OnceLock::new();
30 CACHED
31 .get_or_init(|| {
32 if let Ok(raw) = env::var("TURBO_TASKS_AVAILABLE_PARALLELISM") {
33 Ok(raw.parse::<NonZeroUsize>().unwrap_or_else(|err| {
34 panic!("Invalid TURBO_TASKS_AVAILABLE_PARALLELISM={raw:?}: {err}")
35 }))
36 } else {
37 thread::available_parallelism().map_err(Arc::new)
38 }
39 })
40 .clone()
41}
42
43struct Chunked {
44 chunk_size: usize,
45 chunk_count: usize,
46}
47
48fn get_chunked(len: usize) -> Option<Chunked> {
49 if len <= 1 {
50 return None;
51 }
52 let chunk_size = good_chunk_size(len);
53 let chunk_count = len.div_ceil(chunk_size);
54 if chunk_count <= 1 {
55 return None;
56 }
57 Some(Chunked {
58 chunk_size,
59 chunk_count,
60 })
61}
62
63pub fn for_each<'l, T, F>(items: &'l [T], f: F)
64where
65 T: Sync,
66 F: Fn(&'l T) + Send + Sync,
67{
68 let Some(Chunked {
69 chunk_size,
70 chunk_count,
71 }) = get_chunked(items.len())
72 else {
73 for item in items {
74 f(item);
75 }
76 return;
77 };
78 let f = &f;
79 let _results = scope_and_block(chunk_count, |scope| {
80 for chunk in items.chunks(chunk_size) {
81 scope.spawn(move || {
82 for item in chunk {
83 f(item);
84 }
85 })
86 }
87 });
88}
89
90pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
91where
92 T: Send + Sync,
93{
94 let Some(Chunked {
95 chunk_size,
96 chunk_count,
97 }) = get_chunked(items.len())
98 else {
99 for item in items {
100 f(item);
101 }
102 return;
103 };
104 let f = &f;
105 let _results = scope_and_block(chunk_count, |scope| {
106 for chunk in into_chunks(items, chunk_size) {
107 scope.spawn(move || {
108 for item in chunk {
110 f(item);
111 }
112 })
113 }
114 });
115}
116
117pub fn try_for_each<'l, T, E>(
118 items: &'l [T],
119 f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
120) -> Result<(), E>
121where
122 T: Sync,
123 E: Send + 'static,
124{
125 let Some(Chunked {
126 chunk_size,
127 chunk_count,
128 }) = get_chunked(items.len())
129 else {
130 for item in items {
131 f(item)?;
132 }
133 return Ok(());
134 };
135 let f = &f;
136 scope_and_block(chunk_count, |scope| {
137 for chunk in items.chunks(chunk_size) {
138 scope.spawn(move || {
139 for item in chunk {
140 f(item)?;
141 }
142 Ok(())
143 })
144 }
145 })
146 .collect::<Result<(), E>>()
147}
148
149pub fn try_for_each_mut<'l, T, E>(
150 items: &'l mut [T],
151 f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
152) -> Result<(), E>
153where
154 T: Send + Sync,
155 E: Send + 'static,
156{
157 let Some(Chunked {
158 chunk_size,
159 chunk_count,
160 }) = get_chunked(items.len())
161 else {
162 for item in items {
163 f(item)?;
164 }
165 return Ok(());
166 };
167 let f = &f;
168 scope_and_block(chunk_count, |scope| {
169 for chunk in items.chunks_mut(chunk_size) {
170 scope.spawn(move || {
171 for item in chunk {
172 f(item)?;
173 }
174 Ok(())
175 })
176 }
177 })
178 .collect::<Result<(), E>>()
179}
180
181pub fn try_for_each_owned<T, E>(
182 items: Vec<T>,
183 f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
184) -> Result<(), E>
185where
186 T: Send + Sync,
187 E: Send + 'static,
188{
189 let Some(Chunked {
190 chunk_size,
191 chunk_count,
192 }) = get_chunked(items.len())
193 else {
194 for item in items {
195 f(item)?;
196 }
197 return Ok(());
198 };
199 let f = &f;
200 scope_and_block(chunk_count, |scope| {
201 for chunk in into_chunks(items, chunk_size) {
202 scope.spawn(move || {
203 for item in chunk {
204 f(item)?;
205 }
206 Ok(())
207 })
208 }
209 })
210 .collect::<Result<(), E>>()
211}
212
213pub fn map_collect<'l, Item, PerItemResult, Result>(
214 items: &'l [Item],
215 f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
216) -> Result
217where
218 Item: Sync,
219 PerItemResult: Send + Sync + 'l,
220 Result: FromIterator<PerItemResult>,
221{
222 let Some(Chunked {
223 chunk_size,
224 chunk_count,
225 }) = get_chunked(items.len())
226 else {
227 return Result::from_iter(items.iter().map(f));
228 };
229 let f = &f;
230 scope_and_block(chunk_count, |scope| {
231 for chunk in items.chunks(chunk_size) {
232 scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
233 }
234 })
235 .flatten()
236 .collect()
237}
238
239pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
240 items: Vec<Item>,
241 f: impl Fn(Item) -> PerItemResult + Send + Sync,
242) -> Result
243where
244 Item: Send + Sync,
245 PerItemResult: Send + Sync + 'l,
246 Result: FromIterator<PerItemResult>,
247{
248 let Some(Chunked {
249 chunk_size,
250 chunk_count,
251 }) = get_chunked(items.len())
252 else {
253 return Result::from_iter(items.into_iter().map(f));
254 };
255 let f = &f;
256 scope_and_block(chunk_count, |scope| {
257 for chunk in into_chunks(items, chunk_size) {
258 scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
259 }
260 })
261 .flatten()
262 .collect()
263}
264
265pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
266 items: Vec<Item>,
267 f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
268) -> Result
269where
270 Item: Send + Sync,
271 PerItemResult: Send + Sync + 'l,
272 Result: FromIterator<PerItemResult>,
273{
274 let Some(Chunked {
275 chunk_size,
276 chunk_count,
277 }) = get_chunked(items.len())
278 else {
279 let len = items.len();
280 return Result::from_iter(into_chunks(items, len).map(f));
281 };
282 let f = &f;
283 scope_and_block(chunk_count, |scope| {
284 for chunk in into_chunks(items, chunk_size) {
285 scope.spawn(move || f(chunk))
286 }
287 })
288 .collect()
289}
290
291#[cfg(test)]
292mod tests {
293 use std::{
294 panic::{AssertUnwindSafe, catch_unwind},
295 sync::atomic::{AtomicI32, Ordering},
296 };
297
298 use super::*;
299
300 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
301 async fn test_parallel_for_each() {
302 let input = vec![1, 2, 3, 4, 5];
303 let sum = AtomicI32::new(0);
304 for_each(&input, |&x| {
305 sum.fetch_add(x, Ordering::SeqCst);
306 });
307 assert_eq!(sum.load(Ordering::SeqCst), 15);
308 }
309
310 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
311 async fn test_parallel_try_for_each() {
312 let input = vec![1, 2, 3, 4, 5];
313 let result = try_for_each(&input, |&x| {
314 if x % 2 == 0 {
315 Ok(())
316 } else {
317 Err(format!("Odd number {x} encountered"))
318 }
319 });
320 assert!(result.is_err());
321 assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
322 }
323
324 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
325 async fn test_parallel_try_for_each_mut() {
326 let mut input = vec![1, 2, 3, 4, 5];
327 let result = try_for_each_mut(&mut input, |x| {
328 *x += 10;
329 if *x % 2 == 0 {
330 Ok(())
331 } else {
332 Err(format!("Odd number {} encountered", *x))
333 }
334 });
335 assert!(result.is_err());
336 assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
337 assert_eq!(input, vec![11, 12, 13, 14, 15]);
338 }
339
340 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341 async fn test_parallel_for_each_owned() {
342 let input = vec![1, 2, 3, 4, 5];
343 let sum = AtomicI32::new(0);
344 for_each_owned(input, |x| {
345 sum.fetch_add(x, Ordering::SeqCst);
346 });
347 assert_eq!(sum.load(Ordering::SeqCst), 15);
348 }
349
350 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
351 async fn test_parallel_map_collect() {
352 let input = vec![1, 2, 3, 4, 5];
353 let result: Vec<_> = map_collect(&input, |&x| x * 2);
354 assert_eq!(result, vec![2, 4, 6, 8, 10]);
355 }
356
357 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
358 async fn test_parallel_map_collect_owned() {
359 let input = vec![1, 2, 3, 4, 5];
360 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
361 assert_eq!(result, vec![2, 4, 6, 8, 10]);
362 }
363
364 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
365 async fn test_parallel_map_collect_owned_many() {
366 let input = vec![1; 1000];
367 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
368 assert_eq!(result, vec![2; 1000]);
369 }
370
371 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
372 async fn test_panic_in_scope() {
373 let result = catch_unwind(AssertUnwindSafe(|| {
374 let mut input = vec![1; 1000];
375 input[744] = 2;
376 for_each(&input, |x| {
377 if *x == 2 {
378 panic!("Intentional panic");
379 }
380 });
381 panic!("Should not get here")
382 }));
383 assert!(result.is_err());
384 assert_eq!(
385 result.unwrap_err().downcast_ref::<&str>(),
386 Some(&"Intentional panic")
387 );
388 }
389}