1use crate::{
8 scope::scope_and_block,
9 util::{good_chunk_size, into_chunks},
10};
11
12pub fn for_each<'l, T, F>(items: &'l [T], f: F)
13where
14 T: Sync,
15 F: Fn(&'l T) + Send + Sync,
16{
17 let len = items.len();
18 if len <= 1 {
19 for item in items {
20 f(item);
21 }
22 return;
23 }
24 let chunk_size = good_chunk_size(len);
25 let f = &f;
26 let _results = scope_and_block(len.div_ceil(chunk_size), |scope| {
27 for chunk in items.chunks(chunk_size) {
28 scope.spawn(async move {
29 for item in chunk {
30 f(item);
31 }
32 })
33 }
34 });
35}
36
37pub fn for_each_owned<T>(items: Vec<T>, f: impl Fn(T) + Send + Sync)
38where
39 T: Send + Sync,
40{
41 let len = items.len();
42 if len <= 1 {
43 for item in items {
44 f(item);
45 }
46 return;
47 }
48 let chunk_size = good_chunk_size(len);
49 let f = &f;
50 let _results = scope_and_block(len.div_ceil(chunk_size), |scope| {
51 for chunk in into_chunks(items, chunk_size) {
52 scope.spawn(async move {
53 for item in chunk {
55 f(item);
56 }
57 })
58 }
59 });
60}
61
62pub fn try_for_each<'l, T, E>(
63 items: &'l [T],
64 f: impl (Fn(&'l T) -> Result<(), E>) + Send + Sync,
65) -> Result<(), E>
66where
67 T: Sync,
68 E: Send + 'static,
69{
70 let len = items.len();
71 if len <= 1 {
72 for item in items {
73 f(item)?;
74 }
75 return Ok(());
76 }
77 let chunk_size = good_chunk_size(len);
78 let f = &f;
79 scope_and_block(len.div_ceil(chunk_size), |scope| {
80 for chunk in items.chunks(chunk_size) {
81 scope.spawn(async move {
82 for item in chunk {
83 f(item)?;
84 }
85 Ok(())
86 })
87 }
88 })
89 .collect::<Result<(), E>>()
90}
91
92pub fn try_for_each_mut<'l, T, E>(
93 items: &'l mut [T],
94 f: impl (Fn(&'l mut T) -> Result<(), E>) + Send + Sync,
95) -> Result<(), E>
96where
97 T: Send + Sync,
98 E: Send + 'static,
99{
100 let len = items.len();
101 if len <= 1 {
102 for item in items {
103 f(item)?;
104 }
105 return Ok(());
106 }
107 let chunk_size = good_chunk_size(len);
108 let f = &f;
109 scope_and_block(len.div_ceil(chunk_size), |scope| {
110 for chunk in items.chunks_mut(chunk_size) {
111 scope.spawn(async move {
112 for item in chunk {
113 f(item)?;
114 }
115 Ok(())
116 })
117 }
118 })
119 .collect::<Result<(), E>>()
120}
121
122pub fn try_for_each_owned<T, E>(
123 items: Vec<T>,
124 f: impl (Fn(T) -> Result<(), E>) + Send + Sync,
125) -> Result<(), E>
126where
127 T: Send + Sync,
128 E: Send + 'static,
129{
130 let len = items.len();
131 if len <= 1 {
132 for item in items {
133 f(item)?;
134 }
135 return Ok(());
136 }
137 let chunk_size = good_chunk_size(len);
138 let f = &f;
139 scope_and_block(len.div_ceil(chunk_size), |scope| {
140 for chunk in into_chunks(items, chunk_size) {
141 scope.spawn(async move {
142 for item in chunk {
143 f(item)?;
144 }
145 Ok(())
146 })
147 }
148 })
149 .collect::<Result<(), E>>()
150}
151
152pub fn map_collect<'l, Item, PerItemResult, Result>(
153 items: &'l [Item],
154 f: impl Fn(&'l Item) -> PerItemResult + Send + Sync,
155) -> Result
156where
157 Item: Sync,
158 PerItemResult: Send + Sync + 'l,
159 Result: FromIterator<PerItemResult>,
160{
161 let len = items.len();
162 if len == 0 {
163 return Result::from_iter(std::iter::empty()); }
166 let chunk_size = good_chunk_size(len);
167 let f = &f;
168 scope_and_block(len.div_ceil(chunk_size), |scope| {
169 for chunk in items.chunks(chunk_size) {
170 scope.spawn(async move { chunk.iter().map(f).collect::<Vec<_>>() })
171 }
172 })
173 .flatten()
174 .collect()
175}
176
177pub fn map_collect_owned<'l, Item, PerItemResult, Result>(
178 items: Vec<Item>,
179 f: impl Fn(Item) -> PerItemResult + Send + Sync,
180) -> Result
181where
182 Item: Send + Sync,
183 PerItemResult: Send + Sync + 'l,
184 Result: FromIterator<PerItemResult>,
185{
186 let len = items.len();
187 if len == 0 {
188 return Result::from_iter(std::iter::empty()); }
191 let chunk_size = good_chunk_size(len);
192 let f = &f;
193 scope_and_block(len.div_ceil(chunk_size), |scope| {
194 for chunk in into_chunks(items, chunk_size) {
195 scope.spawn(async move { chunk.map(f).collect::<Vec<_>>() })
196 }
197 })
198 .flatten()
199 .collect()
200}
201
202#[cfg(test)]
203mod tests {
204 use std::{
205 panic::{AssertUnwindSafe, catch_unwind},
206 sync::atomic::{AtomicI32, Ordering},
207 };
208
209 use super::*;
210
211 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
212 async fn test_parallel_for_each() {
213 let input = vec![1, 2, 3, 4, 5];
214 let sum = AtomicI32::new(0);
215 for_each(&input, |&x| {
216 sum.fetch_add(x, Ordering::SeqCst);
217 });
218 assert_eq!(sum.load(Ordering::SeqCst), 15);
219 }
220
221 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
222 async fn test_parallel_try_for_each() {
223 let input = vec![1, 2, 3, 4, 5];
224 let result = try_for_each(&input, |&x| {
225 if x % 2 == 0 {
226 Ok(())
227 } else {
228 Err(format!("Odd number {x} encountered"))
229 }
230 });
231 assert!(result.is_err());
232 assert_eq!(result.unwrap_err(), "Odd number 1 encountered");
233 }
234
235 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
236 async fn test_parallel_try_for_each_mut() {
237 let mut input = vec![1, 2, 3, 4, 5];
238 let result = try_for_each_mut(&mut input, |x| {
239 *x += 10;
240 if *x % 2 == 0 {
241 Ok(())
242 } else {
243 Err(format!("Odd number {} encountered", *x))
244 }
245 });
246 assert!(result.is_err());
247 assert_eq!(result.unwrap_err(), "Odd number 11 encountered");
248 assert_eq!(input, vec![11, 12, 13, 14, 15]);
249 }
250
251 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
252 async fn test_parallel_for_each_owned() {
253 let input = vec![1, 2, 3, 4, 5];
254 let sum = AtomicI32::new(0);
255 for_each_owned(input, |x| {
256 sum.fetch_add(x, Ordering::SeqCst);
257 });
258 assert_eq!(sum.load(Ordering::SeqCst), 15);
259 }
260
261 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
262 async fn test_parallel_map_collect() {
263 let input = vec![1, 2, 3, 4, 5];
264 let result: Vec<_> = map_collect(&input, |&x| x * 2);
265 assert_eq!(result, vec![2, 4, 6, 8, 10]);
266 }
267
268 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
269 async fn test_parallel_map_collect_owned() {
270 let input = vec![1, 2, 3, 4, 5];
271 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
272 assert_eq!(result, vec![2, 4, 6, 8, 10]);
273 }
274
275 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
276 async fn test_parallel_map_collect_owned_many() {
277 let input = vec![1; 1000];
278 let result: Vec<_> = map_collect_owned(input, |x| x * 2);
279 assert_eq!(result, vec![2; 1000]);
280 }
281
282 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
283 async fn test_panic_in_scope() {
284 let result = catch_unwind(AssertUnwindSafe(|| {
285 let mut input = vec![1; 1000];
286 input[744] = 2;
287 for_each(&input, |x| {
288 if *x == 2 {
289 panic!("Intentional panic");
290 }
291 });
292 panic!("Should not get here")
293 }));
294 assert!(result.is_err());
295 assert_eq!(
296 result.unwrap_err().downcast_ref::<&str>(),
297 Some(&"Intentional panic")
298 );
299 }
300}