Skip to main content

turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    ops::{Deref, DerefMut},
7    sync::Arc,
8    time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13    Decode, Encode,
14    de::Decoder,
15    enc::Encoder,
16    error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21use turbo_tasks_hash::HashAlgorithm;
22
23// This import is necessary for derive macros to work, as their expansion refers to the crate
24// name directly.
25use crate as turbo_tasks;
26use crate::{
27    DynTaskInputs, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
28    trace::TraceRawVcs,
29};
30
31/// Trait to implement in order for a type to be accepted as a
32/// [`#[turbo_tasks::function]`][crate::function] argument.
33///
34/// ## Serialization
35///
36/// For persistent caching of a task, arguments must be serializable. All `TaskInput`s must
37/// implement the bincode [`Encode`] and [`Decode`] traits.
38///
39/// Transient task inputs are required to implement [`Encode`] and [`Decode`], but are allowed to
40/// panic at runtime. This requirement could be lifted in the future.
41///
42/// Bincode encoding must be deterministic and compatible with [`Eq`] comparisons. If two
43/// `TaskInput`s compare equal they must also encode to the same bytes.
44///
45/// ## Hash and Eq
46///
47/// Arguments are used as part of keys in a `HashMap`, so they must implement of [`PartialEq`],
48/// [`Eq`], and [`Hash`] traits.
49///
50/// ## [`Vc<T>`][Vc]
51///
52/// A [`Vc`] is a pointer to a cell. It implements `TaskInput` and serves as a "pass by reference"
53/// argument:
54///
55/// - **Memoization**: [`Vc`] is keyed by pointer for memoization purposes. Identical values in
56///   different cells are treated as distinct.
57/// - **Singleton Pattern**: To ensure memoization efficiency, the singleton pattern can be employed
58///   to guarantee that identical values yield the same `Vc`. For more info see [Singleton Pattern
59///   Guide][singleton].
60///
61/// [singleton]: https://turbopack-rust-docs.vercel.sh/turbo-engine/singleton.html
62///
63/// ## Deriving `TaskInput`
64///
65/// Structs or enums can be made into task inputs by deriving `TaskInput`:
66///
67/// ```rust
68/// #[derive(TaskInput)]
69/// struct MyStruct {
70///     // Fields go here...
71/// }
72/// ```
73///
74/// Derived `TaskInput` types **passed by value**. When called, arguments are moved into a `Box`,
75/// and then cloned before being passed into the function. If the task is invalidated, the
76/// `TaskInput` is cloned again to allow the function to be re-executed. It's recommended to ensure
77/// that these types are cheap to clone.
78///
79/// Reference-counted types like [`Arc`] are cheap to clone, but each reference contained in a
80/// `TaskInput` will be serialized independently in the persistent cache, and may consume extra disk
81/// space. If an [`Arc`] points to a large type, consider wrapping that type in [`Vc`], so that only
82/// one copy of the value will be serialized.
83pub trait TaskInput:
84    Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
85{
86    /// This method should resolve any [`Vc`]s nested inside of this object, cloning the object in
87    /// the process. If the input is unresolved ([`TaskInput::is_resolved`]) a "local" resolution
88    /// task is created that runs this method.
89    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
90        async { Ok(self.clone()) }
91    }
92
93    /// This should return `true` if there are any unresolved [`Vc`]s in the type.
94    ///
95    /// Note that [`Vc`]s can sometimes be internally resolved, so you should call
96    /// [`Vc::is_resolved`] (or rely on the derive macro for this trait) instead of returning `true`
97    /// for any [`Vc`]. [`ResolvedVc::is_resolved`] always returns `true`.
98    ///
99    /// If this returns `true`, a "local" resolution task calling [`TaskInput::resolve_input`] will
100    /// be spawned before the function accepting the arguments is run.
101    ///
102    /// If this returns `false`, the `TaskInput` will be [cloned][Clone] instead of resolved, and
103    /// the function's task will be spawned directly without a resolution step.
104    fn is_resolved(&self) -> bool {
105        true
106    }
107
108    /// This should return true if this object contains a [`Vc`] (or any subtype of [`Vc`]) pointing
109    /// to a cell owned by a transient task.
110    ///
111    /// Any function called with a transient `TaskInput` will be transient. Any [`Vc`] constructed
112    /// in a transient task or in a top-level [`run_once`][crate::run_once] closure will be
113    /// transient.
114    ///
115    /// Internally, a [`Vc`] can be determined to be transient by comparing the owning task's id
116    /// with the [`TRANSIENT_TASK_BIT`][crate::TRANSIENT_TASK_BIT] mask.
117    fn is_transient(&self) -> bool;
118}
119
120macro_rules! impl_task_input {
121    ($($t:ty),*) => {
122        $(
123            impl TaskInput for $t {
124                fn is_transient(&self) -> bool {
125                    false
126                }
127            }
128        )*
129    };
130}
131
132impl_task_input! {
133    (),
134    bool,
135    u8,
136    u16,
137    u32,
138    i32,
139    u64,
140    usize,
141    RcStr,
142    TaskId,
143    ValueTypeId,
144    Duration,
145    String,
146    HashAlgorithm
147}
148
149impl<T> TaskInput for Vec<T>
150where
151    T: TaskInput,
152{
153    fn is_resolved(&self) -> bool {
154        self.iter().all(TaskInput::is_resolved)
155    }
156
157    fn is_transient(&self) -> bool {
158        self.iter().any(TaskInput::is_transient)
159    }
160
161    async fn resolve_input(&self) -> Result<Self> {
162        let mut resolved = Vec::with_capacity(self.len());
163        for value in self {
164            resolved.push(value.resolve_input().await?);
165        }
166        Ok(resolved)
167    }
168}
169
170impl<T> TaskInput for Box<T>
171where
172    T: TaskInput,
173{
174    fn is_resolved(&self) -> bool {
175        self.as_ref().is_resolved()
176    }
177
178    fn is_transient(&self) -> bool {
179        self.as_ref().is_transient()
180    }
181
182    async fn resolve_input(&self) -> Result<Self> {
183        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
184    }
185}
186
187impl<T> TaskInput for Arc<T>
188where
189    T: TaskInput,
190{
191    fn is_resolved(&self) -> bool {
192        self.as_ref().is_resolved()
193    }
194
195    fn is_transient(&self) -> bool {
196        self.as_ref().is_transient()
197    }
198
199    async fn resolve_input(&self) -> Result<Self> {
200        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
201    }
202}
203
204impl<T> TaskInput for ReadRef<T>
205where
206    T: TaskInput,
207{
208    fn is_resolved(&self) -> bool {
209        Self::as_raw_ref(self).is_resolved()
210    }
211
212    fn is_transient(&self) -> bool {
213        Self::as_raw_ref(self).is_transient()
214    }
215
216    async fn resolve_input(&self) -> Result<Self> {
217        Ok(ReadRef::new_owned(
218            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
219        ))
220    }
221}
222
223impl<T> TaskInput for Option<T>
224where
225    T: TaskInput,
226{
227    fn is_resolved(&self) -> bool {
228        match self {
229            Some(value) => value.is_resolved(),
230            None => true,
231        }
232    }
233
234    fn is_transient(&self) -> bool {
235        match self {
236            Some(value) => value.is_transient(),
237            None => false,
238        }
239    }
240
241    async fn resolve_input(&self) -> Result<Self> {
242        match self {
243            Some(value) => Ok(Some(value.resolve_input().await?)),
244            None => Ok(None),
245        }
246    }
247}
248
249impl<T> TaskInput for Vc<T>
250where
251    T: Send + Sync + ?Sized,
252{
253    fn is_resolved(&self) -> bool {
254        Vc::is_resolved(*self)
255    }
256
257    fn is_transient(&self) -> bool {
258        self.node.is_transient()
259    }
260
261    async fn resolve_input(&self) -> Result<Self> {
262        Ok(*(*self).to_resolved().await?)
263    }
264}
265
266// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
267// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
268impl<T> TaskInput for ResolvedVc<T>
269where
270    T: Send + Sync + ?Sized,
271{
272    fn is_resolved(&self) -> bool {
273        true
274    }
275
276    fn is_transient(&self) -> bool {
277        self.node.is_transient()
278    }
279
280    async fn resolve_input(&self) -> Result<Self> {
281        Ok(*self)
282    }
283}
284
285impl<T> TaskInput for TransientValue<T>
286where
287    T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
288{
289    fn is_transient(&self) -> bool {
290        true
291    }
292}
293
294impl<T> Encode for TransientValue<T> {
295    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
296        Err(EncodeError::Other("cannot encode transient task inputs"))
297    }
298}
299
300impl<Context, T> Decode<Context> for TransientValue<T> {
301    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
302        Err(DecodeError::Other("cannot decode transient task inputs"))
303    }
304}
305
306impl<T> TaskInput for TransientInstance<T>
307where
308    T: Sync + Send + TraceRawVcs + 'static,
309{
310    fn is_transient(&self) -> bool {
311        true
312    }
313}
314
315impl<T> Encode for TransientInstance<T> {
316    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
317        Err(EncodeError::Other("cannot encode transient task inputs"))
318    }
319}
320
321impl<Context, T> Decode<Context> for TransientInstance<T> {
322    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
323        Err(DecodeError::Other("cannot decode transient task inputs"))
324    }
325}
326
327impl<K, V> TaskInput for BTreeMap<K, V>
328where
329    K: TaskInput + Ord,
330    V: TaskInput,
331{
332    async fn resolve_input(&self) -> Result<Self> {
333        let mut new_map = BTreeMap::new();
334        for (k, v) in self {
335            new_map.insert(
336                TaskInput::resolve_input(k).await?,
337                TaskInput::resolve_input(v).await?,
338            );
339        }
340        Ok(new_map)
341    }
342
343    fn is_resolved(&self) -> bool {
344        self.iter()
345            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
346    }
347
348    fn is_transient(&self) -> bool {
349        self.iter()
350            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
351    }
352}
353
354impl<T> TaskInput for BTreeSet<T>
355where
356    T: TaskInput + Ord,
357{
358    async fn resolve_input(&self) -> Result<Self> {
359        let mut new_set = BTreeSet::new();
360        for value in self {
361            new_set.insert(TaskInput::resolve_input(value).await?);
362        }
363        Ok(new_set)
364    }
365
366    fn is_resolved(&self) -> bool {
367        self.iter().all(TaskInput::is_resolved)
368    }
369
370    fn is_transient(&self) -> bool {
371        self.iter().any(TaskInput::is_transient)
372    }
373}
374
375impl<K, V> TaskInput for FrozenMap<K, V>
376where
377    K: TaskInput + Ord + 'static,
378    V: TaskInput + 'static,
379{
380    async fn resolve_input(&self) -> Result<Self> {
381        let mut new_entries = Vec::with_capacity(self.len());
382        for (k, v) in self {
383            new_entries.push((
384                TaskInput::resolve_input(k).await?,
385                TaskInput::resolve_input(v).await?,
386            ));
387        }
388        // note: resolving might deduplicate `Vc`s in keys
389        Ok(Self::from(new_entries))
390    }
391
392    fn is_resolved(&self) -> bool {
393        self.iter()
394            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
395    }
396
397    fn is_transient(&self) -> bool {
398        self.iter()
399            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
400    }
401}
402
403impl<T> TaskInput for FrozenSet<T>
404where
405    T: TaskInput + Ord + 'static,
406{
407    async fn resolve_input(&self) -> Result<Self> {
408        let mut new_set = Vec::with_capacity(self.len());
409        for value in self {
410            new_set.push(TaskInput::resolve_input(value).await?);
411        }
412        Ok(Self::from_iter(new_set))
413    }
414
415    fn is_resolved(&self) -> bool {
416        self.iter().all(TaskInput::is_resolved)
417    }
418
419    fn is_transient(&self) -> bool {
420        self.iter().any(TaskInput::is_transient)
421    }
422}
423
424/// A thin wrapper around [`Either`] that implements the traits required by [`TaskInput`], notably
425/// [`Encode`] and [`Decode`].
426#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
427pub struct EitherTaskInput<L, R>(pub Either<L, R>);
428
429impl<L, R> Deref for EitherTaskInput<L, R> {
430    type Target = Either<L, R>;
431
432    fn deref(&self) -> &Self::Target {
433        &self.0
434    }
435}
436
437impl<L, R> DerefMut for EitherTaskInput<L, R> {
438    fn deref_mut(&mut self) -> &mut Self::Target {
439        &mut self.0
440    }
441}
442
443impl<L, R> Encode for EitherTaskInput<L, R>
444where
445    L: Encode,
446    R: Encode,
447{
448    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
449        turbo_bincode::either::encode(self, encoder)
450    }
451}
452
453impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
454where
455    L: Decode<Context>,
456    R: Decode<Context>,
457{
458    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
459        turbo_bincode::either::decode(decoder).map(Self)
460    }
461}
462
463impl<L, R> TaskInput for EitherTaskInput<L, R>
464where
465    L: TaskInput,
466    R: TaskInput,
467{
468    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
469        self.as_ref().map_either(
470            |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
471            |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
472        )
473    }
474
475    fn is_resolved(&self) -> bool {
476        self.as_ref()
477            .either(TaskInput::is_resolved, TaskInput::is_resolved)
478    }
479
480    fn is_transient(&self) -> bool {
481        self.as_ref()
482            .either(TaskInput::is_transient, TaskInput::is_transient)
483    }
484}
485
486macro_rules! tuple_impls {
487    ( $( $name:ident )+ ) => {
488        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
489        where $($name: TaskInput),+
490        {
491            #[allow(non_snake_case)]
492            fn is_resolved(&self) -> bool {
493                let ($($name,)+) = self;
494                $($name.is_resolved() &&)+ true
495            }
496
497            #[allow(non_snake_case)]
498            fn is_transient(&self) -> bool {
499                let ($($name,)+) = self;
500                $($name.is_transient() ||)+ false
501            }
502
503            #[allow(non_snake_case)]
504            async fn resolve_input(&self) -> Result<Self> {
505                let ($($name,)+) = self;
506                Ok(($($name.resolve_input().await?,)+))
507            }
508        }
509    };
510}
511
512// Implement `TaskInput` for all tuples of 1 to 12 elements.
513tuple_impls! { A }
514tuple_impls! { A B }
515tuple_impls! { A B C }
516tuple_impls! { A B C D }
517tuple_impls! { A B C D E }
518tuple_impls! { A B C D E F }
519tuple_impls! { A B C D E F G }
520tuple_impls! { A B C D E F G H }
521tuple_impls! { A B C D E F G H I }
522tuple_impls! { A B C D E F G H I J }
523tuple_impls! { A B C D E F G H I J K }
524tuple_impls! { A B C D E F G H I J K L }
525
526#[cfg(test)]
527mod tests {
528    use turbo_rcstr::rcstr;
529    use turbo_tasks_macros::TaskInput;
530
531    use super::*;
532
533    fn assert_task_input<T>(_: T)
534    where
535        T: TaskInput,
536    {
537    }
538
539    #[test]
540    fn test_no_fields() -> Result<()> {
541        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
542        struct NoFields;
543
544        assert_task_input(NoFields);
545        Ok(())
546    }
547
548    #[test]
549    fn test_one_unnamed_field() -> Result<()> {
550        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
551        struct OneUnnamedField(u32);
552
553        assert_task_input(OneUnnamedField(42));
554        Ok(())
555    }
556
557    #[test]
558    fn test_multiple_unnamed_fields() -> Result<()> {
559        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
560        struct MultipleUnnamedFields(u32, RcStr);
561
562        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
563        Ok(())
564    }
565
566    #[test]
567    fn test_one_named_field() -> Result<()> {
568        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
569        struct OneNamedField {
570            named: u32,
571        }
572
573        assert_task_input(OneNamedField { named: 42 });
574        Ok(())
575    }
576
577    #[test]
578    fn test_multiple_named_fields() -> Result<()> {
579        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
580        struct MultipleNamedFields {
581            named: u32,
582            other: RcStr,
583        }
584
585        assert_task_input(MultipleNamedFields {
586            named: 42,
587            other: rcstr!("42"),
588        });
589        Ok(())
590    }
591
592    #[test]
593    fn test_generic_field() -> Result<()> {
594        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
595        struct GenericField<T>(T);
596
597        assert_task_input(GenericField(42));
598        assert_task_input(GenericField(rcstr!("42")));
599        Ok(())
600    }
601
602    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
603    enum OneVariant {
604        Variant,
605    }
606
607    #[test]
608    fn test_one_variant() -> Result<()> {
609        assert_task_input(OneVariant::Variant);
610        Ok(())
611    }
612
613    #[test]
614    fn test_multiple_variants() -> Result<()> {
615        #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
616        enum MultipleVariants {
617            Variant1,
618            Variant2,
619        }
620
621        assert_task_input(MultipleVariants::Variant2);
622        Ok(())
623    }
624
625    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
626    enum MultipleVariantsAndHeterogeneousFields {
627        Variant1,
628        Variant2(u32),
629        Variant3 { named: u32 },
630        Variant4(u32, RcStr),
631        Variant5 { named: u32, other: RcStr },
632    }
633
634    #[test]
635    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
636        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
637            named: 42,
638            other: rcstr!("42"),
639        });
640        Ok(())
641    }
642
643    #[test]
644    fn test_nested_variants() -> Result<()> {
645        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
646        enum NestedVariants {
647            Variant1,
648            Variant2(MultipleVariantsAndHeterogeneousFields),
649            Variant3 { named: OneVariant },
650            Variant4(OneVariant, RcStr),
651            Variant5 { named: OneVariant, other: RcStr },
652        }
653
654        assert_task_input(NestedVariants::Variant5 {
655            named: OneVariant::Variant,
656            other: rcstr!("42"),
657        });
658        assert_task_input(NestedVariants::Variant2(
659            MultipleVariantsAndHeterogeneousFields::Variant5 {
660                named: 42,
661                other: rcstr!("42"),
662            },
663        ));
664        Ok(())
665    }
666}