Skip to main content

turbo_tasks/
scope.rs

1//! A scoped tokio spawn implementation that allow a non-'static lifetime for tasks.
2
3use std::{
4    any::Any,
5    collections::VecDeque,
6    marker::PhantomData,
7    panic::{self, AssertUnwindSafe, catch_unwind},
8    sync::{
9        Arc,
10        atomic::{AtomicUsize, Ordering},
11    },
12    thread::{self, Thread, available_parallelism},
13    time::{Duration, Instant},
14};
15
16use once_cell::sync::Lazy;
17use parking_lot::{Condvar, Mutex};
18use tokio::{runtime::Handle, task::block_in_place};
19use tracing::{Span, info_span};
20
21use crate::{TurboTasksApi, manager::try_turbo_tasks, turbo_tasks_scope};
22
23/// Number of worker tasks to spawn that process jobs. It's 1 less than the number of cpus as we
24/// also use the current task as worker.
25static WORKER_TASKS: Lazy<usize> = Lazy::new(|| available_parallelism().map_or(0, |n| n.get() - 1));
26
27enum WorkQueueJob {
28    Job(usize, Box<dyn FnOnce() + Send + 'static>),
29    End,
30}
31
32struct ScopeInner {
33    main_thread: Thread,
34    remaining_tasks: AtomicUsize,
35    /// The first panic that occurred in the tasks, by task index.
36    /// The usize value is the index of the task.
37    panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
38    /// The work queue for spawned jobs that have not yet been picked up by a worker task.
39    work_queue: Mutex<VecDeque<WorkQueueJob>>,
40    /// A condition variable to notify worker tasks of new work or end of work.
41    work_queue_condition_var: Condvar,
42}
43
44impl ScopeInner {
45    fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
46        if let Some((err, index)) = panic {
47            let mut old_panic = self.panic.lock();
48            if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
49                *old_panic = Some((err, index));
50            }
51        }
52        if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
53            self.main_thread.unpark();
54        }
55    }
56
57    fn wait(&self) {
58        if self.remaining_tasks.load(Ordering::Acquire) == 0 {
59            return;
60        }
61
62        let _span = info_span!("blocking").entered();
63
64        // Park up to 1ms without block_in_place to avoid the overhead.
65        const TIMEOUT: Duration = Duration::from_millis(1);
66        let beginning_park = Instant::now();
67
68        let mut timeout_remaining = TIMEOUT;
69        loop {
70            thread::park_timeout(timeout_remaining);
71            if self.remaining_tasks.load(Ordering::Acquire) == 0 {
72                return;
73            }
74            let elapsed = beginning_park.elapsed();
75            if elapsed >= TIMEOUT {
76                break;
77            }
78            timeout_remaining = TIMEOUT - elapsed;
79        }
80
81        // Park with block_in_place to allow to continue other work
82        block_in_place(|| {
83            while self.remaining_tasks.load(Ordering::Acquire) != 0 {
84                thread::park();
85            }
86        });
87    }
88
89    fn wait_and_rethrow_panic(&self) {
90        self.wait();
91        if let Some((err, _)) = self.panic.lock().take() {
92            panic::resume_unwind(err);
93        }
94    }
95
96    fn worker(&self, first_job_index: usize, first_job: Box<dyn FnOnce() + Send + 'static>) {
97        let mut current_job_index = first_job_index;
98        let mut current_job = first_job;
99        loop {
100            let result = catch_unwind(AssertUnwindSafe(current_job));
101            let panic = result.err().map(|e| (e, current_job_index));
102            self.on_task_finished(panic);
103            let Some((index, job)) = self.pick_job_from_work_queue() else {
104                return;
105            };
106            current_job_index = index;
107            current_job = job;
108        }
109    }
110
111    fn pick_job_from_work_queue(&self) -> Option<(usize, Box<dyn FnOnce() + Send + 'static>)> {
112        let mut work_queue = self.work_queue.lock();
113        let job = loop {
114            if let Some(job) = work_queue.pop_front() {
115                break job;
116            } else {
117                self.work_queue_condition_var.wait(&mut work_queue);
118            };
119        };
120        match job {
121            WorkQueueJob::Job(index, job) => {
122                drop(work_queue);
123                Some((index, job))
124            }
125            WorkQueueJob::End => {
126                work_queue.push_front(WorkQueueJob::End);
127                drop(work_queue);
128                self.work_queue_condition_var.notify_all();
129                None
130            }
131        }
132    }
133
134    fn end_and_help_complete(&self) {
135        let job;
136        {
137            let mut work_queue = self.work_queue.lock();
138            job = work_queue.pop_front();
139            work_queue.push_back(WorkQueueJob::End);
140        }
141        self.work_queue_condition_var.notify_all();
142        if let Some(WorkQueueJob::Job(index, job)) = job {
143            self.worker(index, job);
144        }
145    }
146}
147
148/// Scope to allow spawning tasks with a limited lifetime.
149///
150/// Dropping this Scope will wait for all tasks to complete.
151pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
152    results: &'scope [Mutex<Option<R>>],
153    index: AtomicUsize,
154    inner: Arc<ScopeInner>,
155    handle: Handle,
156    turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
157    span: Span,
158    /// Invariance over 'env, to make sure 'env cannot shrink,
159    /// which is necessary for soundness.
160    ///
161    /// see https://doc.rust-lang.org/src/std/thread/scoped.rs.html#12-29
162    env: PhantomData<&'env mut &'env ()>,
163}
164
165impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
166    /// Creates a new scope.
167    ///
168    /// # Safety
169    ///
170    /// The caller must ensure `Scope` is dropped and not forgotten.
171    unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
172        Self {
173            results,
174            index: AtomicUsize::new(0),
175            inner: Arc::new(ScopeInner {
176                main_thread: thread::current(),
177                remaining_tasks: AtomicUsize::new(0),
178                panic: Mutex::new(None),
179                work_queue: Mutex::new(VecDeque::new()),
180                work_queue_condition_var: Condvar::new(),
181            }),
182            handle: Handle::current(),
183            turbo_tasks: try_turbo_tasks(),
184            span: Span::current(),
185            env: PhantomData,
186        }
187    }
188
189    /// Spawns a new task in the scope.
190    pub fn spawn<F>(&self, f: F)
191    where
192        F: FnOnce() -> R + Send + 'env,
193    {
194        let index = self.index.fetch_add(1, Ordering::Relaxed);
195        assert!(index < self.results.len(), "Too many tasks spawned");
196        let result_cell: &Mutex<Option<R>> = &self.results[index];
197
198        let turbo_tasks = self.turbo_tasks.clone();
199        let f: Box<dyn FnOnce() + Send + 'scope> = Box::new(|| {
200            let result = {
201                if let Some(turbo_tasks) = turbo_tasks {
202                    // Ensure that the turbo tasks context is maintained across the job.
203                    turbo_tasks_scope(turbo_tasks, f)
204                } else {
205                    // If no turbo tasks context is available, just run the job.
206                    f()
207                }
208            };
209            *result_cell.lock() = Some(result);
210        });
211        let f: *mut (dyn FnOnce() + Send + 'scope) = Box::into_raw(f);
212
213        // SAFETY: Scope ensures (e. g. in Drop) that spawned tasks is awaited before the
214        // lifetime `'env` ends.
215        let f = unsafe {
216            std::mem::transmute::<
217                *mut (dyn FnOnce() + Send + 'scope),
218                *mut (dyn FnOnce() + Send + 'static),
219            >(f)
220        };
221
222        // SAFETY: We just called `Box::into_raw`.
223        let f = unsafe { Box::from_raw(f) };
224
225        let span = self.span.clone();
226
227        self.inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
228
229        // The first job always goes to the work_queue to be worked on by the main thread.
230        // After that we spawn a new worker for every job until we reach WORKER_TASKS.
231        // After that we queue up jobs in the work_queue again.
232        if (1..=*WORKER_TASKS).contains(&index) {
233            let inner = self.inner.clone();
234            // Spawn a worker task that will process that tasks and potentially more.
235            self.handle.spawn(async move {
236                let _span = span.entered();
237                inner.worker(index, f);
238            });
239        } else {
240            // Queue the task to be processed by a worker task.
241            self.inner
242                .work_queue
243                .lock()
244                .push_back(WorkQueueJob::Job(index, f));
245            self.inner.work_queue_condition_var.notify_one();
246        }
247    }
248}
249
250impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
251    fn drop(&mut self) {
252        self.inner.end_and_help_complete();
253        self.inner.wait_and_rethrow_panic();
254    }
255}
256
257/// Helper method to spawn tasks in parallel, ensuring that all tasks are awaited and errors are
258/// handled. Also ensures turbo tasks and tracing context are maintained across the tasks.
259///
260/// Be aware that although this function avoids starving other independently spawned tasks, any
261/// other code running concurrently in the same task will be suspended during the call to
262/// block_in_place. This can happen e.g. when using the `join!` macro. To avoid this issue, call
263/// `scope_and_block` in `spawn_blocking`.
264pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
265where
266    R: Send + 'env,
267    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
268{
269    let mut results = Vec::with_capacity(number_of_tasks);
270    for _ in 0..number_of_tasks {
271        results.push(Mutex::new(None));
272    }
273    let results = results.into_boxed_slice();
274    let result = {
275        // SAFETY: We drop the Scope later.
276        let scope = unsafe { Scope::new(&results) };
277        catch_unwind(AssertUnwindSafe(|| f(&scope)))
278    };
279    if let Err(panic) = result {
280        panic::resume_unwind(panic);
281    }
282    results.into_iter().map(|mutex| {
283        mutex
284            .into_inner()
285            .expect("All values are set when the scope returns without panic")
286    })
287}
288
289#[cfg(test)]
290mod tests {
291    use std::panic::{AssertUnwindSafe, catch_unwind};
292
293    use super::*;
294
295    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
296    async fn test_scope() {
297        let results = scope_and_block(1000, |scope| {
298            for i in 0..1000 {
299                scope.spawn(move || i);
300            }
301        });
302        let results = results.collect::<Vec<_>>();
303        results.iter().enumerate().for_each(|(i, &result)| {
304            assert_eq!(result, i);
305        });
306        assert_eq!(results.len(), 1000);
307    }
308
309    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
310    async fn test_empty_scope() {
311        let results = scope_and_block(0, |scope| {
312            if false {
313                scope.spawn(|| 42);
314            }
315        });
316        assert_eq!(results.count(), 0);
317    }
318
319    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
320    async fn test_single_task() {
321        let results = scope_and_block(1, |scope| {
322            scope.spawn(|| 42);
323        })
324        .collect::<Vec<_>>();
325        assert_eq!(results, vec![42]);
326    }
327
328    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
329    async fn test_task_finish_before_scope() {
330        let results = scope_and_block(1, |scope| {
331            scope.spawn(|| 42);
332            thread::sleep(std::time::Duration::from_millis(100));
333        })
334        .collect::<Vec<_>>();
335        assert_eq!(results, vec![42]);
336    }
337
338    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
339    async fn test_task_finish_after_scope() {
340        let results = scope_and_block(1, |scope| {
341            scope.spawn(|| {
342                thread::sleep(std::time::Duration::from_millis(100));
343                42
344            });
345        })
346        .collect::<Vec<_>>();
347        assert_eq!(results, vec![42]);
348    }
349
350    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
351    async fn test_panic_in_scope_factory() {
352        let result = catch_unwind(AssertUnwindSafe(|| {
353            let _results = scope_and_block(1000, |scope| {
354                for i in 0..500 {
355                    scope.spawn(move || i);
356                }
357                panic!("Intentional panic");
358            });
359            unreachable!();
360        }));
361        assert!(result.is_err());
362        assert_eq!(
363            result.unwrap_err().downcast_ref::<&str>(),
364            Some(&"Intentional panic")
365        );
366    }
367
368    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
369    async fn test_panic_in_scope_task() {
370        let result = catch_unwind(AssertUnwindSafe(|| {
371            let _results = scope_and_block(1000, |scope| {
372                for i in 0..1000 {
373                    scope.spawn(move || {
374                        if i == 500 {
375                            panic!("Intentional panic");
376                        } else if i == 501 {
377                            panic!("Wrong intentional panic");
378                        } else {
379                            i
380                        }
381                    });
382                }
383            });
384            unreachable!();
385        }));
386        assert!(result.is_err());
387        assert_eq!(
388            result.unwrap_err().downcast_ref::<&str>(),
389            Some(&"Intentional panic")
390        );
391    }
392}