1use std::{future::Future, marker::PhantomData, pin::Pin};
25
26use anyhow::Result;
27
28use super::{TaskInput, TaskOutput};
29use crate::{RawVc, Vc, VcRead, VcValueType, dyn_task_inputs::DynTaskInputs};
30
31pub type NativeTaskFuture = Pin<Box<dyn Future<Output = Result<RawVc>> + Send>>;
32
33pub trait TaskFn: Send + Sync + 'static {
34 fn functor(&self, this: Option<RawVc>, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture>;
35}
36
37pub trait TaskFnInputs: TaskFn {
39 type INPUTS: TaskInput + TaskInputs;
40}
41
42pub const fn into_task_fn<
43 Mode: TaskFnMode,
44 Inputs: TaskInputs,
45 F: TaskFnInputFunction<Mode, Inputs>,
46>(
47 f: F,
48) -> FunctionTaskFn<F, Mode, Inputs> {
49 FunctionTaskFn {
50 task_fn: f,
51 mode: PhantomData,
52 inputs: PhantomData,
53 }
54}
55
56pub const fn into_task_fn_with_this<
57 Mode: TaskFnMode,
58 This: Send + Sync + 'static,
59 Inputs: TaskInputs,
60 F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
61>(
62 f: F,
63) -> FunctionTaskFnWithThis<F, Mode, This, Inputs> {
64 FunctionTaskFnWithThis {
65 task_fn: f,
66 mode: PhantomData,
67 this: PhantomData,
68 inputs: PhantomData,
69 }
70}
71
72pub struct FunctionTaskFn<F, Mode: TaskFnMode, Inputs: TaskInputs> {
73 task_fn: F,
74 mode: PhantomData<Mode>,
75 inputs: PhantomData<Inputs>,
76}
77
78impl<F, Mode, Inputs> TaskFn for FunctionTaskFn<F, Mode, Inputs>
79where
80 F: TaskFnInputFunction<Mode, Inputs>,
81 Mode: TaskFnMode,
82 Inputs: TaskInputs,
83{
84 fn functor(&self, _this: Option<RawVc>, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
85 TaskFnInputFunction::functor(self.task_fn, arg)
86 }
87}
88
89impl<F, Mode, Inputs> TaskFnInputs for FunctionTaskFn<F, Mode, Inputs>
90where
91 F: TaskFnInputFunction<Mode, Inputs>,
92 Mode: TaskFnMode,
93 Inputs: TaskInputs + TaskInput,
94{
95 type INPUTS = Inputs;
96}
97
98pub struct FunctionTaskFnWithThis<
99 F,
100 Mode: TaskFnMode,
101 This: Sync + Send + 'static,
102 Inputs: TaskInputs,
103> {
104 task_fn: F,
105 mode: PhantomData<Mode>,
106 this: PhantomData<This>,
107 inputs: PhantomData<Inputs>,
108}
109
110impl<F, Mode, This, Inputs> TaskFn for FunctionTaskFnWithThis<F, Mode, This, Inputs>
111where
112 F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
113 Mode: TaskFnMode,
114 This: Sync + Send + 'static,
115 Inputs: TaskInputs,
116{
117 fn functor(&self, this: Option<RawVc>, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
118 let Some(this) = this else {
119 panic!("Method needs a `self` argument");
120 };
121 TaskFnInputFunctionWithThis::functor(self.task_fn, this, arg)
122 }
123}
124
125impl<F, Mode, This, Inputs> TaskFnInputs for FunctionTaskFnWithThis<F, Mode, This, Inputs>
126where
127 F: TaskFnInputFunctionWithThis<Mode, This, Inputs>,
128 Mode: TaskFnMode,
129 This: Sync + Send + 'static,
130 Inputs: TaskInputs + TaskInput,
131{
132 type INPUTS = Inputs;
133}
134
135#[doc(hidden)]
136pub trait TaskFnInputFunction<Mode: TaskFnMode, Inputs: TaskInputs>:
137 Send + Sync + Copy + 'static
138{
139 fn functor(self, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture>;
140}
141
142#[doc(hidden)]
143pub trait TaskFnInputFunctionWithThis<
144 Mode: TaskFnMode,
145 This: Sync + Send + 'static,
146 Inputs: TaskInputs,
147>: Send + Sync + Copy + 'static
148{
149 fn functor(self, this: RawVc, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture>;
150}
151
152pub trait TaskInputs: Send + Sync + 'static {}
153
154pub trait TaskFnMode: Send + Sync + 'static {}
160
161pub struct FunctionMode;
162impl TaskFnMode for FunctionMode {}
163
164pub struct AsyncFunctionMode;
165impl TaskFnMode for AsyncFunctionMode {}
166
167pub struct MethodMode;
168impl TaskFnMode for MethodMode {}
169
170pub struct AsyncMethodMode;
171impl TaskFnMode for AsyncMethodMode {}
172
173macro_rules! task_inputs_impl {
174 ( $( $arg:ident )* ) => {
175 impl<$($arg,)*> TaskInputs for ($($arg,)*)
176 where
177 $($arg: TaskInput + 'static,)*
178 {}
179 }
180}
181
182fn get_args<T: DynTaskInputs + Clone>(arg: &dyn DynTaskInputs) -> Result<T> {
188 let value = (arg as &dyn std::any::Any).downcast_ref::<T>().cloned();
189 #[cfg(debug_assertions)]
190 return anyhow::Context::with_context(value, || {
191 crate::native_function::debug_downcast_args_error_msg(
192 std::any::type_name::<T>(),
193 arg.dyn_type_name(),
194 )
195 });
196 #[cfg(not(debug_assertions))]
197 return anyhow::Context::context(value, "Invalid argument type");
198}
199
200macro_rules! task_fn_impl {
201 ( $async_fn_trait:ident $arg_len:literal $( $arg:ident )* ) => {
202 impl<F, Output, $($arg,)*> TaskFnInputFunction<FunctionMode, ($($arg,)*)> for F
203 where
204 $($arg: TaskInput + 'static,)*
205 F: Fn($($arg,)*) -> Output + Send + Sync + Copy + 'static,
206 Output: TaskOutput + 'static,
207 {
208 #[allow(non_snake_case)]
209 fn functor(self, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
210 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
211 Ok(Box::pin(async move {
212 (self)($($arg,)*).try_into_raw_vc()
213 }))
214 }
215 }
216
217 impl<F, Output, FutureOutput, $($arg,)*> TaskFnInputFunction<AsyncFunctionMode, ($($arg,)*)> for F
218 where
219 $($arg: TaskInput + 'static,)*
220 F: Fn($($arg,)*) -> FutureOutput + Send + Sync + Copy + 'static,
221 FutureOutput: Future<Output = Output> + Send + 'static,
222 Output: TaskOutput + 'static,
223 {
224 #[allow(non_snake_case)]
225 fn functor(self, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
226 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
227 Ok(Box::pin(async move {
228 (self)($($arg,)*).await.try_into_raw_vc()
229 }))
230 }
231 }
232
233 impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<MethodMode, Recv, ($($arg,)*)> for F
234 where
235 Recv: VcValueType,
236 $($arg: TaskInput + 'static,)*
237 F: Fn(&Recv, $($arg,)*) -> Output + Send + Sync + Copy + 'static,
238 Output: TaskOutput + 'static,
239 {
240 #[allow(non_snake_case)]
241 fn functor(self, this: RawVc, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
242 let recv = Vc::<Recv>::from(this);
243 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
244 Ok(Box::pin(async move {
245 let recv = recv.await?;
246 let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
247 (self)(recv, $($arg,)*).try_into_raw_vc()
248 }))
249 }
250 }
251
252 impl<F, Output, Recv, $($arg,)*> TaskFnInputFunctionWithThis<FunctionMode, Recv, ($($arg,)*)> for F
253 where
254 Recv: Sync + Send + 'static,
255 $($arg: TaskInput + 'static,)*
256 F: Fn(Vc<Recv>, $($arg,)*) -> Output + Send + Sync + Copy + 'static,
257 Output: TaskOutput + 'static,
258 {
259 #[allow(non_snake_case)]
260 fn functor(self, this: RawVc, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
261 let recv = Vc::<Recv>::from(this);
262 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
263 Ok(Box::pin(async move {
264 (self)(recv, $($arg,)*).try_into_raw_vc()
265 }))
266 }
267 }
268
269 pub trait $async_fn_trait<A0, $($arg,)*>: Fn(A0, $($arg,)*) -> Self::OutputFuture {
270 type OutputFuture: Future<Output = <Self as $async_fn_trait<A0, $($arg,)*>>::Output> + Send;
271 type Output: TaskOutput;
272 }
273
274 impl<F: ?Sized, Fut, A0, $($arg,)*> $async_fn_trait<A0, $($arg,)*> for F
275 where
276 F: Fn(A0, $($arg,)*) -> Fut,
277 Fut: Future + Send,
278 Fut::Output: TaskOutput + 'static
279 {
280 type OutputFuture = Fut;
281 type Output = Fut::Output;
282 }
283
284 impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncMethodMode, Recv, ($($arg,)*)> for F
285 where
286 Recv: VcValueType,
287 $($arg: TaskInput + 'static,)*
288 F: for<'a> $async_fn_trait<&'a Recv, $($arg,)*> + Copy + Send + Sync + 'static,
289 {
290 #[allow(non_snake_case)]
291 fn functor(self, this: RawVc, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
292 let recv = Vc::<Recv>::from(this);
293 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
294 Ok(Box::pin(async move {
295 let recv = recv.await?;
296 let recv = <Recv::Read as VcRead<Recv>>::target_to_value_ref(&*recv);
297 (self)(recv, $($arg,)*).await.try_into_raw_vc()
298 }))
299 }
300 }
301
302 impl<F, Recv, $($arg,)*> TaskFnInputFunctionWithThis<AsyncFunctionMode, Recv, ($($arg,)*)> for F
303 where
304 Recv: Sync + Send + 'static,
305 $($arg: TaskInput + 'static,)*
306 F: $async_fn_trait<Vc<Recv>, $($arg,)*> + Copy + Send + Sync + 'static,
307 {
308 #[allow(non_snake_case)]
309 fn functor(self, this: RawVc, arg: &dyn DynTaskInputs) -> Result<NativeTaskFuture> {
310 let recv = Vc::<Recv>::from(this);
311 let ($($arg,)*) = get_args::<($($arg,)*)>(arg)?;
312 Ok(Box::pin(async move {
313 (self)(recv, $($arg,)*).await.try_into_raw_vc()
314 }))
315 }
316 }
317 };
318}
319
320task_fn_impl! { AsyncFn0 0 }
321task_fn_impl! { AsyncFn1 1 A1 }
322task_fn_impl! { AsyncFn2 2 A1 A2 }
323task_fn_impl! { AsyncFn3 3 A1 A2 A3 }
324task_fn_impl! { AsyncFn4 4 A1 A2 A3 A4 }
325task_fn_impl! { AsyncFn5 5 A1 A2 A3 A4 A5 }
326task_fn_impl! { AsyncFn6 6 A1 A2 A3 A4 A5 A6 }
327task_fn_impl! { AsyncFn7 7 A1 A2 A3 A4 A5 A6 A7 }
328task_fn_impl! { AsyncFn8 8 A1 A2 A3 A4 A5 A6 A7 A8 }
329task_fn_impl! { AsyncFn9 9 A1 A2 A3 A4 A5 A6 A7 A8 A9 }
330task_fn_impl! { AsyncFn10 10 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
331task_fn_impl! { AsyncFn11 11 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
332task_fn_impl! { AsyncFn12 12 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
333
334task_inputs_impl! {}
337task_inputs_impl! { A1 }
338task_inputs_impl! { A1 A2 }
339task_inputs_impl! { A1 A2 A3 }
340task_inputs_impl! { A1 A2 A3 A4 }
341task_inputs_impl! { A1 A2 A3 A4 A5 }
342task_inputs_impl! { A1 A2 A3 A4 A5 A6 }
343task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 }
344task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 }
345task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 }
346task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 }
347task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 }
348task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 }
349task_inputs_impl! { A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A13 }
350
351#[cfg(test)]
352mod tests {
353 use turbo_rcstr::RcStr;
354
355 use super::*;
356 use crate::{ShrinkToFit, VcCellNewMode, VcDefaultRead};
357
358 #[test]
359 fn test_task_fn() {
360 fn no_args() -> crate::Vc<i32> {
361 todo!()
362 }
363
364 fn one_arg(_a: i32) -> crate::Vc<i32> {
365 todo!()
366 }
367
368 async fn async_one_arg(_a: i32) -> crate::Vc<i32> {
369 todo!()
370 }
371
372 fn with_recv(_a: &i32) -> crate::Vc<i32> {
373 todo!()
374 }
375
376 async fn async_with_recv(_a: &i32) -> crate::Vc<i32> {
377 todo!()
378 }
379
380 fn with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
381 todo!()
382 }
383
384 async fn async_with_recv_and_str(_a: &i32, _s: RcStr) -> crate::Vc<i32> {
385 todo!()
386 }
387
388 async fn async_with_recv_and_str_and_result(_a: &i32, _s: RcStr) -> Result<crate::Vc<i32>> {
389 todo!()
390 }
391
392 fn accepts_task_fn<F>(_task_fn: F)
393 where
394 F: TaskFn,
395 {
396 }
397
398 struct Struct;
399 impl Struct {
400 async fn inherent_method(&self) {}
401 }
402
403 impl ShrinkToFit for Struct {
404 fn shrink_to_fit(&mut self) {}
405 }
406
407 unsafe impl VcValueType for Struct {
408 type Read = VcDefaultRead<Struct>;
409
410 type CellMode = VcCellNewMode<Struct>;
411
412 fn get_value_type_id() -> crate::ValueTypeId {
413 todo!()
414 }
415
416 fn has_serialization() -> bool {
417 false
418 }
419 }
420
421 trait AsyncTrait {
422 async fn async_method(&self);
423 }
424
425 impl AsyncTrait for Struct {
426 async fn async_method(&self) {
427 todo!()
428 }
429 }
430
431 let task_fn = into_task_fn(no_args);
453 accepts_task_fn(task_fn);
454 let task_fn = into_task_fn(one_arg);
455 accepts_task_fn(task_fn);
456 let task_fn = into_task_fn(async_one_arg);
457 accepts_task_fn(task_fn);
458 let task_fn = into_task_fn_with_this(with_recv);
459 accepts_task_fn(task_fn);
460 let task_fn = into_task_fn_with_this(async_with_recv);
461 accepts_task_fn(task_fn);
462 let task_fn = into_task_fn_with_this(with_recv_and_str);
463 accepts_task_fn(task_fn);
464 let task_fn = into_task_fn_with_this(async_with_recv_and_str);
465 accepts_task_fn(task_fn);
466 let task_fn = into_task_fn_with_this(async_with_recv_and_str_and_result);
467 accepts_task_fn(task_fn);
468 let task_fn = into_task_fn_with_this(<Struct as AsyncTrait>::async_method);
469 accepts_task_fn(task_fn);
470 let task_fn = into_task_fn_with_this(Struct::inherent_method);
471 accepts_task_fn(task_fn);
472
473 }
480}