turbo_tasks/
scope.rs

1//! A scoped tokio spawn implementation that allow a non-'static lifetime for tasks.
2
3use std::{
4    any::Any,
5    marker::PhantomData,
6    panic::{self, AssertUnwindSafe, catch_unwind},
7    pin::Pin,
8    sync::{
9        Arc,
10        atomic::{AtomicUsize, Ordering},
11    },
12    thread::{self, Thread},
13};
14
15use futures::FutureExt;
16use parking_lot::Mutex;
17use tokio::{runtime::Handle, task::block_in_place};
18use tracing::{Instrument, Span, info_span};
19
20use crate::{
21    TurboTasksApi,
22    manager::{try_turbo_tasks, turbo_tasks_future_scope},
23};
24
25struct ScopeInner {
26    main_thread: Thread,
27    remaining_tasks: AtomicUsize,
28    /// The first panic that occurred in the tasks, by task index.
29    /// The usize value is the index of the task.
30    panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
31}
32
33impl ScopeInner {
34    fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
35        if let Some((err, index)) = panic {
36            let mut old_panic = self.panic.lock();
37            if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
38                *old_panic = Some((err, index));
39            }
40        }
41        if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
42            self.main_thread.unpark();
43        }
44    }
45
46    fn wait(&self) {
47        let _span = info_span!("blocking").entered();
48        while self.remaining_tasks.load(Ordering::Acquire) != 0 {
49            thread::park();
50        }
51        if let Some((err, _)) = self.panic.lock().take() {
52            panic::resume_unwind(err);
53        }
54    }
55}
56
57/// Scope to allow spawning tasks with a limited lifetime.
58///
59/// Dropping this Scope will wait for all tasks to complete.
60pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
61    results: &'scope [Mutex<Option<R>>],
62    index: AtomicUsize,
63    inner: Arc<ScopeInner>,
64    handle: Handle,
65    turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
66    span: Span,
67    /// Invariance over 'env, to make sure 'env cannot shrink,
68    /// which is necessary for soundness.
69    ///
70    /// see https://doc.rust-lang.org/src/std/thread/scoped.rs.html#12-29
71    env: PhantomData<&'env mut &'env ()>,
72}
73
74impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
75    /// Creates a new scope.
76    ///
77    /// # Safety
78    ///
79    /// The caller must ensure `Scope` is dropped and not forgotten.
80    unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
81        Self {
82            results,
83            index: AtomicUsize::new(0),
84            inner: Arc::new(ScopeInner {
85                main_thread: thread::current(),
86                remaining_tasks: AtomicUsize::new(0),
87                panic: Mutex::new(None),
88            }),
89            handle: Handle::current(),
90            turbo_tasks: try_turbo_tasks(),
91            span: Span::current(),
92            env: PhantomData,
93        }
94    }
95
96    /// Spawns a new task in the scope.
97    pub fn spawn<F>(&self, f: F)
98    where
99        F: Future<Output = R> + Send + 'env,
100    {
101        let index = self.index.fetch_add(1, Ordering::Relaxed);
102        assert!(index < self.results.len(), "Too many tasks spawned");
103        let result_cell: &Mutex<Option<R>> = &self.results[index];
104
105        let f: Box<dyn Future<Output = ()> + Send + 'scope> = Box::new(async move {
106            let result = f.await;
107            *result_cell.lock() = Some(result);
108        });
109        let f: *mut (dyn Future<Output = ()> + Send + 'scope) = Box::into_raw(f);
110        // SAFETY: Scope ensures (e. g. in Drop) that spawned tasks is awaited before the
111        // lifetime `'env` ends.
112        #[allow(
113            clippy::unnecessary_cast,
114            reason = "Clippy thinks this is unnecessary, but it actually changes the lifetime"
115        )]
116        let f = f as *mut (dyn Future<Output = ()> + Send + 'static);
117        // SAFETY: We just called `Box::into_raw`.
118        let f = unsafe { Box::from_raw(f) };
119        // We pin the future in the box in memory to be able to await it.
120        let f = Pin::from(f);
121
122        let turbo_tasks = self.turbo_tasks.clone();
123        let span = self.span.clone();
124
125        let inner = self.inner.clone();
126        inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
127        self.handle.spawn(async move {
128            let result = AssertUnwindSafe(
129                async move {
130                    if let Some(turbo_tasks) = turbo_tasks {
131                        // Ensure that the turbo tasks context is maintained across the task.
132                        turbo_tasks_future_scope(turbo_tasks, f).await;
133                    } else {
134                        // If no turbo tasks context is available, just run the future.
135                        f.await;
136                    }
137                }
138                .instrument(span),
139            )
140            .catch_unwind()
141            .await;
142            let panic = result.err().map(|e| (e, index));
143            inner.on_task_finished(panic);
144        });
145    }
146}
147
148impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
149    fn drop(&mut self) {
150        self.inner.wait();
151    }
152}
153
154/// Helper method to spawn tasks in parallel, ensuring that all tasks are awaited and errors are
155/// handled. Also ensures turbo tasks and tracing context are maintained across the tasks.
156///
157/// Be aware that although this function avoids starving other independently spawned tasks, any
158/// other code running concurrently in the same task will be suspended during the call to
159/// block_in_place. This can happen e.g. when using the `join!` macro. To avoid this issue, call
160/// `scope_and_block` in `spawn_blocking`.
161pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
162where
163    R: Send + 'env,
164    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
165{
166    block_in_place(|| {
167        let mut results = Vec::with_capacity(number_of_tasks);
168        for _ in 0..number_of_tasks {
169            results.push(Mutex::new(None));
170        }
171        let results = results.into_boxed_slice();
172        let result = {
173            // SAFETY: We drop the Scope later.
174            let scope = unsafe { Scope::new(&results) };
175            catch_unwind(AssertUnwindSafe(|| f(&scope)))
176        };
177        if let Err(panic) = result {
178            panic::resume_unwind(panic);
179        }
180        results.into_iter().map(|mutex| {
181            mutex
182                .into_inner()
183                .expect("All values are set when the scope returns without panic")
184        })
185    })
186}
187
188#[cfg(test)]
189mod tests {
190    use std::panic::{AssertUnwindSafe, catch_unwind};
191
192    use super::*;
193
194    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
195    async fn test_scope() {
196        let results = scope_and_block(1000, |scope| {
197            for i in 0..1000 {
198                scope.spawn(async move { i });
199            }
200        });
201        results.enumerate().for_each(|(i, result)| {
202            assert_eq!(result, i);
203        });
204    }
205
206    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
207    async fn test_empty_scope() {
208        let results = scope_and_block(0, |scope| {
209            if false {
210                scope.spawn(async move { 42 });
211            }
212        });
213        assert_eq!(results.count(), 0);
214    }
215
216    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
217    async fn test_single_task() {
218        let results = scope_and_block(1, |scope| {
219            scope.spawn(async move { 42 });
220        })
221        .collect::<Vec<_>>();
222        assert_eq!(results, vec![42]);
223    }
224
225    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
226    async fn test_task_finish_before_scope() {
227        let results = scope_and_block(1, |scope| {
228            scope.spawn(async move { 42 });
229            thread::sleep(std::time::Duration::from_millis(100));
230        })
231        .collect::<Vec<_>>();
232        assert_eq!(results, vec![42]);
233    }
234
235    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
236    async fn test_task_finish_after_scope() {
237        let results = scope_and_block(1, |scope| {
238            scope.spawn(async move {
239                thread::sleep(std::time::Duration::from_millis(100));
240                42
241            });
242        })
243        .collect::<Vec<_>>();
244        assert_eq!(results, vec![42]);
245    }
246
247    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
248    async fn test_panic_in_scope_factory() {
249        let result = catch_unwind(AssertUnwindSafe(|| {
250            let _results = scope_and_block(1000, |scope| {
251                for i in 0..500 {
252                    scope.spawn(async move { i });
253                }
254                panic!("Intentional panic");
255            });
256            unreachable!();
257        }));
258        assert!(result.is_err());
259        assert_eq!(
260            result.unwrap_err().downcast_ref::<&str>(),
261            Some(&"Intentional panic")
262        );
263    }
264
265    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
266    async fn test_panic_in_scope_task() {
267        let result = catch_unwind(AssertUnwindSafe(|| {
268            let _results = scope_and_block(1000, |scope| {
269                for i in 0..1000 {
270                    scope.spawn(async move {
271                        if i == 500 {
272                            panic!("Intentional panic");
273                        } else if i == 501 {
274                            panic!("Wrong intentional panic");
275                        } else {
276                            i
277                        }
278                    });
279                }
280            });
281            unreachable!();
282        }));
283        assert!(result.is_err());
284        assert_eq!(
285            result.unwrap_err().downcast_ref::<&str>(),
286            Some(&"Intentional panic")
287        );
288    }
289}