Skip to main content

turbo_tasks/task/
function.rs

1//! # Function tasks
2//!
3//! This module contains the trait definitions and implementations that are
4//! necessary for accepting functions as tasks when using the
5//! `turbo_tasks::function` macro.
6//!
7//! This system is inspired by Bevy's Systems and Axum's Handlers.
8//!
9//! The original principle is somewhat simple: a function is accepted if all
10//! of its arguments implement `TaskInput` and its return type implements
11//! `TaskOutput`. There are a few hoops one needs to jump through to make this
12//! work, but they are described in this blog post:
13//! <https://blog.logrocket.com/rust-bevy-entity-component-system/>
14//!
15//! However, there is an additional complication in our case: async methods
16//! that accept a reference to the receiver as their first argument.
17//!
18//! This complication handled through our own version of the `async_trait`
19//! crate, which allows us to target `async fn` as trait bounds. The naive
20//! approach runs into many issues with lifetimes, hence the need for an
21//! intermediate trait. However, this implementation doesn't support all async
22//! methods (see commented out tests).
23
24use std::{future::Future, marker::PhantomData, pin::Pin};
25
26use anyhow::Result;
27
28use super::{TaskInput, TaskOutput};
29use crate::{RawVc, Vc, VcRead, VcValueType, magic_any::MagicAny};
30
31pub type NativeTaskFuture = Pin<Box<dyn Future<Output = Result<RawVc>> + Send>>;
32
33pub trait TaskFn: Send + Sync + 'static {
34    fn functor(&self, this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
35}
36
37/// A trait for `TaskFn` implementations that allows task inputs to be extracted as a type.
38pub trait TaskFnInputs: TaskFn {
39    type INPUTS: TaskInput + TaskInputs;
40}
41
42pub const fn into_task_fn<
43    Mode: TaskFnMode,
44    Inputs: TaskInputs,
45    F: TaskFnInputFunction<Mode, Inputs>,
46>(
47    f: F,
48) -> FunctionTaskFn<F, Mode, Inputs> {
49    FunctionTaskFn {
50        task_fn: f,
51        mode: PhantomData,
52        inputs: PhantomData,
53    }
54}
55
56pub const fn into_task_fn_with_this<
57    Mode: TaskFnMode,
58    This: Send + Sync + 'static,
59    Inputs: TaskInputs,
60    F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
61>(
62    f: F,
63) -> FunctionTaskFnWithThis<F, Mode, This, Inputs> {
64    FunctionTaskFnWithThis {
65        task_fn: f,
66        mode: PhantomData,
67        this: PhantomData,
68        inputs: PhantomData,
69    }
70}
71
72pub struct FunctionTaskFn<F, Mode: TaskFnMode, Inputs: TaskInputs> {
73    task_fn: F,
74    mode: PhantomData<Mode>,
75    inputs: PhantomData<Inputs>,
76}
77
78impl<F, Mode, Inputs> TaskFn for FunctionTaskFn<F, Mode, Inputs>
79where
80    F: TaskFnInputFunction<Mode, Inputs>,
81    Mode: TaskFnMode,
82    Inputs: TaskInputs,
83{
84    fn functor(&self, _this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
85        TaskFnInputFunction::functor(&self.task_fn, arg)
86    }
87}
88
89impl<F, Mode, Inputs> TaskFnInputs for FunctionTaskFn<F, Mode, Inputs>
90where
91    F: TaskFnInputFunction<Mode, Inputs>,
92    Mode: TaskFnMode,
93    Inputs: TaskInputs + TaskInput,
94{
95    type INPUTS = Inputs;
96}
97
98pub struct FunctionTaskFnWithThis<
99    F,
100    Mode: TaskFnMode,
101    This: Sync + Send + 'static,
102    Inputs: TaskInputs,
103> {
104    task_fn: F,
105    mode: PhantomData<Mode>,
106    this: PhantomData<This>,
107    inputs: PhantomData<Inputs>,
108}
109
110impl<F, Mode, This, Inputs> TaskFn for FunctionTaskFnWithThis<F, Mode, This, Inputs>
111where
112    F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
113    Mode: TaskFnMode,
114    This: Sync + Send + 'static,
115    Inputs: TaskInputs,
116{
117    fn functor(&self, this: Option<RawVc>, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
118        let Some(this) = this else {
119            panic!("Method needs a `self` argument");
120        };
121        TaskFnInputFunctionWithThis::functor(&self.task_fn, this, arg)
122    }
123}
124
125impl<F, Mode, This, Inputs> TaskFnInputs for FunctionTaskFnWithThis<F, Mode, This, Inputs>
126where
127    F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
128    Mode: TaskFnMode,
129    This: Sync + Send + 'static,
130    Inputs: TaskInputs + TaskInput,
131{
132    type INPUTS = Inputs;
133}
134
135#[doc(hidden)]
136pub trait TaskFnInputFunction<Mode: TaskFnMode, Inputs: TaskInputs>:
137    Send + Sync + Clone + 'static
138{
139    fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
140}
141
142#[doc(hidden)]
143pub trait TaskFnInputFunctionWithThis<
144    Mode: TaskFnMode,
145    This: Sync + Send + 'static,
146    Inputs: TaskInputs,
147>: Send + Sync + Clone + 'static
148{
149    fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture>;
150}
151
152pub trait TaskInputs: Send + Sync + 'static {}
153
154/// Modes to allow multiple `TaskFnInputFunction` blanket implementations on
155/// `Fn`s. Even though the implementations are non-conflicting in practice, they
156/// could be in theory (at least from with the compiler's current limitations).
157/// Despite this, the compiler is still able to infer the correct mode from a
158/// function.
159pub trait TaskFnMode: Send + Sync + 'static {}
160
161pub struct FunctionMode;
162impl TaskFnMode for FunctionMode {}
163
164pub struct AsyncFunctionMode;
165impl TaskFnMode for AsyncFunctionMode {}
166
167pub struct MethodMode;
168impl TaskFnMode for MethodMode {}
169
170pub struct AsyncMethodMode;
171impl TaskFnMode for AsyncMethodMode {}
172
173macro_rules! task_inputs_impl {
174    ( $( $arg:ident )* ) => {
175        impl<$($arg,)*> TaskInputs for ($($arg,)*)
176        where
177            $($arg: TaskInput + 'static,)*
178        {}
179    }
180}
181
182/// Downcast, and clone all the arguments in the singular `arg` tuple.
183///
184/// This helper function for `task_fn_impl!()` reduces the amount of code inside the macro, and
185/// gives the compiler more chances to dedupe monomorphized code across small functions with less
186/// typevars.
187fn get_args<T: MagicAny + Clone>(arg: &dyn MagicAny) -> Result<T> {
188    let value = (arg as &dyn std::any::Any).downcast_ref::<T>().cloned();
189    #[cfg(debug_assertions)]
190    return anyhow::Context::with_context(value, || {
191        crate::native_function::debug_downcast_args_error_msg(
192            std::any::type_name::<T>(),
193            arg.magic_type_name(),
194        )
195    });
196    #[cfg(not(debug_assertions))]
197    return anyhow::Context::context(value, "Invalid argument type");
198}
199
200// Helper function for `task_fn_impl!()`
201async fn output_try_into_non_local_raw_vc(output: impl TaskOutput) -> Result<RawVc> {
202    output.try_into_raw_vc()?.to_non_local().await
203}
204
205macro_rules! task_fn_impl {
206    ( $async_fn_trait:ident $arg_len:literal $( $arg:ident )* ) => {
207        impl<F, Output, $($arg,)*> TaskFnInputFunction<FunctionMode, ($($arg,)*)> for F
208        where
209            $($arg: TaskInput + 'static,)*
210            F: Fn($($arg,)*) -> Output + Send + Sync + Clone + 'static,
211            Output: TaskOutput + 'static,
212        {
213            #[allow(non_snake_case)]
214            fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
215                let task_fn = self.clone();
216                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
217                Ok(Box::pin(async move {
218                    let output = (task_fn)($($arg,)*);
219                    output_try_into_non_local_raw_vc(output).await
220                }))
221            }
222        }
223
224        impl<F, Output, FutureOutput, $($arg,)*> TaskFnInputFunction<AsyncFunctionMode, ($($arg,)*)> for F
225        where
226            $($arg: TaskInput + 'static,)*
227            F: Fn($($arg,)*) -> FutureOutput + Send + Sync + Clone + 'static,
228            FutureOutput: Future<Output = Output> + Send + 'static,
229            Output: TaskOutput + 'static,
230        {
231            #[allow(non_snake_case)]
232            fn functor(&self, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
233                let task_fn = self.clone();
234                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
235                Ok(Box::pin(async move {
236                    let output = (task_fn)($($arg,)*).await;
237                    output_try_into_non_local_raw_vc(output).await
238                }))
239            }
240        }
241
242        impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<MethodMode, Recv, ($($arg,)*)> for F
243        where
244            Recv: VcValueType,
245            $($arg: TaskInput + 'static,)*
246            F: Fn(&Recv, $($arg,)*) -> Output + Send + Sync + Clone + 'static,
247            Output: TaskOutput + 'static,
248        {
249            #[allow(non_snake_case)]
250            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
251                let task_fn = self.clone();
252                let recv = Vc::<Recv>::from(this);
253                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
254                Ok(Box::pin(async move {
255                    let recv = recv.await?;
256                    let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
257                    let output = (task_fn)(recv, $($arg,)*);
258                    output_try_into_non_local_raw_vc(output).await
259                }))
260            }
261        }
262
263        impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<FunctionMode, Recv, ($($arg,)*)> for F
264        where
265            Recv: Sync + Send + 'static,
266            $($arg: TaskInput + 'static,)*
267            F: Fn(Vc<Recv>, $($arg,)*) -> Output + Send + Sync + Clone + 'static,
268            Output: TaskOutput + 'static,
269        {
270            #[allow(non_snake_case)]
271            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
272                let task_fn = self.clone();
273                let recv = Vc::<Recv>::from(this);
274                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
275                Ok(Box::pin(async move {
276                    let output = (task_fn)(recv, $($arg,)*);
277                    output_try_into_non_local_raw_vc(output).await
278                }))
279            }
280        }
281
282        pub trait $async_fn_trait<A0, $($arg,)*>: Fn(A0, $($arg,)*) -> Self::OutputFuture {
283            type OutputFuture: Future<Output = <Self as $async_fn_trait<A0, $($arg,)*>>::Output> + Send;
284            type Output: TaskOutput;
285        }
286
287        impl<F: ?Sized, Fut, A0, $($arg,)*> $async_fn_trait<A0, $($arg,)*> for F
288        where
289            F: Fn(A0, $($arg,)*) -> Fut,
290            Fut: Future + Send,
291            Fut::Output: TaskOutput + 'static
292        {
293            type OutputFuture = Fut;
294            type Output = Fut::Output;
295        }
296
297        impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncMethodMode, Recv, ($($arg,)*)> for F
298        where
299            Recv: VcValueType,
300            $($arg: TaskInput + 'static,)*
301            F: for<'a> $async_fn_trait<&'a Recv, $($arg,)*> + Clone + Send + Sync + 'static,
302        {
303            #[allow(non_snake_case)]
304            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
305                let task_fn = self.clone();
306                let recv = Vc::<Recv>::from(this);
307                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
308                Ok(Box::pin(async move {
309                    let recv = recv.await?;
310                    let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
311                    let output = (task_fn)(recv, $($arg,)*).await;
312                    output_try_into_non_local_raw_vc(output).await
313                }))
314            }
315        }
316
317        impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncFunctionMode, Recv, ($($arg,)*)> for F
318        where
319            Recv: Sync + Send + 'static,
320            $($arg: TaskInput + 'static,)*
321            F: $async_fn_trait<Vc<Recv>, $($arg,)*> + Clone + Send + Sync + 'static,
322        {
323            #[allow(non_snake_case)]
324            fn functor(&self, this: RawVc, arg: &dyn MagicAny) -> Result<NativeTaskFuture> {
325                let task_fn = self.clone();
326                let recv = Vc::<Recv>::from(this);
327                let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
328                Ok(Box::pin(async move {
329                    let output = (task_fn)(recv, $($arg,)*).await;
330                    output_try_into_non_local_raw_vc(output).await
331                }))
332            }
333        }
334    };
335}
336
337task_fn_impl! { AsyncFn0 0 }
338task_fn_impl! { AsyncFn1 1 A1 }
339task_fn_impl! { AsyncFn2 2 A1 A2 }
340task_fn_impl! { AsyncFn3 3 A1 A2 A3 }
341task_fn_impl! { AsyncFn4 4 A1 A2 A3 A4 }
342task_fn_impl! { AsyncFn5 5 A1 A2 A3 A4 A5 }
343task_fn_impl! { AsyncFn6 6 A1 A2 A3 A4 A5 A6 }
344task_fn_impl! { AsyncFn7 7 A1 A2 A3 A4 A5 A6 A7 }
345task_fn_impl! { AsyncFn8 8 A1 A2 A3 A4 A5 A6 A7 A8 }
346task_fn_impl! { AsyncFn9 9 A1 A2 A3 A4 A5 A6 A7 A8 A9 }
347task_fn_impl! { AsyncFn10 10 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
348task_fn_impl! { AsyncFn11 11 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
349task_fn_impl! { AsyncFn12 12 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
350
351// There needs to be one more implementation than task_fn_impl to account for
352// the receiver.
353task_inputs_impl! {}
354task_inputs_impl! { A1 }
355task_inputs_impl! { A1 A2 }
356task_inputs_impl! { A1 A2 A3 }
357task_inputs_impl! { A1 A2 A3 A4 }
358task_inputs_impl! { A1 A2 A3 A4 A5 }
359task_inputs_impl! { A1 A2 A3 A4 A5 A6 }
360task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 }
361task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 }
362task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 }
363task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
364task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
365task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
366task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A13 }
367
368#[cfg(test)]
369mod tests {
370    use turbo_rcstr::RcStr;
371
372    use super::*;
373    use crate::{ShrinkToFit, VcCellNewMode, VcDefaultRead};
374
375    #[test]
376    fn test_task_fn() {
377        fn no_args() -> crate::Vc<i32> {
378            todo!()
379        }
380
381        fn one_arg(_a: i32) -> crate::Vc<i32> {
382            todo!()
383        }
384
385        async fn async_one_arg(_a: i32) -> crate::Vc<i32> {
386            todo!()
387        }
388
389        fn with_recv(_a: &i32) -> crate::Vc<i32> {
390            todo!()
391        }
392
393        async fn async_with_recv(_a: &i32) -> crate::Vc<i32> {
394            todo!()
395        }
396
397        fn with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
398            todo!()
399        }
400
401        async fn async_with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
402            todo!()
403        }
404
405        async fn async_with_recv_and_str_and_result(_a: &i32, _s: RcStr) -> Result<crate::Vc<i32>> {
406            todo!()
407        }
408
409        fn accepts_task_fn<F>(_task_fn: F)
410        where
411            F: TaskFn,
412        {
413        }
414
415        struct Struct;
416        impl Struct {
417            async fn inherent_method(&self) {}
418        }
419
420        impl ShrinkToFit for Struct {
421            fn shrink_to_fit(&mut self) {}
422        }
423
424        unsafe impl VcValueType for Struct {
425            type Read = VcDefaultRead<Struct>;
426
427            type CellMode = VcCellNewMode<Struct>;
428
429            fn get_value_type_id() -> crate::ValueTypeId {
430                todo!()
431            }
432
433            fn has_serialization() -> bool {
434                false
435            }
436        }
437
438        trait AsyncTrait {
439            async fn async_method(&self);
440        }
441
442        impl AsyncTrait for Struct {
443            async fn async_method(&self) {
444                todo!()
445            }
446        }
447
448        /*
449        async fn async_with_recv_and_str_and_lf(
450            _a: &i32,
451            _s: String,
452        ) -> Result<crate::Vc<i32>, crate::Vc<i32>> {
453            todo!()
454        }
455
456        #[async_trait::async_trait]
457        trait BoxAsyncTrait {
458            async fn box_async_method(&self);
459        }
460
461        #[async_trait::async_trait]
462        impl BoxAsyncTrait for Struct {
463            async fn box_async_method(&self) {
464                todo!()
465            }
466        }
467        */
468
469        let task_fn = into_task_fn(no_args);
470        accepts_task_fn(task_fn);
471        let task_fn = into_task_fn(one_arg);
472        accepts_task_fn(task_fn);
473        let task_fn = into_task_fn(async_one_arg);
474        accepts_task_fn(task_fn);
475        let task_fn = into_task_fn_with_this(with_recv);
476        accepts_task_fn(task_fn);
477        let task_fn = into_task_fn_with_this(async_with_recv);
478        accepts_task_fn(task_fn);
479        let task_fn = into_task_fn_with_this(with_recv_and_str);
480        accepts_task_fn(task_fn);
481        let task_fn = into_task_fn_with_this(async_with_recv_and_str);
482        accepts_task_fn(task_fn);
483        let task_fn = into_task_fn_with_this(async_with_recv_and_str_and_result);
484        accepts_task_fn(task_fn);
485        let task_fn = into_task_fn_with_this(<Struct as AsyncTrait>::async_method);
486        accepts_task_fn(task_fn);
487        let task_fn = into_task_fn_with_this(Struct::inherent_method);
488        accepts_task_fn(task_fn);
489
490        /*
491        let task_fn = <Struct as BoxAsyncTrait>::box_async_method.into_task_fn();
492        accepts_task_fn(task_fn);
493        let task_fn = async_with_recv_and_str_and_lf.into_task_fn();
494        accepts_task_fn(task_fn);
495        */
496    }
497}