turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    sync::Arc,
7    time::Duration,
8};
9
10use anyhow::Result;
11use either::Either;
12use serde::{Deserialize, Serialize};
13use turbo_rcstr::RcStr;
14
15use crate::{
16    MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
17    trace::TraceRawVcs,
18};
19
20/// Trait to implement in order for a type to be accepted as a
21/// [`#[turbo_tasks::function]`][crate::function] argument.
22pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
23    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
24        async { Ok(self.clone()) }
25    }
26    fn is_resolved(&self) -> bool {
27        true
28    }
29    fn is_transient(&self) -> bool;
30}
31
32macro_rules! impl_task_input {
33    ($($t:ty),*) => {
34        $(
35            impl TaskInput for $t {
36                fn is_transient(&self) -> bool {
37                    false
38                }
39            }
40        )*
41    };
42}
43
44impl_task_input! {
45    (),
46    bool,
47    u8,
48    u16,
49    u32,
50    i32,
51    u64,
52    usize,
53    RcStr,
54    TaskId,
55    ValueTypeId,
56    Duration,
57    String
58}
59
60impl<T> TaskInput for Vec<T>
61where
62    T: TaskInput,
63{
64    fn is_resolved(&self) -> bool {
65        self.iter().all(TaskInput::is_resolved)
66    }
67
68    fn is_transient(&self) -> bool {
69        self.iter().any(TaskInput::is_transient)
70    }
71
72    async fn resolve_input(&self) -> Result<Self> {
73        let mut resolved = Vec::with_capacity(self.len());
74        for value in self {
75            resolved.push(value.resolve_input().await?);
76        }
77        Ok(resolved)
78    }
79}
80
81impl<T> TaskInput for Box<T>
82where
83    T: TaskInput,
84{
85    fn is_resolved(&self) -> bool {
86        self.as_ref().is_resolved()
87    }
88
89    fn is_transient(&self) -> bool {
90        self.as_ref().is_transient()
91    }
92
93    async fn resolve_input(&self) -> Result<Self> {
94        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
95    }
96}
97
98impl<T> TaskInput for Arc<T>
99where
100    T: TaskInput,
101{
102    fn is_resolved(&self) -> bool {
103        self.as_ref().is_resolved()
104    }
105
106    fn is_transient(&self) -> bool {
107        self.as_ref().is_transient()
108    }
109
110    async fn resolve_input(&self) -> Result<Self> {
111        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
112    }
113}
114
115impl<T> TaskInput for Option<T>
116where
117    T: TaskInput,
118{
119    fn is_resolved(&self) -> bool {
120        match self {
121            Some(value) => value.is_resolved(),
122            None => true,
123        }
124    }
125
126    fn is_transient(&self) -> bool {
127        match self {
128            Some(value) => value.is_transient(),
129            None => false,
130        }
131    }
132
133    async fn resolve_input(&self) -> Result<Self> {
134        match self {
135            Some(value) => Ok(Some(value.resolve_input().await?)),
136            None => Ok(None),
137        }
138    }
139}
140
141impl<T> TaskInput for Vc<T>
142where
143    T: Send + Sync + ?Sized,
144{
145    fn is_resolved(&self) -> bool {
146        Vc::is_resolved(*self)
147    }
148
149    fn is_transient(&self) -> bool {
150        self.node.is_transient()
151    }
152
153    async fn resolve_input(&self) -> Result<Self> {
154        Vc::resolve(*self).await
155    }
156}
157
158// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
159// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
160impl<T> TaskInput for ResolvedVc<T>
161where
162    T: Send + Sync + ?Sized,
163{
164    fn is_resolved(&self) -> bool {
165        true
166    }
167
168    fn is_transient(&self) -> bool {
169        self.node.is_transient()
170    }
171
172    async fn resolve_input(&self) -> Result<Self> {
173        Ok(*self)
174    }
175}
176
177impl<T> TaskInput for TransientValue<T>
178where
179    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
180{
181    fn is_transient(&self) -> bool {
182        true
183    }
184}
185
186impl<T> Serialize for TransientValue<T>
187where
188    T: MagicAny + Clone + 'static,
189{
190    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
191    where
192        S: serde::Serializer,
193    {
194        Err(serde::ser::Error::custom(
195            "cannot serialize transient task inputs",
196        ))
197    }
198}
199
200impl<'de, T> Deserialize<'de> for TransientValue<T>
201where
202    T: MagicAny + Clone + 'static,
203{
204    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
205    where
206        D: serde::Deserializer<'de>,
207    {
208        Err(serde::de::Error::custom(
209            "cannot deserialize transient task inputs",
210        ))
211    }
212}
213
214impl<T> TaskInput for TransientInstance<T>
215where
216    T: Sync + Send + TraceRawVcs + 'static,
217{
218    fn is_transient(&self) -> bool {
219        true
220    }
221}
222
223impl<T> Serialize for TransientInstance<T> {
224    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
225    where
226        S: serde::Serializer,
227    {
228        Err(serde::ser::Error::custom(
229            "cannot serialize transient task inputs",
230        ))
231    }
232}
233
234impl<'de, T> Deserialize<'de> for TransientInstance<T> {
235    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
236    where
237        D: serde::Deserializer<'de>,
238    {
239        Err(serde::de::Error::custom(
240            "cannot deserialize transient task inputs",
241        ))
242    }
243}
244
245impl<L, R> TaskInput for Either<L, R>
246where
247    L: TaskInput,
248    R: TaskInput,
249{
250    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
251        self.as_ref().map_either(
252            |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
253            |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
254        )
255    }
256
257    fn is_resolved(&self) -> bool {
258        self.as_ref()
259            .either(TaskInput::is_resolved, TaskInput::is_resolved)
260    }
261
262    fn is_transient(&self) -> bool {
263        self.as_ref()
264            .either(TaskInput::is_transient, TaskInput::is_transient)
265    }
266}
267
268impl<K, V> TaskInput for BTreeMap<K, V>
269where
270    K: TaskInput + Ord,
271    V: TaskInput,
272{
273    async fn resolve_input(&self) -> Result<Self> {
274        let mut new_map = BTreeMap::new();
275        for (k, v) in self {
276            new_map.insert(
277                TaskInput::resolve_input(k).await?,
278                TaskInput::resolve_input(v).await?,
279            );
280        }
281        Ok(new_map)
282    }
283
284    fn is_resolved(&self) -> bool {
285        self.iter()
286            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
287    }
288
289    fn is_transient(&self) -> bool {
290        self.iter()
291            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
292    }
293}
294
295impl<T> TaskInput for BTreeSet<T>
296where
297    T: TaskInput + Ord,
298{
299    async fn resolve_input(&self) -> Result<Self> {
300        let mut new_map = BTreeSet::new();
301        for value in self {
302            new_map.insert(TaskInput::resolve_input(value).await?);
303        }
304        Ok(new_map)
305    }
306
307    fn is_resolved(&self) -> bool {
308        self.iter().all(TaskInput::is_resolved)
309    }
310
311    fn is_transient(&self) -> bool {
312        self.iter().any(TaskInput::is_transient)
313    }
314}
315
316macro_rules! tuple_impls {
317    ( $( $name:ident )+ ) => {
318        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
319        where $($name: TaskInput),+
320        {
321            #[allow(non_snake_case)]
322            fn is_resolved(&self) -> bool {
323                let ($($name,)+) = self;
324                $($name.is_resolved() &&)+ true
325            }
326
327            #[allow(non_snake_case)]
328            fn is_transient(&self) -> bool {
329                let ($($name,)+) = self;
330                $($name.is_transient() ||)+ false
331            }
332
333            #[allow(non_snake_case)]
334            async fn resolve_input(&self) -> Result<Self> {
335                let ($($name,)+) = self;
336                Ok(($($name.resolve_input().await?,)+))
337            }
338        }
339    };
340}
341
342// Implement `TaskInput` for all tuples of 1 to 12 elements.
343tuple_impls! { A }
344tuple_impls! { A B }
345tuple_impls! { A B C }
346tuple_impls! { A B C D }
347tuple_impls! { A B C D E }
348tuple_impls! { A B C D E F }
349tuple_impls! { A B C D E F G }
350tuple_impls! { A B C D E F G H }
351tuple_impls! { A B C D E F G H I }
352tuple_impls! { A B C D E F G H I J }
353tuple_impls! { A B C D E F G H I J K }
354tuple_impls! { A B C D E F G H I J K L }
355
356#[cfg(test)]
357mod tests {
358    use turbo_rcstr::rcstr;
359    use turbo_tasks_macros::TaskInput;
360
361    use super::*;
362    // This is necessary for the derive macro to work, as its expansion refers to
363    // the crate name directly.
364    use crate as turbo_tasks;
365
366    fn assert_task_input<T>(_: T)
367    where
368        T: TaskInput,
369    {
370    }
371
372    #[test]
373    fn test_no_fields() -> Result<()> {
374        #[derive(
375            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
376        )]
377        struct NoFields;
378
379        assert_task_input(NoFields);
380        Ok(())
381    }
382
383    #[test]
384    fn test_one_unnamed_field() -> Result<()> {
385        #[derive(
386            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
387        )]
388        struct OneUnnamedField(u32);
389
390        assert_task_input(OneUnnamedField(42));
391        Ok(())
392    }
393
394    #[test]
395    fn test_multiple_unnamed_fields() -> Result<()> {
396        #[derive(
397            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
398        )]
399        struct MultipleUnnamedFields(u32, RcStr);
400
401        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
402        Ok(())
403    }
404
405    #[test]
406    fn test_one_named_field() -> Result<()> {
407        #[derive(
408            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
409        )]
410        struct OneNamedField {
411            named: u32,
412        }
413
414        assert_task_input(OneNamedField { named: 42 });
415        Ok(())
416    }
417
418    #[test]
419    fn test_multiple_named_fields() -> Result<()> {
420        #[derive(
421            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
422        )]
423        struct MultipleNamedFields {
424            named: u32,
425            other: RcStr,
426        }
427
428        assert_task_input(MultipleNamedFields {
429            named: 42,
430            other: rcstr!("42"),
431        });
432        Ok(())
433    }
434
435    #[test]
436    fn test_generic_field() -> Result<()> {
437        #[derive(
438            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
439        )]
440        struct GenericField<T>(T);
441
442        assert_task_input(GenericField(42));
443        assert_task_input(GenericField(rcstr!("42")));
444        Ok(())
445    }
446
447    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
448    enum OneVariant {
449        Variant,
450    }
451
452    #[test]
453    fn test_one_variant() -> Result<()> {
454        assert_task_input(OneVariant::Variant);
455        Ok(())
456    }
457
458    #[test]
459    fn test_multiple_variants() -> Result<()> {
460        #[derive(
461            Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
462        )]
463        enum MultipleVariants {
464            Variant1,
465            Variant2,
466        }
467
468        assert_task_input(MultipleVariants::Variant2);
469        Ok(())
470    }
471
472    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
473    enum MultipleVariantsAndHeterogeneousFields {
474        Variant1,
475        Variant2(u32),
476        Variant3 { named: u32 },
477        Variant4(u32, RcStr),
478        Variant5 { named: u32, other: RcStr },
479    }
480
481    #[test]
482    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
483        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
484            named: 42,
485            other: rcstr!("42"),
486        });
487        Ok(())
488    }
489
490    #[test]
491    fn test_nested_variants() -> Result<()> {
492        #[derive(
493            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
494        )]
495        enum NestedVariants {
496            Variant1,
497            Variant2(MultipleVariantsAndHeterogeneousFields),
498            Variant3 { named: OneVariant },
499            Variant4(OneVariant, RcStr),
500            Variant5 { named: OneVariant, other: RcStr },
501        }
502
503        assert_task_input(NestedVariants::Variant5 {
504            named: OneVariant::Variant,
505            other: rcstr!("42"),
506        });
507        assert_task_input(NestedVariants::Variant2(
508            MultipleVariantsAndHeterogeneousFields::Variant5 {
509                named: 42,
510                other: rcstr!("42"),
511            },
512        ));
513        Ok(())
514    }
515}