1use crate::{
8 scope::scope_and_block,
9 util::{Chunk, good_chunk_size, into_chunks},
10};
11
12struct Chunked {
13 chunk_size: usize,
14 chunk_count: usize,
15}
16
17fn get_chunked(len: usize) -> Option<Chunked> {
18 if len <= 1 {
19 return None;
20 }
21 let chunk_size = good_chunk_size(len);
22 let chunk_count = len.div_ceil(chunk_size);
23 if chunk_count <= 1 {
24 return None;
25 }
26 Some(Chunked {
27 chunk_size,
28 chunk_count,
29 })
30}
31
32pub fn for_each<'l, T, F>(items: &'l [T], f: F)
33where
34 T: Sync,
35 F: Fn(&'l T) + Send + Sync,
36{
37 let Some(Chunked {
38 chunk_size,
39 chunk_count,
40 }) = get_chunked(items.len())
41 else {
42 for item in items {
43 f(item);
44 }
45 return;
46 };
47 let f = &f;
48 let _results = scope_and_block(chunk_count, |scope| {
49 for chunk in items.chunks(chunk_size) {
50 scope.spawn(move || {
51 for item in chunk {
52 f(item);
53 }
54 })
55 }
56 });
57}
58
59pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
60where
61 T: Send + Sync,
62{
63 let Some(Chunked {
64 chunk_size,
65 chunk_count,
66 }) = get_chunked(items.len())
67 else {
68 for item in items {
69 f(item);
70 }
71 return;
72 };
73 let f = &f;
74 let _results = scope_and_block(chunk_count, |scope| {
75 for chunk in into_chunks(items, chunk_size) {
76 scope.spawn(move || {
77 for item in chunk {
79 f(item);
80 }
81 })
82 }
83 });
84}
85
86pub fn try_for_each<'l, T, E>(
87 items: &'l [T],
88 f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
89) -> Result<(), E>
90where
91 T: Sync,
92 E: Send + 'static,
93{
94 let Some(Chunked {
95 chunk_size,
96 chunk_count,
97 }) = get_chunked(items.len())
98 else {
99 for item in items {
100 f(item)?;
101 }
102 return Ok(());
103 };
104 let f = &f;
105 scope_and_block(chunk_count, |scope| {
106 for chunk in items.chunks(chunk_size) {
107 scope.spawn(move || {
108 for item in chunk {
109 f(item)?;
110 }
111 Ok(())
112 })
113 }
114 })
115 .collect::<Result<(), E>>()
116}
117
118pub fn try_for_each_mut<'l, T, E>(
119 items: &'l mut [T],
120 f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
121) -> Result<(), E>
122where
123 T: Send + Sync,
124 E: Send + 'static,
125{
126 let Some(Chunked {
127 chunk_size,
128 chunk_count,
129 }) = get_chunked(items.len())
130 else {
131 for item in items {
132 f(item)?;
133 }
134 return Ok(());
135 };
136 let f = &f;
137 scope_and_block(chunk_count, |scope| {
138 for chunk in items.chunks_mut(chunk_size) {
139 scope.spawn(move || {
140 for item in chunk {
141 f(item)?;
142 }
143 Ok(())
144 })
145 }
146 })
147 .collect::<Result<(), E>>()
148}
149
150pub fn try_for_each_owned<T, E>(
151 items: Vec<T>,
152 f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
153) -> Result<(), E>
154where
155 T: Send + Sync,
156 E: Send + 'static,
157{
158 let Some(Chunked {
159 chunk_size,
160 chunk_count,
161 }) = get_chunked(items.len())
162 else {
163 for item in items {
164 f(item)?;
165 }
166 return Ok(());
167 };
168 let f = &f;
169 scope_and_block(chunk_count, |scope| {
170 for chunk in into_chunks(items, chunk_size) {
171 scope.spawn(move || {
172 for item in chunk {
173 f(item)?;
174 }
175 Ok(())
176 })
177 }
178 })
179 .collect::<Result<(), E>>()
180}
181
182pub fn map_collect<'l, Item, PerItemResult, Result>(
183 items: &'l [Item],
184 f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
185) -> Result
186where
187 Item: Sync,
188 PerItemResult: Send + Sync + 'l,
189 Result: FromIterator<PerItemResult>,
190{
191 let Some(Chunked {
192 chunk_size,
193 chunk_count,
194 }) = get_chunked(items.len())
195 else {
196 return Result::from_iter(items.iter().map(f));
197 };
198 let f = &f;
199 scope_and_block(chunk_count, |scope| {
200 for chunk in items.chunks(chunk_size) {
201 scope.spawn(move || chunk.iter().map(f).collect::<Vec<_>>())
202 }
203 })
204 .flatten()
205 .collect()
206}
207
208pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
209 items: Vec<Item>,
210 f: impl Fn(Item) -> PerItemResult + Send + Sync,
211) -> Result
212where
213 Item: Send + Sync,
214 PerItemResult: Send + Sync + 'l,
215 Result: FromIterator<PerItemResult>,
216{
217 let Some(Chunked {
218 chunk_size,
219 chunk_count,
220 }) = get_chunked(items.len())
221 else {
222 return Result::from_iter(items.into_iter().map(f));
223 };
224 let f = &f;
225 scope_and_block(chunk_count, |scope| {
226 for chunk in into_chunks(items, chunk_size) {
227 scope.spawn(move || chunk.map(f).collect::<Vec<_>>())
228 }
229 })
230 .flatten()
231 .collect()
232}
233
234pub fn map_collect_chunked_owned<'l, Item, PerItemResult, Result>(
235 items: Vec<Item>,
236 f: impl Fn(Chunk<Item>) -> PerItemResult + Send + Sync,
237) -> Result
238where
239 Item: Send + Sync,
240 PerItemResult: Send + Sync + 'l,
241 Result: FromIterator<PerItemResult>,
242{
243 let Some(Chunked {
244 chunk_size,
245 chunk_count,
246 }) = get_chunked(items.len())
247 else {
248 let len = items.len();
249 return Result::from_iter(into_chunks(items, len).map(f));
250 };
251 let f = &f;
252 scope_and_block(chunk_count, |scope| {
253 for chunk in into_chunks(items, chunk_size) {
254 scope.spawn(move || f(chunk))
255 }
256 })
257 .collect()
258}
259
260#[cfg(test)]
261mod tests {
262 use std::{
263 panic::{AssertUnwindSafe, catch_unwind},
264 sync::atomic::{AtomicI32, Ordering},
265 };
266
267 use super::*;
268
269 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
270 async fn test_parallel_for_each() {
271 let input = vec![1, 2, 3, 4, 5];
272 let sum = AtomicI32::new(0);
273 for_each(&input, |&x| {
274 sum.fetch_add(x, Ordering::SeqCst);
275 });
276 assert_eq!(sum.load(Ordering::SeqCst), 15);
277 }
278
279 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
280 async fn test_parallel_try_for_each() {
281 let input = vec![1, 2, 3, 4, 5];
282 let result = try_for_each(&input, |&x| {
283 if x % 2 == 0 {
284 Ok(())
285 } else {
286 Err(format!("Odd number {x} encountered"))
287 }
288 });
289 assert!(result.is_err());
290 assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
291 }
292
293 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
294 async fn test_parallel_try_for_each_mut() {
295 let mut input = vec![1, 2, 3, 4, 5];
296 let result = try_for_each_mut(&mut input, |x| {
297 *x += 10;
298 if *x % 2 == 0 {
299 Ok(())
300 } else {
301 Err(format!("Odd number {} encountered", *x))
302 }
303 });
304 assert!(result.is_err());
305 assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
306 assert_eq!(input, vec![11, 12, 13, 14, 15]);
307 }
308
309 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
310 async fn test_parallel_for_each_owned() {
311 let input = vec![1, 2, 3, 4, 5];
312 let sum = AtomicI32::new(0);
313 for_each_owned(input, |x| {
314 sum.fetch_add(x, Ordering::SeqCst);
315 });
316 assert_eq!(sum.load(Ordering::SeqCst), 15);
317 }
318
319 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
320 async fn test_parallel_map_collect() {
321 let input = vec![1, 2, 3, 4, 5];
322 let result: Vec<_> = map_collect(&input, |&x| x * 2);
323 assert_eq!(result, vec![2, 4, 6, 8, 10]);
324 }
325
326 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
327 async fn test_parallel_map_collect_owned() {
328 let input = vec![1, 2, 3, 4, 5];
329 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
330 assert_eq!(result, vec![2, 4, 6, 8, 10]);
331 }
332
333 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
334 async fn test_parallel_map_collect_owned_many() {
335 let input = vec![1; 1000];
336 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
337 assert_eq!(result, vec![2; 1000]);
338 }
339
340 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341 async fn test_panic_in_scope() {
342 let result = catch_unwind(AssertUnwindSafe(|| {
343 let mut input = vec![1; 1000];
344 input[744] = 2;
345 for_each(&input, |x| {
346 if *x == 2 {
347 panic!("Intentional panic");
348 }
349 });
350 panic!("Should not get here")
351 }));
352 assert!(result.is_err());
353 assert_eq!(
354 result.unwrap_err().downcast_ref::<&str>(),
355 Some(&"Intentional panic")
356 );
357 }
358}