1use std::{
4 any::Any,
5 marker::PhantomData,
6 panic::{self, AssertUnwindSafe, catch_unwind},
7 pin::Pin,
8 sync::{
9 Arc,
10 atomic::{AtomicUsize, Ordering},
11 },
12 thread::{self, Thread},
13};
14
15use futures::FutureExt;
16use parking_lot::Mutex;
17use tokio::{runtime::Handle, task::block_in_place};
18use tracing::{Instrument, Span, info_span};
19
20use crate::{
21 TurboTasksApi,
22 manager::{try_turbo_tasks, turbo_tasks_future_scope},
23};
24
25struct ScopeInner {
26 main_thread: Thread,
27 remaining_tasks: AtomicUsize,
28 panic: Mutex<Option<(Box<dyn Any + Send + 'static>, usize)>>,
31}
32
33impl ScopeInner {
34 fn on_task_finished(&self, panic: Option<(Box<dyn Any + Send + 'static>, usize)>) {
35 if let Some((err, index)) = panic {
36 let mut old_panic = self.panic.lock();
37 if old_panic.as_ref().is_none_or(|&(_, i)| i > index) {
38 *old_panic = Some((err, index));
39 }
40 }
41 if self.remaining_tasks.fetch_sub(1, Ordering::Release) == 1 {
42 self.main_thread.unpark();
43 }
44 }
45
46 fn wait(&self) {
47 let _span = info_span!("blocking").entered();
48 while self.remaining_tasks.load(Ordering::Acquire) != 0 {
49 thread::park();
50 }
51 if let Some((err, _)) = self.panic.lock().take() {
52 panic::resume_unwind(err);
53 }
54 }
55}
56
57pub struct Scope<'scope, 'env: 'scope, R: Send + 'env> {
61 results: &'scope [Mutex<Option<R>>],
62 index: AtomicUsize,
63 inner: Arc<ScopeInner>,
64 handle: Handle,
65 turbo_tasks: Option<Arc<dyn TurboTasksApi>>,
66 span: Span,
67 env: PhantomData<&'env mut &'env ()>,
72}
73
74impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
75 unsafe fn new(results: &'scope [Mutex<Option<R>>]) -> Self {
81 Self {
82 results,
83 index: AtomicUsize::new(0),
84 inner: Arc::new(ScopeInner {
85 main_thread: thread::current(),
86 remaining_tasks: AtomicUsize::new(0),
87 panic: Mutex::new(None),
88 }),
89 handle: Handle::current(),
90 turbo_tasks: try_turbo_tasks(),
91 span: Span::current(),
92 env: PhantomData,
93 }
94 }
95
96 pub fn spawn<F>(&self, f: F)
98 where
99 F: Future<Output = R> + Send + 'env,
100 {
101 let index = self.index.fetch_add(1, Ordering::Relaxed);
102 assert!(index < self.results.len(), "Too many tasks spawned");
103 let result_cell: &Mutex<Option<R>> = &self.results[index];
104
105 let f: Box<dyn Future<Output = ()> + Send + 'scope> = Box::new(async move {
106 let result = f.await;
107 *result_cell.lock() = Some(result);
108 });
109 let f: *mut (dyn Future<Output = ()> + Send + 'scope) = Box::into_raw(f);
110 #[allow(
113 clippy::unnecessary_cast,
114 reason = "Clippy thinks this is unnecessary, but it actually changes the lifetime"
115 )]
116 let f = f as *mut (dyn Future<Output = ()> + Send + 'static);
117 let f = unsafe { Box::from_raw(f) };
119 let f = Pin::from(f);
121
122 let turbo_tasks = self.turbo_tasks.clone();
123 let span = self.span.clone();
124
125 let inner = self.inner.clone();
126 inner.remaining_tasks.fetch_add(1, Ordering::Relaxed);
127 self.handle.spawn(async move {
128 let result = AssertUnwindSafe(
129 async move {
130 if let Some(turbo_tasks) = turbo_tasks {
131 turbo_tasks_future_scope(turbo_tasks, f).await;
133 } else {
134 f.await;
136 }
137 }
138 .instrument(span),
139 )
140 .catch_unwind()
141 .await;
142 let panic = result.err().map(|e| (e, index));
143 inner.on_task_finished(panic);
144 });
145 }
146}
147
148impl<'scope, 'env: 'scope, R: Send + 'env> Drop for Scope<'scope, 'env, R> {
149 fn drop(&mut self) {
150 self.inner.wait();
151 }
152}
153
154pub fn scope_and_block<'env, F, R>(number_of_tasks: usize, f: F) -> impl Iterator<Item = R>
162where
163 R: Send + 'env,
164 F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, R>) + 'env,
165{
166 block_in_place(|| {
167 let mut results = Vec::with_capacity(number_of_tasks);
168 for _ in 0..number_of_tasks {
169 results.push(Mutex::new(None));
170 }
171 let results = results.into_boxed_slice();
172 let result = {
173 let scope = unsafe { Scope::new(&results) };
175 catch_unwind(AssertUnwindSafe(|| f(&scope)))
176 };
177 if let Err(panic) = result {
178 panic::resume_unwind(panic);
179 }
180 results.into_iter().map(|mutex| {
181 mutex
182 .into_inner()
183 .expect("All values are set when the scope returns without panic")
184 })
185 })
186}
187
188#[cfg(test)]
189mod tests {
190 use std::panic::{AssertUnwindSafe, catch_unwind};
191
192 use super::*;
193
194 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
195 async fn test_scope() {
196 let results = scope_and_block(1000, |scope| {
197 for i in 0..1000 {
198 scope.spawn(async move { i });
199 }
200 });
201 results.enumerate().for_each(|(i, result)| {
202 assert_eq!(result, i);
203 });
204 }
205
206 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
207 async fn test_empty_scope() {
208 let results = scope_and_block(0, |scope| {
209 if false {
210 scope.spawn(async move { 42 });
211 }
212 });
213 assert_eq!(results.count(), 0);
214 }
215
216 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
217 async fn test_single_task() {
218 let results = scope_and_block(1, |scope| {
219 scope.spawn(async move { 42 });
220 })
221 .collect::<Vec<_>>();
222 assert_eq!(results, vec![42]);
223 }
224
225 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
226 async fn test_task_finish_before_scope() {
227 let results = scope_and_block(1, |scope| {
228 scope.spawn(async move { 42 });
229 thread::sleep(std::time::Duration::from_millis(100));
230 })
231 .collect::<Vec<_>>();
232 assert_eq!(results, vec![42]);
233 }
234
235 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
236 async fn test_task_finish_after_scope() {
237 let results = scope_and_block(1, |scope| {
238 scope.spawn(async move {
239 thread::sleep(std::time::Duration::from_millis(100));
240 42
241 });
242 })
243 .collect::<Vec<_>>();
244 assert_eq!(results, vec![42]);
245 }
246
247 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
248 async fn test_panic_in_scope_factory() {
249 let result = catch_unwind(AssertUnwindSafe(|| {
250 let _results = scope_and_block(1000, |scope| {
251 for i in 0..500 {
252 scope.spawn(async move { i });
253 }
254 panic!("Intentional panic");
255 });
256 unreachable!();
257 }));
258 assert!(result.is_err());
259 assert_eq!(
260 result.unwrap_err().downcast_ref::<&str>(),
261 Some(&"Intentional panic")
262 );
263 }
264
265 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
266 async fn test_panic_in_scope_task() {
267 let result = catch_unwind(AssertUnwindSafe(|| {
268 let _results = scope_and_block(1000, |scope| {
269 for i in 0..1000 {
270 scope.spawn(async move {
271 if i == 500 {
272 panic!("Intentional panic");
273 } else if i == 501 {
274 panic!("Wrong intentional panic");
275 } else {
276 i
277 }
278 });
279 }
280 });
281 unreachable!();
282 }));
283 assert!(result.is_err());
284 assert_eq!(
285 result.unwrap_err().downcast_ref::<&str>(),
286 Some(&"Intentional panic")
287 );
288 }
289}