turbo_tasks/task/
function.rs

1//! # Function tasks
2//!
3//! This module contains the trait definitions and implementations that are
4//! necessary for accepting functions as tasks when using the
5//! `turbo_tasks::function` macro.
6//!
7//! This system is inspired by Bevy's Systems and Axum's Handlers.
8//!
9//! The original principle is somewhat simple: a function is accepted if all
10//! of its arguments implement `TaskInput` and its return type implements
11//! `TaskOutput`. There are a few hoops one needs to jump through to make this
12//! work, but they are described in this blog post:
13//! <https://blog.logrocket.com/rust-bevy-entity-component-system/>
14//!
15//! However, there is an additional complication in our case: async methods
16//! that accept a reference to the receiver as their first argument.
17//!
18//! This complication handled through our own version of the `async_trait`
19//! crate, which allows us to target `async fn` as trait bounds. The naive
20//! approach runs into many issues with lifetimes, hence the need for an
21//! intermediate trait. However, this implementation doesn't support all async
22//! methods (see commented out tests).
23
24use std::{future::Future, marker::PhantomData, pin::Pin};
25
26use anyhow::Result;
27
28use super::{TaskInput, TaskOutput};
29use crate::{RawVc, Vc, VcRead, VcValueType, magic_any::MagicAny};
30
31pub type NativeTaskFuture = Pin<Box<dyn Future<Output = Result<RawVc>> + Send>>;
32
33pub trait TaskFn: Send + Sync + 'static {
34    fn functor(&self, this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
35}
36
37pub trait IntoTaskFn<Mode, Inputs> {
38    type TaskFn: TaskFn;
39
40    fn into_task_fn(self) -> Self::TaskFn;
41}
42
43impl<F, Mode, Inputs> IntoTaskFn<Mode, Inputs> for F
44where
45    F: TaskFnInputFunction<Mode, Inputs>,
46    Mode: TaskFnMode,
47    Inputs: TaskInputs,
48{
49    type TaskFn = FunctionTaskFn<F, Mode, Inputs>;
50
51    fn into_task_fn(self) -> Self::TaskFn {
52        FunctionTaskFn {
53            task_fn: self,
54            mode: PhantomData,
55            inputs: PhantomData,
56        }
57    }
58}
59
60pub trait IntoTaskFnWithThis<Mode, This, Inputs> {
61    type TaskFn: TaskFn;
62
63    fn into_task_fn_with_this(self) -> Self::TaskFn;
64}
65
66impl<F, Mode, This, Inputs> IntoTaskFnWithThis<Mode, This, Inputs> for F
67where
68    F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
69    Mode: TaskFnMode,
70    This: Sync + Send + 'static,
71    Inputs: TaskInputs,
72{
73    type TaskFn = FunctionTaskFnWithThis<F, Mode, This, Inputs>;
74
75    fn into_task_fn_with_this(self) -> Self::TaskFn {
76        FunctionTaskFnWithThis {
77            task_fn: self,
78            mode: PhantomData,
79            this: PhantomData,
80            inputs: PhantomData,
81        }
82    }
83}
84
85pub struct FunctionTaskFn<F, Mode: TaskFnMode, Inputs: TaskInputs> {
86    task_fn: F,
87    mode: PhantomData<Mode>,
88    inputs: PhantomData<Inputs>,
89}
90
91impl<F, Mode, Inputs> TaskFn for FunctionTaskFn<F, Mode, Inputs>
92where
93    F: TaskFnInputFunction<Mode, Inputs>,
94    Mode: TaskFnMode,
95    Inputs: TaskInputs,
96{
97    fn functor(&self, _this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
98        TaskFnInputFunction::functor(&self.task_fn, arg)
99    }
100}
101
102pub struct FunctionTaskFnWithThis<
103    F,
104    Mode: TaskFnMode,
105    This: Sync + Send + 'static,
106    Inputs: TaskInputs,
107> {
108    task_fn: F,
109    mode: PhantomData<Mode>,
110    this: PhantomData<This>,
111    inputs: PhantomData<Inputs>,
112}
113
114impl<F, Mode, This, Inputs> TaskFn for FunctionTaskFnWithThis<F, Mode, This, Inputs>
115where
116    F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
117    Mode: TaskFnMode,
118    This: Sync + Send + 'static,
119    Inputs: TaskInputs,
120{
121    fn functor(&self, this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
122        let Some(this) = this else {
123            panic!("Method needs a `self` argument");
124        };
125        TaskFnInputFunctionWithThis::functor(&self.task_fn, this, arg)
126    }
127}
128
129trait TaskFnInputFunction<Mode: TaskFnMode, Inputs: TaskInputs>: Send + Sync + Clone + 'static {
130    fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
131}
132
133trait TaskFnInputFunctionWithThis<Mode: TaskFnMode, This: Sync + Send + 'static, Inputs: TaskInputs>:
134    Send + Sync + Clone + 'static
135{
136    fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
137}
138
139pub trait TaskInputs: Send + Sync + 'static {}
140
141/// Modes to allow multiple `TaskFnInputFunction` blanket implementations on
142/// `Fn`s. Even though the implementations are non-conflicting in practice, they
143/// could be in theory (at least from with the compiler's current limitations).
144/// Despite this, the compiler is still able to infer the correct mode from a
145/// function.
146pub trait TaskFnMode: Send + Sync + 'static {}
147
148pub struct FunctionMode;
149impl TaskFnMode for FunctionMode {}
150
151pub struct AsyncFunctionMode;
152impl TaskFnMode for AsyncFunctionMode {}
153
154pub struct MethodMode;
155impl TaskFnMode for MethodMode {}
156
157pub struct AsyncMethodMode;
158impl TaskFnMode for AsyncMethodMode {}
159
160macro_rules! task_inputs_impl {
161    ( $( $arg:ident )* ) => {
162        impl<$($arg,)*> TaskInputs for ($($arg,)*)
163        where
164            $($arg: TaskInput + 'static,)*
165        {}
166    }
167}
168
169/// Downcast, and clone all the arguments in the singular `arg` tuple.
170///
171/// This helper function for `task_fn_impl!()` reduces the amount of code inside the macro, and
172/// gives the compiler more chances to dedupe monomorphized code across small functions with less
173/// typevars.
174fn get_args<T: MagicAny + Clone>(arg: &dyn MagicAny) -> Result<T> {
175    let value = (arg as &dyn std::any::Any).downcast_ref::<T>().cloned();
176    #[cfg(debug_assertions)]
177    return anyhow::Context::with_context(value, || {
178        crate::native_function::debug_downcast_args_error_msg(
179            std::any::type_name::<T>(),
180            arg.magic_type_name(),
181        )
182    });
183    #[cfg(not(debug_assertions))]
184    return anyhow::Context::context(value, "Invalid argument type");
185}
186
187// Helper function for `task_fn_impl!()`
188async fn output_try_into_non_local_raw_vc(output: impl TaskOutput) -> Result<RawVc> {
189    // TODO: Potential future optimization: If we know we're inside a local task, we can avoid
190    // calling `to_non_local()` here, which might let us avoid constructing a non-local cell for the
191    // local task's return value. Flattening chains of `RawVc::LocalOutput` may still be useful to
192    // reduce traversal later.
193    output.try_into_raw_vc()?.to_non_local().await
194}
195
196macro_rules! task_fn_impl {
197    ( $async_fn_trait:ident $arg_len:literal $( $arg:ident )* ) => {
198        impl<F, Output, $($arg,)*> TaskFnInputFunction<FunctionMode, ($($arg,)*)> for F
199        where
200            $($arg: TaskInput + 'static,)*
201            F: Fn($($arg,)*) -> Output + Send + Sync + Clone + 'static,
202            Output: TaskOutput + 'static,
203        {
204            #[allow(non_snake_case)]
205            fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
206                let task_fn = self.clone();
207                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
208                Ok(Box::pin(async move {
209                    let output = (task_fn)($($arg,)*);
210                    output_try_into_non_local_raw_vc(output).await
211                }))
212            }
213        }
214
215        impl<F, Output, FutureOutput, $($arg,)*> TaskFnInputFunction<AsyncFunctionMode, ($($arg,)*)> for F
216        where
217            $($arg: TaskInput + 'static,)*
218            F: Fn($($arg,)*) -> FutureOutput + Send + Sync + Clone + 'static,
219            FutureOutput: Future<Output = Output> + Send + 'static,
220            Output: TaskOutput + 'static,
221        {
222            #[allow(non_snake_case)]
223            fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
224                let task_fn = self.clone();
225                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
226                Ok(Box::pin(async move {
227                    let output = (task_fn)($($arg,)*).await;
228                    output_try_into_non_local_raw_vc(output).await
229                }))
230            }
231        }
232
233        impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<MethodMode, Recv, ($($arg,)*)> for F
234        where
235            Recv: VcValueType,
236            $($arg: TaskInput + 'static,)*
237            F: Fn(&Recv, $($arg,)*) -> Output + Send + Sync + Clone + 'static,
238            Output: TaskOutput + 'static,
239        {
240            #[allow(non_snake_case)]
241            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
242                let task_fn = self.clone();
243                let recv = Vc::<Recv>::from(this);
244                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
245                Ok(Box::pin(async move {
246                    let recv = recv.await?;
247                    let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
248                    let output = (task_fn)(recv, $($arg,)*);
249                    output_try_into_non_local_raw_vc(output).await
250                }))
251            }
252        }
253
254        impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<FunctionMode, Recv, ($($arg,)*)> for F
255        where
256            Recv: Sync + Send + 'static,
257            $($arg: TaskInput + 'static,)*
258            F: Fn(Vc<Recv>, $($arg,)*) -> Output + Send + Sync + Clone + 'static,
259            Output: TaskOutput + 'static,
260        {
261            #[allow(non_snake_case)]
262            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
263                let task_fn = self.clone();
264                let recv = Vc::<Recv>::from(this);
265                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
266                Ok(Box::pin(async move {
267                    let output = (task_fn)(recv, $($arg,)*);
268                    output_try_into_non_local_raw_vc(output).await
269                }))
270            }
271        }
272
273        pub trait $async_fn_trait<A0, $($arg,)*>: Fn(A0, $($arg,)*) -> Self::OutputFuture {
274            type OutputFuture: Future<Output = <Self as $async_fn_trait<A0, $($arg,)*>>::Output> + Send;
275            type Output: TaskOutput;
276        }
277
278        impl<F: ?Sized, Fut, A0, $($arg,)*> $async_fn_trait<A0, $($arg,)*> for F
279        where
280            F: Fn(A0, $($arg,)*) -> Fut,
281            Fut: Future + Send,
282            Fut::Output: TaskOutput + 'static
283        {
284            type OutputFuture = Fut;
285            type Output = Fut::Output;
286        }
287
288        impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncMethodMode, Recv, ($($arg,)*)> for F
289        where
290            Recv: VcValueType,
291            $($arg: TaskInput + 'static,)*
292            F: for<'a> $async_fn_trait<&'a Recv, $($arg,)*> + Clone + Send + Sync + 'static,
293        {
294            #[allow(non_snake_case)]
295            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
296                let task_fn = self.clone();
297                let recv = Vc::<Recv>::from(this);
298                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
299                Ok(Box::pin(async move {
300                    let recv = recv.await?;
301                    let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
302                    let output = (task_fn)(recv, $($arg,)*).await;
303                    output_try_into_non_local_raw_vc(output).await
304                }))
305            }
306        }
307
308        impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncFunctionMode, Recv, ($($arg,)*)> for F
309        where
310            Recv: Sync + Send + 'static,
311            $($arg: TaskInput + 'static,)*
312            F: $async_fn_trait<Vc<Recv>, $($arg,)*> + Clone + Send + Sync + 'static,
313        {
314            #[allow(non_snake_case)]
315            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
316                let task_fn = self.clone();
317                let recv = Vc::<Recv>::from(this);
318                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
319                Ok(Box::pin(async move {
320                    let output = (task_fn)(recv, $($arg,)*).await;
321                    output_try_into_non_local_raw_vc(output).await
322                }))
323            }
324        }
325    };
326}
327
328task_fn_impl! { AsyncFn0 0 }
329task_fn_impl! { AsyncFn1 1 A1 }
330task_fn_impl! { AsyncFn2 2 A1 A2 }
331task_fn_impl! { AsyncFn3 3 A1 A2 A3 }
332task_fn_impl! { AsyncFn4 4 A1 A2 A3 A4 }
333task_fn_impl! { AsyncFn5 5 A1 A2 A3 A4 A5 }
334task_fn_impl! { AsyncFn6 6 A1 A2 A3 A4 A5 A6 }
335task_fn_impl! { AsyncFn7 7 A1 A2 A3 A4 A5 A6 A7 }
336task_fn_impl! { AsyncFn8 8 A1 A2 A3 A4 A5 A6 A7 A8 }
337task_fn_impl! { AsyncFn9 9 A1 A2 A3 A4 A5 A6 A7 A8 A9 }
338task_fn_impl! { AsyncFn10 10 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
339task_fn_impl! { AsyncFn11 11 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
340task_fn_impl! { AsyncFn12 12 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
341
342// There needs to be one more implementation than task_fn_impl to account for
343// the receiver.
344task_inputs_impl! {}
345task_inputs_impl! { A1 }
346task_inputs_impl! { A1 A2 }
347task_inputs_impl! { A1 A2 A3 }
348task_inputs_impl! { A1 A2 A3 A4 }
349task_inputs_impl! { A1 A2 A3 A4 A5 }
350task_inputs_impl! { A1 A2 A3 A4 A5 A6 }
351task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 }
352task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 }
353task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 }
354task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
355task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
356task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
357task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A13 }
358
359#[cfg(test)]
360mod tests {
361    use turbo_rcstr::RcStr;
362
363    use super::*;
364    use crate::{ShrinkToFit, VcCellNewMode, VcDefaultRead};
365
366    #[test]
367    fn test_task_fn() {
368        fn no_args() -> crate::Vc<i32> {
369            todo!()
370        }
371
372        fn one_arg(_a: i32) -> crate::Vc<i32> {
373            todo!()
374        }
375
376        async fn async_one_arg(_a: i32) -> crate::Vc<i32> {
377            todo!()
378        }
379
380        fn with_recv(_a: &i32) -> crate::Vc<i32> {
381            todo!()
382        }
383
384        async fn async_with_recv(_a: &i32) -> crate::Vc<i32> {
385            todo!()
386        }
387
388        fn with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
389            todo!()
390        }
391
392        async fn async_with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
393            todo!()
394        }
395
396        async fn async_with_recv_and_str_and_result(_a: &i32, _s: RcStr) -> Result<crate::Vc<i32>> {
397            todo!()
398        }
399
400        fn accepts_task_fn<F>(_task_fn: F)
401        where
402            F: TaskFn,
403        {
404        }
405
406        struct Struct;
407        impl Struct {
408            async fn inherent_method(&self) {}
409        }
410
411        impl ShrinkToFit for Struct {
412            fn shrink_to_fit(&mut self) {}
413        }
414
415        unsafe impl VcValueType for Struct {
416            type Read = VcDefaultRead<Struct>;
417
418            type CellMode = VcCellNewMode<Struct>;
419
420            fn get_value_type_id() -> crate::ValueTypeId {
421                todo!()
422            }
423        }
424
425        trait AsyncTrait {
426            async fn async_method(&self);
427        }
428
429        impl AsyncTrait for Struct {
430            async fn async_method(&self) {
431                todo!()
432            }
433        }
434
435        /*
436        async fn async_with_recv_and_str_and_lf(
437            _a: &i32,
438            _s: String,
439        ) -> Result<crate::Vc<i32>, crate::Vc<i32>> {
440            todo!()
441        }
442
443        #[async_trait::async_trait]
444        trait BoxAsyncTrait {
445            async fn box_async_method(&self);
446        }
447
448        #[async_trait::async_trait]
449        impl BoxAsyncTrait for Struct {
450            async fn box_async_method(&self) {
451                todo!()
452            }
453        }
454        */
455
456        let _task_fn = no_args.into_task_fn();
457        accepts_task_fn(no_args.into_task_fn());
458        let _task_fn = one_arg.into_task_fn();
459        accepts_task_fn(one_arg.into_task_fn());
460        let _task_fn = async_one_arg.into_task_fn();
461        accepts_task_fn(async_one_arg.into_task_fn());
462        let task_fn = with_recv.into_task_fn_with_this();
463        accepts_task_fn(task_fn);
464        let task_fn = async_with_recv.into_task_fn_with_this();
465        accepts_task_fn(task_fn);
466        let task_fn = with_recv_and_str.into_task_fn_with_this();
467        accepts_task_fn(task_fn);
468        let task_fn = async_with_recv_and_str.into_task_fn_with_this();
469        accepts_task_fn(task_fn);
470        let task_fn = async_with_recv_and_str_and_result.into_task_fn_with_this();
471        accepts_task_fn(task_fn);
472        let task_fn = <Struct as AsyncTrait>::async_method.into_task_fn_with_this();
473        accepts_task_fn(task_fn);
474        let task_fn = Struct::inherent_method.into_task_fn_with_this();
475        accepts_task_fn(task_fn);
476
477        /*
478        let task_fn = <Struct as BoxAsyncTrait>::box_async_method.into_task_fn();
479        accepts_task_fn(task_fn);
480        let task_fn = async_with_recv_and_str_and_lf.into_task_fn();
481        accepts_task_fn(task_fn);
482        */
483    }
484}