turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    sync::Arc,
7    time::Duration,
8};
9
10use anyhow::Result;
11use either::Either;
12use serde::{Deserialize, Serialize};
13use turbo_rcstr::RcStr;
14
15use crate::{
16    MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
17    trace::TraceRawVcs,
18};
19
20/// Trait to implement in order for a type to be accepted as a
21/// [`#[turbo_tasks::function]`][crate::function] argument.
22///
23/// Serialization must be deterministic and compatible with `eq` comparisons.  If two `TaskInputs`
24/// compare equal they must also serialize to the same bytes.
25pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
26    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
27        async { Ok(self.clone()) }
28    }
29    fn is_resolved(&self) -> bool {
30        true
31    }
32    fn is_transient(&self) -> bool;
33}
34
35macro_rules! impl_task_input {
36    ($($t:ty),*) => {
37        $(
38            impl TaskInput for $t {
39                fn is_transient(&self) -> bool {
40                    false
41                }
42            }
43        )*
44    };
45}
46
47impl_task_input! {
48    (),
49    bool,
50    u8,
51    u16,
52    u32,
53    i32,
54    u64,
55    usize,
56    RcStr,
57    TaskId,
58    ValueTypeId,
59    Duration,
60    String
61}
62
63impl<T> TaskInput for Vec<T>
64where
65    T: TaskInput,
66{
67    fn is_resolved(&self) -> bool {
68        self.iter().all(TaskInput::is_resolved)
69    }
70
71    fn is_transient(&self) -> bool {
72        self.iter().any(TaskInput::is_transient)
73    }
74
75    async fn resolve_input(&self) -> Result<Self> {
76        let mut resolved = Vec::with_capacity(self.len());
77        for value in self {
78            resolved.push(value.resolve_input().await?);
79        }
80        Ok(resolved)
81    }
82}
83
84impl<T> TaskInput for Box<T>
85where
86    T: TaskInput,
87{
88    fn is_resolved(&self) -> bool {
89        self.as_ref().is_resolved()
90    }
91
92    fn is_transient(&self) -> bool {
93        self.as_ref().is_transient()
94    }
95
96    async fn resolve_input(&self) -> Result<Self> {
97        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
98    }
99}
100
101impl<T> TaskInput for Arc<T>
102where
103    T: TaskInput,
104{
105    fn is_resolved(&self) -> bool {
106        self.as_ref().is_resolved()
107    }
108
109    fn is_transient(&self) -> bool {
110        self.as_ref().is_transient()
111    }
112
113    async fn resolve_input(&self) -> Result<Self> {
114        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
115    }
116}
117
118impl<T> TaskInput for ReadRef<T>
119where
120    T: TaskInput,
121{
122    fn is_resolved(&self) -> bool {
123        Self::as_raw_ref(self).is_resolved()
124    }
125
126    fn is_transient(&self) -> bool {
127        Self::as_raw_ref(self).is_transient()
128    }
129
130    async fn resolve_input(&self) -> Result<Self> {
131        Ok(ReadRef::new_owned(
132            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
133        ))
134    }
135}
136
137impl<T> TaskInput for Option<T>
138where
139    T: TaskInput,
140{
141    fn is_resolved(&self) -> bool {
142        match self {
143            Some(value) => value.is_resolved(),
144            None => true,
145        }
146    }
147
148    fn is_transient(&self) -> bool {
149        match self {
150            Some(value) => value.is_transient(),
151            None => false,
152        }
153    }
154
155    async fn resolve_input(&self) -> Result<Self> {
156        match self {
157            Some(value) => Ok(Some(value.resolve_input().await?)),
158            None => Ok(None),
159        }
160    }
161}
162
163impl<T> TaskInput for Vc<T>
164where
165    T: Send + Sync + ?Sized,
166{
167    fn is_resolved(&self) -> bool {
168        Vc::is_resolved(*self)
169    }
170
171    fn is_transient(&self) -> bool {
172        self.node.is_transient()
173    }
174
175    async fn resolve_input(&self) -> Result<Self> {
176        Vc::resolve(*self).await
177    }
178}
179
180// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
181// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
182impl<T> TaskInput for ResolvedVc<T>
183where
184    T: Send + Sync + ?Sized,
185{
186    fn is_resolved(&self) -> bool {
187        true
188    }
189
190    fn is_transient(&self) -> bool {
191        self.node.is_transient()
192    }
193
194    async fn resolve_input(&self) -> Result<Self> {
195        Ok(*self)
196    }
197}
198
199impl<T> TaskInput for TransientValue<T>
200where
201    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
202{
203    fn is_transient(&self) -> bool {
204        true
205    }
206}
207
208impl<T> Serialize for TransientValue<T>
209where
210    T: MagicAny + Clone + 'static,
211{
212    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
213    where
214        S: serde::Serializer,
215    {
216        Err(serde::ser::Error::custom(
217            "cannot serialize transient task inputs",
218        ))
219    }
220}
221
222impl<'de, T> Deserialize<'de> for TransientValue<T>
223where
224    T: MagicAny + Clone + 'static,
225{
226    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
227    where
228        D: serde::Deserializer<'de>,
229    {
230        Err(serde::de::Error::custom(
231            "cannot deserialize transient task inputs",
232        ))
233    }
234}
235
236impl<T> TaskInput for TransientInstance<T>
237where
238    T: Sync + Send + TraceRawVcs + 'static,
239{
240    fn is_transient(&self) -> bool {
241        true
242    }
243}
244
245impl<T> Serialize for TransientInstance<T> {
246    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
247    where
248        S: serde::Serializer,
249    {
250        Err(serde::ser::Error::custom(
251            "cannot serialize transient task inputs",
252        ))
253    }
254}
255
256impl<'de, T> Deserialize<'de> for TransientInstance<T> {
257    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
258    where
259        D: serde::Deserializer<'de>,
260    {
261        Err(serde::de::Error::custom(
262            "cannot deserialize transient task inputs",
263        ))
264    }
265}
266
267impl<L, R> TaskInput for Either<L, R>
268where
269    L: TaskInput,
270    R: TaskInput,
271{
272    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
273        self.as_ref().map_either(
274            |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
275            |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
276        )
277    }
278
279    fn is_resolved(&self) -> bool {
280        self.as_ref()
281            .either(TaskInput::is_resolved, TaskInput::is_resolved)
282    }
283
284    fn is_transient(&self) -> bool {
285        self.as_ref()
286            .either(TaskInput::is_transient, TaskInput::is_transient)
287    }
288}
289
290impl<K, V> TaskInput for BTreeMap<K, V>
291where
292    K: TaskInput + Ord,
293    V: TaskInput,
294{
295    async fn resolve_input(&self) -> Result<Self> {
296        let mut new_map = BTreeMap::new();
297        for (k, v) in self {
298            new_map.insert(
299                TaskInput::resolve_input(k).await?,
300                TaskInput::resolve_input(v).await?,
301            );
302        }
303        Ok(new_map)
304    }
305
306    fn is_resolved(&self) -> bool {
307        self.iter()
308            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
309    }
310
311    fn is_transient(&self) -> bool {
312        self.iter()
313            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
314    }
315}
316
317impl<T> TaskInput for BTreeSet<T>
318where
319    T: TaskInput + Ord,
320{
321    async fn resolve_input(&self) -> Result<Self> {
322        let mut new_map = BTreeSet::new();
323        for value in self {
324            new_map.insert(TaskInput::resolve_input(value).await?);
325        }
326        Ok(new_map)
327    }
328
329    fn is_resolved(&self) -> bool {
330        self.iter().all(TaskInput::is_resolved)
331    }
332
333    fn is_transient(&self) -> bool {
334        self.iter().any(TaskInput::is_transient)
335    }
336}
337
338macro_rules! tuple_impls {
339    ( $( $name:ident )+ ) => {
340        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
341        where $($name: TaskInput),+
342        {
343            #[allow(non_snake_case)]
344            fn is_resolved(&self) -> bool {
345                let ($($name,)+) = self;
346                $($name.is_resolved() &&)+ true
347            }
348
349            #[allow(non_snake_case)]
350            fn is_transient(&self) -> bool {
351                let ($($name,)+) = self;
352                $($name.is_transient() ||)+ false
353            }
354
355            #[allow(non_snake_case)]
356            async fn resolve_input(&self) -> Result<Self> {
357                let ($($name,)+) = self;
358                Ok(($($name.resolve_input().await?,)+))
359            }
360        }
361    };
362}
363
364// Implement `TaskInput` for all tuples of 1 to 12 elements.
365tuple_impls! { A }
366tuple_impls! { A B }
367tuple_impls! { A B C }
368tuple_impls! { A B C D }
369tuple_impls! { A B C D E }
370tuple_impls! { A B C D E F }
371tuple_impls! { A B C D E F G }
372tuple_impls! { A B C D E F G H }
373tuple_impls! { A B C D E F G H I }
374tuple_impls! { A B C D E F G H I J }
375tuple_impls! { A B C D E F G H I J K }
376tuple_impls! { A B C D E F G H I J K L }
377
378#[cfg(test)]
379mod tests {
380    use turbo_rcstr::rcstr;
381    use turbo_tasks_macros::TaskInput;
382
383    use super::*;
384    // This is necessary for the derive macro to work, as its expansion refers to
385    // the crate name directly.
386    use crate as turbo_tasks;
387
388    fn assert_task_input<T>(_: T)
389    where
390        T: TaskInput,
391    {
392    }
393
394    #[test]
395    fn test_no_fields() -> Result<()> {
396        #[derive(
397            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
398        )]
399        struct NoFields;
400
401        assert_task_input(NoFields);
402        Ok(())
403    }
404
405    #[test]
406    fn test_one_unnamed_field() -> Result<()> {
407        #[derive(
408            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
409        )]
410        struct OneUnnamedField(u32);
411
412        assert_task_input(OneUnnamedField(42));
413        Ok(())
414    }
415
416    #[test]
417    fn test_multiple_unnamed_fields() -> Result<()> {
418        #[derive(
419            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
420        )]
421        struct MultipleUnnamedFields(u32, RcStr);
422
423        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
424        Ok(())
425    }
426
427    #[test]
428    fn test_one_named_field() -> Result<()> {
429        #[derive(
430            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
431        )]
432        struct OneNamedField {
433            named: u32,
434        }
435
436        assert_task_input(OneNamedField { named: 42 });
437        Ok(())
438    }
439
440    #[test]
441    fn test_multiple_named_fields() -> Result<()> {
442        #[derive(
443            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
444        )]
445        struct MultipleNamedFields {
446            named: u32,
447            other: RcStr,
448        }
449
450        assert_task_input(MultipleNamedFields {
451            named: 42,
452            other: rcstr!("42"),
453        });
454        Ok(())
455    }
456
457    #[test]
458    fn test_generic_field() -> Result<()> {
459        #[derive(
460            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
461        )]
462        struct GenericField<T>(T);
463
464        assert_task_input(GenericField(42));
465        assert_task_input(GenericField(rcstr!("42")));
466        Ok(())
467    }
468
469    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
470    enum OneVariant {
471        Variant,
472    }
473
474    #[test]
475    fn test_one_variant() -> Result<()> {
476        assert_task_input(OneVariant::Variant);
477        Ok(())
478    }
479
480    #[test]
481    fn test_multiple_variants() -> Result<()> {
482        #[derive(
483            Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
484        )]
485        enum MultipleVariants {
486            Variant1,
487            Variant2,
488        }
489
490        assert_task_input(MultipleVariants::Variant2);
491        Ok(())
492    }
493
494    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
495    enum MultipleVariantsAndHeterogeneousFields {
496        Variant1,
497        Variant2(u32),
498        Variant3 { named: u32 },
499        Variant4(u32, RcStr),
500        Variant5 { named: u32, other: RcStr },
501    }
502
503    #[test]
504    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
505        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
506            named: 42,
507            other: rcstr!("42"),
508        });
509        Ok(())
510    }
511
512    #[test]
513    fn test_nested_variants() -> Result<()> {
514        #[derive(
515            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
516        )]
517        enum NestedVariants {
518            Variant1,
519            Variant2(MultipleVariantsAndHeterogeneousFields),
520            Variant3 { named: OneVariant },
521            Variant4(OneVariant, RcStr),
522            Variant5 { named: OneVariant, other: RcStr },
523        }
524
525        assert_task_input(NestedVariants::Variant5 {
526            named: OneVariant::Variant,
527            other: rcstr!("42"),
528        });
529        assert_task_input(NestedVariants::Variant2(
530            MultipleVariantsAndHeterogeneousFields::Variant5 {
531                named: 42,
532                other: rcstr!("42"),
533            },
534        ));
535        Ok(())
536    }
537}