Skip to main content

turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    ops::{Deref, DerefMut},
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use anyhow::Result;
14use bincode::{
15    Decode, Encode,
16    de::Decoder,
17    enc::Encoder,
18    error::{DecodeError, EncodeError},
19};
20use either::Either;
21use turbo_frozenmap::{FrozenMap, FrozenSet};
22use turbo_rcstr::RcStr;
23use turbo_tasks_hash::HashAlgorithm;
24
25// This import is necessary for derive macros to work, as their expansion refers to the crate
26// name directly.
27use crate::{self as turbo_tasks, ReadRef};
28use crate::{
29    DynTaskInputs, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
30    trace::TraceRawVcs,
31};
32
33/// An 8-byte hand-rolled [`Future`] that immediately resolves to `Ok(self.clone())` of the
34/// referenced value.
35///
36/// Used by the [`TaskInput::resolve_input`] default implementation
37struct CloneReady<'a, T> {
38    pub inner: Option<&'a T>,
39}
40
41impl<'a, T: Clone> Future for CloneReady<'a, T> {
42    type Output = Result<T>;
43
44    fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
45        Poll::Ready(Ok(self
46            .inner
47            .take()
48            .expect("future already polled to completion")
49            .clone()))
50    }
51}
52
53// `CloneReady` holds only a shared reference; it has no self-referential state.
54impl<'a, T> Unpin for CloneReady<'a, T> {}
55
56/// Trait to implement in order for a type to be accepted as a
57/// [`#[turbo_tasks::function]`][crate::function] argument.
58///
59/// ## Serialization
60///
61/// For persistent caching of a task, arguments must be serializable. All `TaskInput`s must
62/// implement the bincode [`Encode`] and [`Decode`] traits.
63///
64/// Transient task inputs are required to implement [`Encode`] and [`Decode`], but are allowed to
65/// panic at runtime. This requirement could be lifted in the future.
66///
67/// Bincode encoding must be deterministic and compatible with [`Eq`] comparisons. If two
68/// `TaskInput`s compare equal they must also encode to the same bytes.
69///
70/// ## Hash and Eq
71///
72/// Arguments are used as part of keys in a `HashMap`, so they must implement of [`PartialEq`],
73/// [`Eq`], and [`Hash`] traits.
74///
75/// ## [`Vc<T>`][Vc]
76///
77/// A [`Vc`] is a pointer to a cell. It implements `TaskInput` and serves as a "pass by reference"
78/// argument:
79///
80/// - **Memoization**: [`Vc`] is keyed by pointer for memoization purposes. Identical values in
81///   different cells are treated as distinct.
82/// - **Singleton Pattern**: To ensure memoization efficiency, the singleton pattern can be employed
83///   to guarantee that identical values yield the same `Vc`. For more info see [Singleton Pattern
84///   Guide][singleton].
85///
86/// [singleton]: https://turbopack-rust-docs.vercel.sh/turbo-engine/singleton.html
87///
88/// ## Deriving `TaskInput`
89///
90/// Structs or enums can be made into task inputs by deriving `TaskInput`:
91///
92/// ```rust
93/// #[turbo_tasks::task_input]
94/// struct MyStruct {
95///     // Fields go here...
96/// }
97/// ```
98///
99/// Derived `TaskInput` types **passed by value**. When called, arguments are moved into a `Box`,
100/// and then cloned before being passed into the function. If the task is invalidated, the
101/// `TaskInput` is cloned again to allow the function to be re-executed. It's recommended to ensure
102/// that these types are cheap to clone.
103///
104/// Reference-counted types like [`Arc`] are cheap to clone, but each reference contained in a
105/// `TaskInput` will be serialized independently in the persistent cache, and may consume extra disk
106/// space. If an [`Arc`] points to a large type, consider wrapping that type in [`Vc`], so that only
107/// one copy of the value will be serialized.
108pub trait TaskInput:
109    Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
110{
111    /// This method should resolve any [`Vc`]s nested inside of this object, cloning the object in
112    /// the process. If the input is unresolved ([`TaskInput::is_resolved`]) a "local" resolution
113    /// task is created that runs this method.
114    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
115        CloneReady { inner: Some(self) }
116    }
117
118    /// This should return `true` if there are any unresolved [`Vc`]s in the type.
119    ///
120    /// Note that [`Vc`]s can sometimes be internally resolved, so you should call
121    /// [`Vc::is_resolved`] (or rely on the derive macro for this trait) instead of returning `true`
122    /// for any [`Vc`]. [`ResolvedVc::is_resolved`] always returns `true`.
123    ///
124    /// If this returns `true`, a "local" resolution task calling [`TaskInput::resolve_input`] will
125    /// be spawned before the function accepting the arguments is run.
126    ///
127    /// If this returns `false`, the `TaskInput` will be [cloned][Clone] instead of resolved, and
128    /// the function's task will be spawned directly without a resolution step.
129    fn is_resolved(&self) -> bool {
130        true
131    }
132
133    /// This should return true if this object contains a [`Vc`] (or any subtype of [`Vc`]) pointing
134    /// to a cell owned by a transient task.
135    ///
136    /// Any function called with a transient `TaskInput` will be transient. Any [`Vc`] constructed
137    /// in a transient task or in a top-level [`run_once`][crate::run_once] closure will be
138    /// transient.
139    ///
140    /// Internally, a [`Vc`] can be determined to be transient by comparing the owning task's id
141    /// with the [`TRANSIENT_TASK_BIT`][crate::TRANSIENT_TASK_BIT] mask.
142    fn is_transient(&self) -> bool;
143}
144
145macro_rules! impl_task_input {
146    ($($t:ty),*) => {
147        $(
148            impl TaskInput for $t {
149                fn is_transient(&self) -> bool {
150                    false
151                }
152            }
153        )*
154    };
155}
156
157impl_task_input! {
158    (),
159    bool,
160    u8,
161    u16,
162    u32,
163    i32,
164    u64,
165    usize,
166    RcStr,
167    TaskId,
168    ValueTypeId,
169    Duration,
170    String,
171    HashAlgorithm
172}
173
174impl<T> TaskInput for Vec<T>
175where
176    T: TaskInput,
177{
178    fn is_resolved(&self) -> bool {
179        self.iter().all(TaskInput::is_resolved)
180    }
181
182    fn is_transient(&self) -> bool {
183        self.iter().any(TaskInput::is_transient)
184    }
185
186    async fn resolve_input(&self) -> Result<Self> {
187        let mut resolved = Vec::with_capacity(self.len());
188        for value in self {
189            resolved.push(value.resolve_input().await?);
190        }
191        Ok(resolved)
192    }
193}
194
195impl<T> TaskInput for Box<T>
196where
197    T: TaskInput,
198{
199    fn is_resolved(&self) -> bool {
200        self.as_ref().is_resolved()
201    }
202
203    fn is_transient(&self) -> bool {
204        self.as_ref().is_transient()
205    }
206
207    async fn resolve_input(&self) -> Result<Self> {
208        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
209    }
210}
211
212impl<T> TaskInput for Arc<T>
213where
214    T: TaskInput,
215{
216    fn is_resolved(&self) -> bool {
217        self.as_ref().is_resolved()
218    }
219
220    fn is_transient(&self) -> bool {
221        self.as_ref().is_transient()
222    }
223
224    async fn resolve_input(&self) -> Result<Self> {
225        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
226    }
227}
228
229impl<T> TaskInput for ReadRef<T>
230where
231    T: TaskInput,
232{
233    fn is_resolved(&self) -> bool {
234        Self::as_raw_ref(self).is_resolved()
235    }
236
237    fn is_transient(&self) -> bool {
238        Self::as_raw_ref(self).is_transient()
239    }
240
241    async fn resolve_input(&self) -> Result<Self> {
242        Ok(ReadRef::new_owned(
243            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
244        ))
245    }
246}
247
248impl<T> TaskInput for Option<T>
249where
250    T: TaskInput,
251{
252    fn is_resolved(&self) -> bool {
253        match self {
254            Some(value) => value.is_resolved(),
255            None => true,
256        }
257    }
258
259    fn is_transient(&self) -> bool {
260        match self {
261            Some(value) => value.is_transient(),
262            None => false,
263        }
264    }
265
266    async fn resolve_input(&self) -> Result<Self> {
267        match self {
268            Some(value) => Ok(Some(value.resolve_input().await?)),
269            None => Ok(None),
270        }
271    }
272}
273
274impl<T> TaskInput for Vc<T>
275where
276    T: Send + Sync + ?Sized,
277{
278    fn is_resolved(&self) -> bool {
279        Vc::is_resolved(*self)
280    }
281
282    fn is_transient(&self) -> bool {
283        self.node.is_transient()
284    }
285
286    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
287        // It isn't ideal to use this function but it exactly matches this usecase (resolved but
288        // still a Vc)
289        (*self).resolve()
290    }
291}
292
293// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
294// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
295impl<T> TaskInput for ResolvedVc<T>
296where
297    T: Send + Sync + ?Sized,
298{
299    fn is_resolved(&self) -> bool {
300        true
301    }
302
303    fn is_transient(&self) -> bool {
304        self.node.is_transient()
305    }
306}
307
308impl<T> TaskInput for TransientValue<T>
309where
310    T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
311{
312    fn is_transient(&self) -> bool {
313        true
314    }
315}
316
317impl<T> Encode for TransientValue<T> {
318    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
319        Err(EncodeError::Other("cannot encode transient task inputs"))
320    }
321}
322
323impl<Context, T> Decode<Context> for TransientValue<T> {
324    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
325        Err(DecodeError::Other("cannot decode transient task inputs"))
326    }
327}
328
329impl<T> TaskInput for TransientInstance<T>
330where
331    T: Sync + Send + TraceRawVcs + 'static,
332{
333    fn is_transient(&self) -> bool {
334        true
335    }
336}
337
338impl<T> Encode for TransientInstance<T> {
339    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
340        Err(EncodeError::Other("cannot encode transient task inputs"))
341    }
342}
343
344impl<Context, T> Decode<Context> for TransientInstance<T> {
345    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
346        Err(DecodeError::Other("cannot decode transient task inputs"))
347    }
348}
349
350impl<K, V> TaskInput for BTreeMap<K, V>
351where
352    K: TaskInput + Ord,
353    V: TaskInput,
354{
355    async fn resolve_input(&self) -> Result<Self> {
356        let mut new_map = BTreeMap::new();
357        for (k, v) in self {
358            new_map.insert(
359                TaskInput::resolve_input(k).await?,
360                TaskInput::resolve_input(v).await?,
361            );
362        }
363        Ok(new_map)
364    }
365
366    fn is_resolved(&self) -> bool {
367        self.iter()
368            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
369    }
370
371    fn is_transient(&self) -> bool {
372        self.iter()
373            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
374    }
375}
376
377impl<T> TaskInput for BTreeSet<T>
378where
379    T: TaskInput + Ord,
380{
381    async fn resolve_input(&self) -> Result<Self> {
382        let mut new_set = BTreeSet::new();
383        for value in self {
384            new_set.insert(TaskInput::resolve_input(value).await?);
385        }
386        Ok(new_set)
387    }
388
389    fn is_resolved(&self) -> bool {
390        self.iter().all(TaskInput::is_resolved)
391    }
392
393    fn is_transient(&self) -> bool {
394        self.iter().any(TaskInput::is_transient)
395    }
396}
397
398impl<K, V> TaskInput for FrozenMap<K, V>
399where
400    K: TaskInput + Ord + 'static,
401    V: TaskInput + 'static,
402{
403    async fn resolve_input(&self) -> Result<Self> {
404        let mut new_entries = Vec::with_capacity(self.len());
405        for (k, v) in self {
406            new_entries.push((
407                TaskInput::resolve_input(k).await?,
408                TaskInput::resolve_input(v).await?,
409            ));
410        }
411        // note: resolving might deduplicate `Vc`s in keys
412        Ok(Self::from(new_entries))
413    }
414
415    fn is_resolved(&self) -> bool {
416        self.iter()
417            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
418    }
419
420    fn is_transient(&self) -> bool {
421        self.iter()
422            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
423    }
424}
425
426impl<T> TaskInput for FrozenSet<T>
427where
428    T: TaskInput + Ord + 'static,
429{
430    async fn resolve_input(&self) -> Result<Self> {
431        let mut new_set = Vec::with_capacity(self.len());
432        for value in self {
433            new_set.push(TaskInput::resolve_input(value).await?);
434        }
435        Ok(Self::from_iter(new_set))
436    }
437
438    fn is_resolved(&self) -> bool {
439        self.iter().all(TaskInput::is_resolved)
440    }
441
442    fn is_transient(&self) -> bool {
443        self.iter().any(TaskInput::is_transient)
444    }
445}
446
447/// A thin wrapper around [`Either`] that implements the traits required by [`TaskInput`], notably
448/// [`Encode`] and [`Decode`].
449#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
450pub struct EitherTaskInput<L, R>(pub Either<L, R>);
451
452impl<L, R> Deref for EitherTaskInput<L, R> {
453    type Target = Either<L, R>;
454
455    fn deref(&self) -> &Self::Target {
456        &self.0
457    }
458}
459
460impl<L, R> DerefMut for EitherTaskInput<L, R> {
461    fn deref_mut(&mut self) -> &mut Self::Target {
462        &mut self.0
463    }
464}
465
466impl<L, R> Encode for EitherTaskInput<L, R>
467where
468    L: Encode,
469    R: Encode,
470{
471    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
472        turbo_bincode::either::encode(self, encoder)
473    }
474}
475
476impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
477where
478    L: Decode<Context>,
479    R: Decode<Context>,
480{
481    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
482        turbo_bincode::either::decode(decoder).map(Self)
483    }
484}
485
486impl<L, R> TaskInput for EitherTaskInput<L, R>
487where
488    L: TaskInput,
489    R: TaskInput,
490{
491    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
492        self.as_ref().map_either(
493            |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
494            |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
495        )
496    }
497
498    fn is_resolved(&self) -> bool {
499        self.as_ref()
500            .either(TaskInput::is_resolved, TaskInput::is_resolved)
501    }
502
503    fn is_transient(&self) -> bool {
504        self.as_ref()
505            .either(TaskInput::is_transient, TaskInput::is_transient)
506    }
507}
508
509macro_rules! tuple_impls {
510    ( $( $name:ident )+ ) => {
511        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
512        where $($name: TaskInput),+
513        {
514            #[allow(non_snake_case)]
515            fn is_resolved(&self) -> bool {
516                let ($($name,)+) = self;
517                $($name.is_resolved() &&)+ true
518            }
519
520            #[allow(non_snake_case)]
521            fn is_transient(&self) -> bool {
522                let ($($name,)+) = self;
523                $($name.is_transient() ||)+ false
524            }
525
526            #[allow(non_snake_case)]
527            async fn resolve_input(&self) -> Result<Self> {
528                let ($($name,)+) = self;
529                Ok(($($name.resolve_input().await?,)+))
530            }
531        }
532    };
533}
534
535// Implement `TaskInput` for all tuples of 1 to 12 elements.
536tuple_impls! { A }
537tuple_impls! { A B }
538tuple_impls! { A B C }
539tuple_impls! { A B C D }
540tuple_impls! { A B C D E }
541tuple_impls! { A B C D E F }
542tuple_impls! { A B C D E F G }
543tuple_impls! { A B C D E F G H }
544tuple_impls! { A B C D E F G H I }
545tuple_impls! { A B C D E F G H I J }
546tuple_impls! { A B C D E F G H I J K }
547tuple_impls! { A B C D E F G H I J K L }
548
549#[cfg(test)]
550mod tests {
551    use turbo_rcstr::rcstr;
552
553    use super::*;
554
555    fn assert_task_input<T>(_: T)
556    where
557        T: TaskInput,
558    {
559    }
560
561    #[test]
562    fn test_no_fields() -> Result<()> {
563        #[turbo_tasks::task_input]
564        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
565        struct NoFields;
566
567        assert_task_input(NoFields);
568        Ok(())
569    }
570
571    #[test]
572    fn test_one_unnamed_field() -> Result<()> {
573        #[turbo_tasks::task_input]
574        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
575        struct OneUnnamedField(u32);
576
577        assert_task_input(OneUnnamedField(42));
578        Ok(())
579    }
580
581    #[test]
582    fn test_multiple_unnamed_fields() -> Result<()> {
583        #[turbo_tasks::task_input]
584        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
585        struct MultipleUnnamedFields(u32, RcStr);
586
587        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
588        Ok(())
589    }
590
591    #[test]
592    fn test_one_named_field() -> Result<()> {
593        #[turbo_tasks::task_input]
594        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
595        struct OneNamedField {
596            named: u32,
597        }
598
599        assert_task_input(OneNamedField { named: 42 });
600        Ok(())
601    }
602
603    #[test]
604    fn test_multiple_named_fields() -> Result<()> {
605        #[turbo_tasks::task_input]
606        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
607        struct MultipleNamedFields {
608            named: u32,
609            other: RcStr,
610        }
611
612        assert_task_input(MultipleNamedFields {
613            named: 42,
614            other: rcstr!("42"),
615        });
616        Ok(())
617    }
618
619    #[test]
620    fn test_generic_field() -> Result<()> {
621        #[turbo_tasks::task_input]
622        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
623        struct GenericField<T>(T);
624
625        assert_task_input(GenericField(42));
626        assert_task_input(GenericField(rcstr!("42")));
627        Ok(())
628    }
629
630    #[turbo_tasks::task_input]
631    #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
632    enum OneVariant {
633        Variant,
634    }
635
636    #[test]
637    fn test_one_variant() -> Result<()> {
638        assert_task_input(OneVariant::Variant);
639        Ok(())
640    }
641
642    #[test]
643    fn test_multiple_variants() -> Result<()> {
644        #[turbo_tasks::task_input]
645        #[derive(Clone, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
646        enum MultipleVariants {
647            Variant1,
648            Variant2,
649        }
650
651        assert_task_input(MultipleVariants::Variant2);
652        Ok(())
653    }
654
655    #[turbo_tasks::task_input]
656    #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
657    enum MultipleVariantsAndHeterogeneousFields {
658        Variant1,
659        Variant2(u32),
660        Variant3 { named: u32 },
661        Variant4(u32, RcStr),
662        Variant5 { named: u32, other: RcStr },
663    }
664
665    #[test]
666    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
667        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
668            named: 42,
669            other: rcstr!("42"),
670        });
671        Ok(())
672    }
673
674    #[test]
675    fn test_nested_variants() -> Result<()> {
676        #[turbo_tasks::task_input]
677        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
678        enum NestedVariants {
679            Variant1,
680            Variant2(MultipleVariantsAndHeterogeneousFields),
681            Variant3 { named: OneVariant },
682            Variant4(OneVariant, RcStr),
683            Variant5 { named: OneVariant, other: RcStr },
684        }
685
686        assert_task_input(NestedVariants::Variant5 {
687            named: OneVariant::Variant,
688            other: rcstr!("42"),
689        });
690        assert_task_input(NestedVariants::Variant2(
691            MultipleVariantsAndHeterogeneousFields::Variant5 {
692                named: 42,
693                other: rcstr!("42"),
694            },
695        ));
696        Ok(())
697    }
698}