turbo_tasks/task/
task_input.rs

1use std::{
2    collections::BTreeMap, fmt::Debug, future::Future, hash::Hash, sync::Arc, time::Duration,
3};
4
5use anyhow::Result;
6use either::Either;
7use serde::{Deserialize, Serialize};
8use turbo_rcstr::RcStr;
9
10use crate::{
11    MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
12    trace::TraceRawVcs,
13};
14
15/// Trait to implement in order for a type to be accepted as a
16/// [`#[turbo_tasks::function]`][crate::function] argument.
17pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
18    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
19        async { Ok(self.clone()) }
20    }
21    fn is_resolved(&self) -> bool {
22        true
23    }
24    fn is_transient(&self) -> bool;
25}
26
27macro_rules! impl_task_input {
28    ($($t:ty),*) => {
29        $(
30            impl TaskInput for $t {
31                fn is_transient(&self) -> bool {
32                    false
33                }
34            }
35        )*
36    };
37}
38
39impl_task_input! {
40    (),
41    bool,
42    u8,
43    u16,
44    u32,
45    i32,
46    u64,
47    usize,
48    RcStr,
49    TaskId,
50    ValueTypeId,
51    Duration,
52    String
53}
54
55impl<T> TaskInput for Vec<T>
56where
57    T: TaskInput,
58{
59    fn is_resolved(&self) -> bool {
60        self.iter().all(TaskInput::is_resolved)
61    }
62
63    fn is_transient(&self) -> bool {
64        self.iter().any(TaskInput::is_transient)
65    }
66
67    async fn resolve_input(&self) -> Result<Self> {
68        let mut resolved = Vec::with_capacity(self.len());
69        for value in self {
70            resolved.push(value.resolve_input().await?);
71        }
72        Ok(resolved)
73    }
74}
75
76impl<T> TaskInput for Box<T>
77where
78    T: TaskInput,
79{
80    fn is_resolved(&self) -> bool {
81        self.as_ref().is_resolved()
82    }
83
84    fn is_transient(&self) -> bool {
85        self.as_ref().is_transient()
86    }
87
88    async fn resolve_input(&self) -> Result<Self> {
89        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
90    }
91}
92
93impl<T> TaskInput for Arc<T>
94where
95    T: TaskInput,
96{
97    fn is_resolved(&self) -> bool {
98        self.as_ref().is_resolved()
99    }
100
101    fn is_transient(&self) -> bool {
102        self.as_ref().is_transient()
103    }
104
105    async fn resolve_input(&self) -> Result<Self> {
106        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
107    }
108}
109
110impl<T> TaskInput for Option<T>
111where
112    T: TaskInput,
113{
114    fn is_resolved(&self) -> bool {
115        match self {
116            Some(value) => value.is_resolved(),
117            None => true,
118        }
119    }
120
121    fn is_transient(&self) -> bool {
122        match self {
123            Some(value) => value.is_transient(),
124            None => false,
125        }
126    }
127
128    async fn resolve_input(&self) -> Result<Self> {
129        match self {
130            Some(value) => Ok(Some(value.resolve_input().await?)),
131            None => Ok(None),
132        }
133    }
134}
135
136impl<T> TaskInput for Vc<T>
137where
138    T: Send + Sync + ?Sized,
139{
140    fn is_resolved(&self) -> bool {
141        Vc::is_resolved(*self)
142    }
143
144    fn is_transient(&self) -> bool {
145        self.node.is_transient()
146    }
147
148    async fn resolve_input(&self) -> Result<Self> {
149        Vc::resolve(*self).await
150    }
151}
152
153// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
154// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
155impl<T> TaskInput for ResolvedVc<T>
156where
157    T: Send + Sync + ?Sized,
158{
159    fn is_resolved(&self) -> bool {
160        true
161    }
162
163    fn is_transient(&self) -> bool {
164        self.node.is_transient()
165    }
166
167    async fn resolve_input(&self) -> Result<Self> {
168        Ok(*self)
169    }
170}
171
172impl<T> TaskInput for TransientValue<T>
173where
174    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
175{
176    fn is_transient(&self) -> bool {
177        true
178    }
179}
180
181impl<T> Serialize for TransientValue<T>
182where
183    T: MagicAny + Clone + 'static,
184{
185    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
186    where
187        S: serde::Serializer,
188    {
189        Err(serde::ser::Error::custom(
190            "cannot serialize transient task inputs",
191        ))
192    }
193}
194
195impl<'de, T> Deserialize<'de> for TransientValue<T>
196where
197    T: MagicAny + Clone + 'static,
198{
199    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
200    where
201        D: serde::Deserializer<'de>,
202    {
203        Err(serde::de::Error::custom(
204            "cannot deserialize transient task inputs",
205        ))
206    }
207}
208
209impl<T> TaskInput for TransientInstance<T>
210where
211    T: Sync + Send + TraceRawVcs + 'static,
212{
213    fn is_transient(&self) -> bool {
214        true
215    }
216}
217
218impl<T> Serialize for TransientInstance<T> {
219    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
220    where
221        S: serde::Serializer,
222    {
223        Err(serde::ser::Error::custom(
224            "cannot serialize transient task inputs",
225        ))
226    }
227}
228
229impl<'de, T> Deserialize<'de> for TransientInstance<T> {
230    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
231    where
232        D: serde::Deserializer<'de>,
233    {
234        Err(serde::de::Error::custom(
235            "cannot deserialize transient task inputs",
236        ))
237    }
238}
239
240impl<L, R> TaskInput for Either<L, R>
241where
242    L: TaskInput,
243    R: TaskInput,
244{
245    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
246        self.as_ref().map_either(
247            |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
248            |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
249        )
250    }
251
252    fn is_resolved(&self) -> bool {
253        self.as_ref()
254            .either(TaskInput::is_resolved, TaskInput::is_resolved)
255    }
256
257    fn is_transient(&self) -> bool {
258        self.as_ref()
259            .either(TaskInput::is_transient, TaskInput::is_transient)
260    }
261}
262
263impl<K, V> TaskInput for BTreeMap<K, V>
264where
265    K: TaskInput + Ord,
266    V: TaskInput,
267{
268    async fn resolve_input(&self) -> Result<Self> {
269        let mut new_map = BTreeMap::new();
270        for (k, v) in self {
271            new_map.insert(
272                TaskInput::resolve_input(k).await?,
273                TaskInput::resolve_input(v).await?,
274            );
275        }
276        Ok(new_map)
277    }
278
279    fn is_resolved(&self) -> bool {
280        self.iter()
281            .all(|(k, v)| TaskInput::is_resolved(k) || TaskInput::is_resolved(v))
282    }
283
284    fn is_transient(&self) -> bool {
285        self.iter()
286            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
287    }
288}
289macro_rules! tuple_impls {
290    ( $( $name:ident )+ ) => {
291        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
292        where $($name: TaskInput),+
293        {
294            #[allow(non_snake_case)]
295            fn is_resolved(&self) -> bool {
296                let ($($name,)+) = self;
297                $($name.is_resolved() &&)+ true
298            }
299
300            #[allow(non_snake_case)]
301            fn is_transient(&self) -> bool {
302                let ($($name,)+) = self;
303                $($name.is_transient() ||)+ false
304            }
305
306            #[allow(non_snake_case)]
307            async fn resolve_input(&self) -> Result<Self> {
308                let ($($name,)+) = self;
309                Ok(($($name.resolve_input().await?,)+))
310            }
311        }
312    };
313}
314
315// Implement `TaskInput` for all tuples of 1 to 12 elements.
316tuple_impls! { A }
317tuple_impls! { A B }
318tuple_impls! { A B C }
319tuple_impls! { A B C D }
320tuple_impls! { A B C D E }
321tuple_impls! { A B C D E F }
322tuple_impls! { A B C D E F G }
323tuple_impls! { A B C D E F G H }
324tuple_impls! { A B C D E F G H I }
325tuple_impls! { A B C D E F G H I J }
326tuple_impls! { A B C D E F G H I J K }
327tuple_impls! { A B C D E F G H I J K L }
328
329#[cfg(test)]
330mod tests {
331    use turbo_rcstr::rcstr;
332    use turbo_tasks_macros::TaskInput;
333
334    use super::*;
335    // This is necessary for the derive macro to work, as its expansion refers to
336    // the crate name directly.
337    use crate as turbo_tasks;
338
339    fn assert_task_input<T>(_: T)
340    where
341        T: TaskInput,
342    {
343    }
344
345    #[test]
346    fn test_no_fields() -> Result<()> {
347        #[derive(
348            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
349        )]
350        struct NoFields;
351
352        assert_task_input(NoFields);
353        Ok(())
354    }
355
356    #[test]
357    fn test_one_unnamed_field() -> Result<()> {
358        #[derive(
359            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
360        )]
361        struct OneUnnamedField(u32);
362
363        assert_task_input(OneUnnamedField(42));
364        Ok(())
365    }
366
367    #[test]
368    fn test_multiple_unnamed_fields() -> Result<()> {
369        #[derive(
370            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
371        )]
372        struct MultipleUnnamedFields(u32, RcStr);
373
374        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
375        Ok(())
376    }
377
378    #[test]
379    fn test_one_named_field() -> Result<()> {
380        #[derive(
381            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
382        )]
383        struct OneNamedField {
384            named: u32,
385        }
386
387        assert_task_input(OneNamedField { named: 42 });
388        Ok(())
389    }
390
391    #[test]
392    fn test_multiple_named_fields() -> Result<()> {
393        #[derive(
394            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
395        )]
396        struct MultipleNamedFields {
397            named: u32,
398            other: RcStr,
399        }
400
401        assert_task_input(MultipleNamedFields {
402            named: 42,
403            other: rcstr!("42"),
404        });
405        Ok(())
406    }
407
408    #[test]
409    fn test_generic_field() -> Result<()> {
410        #[derive(
411            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
412        )]
413        struct GenericField<T>(T);
414
415        assert_task_input(GenericField(42));
416        assert_task_input(GenericField(rcstr!("42")));
417        Ok(())
418    }
419
420    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
421    enum OneVariant {
422        Variant,
423    }
424
425    #[test]
426    fn test_one_variant() -> Result<()> {
427        assert_task_input(OneVariant::Variant);
428        Ok(())
429    }
430
431    #[test]
432    fn test_multiple_variants() -> Result<()> {
433        #[derive(
434            Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
435        )]
436        enum MultipleVariants {
437            Variant1,
438            Variant2,
439        }
440
441        assert_task_input(MultipleVariants::Variant2);
442        Ok(())
443    }
444
445    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
446    enum MultipleVariantsAndHeterogeneousFields {
447        Variant1,
448        Variant2(u32),
449        Variant3 { named: u32 },
450        Variant4(u32, RcStr),
451        Variant5 { named: u32, other: RcStr },
452    }
453
454    #[test]
455    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
456        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
457            named: 42,
458            other: rcstr!("42"),
459        });
460        Ok(())
461    }
462
463    #[test]
464    fn test_nested_variants() -> Result<()> {
465        #[derive(
466            Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
467        )]
468        enum NestedVariants {
469            Variant1,
470            Variant2(MultipleVariantsAndHeterogeneousFields),
471            Variant3 { named: OneVariant },
472            Variant4(OneVariant, RcStr),
473            Variant5 { named: OneVariant, other: RcStr },
474        }
475
476        assert_task_input(NestedVariants::Variant5 {
477            named: OneVariant::Variant,
478            other: rcstr!("42"),
479        });
480        assert_task_input(NestedVariants::Variant2(
481            MultipleVariantsAndHeterogeneousFields::Variant5 {
482                named: 42,
483                other: rcstr!("42"),
484            },
485        ));
486        Ok(())
487    }
488}