1use std::{
4 any::Any,
5 collections::VecDeque,
6 marker::PhantomData,
7 panic::{self, AssertUnwindSafe, catch_unwind},
8 sync::{
9 Arc, LazyLock,
10 atomic::{AtomicUsize, Ordering},
11 },
12 thread::{self, Thread},
13 time::{Duration, Instant},
14};
15
16use parking_lot::{Condvar, Mutex};
17use tokio::{runtime::Handle, task::block_in_place};
18use tracing::{Span, info_span};
19
20use crate::{
21 TurboTasksApi, manager::try_turbo_tasks, parallel::available_parallelism, turbo_tasks_scope,
22};
23
24static WORKER_TASKS: LazyLock<usize> =
27 LazyLock::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
28
29enum WorkQueueJob {
30 Job(usize, Box<dyn FnOnce() + Send + 'static>),
31 End,
32}
33
34struct ScopeInner {
35 main_thread: Thread,
36 remaining_tasks: AtomicUsize,
37 panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
40 work_queue: Mutex<VecDeque<WorkQueueJob>>,
42 work_queue_condition_var: Condvar,
44}
45
46impl ScopeInner {
47 fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
48 if let Some((err, index)) = panic {
49 let mut old_panic = self.panic.lock();
50 if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
51 *old_panic = Some((err, index));
52 }
53 }
54 if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
55 self.main_thread.unpark();
56 }
57 }
58
59 fn wait(&self) {
60 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
61 return;
62 }
63
64 let _span = info_span!("blocking").entered();
65
66 const TIMEOUT: Duration = Duration::from_millis(1);
68 let beginning_park = Instant::now();
69
70 let mut timeout_remaining = TIMEOUT;
71 loop {
72 thread::park_timeout(timeout_remaining);
73 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
74 return;
75 }
76 let elapsed = beginning_park.elapsed();
77 if elapsed >= TIMEOUT {
78 break;
79 }
80 timeout_remaining = TIMEOUT - elapsed;
81 }
82
83 block_in_place(|| {
85 while self.remaining_tasks.load(Ordering::Acquire) != 0 {
86 thread::park();
87 }
88 });
89 }
90
91 fn wait_and_rethrow_panic(&self) {
92 self.wait();
93 if let Some((err, _)) = self.panic.lock().take() {
94 panic::resume_unwind(err);
95 }
96 }
97
98 fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
99 let mut current_job_index = first_job_index;
100 let mut current_job = first_job;
101 loop {
102 let result = catch_unwind(AssertUnwindSafe(current_job));
103 let panic = result.err().map(|e| (e, current_job_index));
104 self.on_task_finished(panic);
105 let Some((index, job)) = self.pick_job_from_work_queue() else {
106 return;
107 };
108 current_job_index = index;
109 current_job = job;
110 }
111 }
112
113 fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
114 let mut work_queue = self.work_queue.lock();
115 let job = loop {
116 if let Some(job) = work_queue.pop_front() {
117 break job;
118 } else {
119 self.work_queue_condition_var.wait(&mut work_queue);
120 };
121 };
122 match job {
123 WorkQueueJob::Job(index, job) => {
124 drop(work_queue);
125 Some((index, job))
126 }
127 WorkQueueJob::End => {
128 work_queue.push_front(WorkQueueJob::End);
129 drop(work_queue);
130 self.work_queue_condition_var.notify_all();
131 None
132 }
133 }
134 }
135
136 fn end_and_help_complete(&self) {
137 let job;
138 {
139 let mut work_queue = self.work_queue.lock();
140 job = work_queue.pop_front();
141 work_queue.push_back(WorkQueueJob::End);
142 }
143 self.work_queue_condition_var.notify_all();
144 if let Some(WorkQueueJob::Job(index, job)) = job {
145 self.worker(index, job);
146 }
147 }
148}
149
150pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
154 results: &'scope [Mutex<Option<R>>],
155 index: AtomicUsize,
156 inner: Arc<ScopeInner>,
157 handle: Handle,
158 turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
159 span: Span,
160 env: PhantomData<&'env mut &'env ()>,
165}
166
167impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
168 unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
174 Self {
175 results,
176 index: AtomicUsize::new(0),
177 inner: Arc::new(ScopeInner {
178 main_thread: thread::current(),
179 remaining_tasks: AtomicUsize::new(0),
180 panic: Mutex::new(None),
181 work_queue: Mutex::new(VecDeque::new()),
182 work_queue_condition_var: Condvar::new(),
183 }),
184 handle: Handle::current(),
185 turbo_tasks: try_turbo_tasks(),
186 span: Span::current(),
187 env: PhantomData,
188 }
189 }
190
191 pub fn spawn<F>(&self, f: F)
193 where
194 F: FnOnce() -> R + Send + 'env,
195 {
196 let index = self.index.fetch_add(1, Ordering::Relaxed);
197 assert!(index < self.results.len(), "Too many tasks spawned");
198 let result_cell: &Mutex<Option<R>> = &self.results[index];
199
200 let turbo_tasks = self.turbo_tasks.clone();
201 let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
202 let result = {
203 if let Some(turbo_tasks) = turbo_tasks {
204 turbo_tasks_scope(turbo_tasks, f)
206 } else {
207 f()
209 }
210 };
211 *result_cell.lock() = Some(result);
212 });
213 let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
214
215 let f = unsafe {
218 std::mem::transmute::<
219 *mut (dyn FnOnce() + Send + 'scope),
220 *mut (dyn FnOnce() + Send + 'static),
221 >(f)
222 };
223
224 let f = unsafe { Box::from_raw(f) };
226
227 let span = self.span.clone();
228
229 self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
230
231 if (1..=*WORKER_TASKS).contains(&index) {
235 let inner = self.inner.clone();
236 self.handle.spawn(async move {
238 let _span = span.entered();
239 inner.worker(index, f);
240 });
241 } else {
242 self.inner
244 .work_queue
245 .lock()
246 .push_back(WorkQueueJob::Job(index, f));
247 self.inner.work_queue_condition_var.notify_one();
248 }
249 }
250}
251
252impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
253 fn drop(&mut self) {
254 self.inner.end_and_help_complete();
255 self.inner.wait_and_rethrow_panic();
256 }
257}
258
259pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
267where
268 R: Send + 'env,
269 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
270{
271 let mut results = Vec::with_capacity(number_of_tasks);
272 for _ in 0..number_of_tasks {
273 results.push(Mutex::new(None));
274 }
275 let results = results.into_boxed_slice();
276 let result = {
277 let scope = unsafe { Scope::new(&results) };
279 catch_unwind(AssertUnwindSafe(|| f(&scope)))
280 };
281 if let Err(panic) = result {
282 panic::resume_unwind(panic);
283 }
284 results.into_iter().map(|mutex| {
285 mutex
286 .into_inner()
287 .expect("All values are set when the scope returns without panic")
288 })
289}
290
291#[cfg(test)]
292mod tests {
293 use std::panic::{AssertUnwindSafe, catch_unwind};
294
295 use super::*;
296
297 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
298 async fn test_scope() {
299 let results = scope_and_block(1000, |scope| {
300 for i in 0..1000 {
301 scope.spawn(move || i);
302 }
303 });
304 let results = results.collect::<Vec<_>>();
305 results.iter().enumerate().for_each(|(i, &result)| {
306 assert_eq!(result, i);
307 });
308 assert_eq!(results.len(), 1000);
309 }
310
311 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
312 async fn test_empty_scope() {
313 let results = scope_and_block(0, |scope| {
314 if false {
315 scope.spawn(|| 42);
316 }
317 });
318 assert_eq!(results.count(), 0);
319 }
320
321 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
322 async fn test_single_task() {
323 let results = scope_and_block(1, |scope| {
324 scope.spawn(|| 42);
325 })
326 .collect::<Vec<_>>();
327 assert_eq!(results, vec![42]);
328 }
329
330 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
331 async fn test_task_finish_before_scope() {
332 let results = scope_and_block(1, |scope| {
333 scope.spawn(|| 42);
334 thread::sleep(std::time::Duration::from_millis(100));
335 })
336 .collect::<Vec<_>>();
337 assert_eq!(results, vec![42]);
338 }
339
340 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341 async fn test_task_finish_after_scope() {
342 let results = scope_and_block(1, |scope| {
343 scope.spawn(|| {
344 thread::sleep(std::time::Duration::from_millis(100));
345 42
346 });
347 })
348 .collect::<Vec<_>>();
349 assert_eq!(results, vec![42]);
350 }
351
352 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
353 async fn test_panic_in_scope_factory() {
354 let result = catch_unwind(AssertUnwindSafe(|| {
355 let _results = scope_and_block(1000, |scope| {
356 for i in 0..500 {
357 scope.spawn(move || i);
358 }
359 panic!("Intentional panic");
360 });
361 unreachable!();
362 }));
363 assert!(result.is_err());
364 assert_eq!(
365 result.unwrap_err().downcast_ref::<&str>(),
366 Some(&"Intentional panic")
367 );
368 }
369
370 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
371 async fn test_panic_in_scope_task() {
372 let result = catch_unwind(AssertUnwindSafe(|| {
373 let _results = scope_and_block(1000, |scope| {
374 for i in 0..1000 {
375 scope.spawn(move || {
376 if i == 500 {
377 panic!("Intentional panic");
378 } else if i == 501 {
379 panic!("Wrong intentional panic");
380 } else {
381 i
382 }
383 });
384 }
385 });
386 unreachable!();
387 }));
388 assert!(result.is_err());
389 assert_eq!(
390 result.unwrap_err().downcast_ref::<&str>(),
391 Some(&"Intentional panic")
392 );
393 }
394}