1use std::{
4 any::Any,
5 collections::VecDeque,
6 marker::PhantomData,
7 panic::{self, AssertUnwindSafe, catch_unwind},
8 sync::{
9 Arc,
10 atomic::{AtomicUsize, Ordering},
11 },
12 thread::{self, Thread, available_parallelism},
13 time::{Duration, Instant},
14};
15
16use once_cell::sync::Lazy;
17use parking_lot::{Condvar, Mutex};
18use tokio::{runtime::Handle, task::block_in_place};
19use tracing::{Span, info_span};
20
21use crate::{TurboTasksApi, manager::try_turbo_tasks, turbo_tasks_scope};
22
23static WORKER_TASKS: Lazy<usize> = Lazy::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
26
27enum WorkQueueJob {
28 Job(usize, Box<dyn FnOnce() + Send + 'static>),
29 End,
30}
31
32struct ScopeInner {
33 main_thread: Thread,
34 remaining_tasks: AtomicUsize,
35 panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
38 work_queue: Mutex<VecDeque<WorkQueueJob>>,
40 work_queue_condition_var: Condvar,
42}
43
44impl ScopeInner {
45 fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
46 if let Some((err, index)) = panic {
47 let mut old_panic = self.panic.lock();
48 if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
49 *old_panic = Some((err, index));
50 }
51 }
52 if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
53 self.main_thread.unpark();
54 }
55 }
56
57 fn wait(&self) {
58 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
59 return;
60 }
61
62 let _span = info_span!("blocking").entered();
63
64 const TIMEOUT: Duration = Duration::from_millis(1);
66 let beginning_park = Instant::now();
67
68 let mut timeout_remaining = TIMEOUT;
69 loop {
70 thread::park_timeout(timeout_remaining);
71 if self.remaining_tasks.load(Ordering::Acquire) == 0 {
72 return;
73 }
74 let elapsed = beginning_park.elapsed();
75 if elapsed >= TIMEOUT {
76 break;
77 }
78 timeout_remaining = TIMEOUT - elapsed;
79 }
80
81 block_in_place(|| {
83 while self.remaining_tasks.load(Ordering::Acquire) != 0 {
84 thread::park();
85 }
86 });
87 }
88
89 fn wait_and_rethrow_panic(&self) {
90 self.wait();
91 if let Some((err, _)) = self.panic.lock().take() {
92 panic::resume_unwind(err);
93 }
94 }
95
96 fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
97 let mut current_job_index = first_job_index;
98 let mut current_job = first_job;
99 loop {
100 let result = catch_unwind(AssertUnwindSafe(current_job));
101 let panic = result.err().map(|e| (e, current_job_index));
102 self.on_task_finished(panic);
103 let Some((index, job)) = self.pick_job_from_work_queue() else {
104 return;
105 };
106 current_job_index = index;
107 current_job = job;
108 }
109 }
110
111 fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
112 let mut work_queue = self.work_queue.lock();
113 let job = loop {
114 if let Some(job) = work_queue.pop_front() {
115 break job;
116 } else {
117 self.work_queue_condition_var.wait(&mut work_queue);
118 };
119 };
120 match job {
121 WorkQueueJob::Job(index, job) => {
122 drop(work_queue);
123 Some((index, job))
124 }
125 WorkQueueJob::End => {
126 work_queue.push_front(WorkQueueJob::End);
127 drop(work_queue);
128 self.work_queue_condition_var.notify_all();
129 None
130 }
131 }
132 }
133
134 fn end_and_help_complete(&self) {
135 let job;
136 {
137 let mut work_queue = self.work_queue.lock();
138 job = work_queue.pop_front();
139 work_queue.push_back(WorkQueueJob::End);
140 }
141 self.work_queue_condition_var.notify_all();
142 if let Some(WorkQueueJob::Job(index, job)) = job {
143 self.worker(index, job);
144 }
145 }
146}
147
148pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
152 results: &'scope [Mutex<Option<R>>],
153 index: AtomicUsize,
154 inner: Arc<ScopeInner>,
155 handle: Handle,
156 turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
157 span: Span,
158 env: PhantomData<&'env mut &'env ()>,
163}
164
165impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
166 unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
172 Self {
173 results,
174 index: AtomicUsize::new(0),
175 inner: Arc::new(ScopeInner {
176 main_thread: thread::current(),
177 remaining_tasks: AtomicUsize::new(0),
178 panic: Mutex::new(None),
179 work_queue: Mutex::new(VecDeque::new()),
180 work_queue_condition_var: Condvar::new(),
181 }),
182 handle: Handle::current(),
183 turbo_tasks: try_turbo_tasks(),
184 span: Span::current(),
185 env: PhantomData,
186 }
187 }
188
189 pub fn spawn<F>(&self, f: F)
191 where
192 F: FnOnce() -> R + Send + 'env,
193 {
194 let index = self.index.fetch_add(1, Ordering::Relaxed);
195 assert!(index < self.results.len(), "Too many tasks spawned");
196 let result_cell: &Mutex<Option<R>> = &self.results[index];
197
198 let turbo_tasks = self.turbo_tasks.clone();
199 let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
200 let result = {
201 if let Some(turbo_tasks) = turbo_tasks {
202 turbo_tasks_scope(turbo_tasks, f)
204 } else {
205 f()
207 }
208 };
209 *result_cell.lock() = Some(result);
210 });
211 let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
212
213 let f = unsafe {
216 std::mem::transmute::<
217 *mut (dyn FnOnce() + Send + 'scope),
218 *mut (dyn FnOnce() + Send + 'static),
219 >(f)
220 };
221
222 let f = unsafe { Box::from_raw(f) };
224
225 let span = self.span.clone();
226
227 self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
228
229 if (1..=*WORKER_TASKS).contains(&index) {
233 let inner = self.inner.clone();
234 self.handle.spawn(async move {
236 let _span = span.entered();
237 inner.worker(index, f);
238 });
239 } else {
240 self.inner
242 .work_queue
243 .lock()
244 .push_back(WorkQueueJob::Job(index, f));
245 self.inner.work_queue_condition_var.notify_one();
246 }
247 }
248}
249
250impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
251 fn drop(&mut self) {
252 self.inner.end_and_help_complete();
253 self.inner.wait_and_rethrow_panic();
254 }
255}
256
257pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
265where
266 R: Send + 'env,
267 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
268{
269 let mut results = Vec::with_capacity(number_of_tasks);
270 for _ in 0..number_of_tasks {
271 results.push(Mutex::new(None));
272 }
273 let results = results.into_boxed_slice();
274 let result = {
275 let scope = unsafe { Scope::new(&results) };
277 catch_unwind(AssertUnwindSafe(|| f(&scope)))
278 };
279 if let Err(panic) = result {
280 panic::resume_unwind(panic);
281 }
282 results.into_iter().map(|mutex| {
283 mutex
284 .into_inner()
285 .expect("All values are set when the scope returns without panic")
286 })
287}
288
289#[cfg(test)]
290mod tests {
291 use std::panic::{AssertUnwindSafe, catch_unwind};
292
293 use super::*;
294
295 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
296 async fn test_scope() {
297 let results = scope_and_block(1000, |scope| {
298 for i in 0..1000 {
299 scope.spawn(move || i);
300 }
301 });
302 let results = results.collect::<Vec<_>>();
303 results.iter().enumerate().for_each(|(i, &result)| {
304 assert_eq!(result, i);
305 });
306 assert_eq!(results.len(), 1000);
307 }
308
309 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
310 async fn test_empty_scope() {
311 let results = scope_and_block(0, |scope| {
312 if false {
313 scope.spawn(|| 42);
314 }
315 });
316 assert_eq!(results.count(), 0);
317 }
318
319 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
320 async fn test_single_task() {
321 let results = scope_and_block(1, |scope| {
322 scope.spawn(|| 42);
323 })
324 .collect::<Vec<_>>();
325 assert_eq!(results, vec![42]);
326 }
327
328 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
329 async fn test_task_finish_before_scope() {
330 let results = scope_and_block(1, |scope| {
331 scope.spawn(|| 42);
332 thread::sleep(std::time::Duration::from_millis(100));
333 })
334 .collect::<Vec<_>>();
335 assert_eq!(results, vec![42]);
336 }
337
338 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
339 async fn test_task_finish_after_scope() {
340 let results = scope_and_block(1, |scope| {
341 scope.spawn(|| {
342 thread::sleep(std::time::Duration::from_millis(100));
343 42
344 });
345 })
346 .collect::<Vec<_>>();
347 assert_eq!(results, vec![42]);
348 }
349
350 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
351 async fn test_panic_in_scope_factory() {
352 let result = catch_unwind(AssertUnwindSafe(|| {
353 let _results = scope_and_block(1000, |scope| {
354 for i in 0..500 {
355 scope.spawn(move || i);
356 }
357 panic!("Intentional panic");
358 });
359 unreachable!();
360 }));
361 assert!(result.is_err());
362 assert_eq!(
363 result.unwrap_err().downcast_ref::<&str>(),
364 Some(&"Intentional panic")
365 );
366 }
367
368 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
369 async fn test_panic_in_scope_task() {
370 let result = catch_unwind(AssertUnwindSafe(|| {
371 let _results = scope_and_block(1000, |scope| {
372 for i in 0..1000 {
373 scope.spawn(move || {
374 if i == 500 {
375 panic!("Intentional panic");
376 } else if i == 501 {
377 panic!("Wrong intentional panic");
378 } else {
379 i
380 }
381 });
382 }
383 });
384 unreachable!();
385 }));
386 assert!(result.is_err());
387 assert_eq!(
388 result.unwrap_err().downcast_ref::<&str>(),
389 Some(&"Intentional panic")
390 );
391 }
392}