turbo_tasks/task/
task_input.rs

1use std::{any::Any, fmt::Debug, future::Future, hash::Hash, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use either::Either;
5use serde::{Deserialize, Serialize};
6use turbo_rcstr::RcStr;
7
8use crate::{
9    MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, Value, ValueTypeId, Vc,
10    trace::TraceRawVcs,
11};
12
13/// Trait to implement in order for a type to be accepted as a
14/// [`#[turbo_tasks::function]`][crate::function] argument.
15pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
16    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
17        async { Ok(self.clone()) }
18    }
19    fn is_resolved(&self) -> bool {
20        true
21    }
22    fn is_transient(&self) -> bool;
23}
24
25macro_rules! impl_task_input {
26    ($($t:ty),*) => {
27        $(
28            impl TaskInput for $t {
29                fn is_transient(&self) -> bool {
30                    false
31                }
32            }
33        )*
34    };
35}
36
37impl_task_input! {
38    (),
39    bool,
40    u8,
41    u16,
42    u32,
43    i32,
44    u64,
45    usize,
46    RcStr,
47    TaskId,
48    ValueTypeId,
49    Duration
50}
51
52impl<T> TaskInput for Vec<T>
53where
54    T: TaskInput,
55{
56    fn is_resolved(&self) -> bool {
57        self.iter().all(TaskInput::is_resolved)
58    }
59
60    fn is_transient(&self) -> bool {
61        self.iter().any(TaskInput::is_transient)
62    }
63
64    async fn resolve_input(&self) -> Result<Self> {
65        let mut resolved = Vec::with_capacity(self.len());
66        for value in self {
67            resolved.push(value.resolve_input().await?);
68        }
69        Ok(resolved)
70    }
71}
72
73impl<T> TaskInput for Box<T>
74where
75    T: TaskInput,
76{
77    fn is_resolved(&self) -> bool {
78        self.as_ref().is_resolved()
79    }
80
81    fn is_transient(&self) -> bool {
82        self.as_ref().is_transient()
83    }
84
85    async fn resolve_input(&self) -> Result<Self> {
86        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
87    }
88}
89
90impl<T> TaskInput for Arc<T>
91where
92    T: TaskInput,
93{
94    fn is_resolved(&self) -> bool {
95        self.as_ref().is_resolved()
96    }
97
98    fn is_transient(&self) -> bool {
99        self.as_ref().is_transient()
100    }
101
102    async fn resolve_input(&self) -> Result<Self> {
103        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
104    }
105}
106
107impl<T> TaskInput for Option<T>
108where
109    T: TaskInput,
110{
111    fn is_resolved(&self) -> bool {
112        match self {
113            Some(value) => value.is_resolved(),
114            None => true,
115        }
116    }
117
118    fn is_transient(&self) -> bool {
119        match self {
120            Some(value) => value.is_transient(),
121            None => false,
122        }
123    }
124
125    async fn resolve_input(&self) -> Result<Self> {
126        match self {
127            Some(value) => Ok(Some(value.resolve_input().await?)),
128            None => Ok(None),
129        }
130    }
131}
132
133impl<T> TaskInput for Vc<T>
134where
135    T: Send + Sync + ?Sized,
136{
137    fn is_resolved(&self) -> bool {
138        Vc::is_resolved(*self)
139    }
140
141    fn is_transient(&self) -> bool {
142        self.node.is_transient()
143    }
144
145    async fn resolve_input(&self) -> Result<Self> {
146        Vc::resolve(*self).await
147    }
148}
149
150// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
151// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
152impl<T> TaskInput for ResolvedVc<T>
153where
154    T: Send + Sync + ?Sized,
155{
156    fn is_resolved(&self) -> bool {
157        true
158    }
159
160    fn is_transient(&self) -> bool {
161        self.node.is_transient()
162    }
163
164    async fn resolve_input(&self) -> Result<Self> {
165        Ok(*self)
166    }
167}
168
169impl<T> TaskInput for Value<T>
170where
171    T: Any
172        + std::fmt::Debug
173        + Clone
174        + std::hash::Hash
175        + Eq
176        + Send
177        + Sync
178        + Serialize
179        + for<'de> Deserialize<'de>
180        + TraceRawVcs
181        + 'static,
182{
183    fn is_resolved(&self) -> bool {
184        true
185    }
186
187    fn is_transient(&self) -> bool {
188        false
189    }
190}
191
192impl<T> TaskInput for TransientValue<T>
193where
194    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
195{
196    fn is_transient(&self) -> bool {
197        true
198    }
199}
200
201impl<T> Serialize for TransientValue<T>
202where
203    T: MagicAny + Clone + 'static,
204{
205    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: serde::Serializer,
208    {
209        Err(serde::ser::Error::custom(
210            "cannot serialize transient task inputs",
211        ))
212    }
213}
214
215impl<'de, T> Deserialize<'de> for TransientValue<T>
216where
217    T: MagicAny + Clone + 'static,
218{
219    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
220    where
221        D: serde::Deserializer<'de>,
222    {
223        Err(serde::de::Error::custom(
224            "cannot deserialize transient task inputs",
225        ))
226    }
227}
228
229impl<T> TaskInput for TransientInstance<T>
230where
231    T: Sync + Send + TraceRawVcs + 'static,
232{
233    fn is_transient(&self) -> bool {
234        true
235    }
236}
237
238impl<T> Serialize for TransientInstance<T> {
239    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
240    where
241        S: serde::Serializer,
242    {
243        Err(serde::ser::Error::custom(
244            "cannot serialize transient task inputs",
245        ))
246    }
247}
248
249impl<'de, T> Deserialize<'de> for TransientInstance<T> {
250    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
251    where
252        D: serde::Deserializer<'de>,
253    {
254        Err(serde::de::Error::custom(
255            "cannot deserialize transient task inputs",
256        ))
257    }
258}
259
260impl<L, R> TaskInput for Either<L, R>
261where
262    L: TaskInput,
263    R: TaskInput,
264{
265    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
266        self.as_ref().map_either(
267            |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
268            |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
269        )
270    }
271
272    fn is_resolved(&self) -> bool {
273        self.as_ref()
274            .either(TaskInput::is_resolved, TaskInput::is_resolved)
275    }
276
277    fn is_transient(&self) -> bool {
278        self.as_ref()
279            .either(TaskInput::is_transient, TaskInput::is_transient)
280    }
281}
282
283macro_rules! tuple_impls {
284    ( $( $name:ident )+ ) => {
285        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
286        where $($name: TaskInput),+
287        {
288            #[allow(non_snake_case)]
289            fn is_resolved(&self) -> bool {
290                let ($($name,)+) = self;
291                $($name.is_resolved() &&)+ true
292            }
293
294            #[allow(non_snake_case)]
295            fn is_transient(&self) -> bool {
296                let ($($name,)+) = self;
297                $($name.is_transient() ||)+ false
298            }
299
300            #[allow(non_snake_case)]
301            async fn resolve_input(&self) -> Result<Self> {
302                let ($($name,)+) = self;
303                Ok(($($name.resolve_input().await?,)+))
304            }
305        }
306    };
307}
308
309// Implement `TaskInput` for all tuples of 1 to 12 elements.
310tuple_impls! { A }
311tuple_impls! { A B }
312tuple_impls! { A B C }
313tuple_impls! { A B C D }
314tuple_impls! { A B C D E }
315tuple_impls! { A B C D E F }
316tuple_impls! { A B C D E F G }
317tuple_impls! { A B C D E F G H }
318tuple_impls! { A B C D E F G H I }
319tuple_impls! { A B C D E F G H I J }
320tuple_impls! { A B C D E F G H I J K }
321tuple_impls! { A B C D E F G H I J K L }
322
323#[cfg(test)]
324mod tests {
325    use turbo_tasks_macros::TaskInput;
326
327    use super::*;
328    // This is necessary for the derive macro to work, as its expansion refers to
329    // the crate name directly.
330    use crate as turbo_tasks;
331
332    fn assert_task_input<T>(_: T)
333    where
334        T: TaskInput,
335    {
336    }
337
338    #[test]
339    fn test_no_fields() -> Result<()> {
340        #[derive(
341            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
342        )]
343        struct NoFields;
344
345        assert_task_input(NoFields);
346        Ok(())
347    }
348
349    #[test]
350    fn test_one_unnamed_field() -> Result<()> {
351        #[derive(
352            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
353        )]
354        struct OneUnnamedField(u32);
355
356        assert_task_input(OneUnnamedField(42));
357        Ok(())
358    }
359
360    #[test]
361    fn test_multiple_unnamed_fields() -> Result<()> {
362        #[derive(
363            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
364        )]
365        struct MultipleUnnamedFields(u32, RcStr);
366
367        assert_task_input(MultipleUnnamedFields(42, "42".into()));
368        Ok(())
369    }
370
371    #[test]
372    fn test_one_named_field() -> Result<()> {
373        #[derive(
374            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
375        )]
376        struct OneNamedField {
377            named: u32,
378        }
379
380        assert_task_input(OneNamedField { named: 42 });
381        Ok(())
382    }
383
384    #[test]
385    fn test_multiple_named_fields() -> Result<()> {
386        #[derive(
387            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
388        )]
389        struct MultipleNamedFields {
390            named: u32,
391            other: RcStr,
392        }
393
394        assert_task_input(MultipleNamedFields {
395            named: 42,
396            other: "42".into(),
397        });
398        Ok(())
399    }
400
401    #[test]
402    fn test_generic_field() -> Result<()> {
403        #[derive(
404            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
405        )]
406        struct GenericField<T>(T);
407
408        assert_task_input(GenericField(42));
409        assert_task_input(GenericField(RcStr::from("42")));
410        Ok(())
411    }
412
413    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
414    enum OneVariant {
415        Variant,
416    }
417
418    #[test]
419    fn test_one_variant() -> Result<()> {
420        assert_task_input(OneVariant::Variant);
421        Ok(())
422    }
423
424    #[test]
425    fn test_multiple_variants() -> Result<()> {
426        #[derive(
427            Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
428        )]
429        enum MultipleVariants {
430            Variant1,
431            Variant2,
432        }
433
434        assert_task_input(MultipleVariants::Variant2);
435        Ok(())
436    }
437
438    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
439    enum MultipleVariantsAndHeterogeneousFields {
440        Variant1,
441        Variant2(u32),
442        Variant3 { named: u32 },
443        Variant4(u32, RcStr),
444        Variant5 { named: u32, other: RcStr },
445    }
446
447    #[test]
448    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
449        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
450            named: 42,
451            other: "42".into(),
452        });
453        Ok(())
454    }
455
456    #[test]
457    fn test_nested_variants() -> Result<()> {
458        #[derive(
459            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
460        )]
461        enum NestedVariants {
462            Variant1,
463            Variant2(MultipleVariantsAndHeterogeneousFields),
464            Variant3 { named: OneVariant },
465            Variant4(OneVariant, RcStr),
466            Variant5 { named: OneVariant, other: RcStr },
467        }
468
469        assert_task_input(NestedVariants::Variant5 {
470            named: OneVariant::Variant,
471            other: "42".into(),
472        });
473        assert_task_input(NestedVariants::Variant2(
474            MultipleVariantsAndHeterogeneousFields::Variant5 {
475                named: 42,
476                other: "42".into(),
477            },
478        ));
479        Ok(())
480    }
481}