1use std::{
4 any::Any,
5 collections::VecDeque,
6 marker::PhantomData,
7 panic::{self, AssertUnwindSafe, catch_unwind},
8 sync::{
9 Arc,
10 atomic::{AtomicUsize, Ordering},
11 },
12 thread::{self, Thread, available_parallelism},
13 time::{Duration, Instant},
14};
15
16use once_cell::sync::Lazy;
17use parking_lot::{Condvar, Mutex};
18use tokio::{runtime::Handle, task::block_in_place};
19use tracing::{Span, info_span};
20
21use crate::{TurboTasksApi, manager::try_turbo_tasks, turbo_tasks_scope};
22
23static WORKER_TASKS: Lazy<usize> = Lazy::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
26
27enum WorkQueueJob {
28 Job(usize, Box<dyn FnOnce() + Send + 'static>),
29 End,
30}
31
32struct ScopeInner {
33 main_thread: Thread,
34 remaining_tasks: AtomicUsize,
35 panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
38 work_queue: Mutex<VecDeque<WorkQueueJob>>,
40 work_queue_condition_var: Condvar,
42}
43
44impl ScopeInner {
45 fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
46 if let Some((err, index)) = panic {
47 let mut old_panic = self.panic.lock();
48 if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
49 *old_panic = Some((err, index));
50 }
51 }
52 if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
53 self.main_thread.unpark();
54 }
55 }
56
57 fn wait(&self) {
58 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
59 return;
60 }
61
62 let _span = info_span!("blocking").entered();
63
64 const TIMEOUT: Duration = Duration::from_millis(1);
66 let beginning_park = Instant::now();
67
68 let mut timeout_remaining = TIMEOUT;
69 loop {
70 thread::park_timeout(timeout_remaining);
71 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
72 return;
73 }
74 let elapsed = beginning_park.elapsed();
75 if elapsed >= TIMEOUT {
76 break;
77 }
78 timeout_remaining = TIMEOUT - elapsed;
79 }
80
81 block_in_place(|| {
83 while self.remaining_tasks.load(Ordering::Acquire) != 0 {
84 thread::park();
85 }
86 });
87 }
88
89 fn wait_and_rethrow_panic(&self) {
90 self.wait();
91 if let Some((err, _)) = self.panic.lock().take() {
92 panic::resume_unwind(err);
93 }
94 }
95
96 fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
97 let mut current_job_index = first_job_index;
98 let mut current_job = first_job;
99 loop {
100 let result = catch_unwind(AssertUnwindSafe(current_job));
101 let panic = result.err().map(|e| (e, current_job_index));
102 self.on_task_finished(panic);
103 let Some((index, job)) = self.pick_job_from_work_queue() else {
104 return;
105 };
106 current_job_index = index;
107 current_job = job;
108 }
109 }
110
111 fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
112 let mut work_queue = self.work_queue.lock();
113 let job = loop {
114 if let Some(job) = work_queue.pop_front() {
115 break job;
116 } else {
117 self.work_queue_condition_var.wait(&mut work_queue);
118 };
119 };
120 match job {
121 WorkQueueJob::Job(index, job) => {
122 drop(work_queue);
123 Some((index, job))
124 }
125 WorkQueueJob::End => {
126 work_queue.push_front(WorkQueueJob::End);
127 drop(work_queue);
128 self.work_queue_condition_var.notify_all();
129 None
130 }
131 }
132 }
133
134 fn end_and_help_complete(&self) {
135 let job;
136 {
137 let mut work_queue = self.work_queue.lock();
138 job = work_queue.pop_front();
139 work_queue.push_back(WorkQueueJob::End);
140 }
141 self.work_queue_condition_var.notify_all();
142 if let Some(WorkQueueJob::Job(index, job)) = job {
143 self.worker(index, job);
144 }
145 }
146}
147
148pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
152 results: &'scope [Mutex<Option<R>>],
153 index: AtomicUsize,
154 inner: Arc<ScopeInner>,
155 handle: Handle,
156 turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
157 span: Span,
158 env: PhantomData<&'env mut &'env ()>,
163}
164
165impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
166 unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
172 Self {
173 results,
174 index: AtomicUsize::new(0),
175 inner: Arc::new(ScopeInner {
176 main_thread: thread::current(),
177 remaining_tasks: AtomicUsize::new(0),
178 panic: Mutex::new(None),
179 work_queue: Mutex::new(VecDeque::new()),
180 work_queue_condition_var: Condvar::new(),
181 }),
182 handle: Handle::current(),
183 turbo_tasks: try_turbo_tasks(),
184 span: Span::current(),
185 env: PhantomData,
186 }
187 }
188
189 pub fn spawn<F>(&self, f: F)
191 where
192 F: FnOnce() -> R + Send + 'env,
193 {
194 let index = self.index.fetch_add(1, Ordering::Relaxed);
195 assert!(index < self.results.len(), "Too many tasks spawned");
196 let result_cell: &Mutex<Option<R>> = &self.results[index];
197
198 let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
199 let result = f();
200 *result_cell.lock() = Some(result);
201 });
202 let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
203 #[allow(
206 clippy::unnecessary_cast,
207 reason = "Clippy thinks this is unnecessary, but it actually changes the lifetime"
208 )]
209 let f = f as *mut (dyn FnOnce() + Send + 'static);
210 let f = unsafe { Box::from_raw(f) };
212
213 let turbo_tasks = self.turbo_tasks.clone();
214 let span = self.span.clone();
215
216 self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
217
218 if (1..=*WORKER_TASKS).contains(&index) {
222 let inner = self.inner.clone();
223 self.handle.spawn(async move {
225 let _span = span.entered();
226 if let Some(turbo_tasks) = turbo_tasks {
227 turbo_tasks_scope(turbo_tasks, || {
229 inner.worker(index, f);
230 });
231 } else {
232 inner.worker(index, f);
234 }
235 });
236 } else {
237 self.inner
239 .work_queue
240 .lock()
241 .push_back(WorkQueueJob::Job(index, f));
242 self.inner.work_queue_condition_var.notify_one();
243 }
244 }
245}
246
247impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
248 fn drop(&mut self) {
249 self.inner.end_and_help_complete();
250 self.inner.wait_and_rethrow_panic();
251 }
252}
253
254pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
262where
263 R: Send + 'env,
264 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
265{
266 let mut results = Vec::with_capacity(number_of_tasks);
267 for _ in 0..number_of_tasks {
268 results.push(Mutex::new(None));
269 }
270 let results = results.into_boxed_slice();
271 let result = {
272 let scope = unsafe { Scope::new(&results) };
274 catch_unwind(AssertUnwindSafe(|| f(&scope)))
275 };
276 if let Err(panic) = result {
277 panic::resume_unwind(panic);
278 }
279 results.into_iter().map(|mutex| {
280 mutex
281 .into_inner()
282 .expect("All values are set when the scope returns without panic")
283 })
284}
285
286#[cfg(test)]
287mod tests {
288 use std::panic::{AssertUnwindSafe, catch_unwind};
289
290 use super::*;
291
292 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
293 async fn test_scope() {
294 let results = scope_and_block(1000, |scope| {
295 for i in 0..1000 {
296 scope.spawn(move || i);
297 }
298 });
299 let results = results.collect::<Vec<_>>();
300 results.iter().enumerate().for_each(|(i, &result)| {
301 assert_eq!(result, i);
302 });
303 assert_eq!(results.len(), 1000);
304 }
305
306 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
307 async fn test_empty_scope() {
308 let results = scope_and_block(0, |scope| {
309 if false {
310 scope.spawn(|| 42);
311 }
312 });
313 assert_eq!(results.count(), 0);
314 }
315
316 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
317 async fn test_single_task() {
318 let results = scope_and_block(1, |scope| {
319 scope.spawn(|| 42);
320 })
321 .collect::<Vec<_>>();
322 assert_eq!(results, vec![42]);
323 }
324
325 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
326 async fn test_task_finish_before_scope() {
327 let results = scope_and_block(1, |scope| {
328 scope.spawn(|| 42);
329 thread::sleep(std::time::Duration::from_millis(100));
330 })
331 .collect::<Vec<_>>();
332 assert_eq!(results, vec![42]);
333 }
334
335 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336 async fn test_task_finish_after_scope() {
337 let results = scope_and_block(1, |scope| {
338 scope.spawn(|| {
339 thread::sleep(std::time::Duration::from_millis(100));
340 42
341 });
342 })
343 .collect::<Vec<_>>();
344 assert_eq!(results, vec![42]);
345 }
346
347 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
348 async fn test_panic_in_scope_factory() {
349 let result = catch_unwind(AssertUnwindSafe(|| {
350 let _results = scope_and_block(1000, |scope| {
351 for i in 0..500 {
352 scope.spawn(move || i);
353 }
354 panic!("Intentional panic");
355 });
356 unreachable!();
357 }));
358 assert!(result.is_err());
359 assert_eq!(
360 result.unwrap_err().downcast_ref::<&str>(),
361 Some(&"Intentional panic")
362 );
363 }
364
365 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
366 async fn test_panic_in_scope_task() {
367 let result = catch_unwind(AssertUnwindSafe(|| {
368 let _results = scope_and_block(1000, |scope| {
369 for i in 0..1000 {
370 scope.spawn(move || {
371 if i == 500 {
372 panic!("Intentional panic");
373 } else if i == 501 {
374 panic!("Wrong intentional panic");
375 } else {
376 i
377 }
378 });
379 }
380 });
381 unreachable!();
382 }));
383 assert!(result.is_err());
384 assert_eq!(
385 result.unwrap_err().downcast_ref::<&str>(),
386 Some(&"Intentional panic")
387 );
388 }
389}