Skip to main content

turbo_tasks/
scope.rs

1//! A scoped tokio spawn implementation that allow a non-'static lifetime for tasks.
2
3use std::{
4    any::Any,
5    collections::VecDeque,
6    marker::PhantomData,
7    panic::{self, AssertUnwindSafe, catch_unwind},
8    sync::{
9        Arc, LazyLock,
10        atomic::{AtomicUsize, Ordering},
11    },
12    thread::{self, Thread},
13    time::{Duration, Instant},
14};
15
16use parking_lot::{Condvar, Mutex};
17use tokio::{runtime::Handle, task::block_in_place};
18use tracing::{Span, info_span};
19
20use crate::{
21    TurboTasksApi, manager::try_turbo_tasks, parallel::available_parallelism, turbo_tasks_scope,
22};
23
24/// Number of worker tasks to spawn that process jobs. It's 1 less than the number of cpus as we
25/// also use the current task as worker.
26static WORKER_TASKS: LazyLock<usize> =
27    LazyLock::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
28
29enum WorkQueueJob {
30    Job(usize, Box<dyn FnOnce() + Send + 'static>),
31    End,
32}
33
34struct ScopeInner {
35    main_thread: Thread,
36    remaining_tasks: AtomicUsize,
37    /// The first panic that occurred in the tasks, by task index.
38    /// The usize value is the index of the task.
39    panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
40    /// The work queue for spawned jobs that have not yet been picked up by a worker task.
41    work_queue: Mutex<VecDeque<WorkQueueJob>>,
42    /// A condition variable to notify worker tasks of new work or end of work.
43    work_queue_condition_var: Condvar,
44}
45
46impl ScopeInner {
47    fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
48        if let Some((err, index)) = panic {
49            let mut old_panic = self.panic.lock();
50            if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
51                *old_panic = Some((err, index));
52            }
53        }
54        if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
55            self.main_thread.unpark();
56        }
57    }
58
59    fn wait(&self) {
60        if self.remaining_tasks.load(Ordering::Acquire) == 0 {
61            return;
62        }
63
64        let _span = info_span!("blocking").entered();
65
66        // Park up to 1ms without block_in_place to avoid the overhead.
67        const TIMEOUT: Duration = Duration::from_millis(1);
68        let beginning_park = Instant::now();
69
70        let mut timeout_remaining = TIMEOUT;
71        loop {
72            thread::park_timeout(timeout_remaining);
73            if self.remaining_tasks.load(Ordering::Acquire) == 0 {
74                return;
75            }
76            let elapsed = beginning_park.elapsed();
77            if elapsed >= TIMEOUT {
78                break;
79            }
80            timeout_remaining = TIMEOUT - elapsed;
81        }
82
83        // Park with block_in_place to allow to continue other work
84        block_in_place(|| {
85            while self.remaining_tasks.load(Ordering::Acquire) != 0 {
86                thread::park();
87            }
88        });
89    }
90
91    fn wait_and_rethrow_panic(&self) {
92        self.wait();
93        if let Some((err, _)) = self.panic.lock().take() {
94            panic::resume_unwind(err);
95        }
96    }
97
98    fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
99        let mut current_job_index = first_job_index;
100        let mut current_job = first_job;
101        loop {
102            let result = catch_unwind(AssertUnwindSafe(current_job));
103            let panic = result.err().map(|e| (e, current_job_index));
104            self.on_task_finished(panic);
105            let Some((index, job)) = self.pick_job_from_work_queue() else {
106                return;
107            };
108            current_job_index = index;
109            current_job = job;
110        }
111    }
112
113    fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
114        let mut work_queue = self.work_queue.lock();
115        let job = loop {
116            if let Some(job) = work_queue.pop_front() {
117                break job;
118            } else {
119                self.work_queue_condition_var.wait(&mut work_queue);
120            };
121        };
122        match job {
123            WorkQueueJob::Job(index, job) => {
124                drop(work_queue);
125                Some((index, job))
126            }
127            WorkQueueJob::End => {
128                work_queue.push_front(WorkQueueJob::End);
129                drop(work_queue);
130                self.work_queue_condition_var.notify_all();
131                None
132            }
133        }
134    }
135
136    fn end_and_help_complete(&self) {
137        let job;
138        {
139            let mut work_queue = self.work_queue.lock();
140            job = work_queue.pop_front();
141            work_queue.push_back(WorkQueueJob::End);
142        }
143        self.work_queue_condition_var.notify_all();
144        if let Some(WorkQueueJob::Job(index, job)) = job {
145            self.worker(index, job);
146        }
147    }
148}
149
150/// Scope to allow spawning tasks with a limited lifetime.
151///
152/// Dropping this Scope will wait for all tasks to complete.
153pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
154    results: &'scope [Mutex<Option<R>>],
155    index: AtomicUsize,
156    inner: Arc<ScopeInner>,
157    handle: Handle,
158    turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
159    span: Span,
160    /// Invariance over 'env, to make sure 'env cannot shrink, which is necessary for soundness.
161    ///
162    /// See the comment in the stdlib implementation:
163    /// <https://github.com/rust-lang/rust/blob/3b1b0ef4d8/library/std/src/thread/scoped.rs#L12-L33>
164    env: PhantomData<&'env mut &'env ()>,
165}
166
167impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
168    /// Creates a new scope.
169    ///
170    /// # Safety
171    ///
172    /// The caller must ensure `Scope` is dropped and not forgotten.
173    unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
174        Self {
175            results,
176            index: AtomicUsize::new(0),
177            inner: Arc::new(ScopeInner {
178                main_thread: thread::current(),
179                remaining_tasks: AtomicUsize::new(0),
180                panic: Mutex::new(None),
181                work_queue: Mutex::new(VecDeque::new()),
182                work_queue_condition_var: Condvar::new(),
183            }),
184            handle: Handle::current(),
185            turbo_tasks: try_turbo_tasks(),
186            span: Span::current(),
187            env: PhantomData,
188        }
189    }
190
191    /// Spawns a new task in the scope.
192    pub fn spawn<F>(&self, f: F)
193    where
194        F: FnOnce() -> R + Send + 'env,
195    {
196        let index = self.index.fetch_add(1, Ordering::Relaxed);
197        assert!(index < self.results.len(), "Too many tasks spawned");
198        let result_cell: &Mutex<Option<R>> = &self.results[index];
199
200        let turbo_tasks = self.turbo_tasks.clone();
201        let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
202            let result = {
203                if let Some(turbo_tasks) = turbo_tasks {
204                    // Ensure that the turbo tasks context is maintained across the job.
205                    turbo_tasks_scope(turbo_tasks, f)
206                } else {
207                    // If no turbo tasks context is available, just run the job.
208                    f()
209                }
210            };
211            *result_cell.lock() = Some(result);
212        });
213        let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
214
215        // SAFETY: Scope ensures (e. g. in Drop) that spawned tasks is awaited before the
216        // lifetime `'env` ends.
217        let f = unsafe {
218            std::mem::transmute::<
219                *mut (dyn FnOnce() + Send + 'scope),
220                *mut (dyn FnOnce() + Send + 'static),
221            >(f)
222        };
223
224        // SAFETY: We just called `Box::into_raw`.
225        let f = unsafe { Box::from_raw(f) };
226
227        let span = self.span.clone();
228
229        self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
230
231        // The first job always goes to the work_queue to be worked on by the main thread.
232        // After that we spawn a new worker for every job until we reach WORKER_TASKS.
233        // After that we queue up jobs in the work_queue again.
234        if (1..=*WORKER_TASKS).contains(&index) {
235            let inner = self.inner.clone();
236            // Spawn a worker task that will process that tasks and potentially more.
237            self.handle.spawn(async move {
238                let _span = span.entered();
239                inner.worker(index, f);
240            });
241        } else {
242            // Queue the task to be processed by a worker task.
243            self.inner
244                .work_queue
245                .lock()
246                .push_back(WorkQueueJob::Job(index, f));
247            self.inner.work_queue_condition_var.notify_one();
248        }
249    }
250}
251
252impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
253    fn drop(&mut self) {
254        self.inner.end_and_help_complete();
255        self.inner.wait_and_rethrow_panic();
256    }
257}
258
259/// Helper method to spawn tasks in parallel, ensuring that all tasks are awaited and errors are
260/// handled. Also ensures turbo tasks and tracing context are maintained across the tasks.
261///
262/// Be aware that although this function avoids starving other independently spawned tasks, any
263/// other code running concurrently in the same task will be suspended during the call to
264/// block_in_place. This can happen e.g. when using the `join!` macro. To avoid this issue, call
265/// `scope_and_block` in `spawn_blocking`.
266pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
267where
268    R: Send + 'env,
269    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
270{
271    let mut results = Vec::with_capacity(number_of_tasks);
272    for _ in 0..number_of_tasks {
273        results.push(Mutex::new(None));
274    }
275    let results = results.into_boxed_slice();
276    let result = {
277        // SAFETY: We drop the Scope later.
278        let scope = unsafe { Scope::new(&results) };
279        catch_unwind(AssertUnwindSafe(|| f(&scope)))
280    };
281    if let Err(panic) = result {
282        panic::resume_unwind(panic);
283    }
284    results.into_iter().map(|mutex| {
285        mutex
286            .into_inner()
287            .expect("All values are set when the scope returns without panic")
288    })
289}
290
291#[cfg(test)]
292mod tests {
293    use std::panic::{AssertUnwindSafe, catch_unwind};
294
295    use super::*;
296
297    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
298    async fn test_scope() {
299        let results = scope_and_block(1000, |scope| {
300            for i in 0..1000 {
301                scope.spawn(move || i);
302            }
303        });
304        let results = results.collect::<Vec<_>>();
305        results.iter().enumerate().for_each(|(i, &result)| {
306            assert_eq!(result, i);
307        });
308        assert_eq!(results.len(), 1000);
309    }
310
311    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
312    async fn test_empty_scope() {
313        let results = scope_and_block(0, |scope| {
314            if false {
315                scope.spawn(|| 42);
316            }
317        });
318        assert_eq!(results.count(), 0);
319    }
320
321    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
322    async fn test_single_task() {
323        let results = scope_and_block(1, |scope| {
324            scope.spawn(|| 42);
325        })
326        .collect::<Vec<_>>();
327        assert_eq!(results, vec![42]);
328    }
329
330    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
331    async fn test_task_finish_before_scope() {
332        let results = scope_and_block(1, |scope| {
333            scope.spawn(|| 42);
334            thread::sleep(std::time::Duration::from_millis(100));
335        })
336        .collect::<Vec<_>>();
337        assert_eq!(results, vec![42]);
338    }
339
340    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
341    async fn test_task_finish_after_scope() {
342        let results = scope_and_block(1, |scope| {
343            scope.spawn(|| {
344                thread::sleep(std::time::Duration::from_millis(100));
345                42
346            });
347        })
348        .collect::<Vec<_>>();
349        assert_eq!(results, vec![42]);
350    }
351
352    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
353    async fn test_panic_in_scope_factory() {
354        let result = catch_unwind(AssertUnwindSafe(|| {
355            let _results = scope_and_block(1000, |scope| {
356                for i in 0..500 {
357                    scope.spawn(move || i);
358                }
359                panic!("Intentional panic");
360            });
361            unreachable!();
362        }));
363        assert!(result.is_err());
364        assert_eq!(
365            result.unwrap_err().downcast_ref::<&str>(),
366            Some(&"Intentional panic")
367        );
368    }
369
370    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
371    async fn test_panic_in_scope_task() {
372        let result = catch_unwind(AssertUnwindSafe(|| {
373            let _results = scope_and_block(1000, |scope| {
374                for i in 0..1000 {
375                    scope.spawn(move || {
376                        if i == 500 {
377                            panic!("Intentional panic");
378                        } else if i == 501 {
379                            panic!("Wrong intentional panic");
380                        } else {
381                            i
382                        }
383                    });
384                }
385            });
386            unreachable!();
387        }));
388        assert!(result.is_err());
389        assert_eq!(
390            result.unwrap_err().downcast_ref::<&str>(),
391            Some(&"Intentional panic")
392        );
393    }
394}