1use crate::{
9 scope::scope_and_block,
10 util::{Chunk, good_chunk_size, into_chunks},
11};
12
13struct Chunked {
14 chunk_size: usize,
15 chunk_count: usize,
16}
17
18fn get_chunked(len: usize) -> Option<Chunked> {
19 if len <= 1 {
20 return None;
21 }
22 let chunk_size = good_chunk_size(len);
23 let chunk_count = len.div_ceil(chunk_size);
24 if chunk_count <= 1 {
25 return None;
26 }
27 Some(Chunked {
28 chunk_size,
29 chunk_count,
30 })
31}
32
33pub fn for_each<'l, T, F>(items: &'l [T], f: F)
34where
35 T: Sync,
36 F: Fn(&'l T) + Send + Sync,
37{
38 let Some(Chunked {
39 chunk_size,
40 chunk_count,
41 }) = get_chunked(items.len())
42 else {
43 for item in items {
44 f(item);
45 }
46 return;
47 };
48 let f = &f;
49 let _results = scope_and_block(chunk_count, |scope| {
50 for chunk in items.chunks(chunk_size) {
51 scope.spawn(move || {
52 for item in chunk {
53 f(item);
54 }
55 })
56 }
57 });
58}
59
60pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
61where
62 T: Send + Sync,
63{
64 let Some(Chunked {
65 chunk_size,
66 chunk_count,
67 }) = get_chunked(items.len())
68 else {
69 for item in items {
70 f(item);
71 }
72 return;
73 };
74 let f = &f;
75 let _results = scope_and_block(chunk_count, |scope| {
76 for chunk in into_chunks(items, chunk_size) {
77 scope.spawn(move || {
78 for item in chunk {
80 f(item);
81 }
82 })
83 }
84 });
85}
86
87pub fn try_for_each<'l, T, E>(
88 items: &'l [T],
89 f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
90) -> Result<(), E>
91where
92 T: Sync,
93 E: Send + 'static,
94{
95 let Some(Chunked {
96 chunk_size,
97 chunk_count,
98 }) = get_chunked(items.len())
99 else {
100 for item in items {
101 f(item)?;
102 }
103 return Ok(());
104 };
105 let f = &f;
106 scope_and_block(chunk_count, |scope| {
107 for chunk in items.chunks(chunk_size) {
108 scope.spawn(move || {
109 for item in chunk {
110 f(item)?;
111 }
112 Ok(())
113 })
114 }
115 })
116 .collect::<Result<(), E>>()
117}
118
119pub fn try_for_each_mut<'l, T, E>(
120 items: &'l mut [T],
121 f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
122) -> Result<(), E>
123where
124 T: Send + Sync,
125 E: Send + 'static,
126{
127 let Some(Chunked {
128 chunk_size,
129 chunk_count,
130 }) = get_chunked(items.len())
131 else {
132 for item in items {
133 f(item)?;
134 }
135 return Ok(());
136 };
137 let f = &f;
138 scope_and_block(chunk_count, |scope| {
139 for chunk in items.chunks_mut(chunk_size) {
140 scope.spawn(move || {
141 for item in chunk {
142 f(item)?;
143 }
144 Ok(())
145 })
146 }
147 })
148 .collect::<Result<(), E>>()
149}
150
151pub fn try_for_each_owned<T, E>(
152 items: Vec<T>,
153 f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
154) -> Result<(), E>
155where
156 T: Send + Sync,
157 E: Send + 'static,
158{
159 let Some(Chunked {
160 chunk_size,
161 chunk_count,
162 }) = get_chunked(items.len())
163 else {
164 for item in items {
165 f(item)?;
166 }
167 return Ok(());
168 };
169 let f = &f;
170 scope_and_block(chunk_count, |scope| {
171 for chunk in into_chunks(items, chunk_size) {
172 scope.spawn(move || {
173 for item in chunk {
174 f(item)?;
175 }
176 Ok(())
177 })
178 }
179 })
180 .collect::<Result<(), E>>()
181}
182
183pub fn map_collect<'l, Item, PerItemResult, Result>(
184 items: &'l [Item],
185 f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
186) -> Result
187where
188 Item: Sync,
189 PerItemResult: Send + Sync + 'l,
190 Result: FromIterator<PerItemResult>,
191{
192 let Some(Chunked {
193 chunk_size,
194 chunk_count,
195 }) = get_chunked(items.len())
196 else {
197 return Result::from_iter(items.iter().map(f));
198 };
199 let f = &f;
200 scope_and_block(chunk_count, |scope| {
201 for chunk in items.chunks(chunk_size) {
202 scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
203 }
204 })
205 .flatten()
206 .collect()
207}
208
209pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
210 items: Vec<Item>,
211 f: impl Fn(Item) -> PerItemResult + Send + Sync,
212) -> Result
213where
214 Item: Send + Sync,
215 PerItemResult: Send + Sync + 'l,
216 Result: FromIterator<PerItemResult>,
217{
218 let Some(Chunked {
219 chunk_size,
220 chunk_count,
221 }) = get_chunked(items.len())
222 else {
223 return Result::from_iter(items.into_iter().map(f));
224 };
225 let f = &f;
226 scope_and_block(chunk_count, |scope| {
227 for chunk in into_chunks(items, chunk_size) {
228 scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
229 }
230 })
231 .flatten()
232 .collect()
233}
234
235pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
236 items: Vec<Item>,
237 f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
238) -> Result
239where
240 Item: Send + Sync,
241 PerItemResult: Send + Sync + 'l,
242 Result: FromIterator<PerItemResult>,
243{
244 let Some(Chunked {
245 chunk_size,
246 chunk_count,
247 }) = get_chunked(items.len())
248 else {
249 let len = items.len();
250 return Result::from_iter(into_chunks(items, len).map(f));
251 };
252 let f = &f;
253 scope_and_block(chunk_count, |scope| {
254 for chunk in into_chunks(items, chunk_size) {
255 scope.spawn(move || f(chunk))
256 }
257 })
258 .collect()
259}
260
261#[cfg(test)]
262mod tests {
263 use std::{
264 panic::{AssertUnwindSafe, catch_unwind},
265 sync::atomic::{AtomicI32, Ordering},
266 };
267
268 use super::*;
269
270 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
271 async fn test_parallel_for_each() {
272 let input = vec![1, 2, 3, 4, 5];
273 let sum = AtomicI32::new(0);
274 for_each(&input, |&x| {
275 sum.fetch_add(x, Ordering::SeqCst);
276 });
277 assert_eq!(sum.load(Ordering::SeqCst), 15);
278 }
279
280 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
281 async fn test_parallel_try_for_each() {
282 let input = vec![1, 2, 3, 4, 5];
283 let result = try_for_each(&input, |&x| {
284 if x % 2 == 0 {
285 Ok(())
286 } else {
287 Err(format!("Odd number {x} encountered"))
288 }
289 });
290 assert!(result.is_err());
291 assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
292 }
293
294 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
295 async fn test_parallel_try_for_each_mut() {
296 let mut input = vec![1, 2, 3, 4, 5];
297 let result = try_for_each_mut(&mut input, |x| {
298 *x += 10;
299 if *x % 2 == 0 {
300 Ok(())
301 } else {
302 Err(format!("Odd number {} encountered", *x))
303 }
304 });
305 assert!(result.is_err());
306 assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
307 assert_eq!(input, vec![11, 12, 13, 14, 15]);
308 }
309
310 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
311 async fn test_parallel_for_each_owned() {
312 let input = vec![1, 2, 3, 4, 5];
313 let sum = AtomicI32::new(0);
314 for_each_owned(input, |x| {
315 sum.fetch_add(x, Ordering::SeqCst);
316 });
317 assert_eq!(sum.load(Ordering::SeqCst), 15);
318 }
319
320 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
321 async fn test_parallel_map_collect() {
322 let input = vec![1, 2, 3, 4, 5];
323 let result: Vec<_> = map_collect(&input, |&x| x * 2);
324 assert_eq!(result, vec![2, 4, 6, 8, 10]);
325 }
326
327 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
328 async fn test_parallel_map_collect_owned() {
329 let input = vec![1, 2, 3, 4, 5];
330 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
331 assert_eq!(result, vec![2, 4, 6, 8, 10]);
332 }
333
334 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
335 async fn test_parallel_map_collect_owned_many() {
336 let input = vec![1; 1000];
337 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
338 assert_eq!(result, vec![2; 1000]);
339 }
340
341 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
342 async fn test_panic_in_scope() {
343 let result = catch_unwind(AssertUnwindSafe(|| {
344 let mut input = vec![1; 1000];
345 input[744] = 2;
346 for_each(&input, |x| {
347 if *x == 2 {
348 panic!("Intentional panic");
349 }
350 });
351 panic!("Should not get here")
352 }));
353 assert!(result.is_err());
354 assert_eq!(
355 result.unwrap_err().downcast_ref::<&str>(),
356 Some(&"Intentional panic")
357 );
358 }
359}