turbo_tasks/
scope.rs

1//! A scoped tokio spawn implementation that allow a non-'static lifetime for tasks.
2
3use std::{
4    any::Any,
5    collections::VecDeque,
6    marker::PhantomData,
7    panic::{self, AssertUnwindSafe, catch_unwind},
8    sync::{
9        Arc,
10        atomic::{AtomicUsize, Ordering},
11    },
12    thread::{self, Thread, available_parallelism},
13    time::{Duration, Instant},
14};
15
16use once_cell::sync::Lazy;
17use parking_lot::{Condvar, Mutex};
18use tokio::{runtime::Handle, task::block_in_place};
19use tracing::{Span, info_span};
20
21use crate::{TurboTasksApi, manager::try_turbo_tasks, turbo_tasks_scope};
22
23/// Number of worker tasks to spawn that process jobs. It's 1 less than the number of cpus as we
24/// also use the current task as worker.
25static WORKER_TASKS: Lazy<usize> = Lazy::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
26
27enum WorkQueueJob {
28    Job(usize, Box<dyn FnOnce() + Send + 'static>),
29    End,
30}
31
32struct ScopeInner {
33    main_thread: Thread,
34    remaining_tasks: AtomicUsize,
35    /// The first panic that occurred in the tasks, by task index.
36    /// The usize value is the index of the task.
37    panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
38    /// The work queue for spawned jobs that have not yet been picked up by a worker task.
39    work_queue: Mutex<VecDeque<WorkQueueJob>>,
40    /// A condition variable to notify worker tasks of new work or end of work.
41    work_queue_condition_var: Condvar,
42}
43
44impl ScopeInner {
45    fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
46        if let Some((err, index)) = panic {
47            let mut old_panic = self.panic.lock();
48            if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
49                *old_panic = Some((err, index));
50            }
51        }
52        if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
53            self.main_thread.unpark();
54        }
55    }
56
57    fn wait(&self) {
58        if self.remaining_tasks.load(Ordering::Acquire) == 0 {
59            return;
60        }
61
62        let _span = info_span!("blocking").entered();
63
64        // Park up to 1ms without block_in_place to avoid the overhead.
65        const TIMEOUT: Duration = Duration::from_millis(1);
66        let beginning_park = Instant::now();
67
68        let mut timeout_remaining = TIMEOUT;
69        loop {
70            thread::park_timeout(timeout_remaining);
71            if self.remaining_tasks.load(Ordering::Acquire) == 0 {
72                return;
73            }
74            let elapsed = beginning_park.elapsed();
75            if elapsed >= TIMEOUT {
76                break;
77            }
78            timeout_remaining = TIMEOUT - elapsed;
79        }
80
81        // Park with block_in_place to allow to continue other work
82        block_in_place(|| {
83            while self.remaining_tasks.load(Ordering::Acquire) != 0 {
84                thread::park();
85            }
86        });
87    }
88
89    fn wait_and_rethrow_panic(&self) {
90        self.wait();
91        if let Some((err, _)) = self.panic.lock().take() {
92            panic::resume_unwind(err);
93        }
94    }
95
96    fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
97        let mut current_job_index = first_job_index;
98        let mut current_job = first_job;
99        loop {
100            let result = catch_unwind(AssertUnwindSafe(current_job));
101            let panic = result.err().map(|e| (e, current_job_index));
102            self.on_task_finished(panic);
103            let Some((index, job)) = self.pick_job_from_work_queue() else {
104                return;
105            };
106            current_job_index = index;
107            current_job = job;
108        }
109    }
110
111    fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
112        let mut work_queue = self.work_queue.lock();
113        let job = loop {
114            if let Some(job) = work_queue.pop_front() {
115                break job;
116            } else {
117                self.work_queue_condition_var.wait(&mut work_queue);
118            };
119        };
120        match job {
121            WorkQueueJob::Job(index, job) => {
122                drop(work_queue);
123                Some((index, job))
124            }
125            WorkQueueJob::End => {
126                work_queue.push_front(WorkQueueJob::End);
127                drop(work_queue);
128                self.work_queue_condition_var.notify_all();
129                None
130            }
131        }
132    }
133
134    fn end_and_help_complete(&self) {
135        let job;
136        {
137            let mut work_queue = self.work_queue.lock();
138            job = work_queue.pop_front();
139            work_queue.push_back(WorkQueueJob::End);
140        }
141        self.work_queue_condition_var.notify_all();
142        if let Some(WorkQueueJob::Job(index, job)) = job {
143            self.worker(index, job);
144        }
145    }
146}
147
148/// Scope to allow spawning tasks with a limited lifetime.
149///
150/// Dropping this Scope will wait for all tasks to complete.
151pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
152    results: &'scope [Mutex<Option<R>>],
153    index: AtomicUsize,
154    inner: Arc<ScopeInner>,
155    handle: Handle,
156    turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
157    span: Span,
158    /// Invariance over 'env, to make sure 'env cannot shrink,
159    /// which is necessary for soundness.
160    ///
161    /// see https://doc.rust-lang.org/src/std/thread/scoped.rs.html#12-29
162    env: PhantomData<&'env mut &'env ()>,
163}
164
165impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
166    /// Creates a new scope.
167    ///
168    /// # Safety
169    ///
170    /// The caller must ensure `Scope` is dropped and not forgotten.
171    unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
172        Self {
173            results,
174            index: AtomicUsize::new(0),
175            inner: Arc::new(ScopeInner {
176                main_thread: thread::current(),
177                remaining_tasks: AtomicUsize::new(0),
178                panic: Mutex::new(None),
179                work_queue: Mutex::new(VecDeque::new()),
180                work_queue_condition_var: Condvar::new(),
181            }),
182            handle: Handle::current(),
183            turbo_tasks: try_turbo_tasks(),
184            span: Span::current(),
185            env: PhantomData,
186        }
187    }
188
189    /// Spawns a new task in the scope.
190    pub fn spawn<F>(&self, f: F)
191    where
192        F: FnOnce() -> R + Send + 'env,
193    {
194        let index = self.index.fetch_add(1, Ordering::Relaxed);
195        assert!(index < self.results.len(), "Too many tasks spawned");
196        let result_cell: &Mutex<Option<R>> = &self.results[index];
197
198        let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
199            let result = f();
200            *result_cell.lock() = Some(result);
201        });
202        let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
203        // SAFETY: Scope ensures (e. g. in Drop) that spawned tasks is awaited before the
204        // lifetime `'env` ends.
205        #[allow(
206            clippy::unnecessary_cast,
207            reason = "Clippy thinks this is unnecessary, but it actually changes the lifetime"
208        )]
209        let f = f as *mut (dyn FnOnce() + Send + 'static);
210        // SAFETY: We just called `Box::into_raw`.
211        let f = unsafe { Box::from_raw(f) };
212
213        let turbo_tasks = self.turbo_tasks.clone();
214        let span = self.span.clone();
215
216        self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
217
218        // The first job always goes to the work_queue to be worked on by the main thread.
219        // After that we spawn a new worker for every job until we reach WORKER_TASKS.
220        // After that we queue up jobs in the work_queue again.
221        if (1..=*WORKER_TASKS).contains(&index) {
222            let inner = self.inner.clone();
223            // Spawn a worker task that will process that tasks and potentially more.
224            self.handle.spawn(async move {
225                let _span = span.entered();
226                if let Some(turbo_tasks) = turbo_tasks {
227                    // Ensure that the turbo tasks context is maintained across the worker.
228                    turbo_tasks_scope(turbo_tasks, || {
229                        inner.worker(index, f);
230                    });
231                } else {
232                    // If no turbo tasks context is available, just run the worker.
233                    inner.worker(index, f);
234                }
235            });
236        } else {
237            // Queue the task to be processed by a worker task.
238            self.inner
239                .work_queue
240                .lock()
241                .push_back(WorkQueueJob::Job(index, f));
242            self.inner.work_queue_condition_var.notify_one();
243        }
244    }
245}
246
247impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
248    fn drop(&mut self) {
249        self.inner.end_and_help_complete();
250        self.inner.wait_and_rethrow_panic();
251    }
252}
253
254/// Helper method to spawn tasks in parallel, ensuring that all tasks are awaited and errors are
255/// handled. Also ensures turbo tasks and tracing context are maintained across the tasks.
256///
257/// Be aware that although this function avoids starving other independently spawned tasks, any
258/// other code running concurrently in the same task will be suspended during the call to
259/// block_in_place. This can happen e.g. when using the `join!` macro. To avoid this issue, call
260/// `scope_and_block` in `spawn_blocking`.
261pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
262where
263    R: Send + 'env,
264    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
265{
266    let mut results = Vec::with_capacity(number_of_tasks);
267    for _ in 0..number_of_tasks {
268        results.push(Mutex::new(None));
269    }
270    let results = results.into_boxed_slice();
271    let result = {
272        // SAFETY: We drop the Scope later.
273        let scope = unsafe { Scope::new(&results) };
274        catch_unwind(AssertUnwindSafe(|| f(&scope)))
275    };
276    if let Err(panic) = result {
277        panic::resume_unwind(panic);
278    }
279    results.into_iter().map(|mutex| {
280        mutex
281            .into_inner()
282            .expect("All values are set when the scope returns without panic")
283    })
284}
285
286#[cfg(test)]
287mod tests {
288    use std::panic::{AssertUnwindSafe, catch_unwind};
289
290    use super::*;
291
292    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
293    async fn test_scope() {
294        let results = scope_and_block(1000, |scope| {
295            for i in 0..1000 {
296                scope.spawn(move || i);
297            }
298        });
299        let results = results.collect::<Vec<_>>();
300        results.iter().enumerate().for_each(|(i, &result)| {
301            assert_eq!(result, i);
302        });
303        assert_eq!(results.len(), 1000);
304    }
305
306    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
307    async fn test_empty_scope() {
308        let results = scope_and_block(0, |scope| {
309            if false {
310                scope.spawn(|| 42);
311            }
312        });
313        assert_eq!(results.count(), 0);
314    }
315
316    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
317    async fn test_single_task() {
318        let results = scope_and_block(1, |scope| {
319            scope.spawn(|| 42);
320        })
321        .collect::<Vec<_>>();
322        assert_eq!(results, vec![42]);
323    }
324
325    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
326    async fn test_task_finish_before_scope() {
327        let results = scope_and_block(1, |scope| {
328            scope.spawn(|| 42);
329            thread::sleep(std::time::Duration::from_millis(100));
330        })
331        .collect::<Vec<_>>();
332        assert_eq!(results, vec![42]);
333    }
334
335    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
336    async fn test_task_finish_after_scope() {
337        let results = scope_and_block(1, |scope| {
338            scope.spawn(|| {
339                thread::sleep(std::time::Duration::from_millis(100));
340                42
341            });
342        })
343        .collect::<Vec<_>>();
344        assert_eq!(results, vec![42]);
345    }
346
347    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
348    async fn test_panic_in_scope_factory() {
349        let result = catch_unwind(AssertUnwindSafe(|| {
350            let _results = scope_and_block(1000, |scope| {
351                for i in 0..500 {
352                    scope.spawn(move || i);
353                }
354                panic!("Intentional panic");
355            });
356            unreachable!();
357        }));
358        assert!(result.is_err());
359        assert_eq!(
360            result.unwrap_err().downcast_ref::<&str>(),
361            Some(&"Intentional panic")
362        );
363    }
364
365    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
366    async fn test_panic_in_scope_task() {
367        let result = catch_unwind(AssertUnwindSafe(|| {
368            let _results = scope_and_block(1000, |scope| {
369                for i in 0..1000 {
370                    scope.spawn(move || {
371                        if i == 500 {
372                            panic!("Intentional panic");
373                        } else if i == 501 {
374                            panic!("Wrong intentional panic");
375                        } else {
376                            i
377                        }
378                    });
379                }
380            });
381            unreachable!();
382        }));
383        assert!(result.is_err());
384        assert_eq!(
385            result.unwrap_err().downcast_ref::<&str>(),
386            Some(&"Intentional panic")
387        );
388    }
389}