Skip to main content

turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    ops::{Deref, DerefMut},
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use anyhow::Result;
14use bincode::{
15    Decode, Encode,
16    de::Decoder,
17    enc::Encoder,
18    error::{DecodeError, EncodeError},
19};
20use either::Either;
21use turbo_frozenmap::{FrozenMap, FrozenSet};
22use turbo_rcstr::RcStr;
23use turbo_tasks_hash::HashAlgorithm;
24
25// This import is necessary for derive macros to work, as their expansion refers to the crate
26// name directly.
27use crate::{self as turbo_tasks, ReadRef};
28use crate::{
29    DynTaskInputs, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
30    trace::TraceRawVcs,
31};
32
33/// An 8-byte hand-rolled [`Future`] that immediately resolves to `Ok(self.clone())` of the
34/// referenced value.
35///
36/// Used by the [`TaskInput::resolve_input`] default implementation
37struct CloneReady<'a, T> {
38    pub inner: Option<&'a T>,
39}
40
41impl<'a, T: Clone> Future for CloneReady<'a, T> {
42    type Output = Result<T>;
43
44    fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
45        Poll::Ready(Ok(self
46            .inner
47            .take()
48            .expect("future already polled to completion")
49            .clone()))
50    }
51}
52
53// `CloneReady` holds only a shared reference; it has no self-referential state.
54impl<'a, T> Unpin for CloneReady<'a, T> {}
55
56/// Trait to implement in order for a type to be accepted as a
57/// [`#[turbo_tasks::function]`][crate::function] argument.
58///
59/// ## Serialization
60///
61/// For persistent caching of a task, arguments must be serializable. All `TaskInput`s must
62/// implement the bincode [`Encode`] and [`Decode`] traits.
63///
64/// Transient task inputs are required to implement [`Encode`] and [`Decode`], but are allowed to
65/// panic at runtime. This requirement could be lifted in the future.
66///
67/// Bincode encoding must be deterministic and compatible with [`Eq`] comparisons. If two
68/// `TaskInput`s compare equal they must also encode to the same bytes.
69///
70/// ## Hash and Eq
71///
72/// Arguments are used as part of keys in a `HashMap`, so they must implement of [`PartialEq`],
73/// [`Eq`], and [`Hash`] traits.
74///
75/// ## [`Vc<T>`][Vc]
76///
77/// A [`Vc`] is a pointer to a cell. It implements `TaskInput` and serves as a "pass by reference"
78/// argument:
79///
80/// - **Memoization**: [`Vc`] is keyed by pointer for memoization purposes. Identical values in
81///   different cells are treated as distinct.
82/// - **Singleton Pattern**: To ensure memoization efficiency, the singleton pattern can be employed
83///   to guarantee that identical values yield the same `Vc`. For more info see [Singleton Pattern
84///   Guide][singleton].
85///
86/// [singleton]: https://turbopack-rust-docs.vercel.sh/turbo-engine/singleton.html
87///
88/// ## Deriving `TaskInput`
89///
90/// Structs or enums can be made into task inputs by deriving `TaskInput`:
91///
92/// ```rust
93/// #[turbo_tasks::task_input]
94/// struct MyStruct {
95///     // Fields go here...
96/// }
97/// ```
98///
99/// Derived `TaskInput` types **passed by value**. When called, arguments are moved into a `Box`,
100/// and then cloned before being passed into the function. If the task is invalidated, the
101/// `TaskInput` is cloned again to allow the function to be re-executed. It's recommended to ensure
102/// that these types are cheap to clone.
103///
104/// Reference-counted types like [`Arc`] are cheap to clone, but each reference contained in a
105/// `TaskInput` will be serialized independently in the persistent cache, and may consume extra disk
106/// space. If an [`Arc`] points to a large type, consider wrapping that type in [`Vc`], so that only
107/// one copy of the value will be serialized.
108pub trait TaskInput:
109    Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
110{
111    /// This method should resolve any [`Vc`]s nested inside of this object, cloning the object in
112    /// the process. If the input is unresolved ([`TaskInput::is_resolved`]) a "local" resolution
113    /// task is created that runs this method.
114    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
115        CloneReady { inner: Some(self) }
116    }
117
118    /// This should return `true` if there are any unresolved [`Vc`]s in the type.
119    ///
120    /// Note that [`Vc`]s can sometimes be internally resolved, so you should call
121    /// [`Vc::is_resolved`] (or rely on the derive macro for this trait) instead of returning `true`
122    /// for any [`Vc`]. [`ResolvedVc::is_resolved`] always returns `true`.
123    ///
124    /// If this returns `true`, a "local" resolution task calling [`TaskInput::resolve_input`] will
125    /// be spawned before the function accepting the arguments is run.
126    ///
127    /// If this returns `false`, the `TaskInput` will be [cloned][Clone] instead of resolved, and
128    /// the function's task will be spawned directly without a resolution step.
129    fn is_resolved(&self) -> bool {
130        true
131    }
132
133    /// This should return true if this object contains a [`Vc`] (or any subtype of [`Vc`]) pointing
134    /// to a cell owned by a transient task.
135    ///
136    /// Any function called with a transient `TaskInput` will be transient. Any [`Vc`] constructed
137    /// in a transient task or in a top-level [`run_once`][crate::run_once] closure will be
138    /// transient.
139    ///
140    /// Internally, a [`Vc`] can be determined to be transient by comparing the owning task's id
141    /// with the [`TRANSIENT_TASK_BIT`][crate::TRANSIENT_TASK_BIT] mask.
142    fn is_transient(&self) -> bool;
143}
144
145macro_rules! impl_task_input {
146    ($($t:ty),*) => {
147        $(
148            impl TaskInput for $t {
149                fn is_transient(&self) -> bool {
150                    false
151                }
152            }
153        )*
154    };
155}
156
157impl_task_input! {
158    (),
159    bool,
160    u8,
161    u16,
162    u32,
163    i32,
164    u64,
165    u128,
166    usize,
167    RcStr,
168    TaskId,
169    ValueTypeId,
170    Duration,
171    String,
172    HashAlgorithm
173}
174
175impl<T> TaskInput for Vec<T>
176where
177    T: TaskInput,
178{
179    fn is_resolved(&self) -> bool {
180        self.iter().all(TaskInput::is_resolved)
181    }
182
183    fn is_transient(&self) -> bool {
184        self.iter().any(TaskInput::is_transient)
185    }
186
187    async fn resolve_input(&self) -> Result<Self> {
188        let mut resolved = Vec::with_capacity(self.len());
189        for value in self {
190            resolved.push(value.resolve_input().await?);
191        }
192        Ok(resolved)
193    }
194}
195
196impl<T> TaskInput for Box<T>
197where
198    T: TaskInput,
199{
200    fn is_resolved(&self) -> bool {
201        self.as_ref().is_resolved()
202    }
203
204    fn is_transient(&self) -> bool {
205        self.as_ref().is_transient()
206    }
207
208    async fn resolve_input(&self) -> Result<Self> {
209        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
210    }
211}
212
213impl<T> TaskInput for Arc<T>
214where
215    T: TaskInput,
216{
217    fn is_resolved(&self) -> bool {
218        self.as_ref().is_resolved()
219    }
220
221    fn is_transient(&self) -> bool {
222        self.as_ref().is_transient()
223    }
224
225    async fn resolve_input(&self) -> Result<Self> {
226        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
227    }
228}
229
230impl<T> TaskInput for ReadRef<T>
231where
232    T: TaskInput,
233{
234    fn is_resolved(&self) -> bool {
235        Self::as_raw_ref(self).is_resolved()
236    }
237
238    fn is_transient(&self) -> bool {
239        Self::as_raw_ref(self).is_transient()
240    }
241
242    async fn resolve_input(&self) -> Result<Self> {
243        Ok(ReadRef::new_owned(
244            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
245        ))
246    }
247}
248
249impl<T> TaskInput for Option<T>
250where
251    T: TaskInput,
252{
253    fn is_resolved(&self) -> bool {
254        match self {
255            Some(value) => value.is_resolved(),
256            None => true,
257        }
258    }
259
260    fn is_transient(&self) -> bool {
261        match self {
262            Some(value) => value.is_transient(),
263            None => false,
264        }
265    }
266
267    async fn resolve_input(&self) -> Result<Self> {
268        match self {
269            Some(value) => Ok(Some(value.resolve_input().await?)),
270            None => Ok(None),
271        }
272    }
273}
274
275impl<T> TaskInput for Vc<T>
276where
277    T: Send + Sync + ?Sized,
278{
279    fn is_resolved(&self) -> bool {
280        Vc::is_resolved(*self)
281    }
282
283    fn is_transient(&self) -> bool {
284        self.node.is_transient()
285    }
286
287    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
288        // It isn't ideal to use this function but it exactly matches this usecase (resolved but
289        // still a Vc)
290        (*self).resolve()
291    }
292}
293
294// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
295// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
296impl<T> TaskInput for ResolvedVc<T>
297where
298    T: Send + Sync + ?Sized,
299{
300    fn is_resolved(&self) -> bool {
301        true
302    }
303
304    fn is_transient(&self) -> bool {
305        self.node.is_transient()
306    }
307}
308
309impl<T> TaskInput for TransientValue<T>
310where
311    T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
312{
313    fn is_transient(&self) -> bool {
314        true
315    }
316}
317
318impl<T> Encode for TransientValue<T> {
319    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
320        Err(EncodeError::Other("cannot encode transient task inputs"))
321    }
322}
323
324impl<Context, T> Decode<Context> for TransientValue<T> {
325    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
326        Err(DecodeError::Other("cannot decode transient task inputs"))
327    }
328}
329
330impl<T> TaskInput for TransientInstance<T>
331where
332    T: Sync + Send + TraceRawVcs + 'static,
333{
334    fn is_transient(&self) -> bool {
335        true
336    }
337}
338
339impl<T> Encode for TransientInstance<T> {
340    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
341        Err(EncodeError::Other("cannot encode transient task inputs"))
342    }
343}
344
345impl<Context, T> Decode<Context> for TransientInstance<T> {
346    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
347        Err(DecodeError::Other("cannot decode transient task inputs"))
348    }
349}
350
351impl<K, V> TaskInput for BTreeMap<K, V>
352where
353    K: TaskInput + Ord,
354    V: TaskInput,
355{
356    async fn resolve_input(&self) -> Result<Self> {
357        let mut new_map = BTreeMap::new();
358        for (k, v) in self {
359            new_map.insert(
360                TaskInput::resolve_input(k).await?,
361                TaskInput::resolve_input(v).await?,
362            );
363        }
364        Ok(new_map)
365    }
366
367    fn is_resolved(&self) -> bool {
368        self.iter()
369            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
370    }
371
372    fn is_transient(&self) -> bool {
373        self.iter()
374            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
375    }
376}
377
378impl<T> TaskInput for BTreeSet<T>
379where
380    T: TaskInput + Ord,
381{
382    async fn resolve_input(&self) -> Result<Self> {
383        let mut new_set = BTreeSet::new();
384        for value in self {
385            new_set.insert(TaskInput::resolve_input(value).await?);
386        }
387        Ok(new_set)
388    }
389
390    fn is_resolved(&self) -> bool {
391        self.iter().all(TaskInput::is_resolved)
392    }
393
394    fn is_transient(&self) -> bool {
395        self.iter().any(TaskInput::is_transient)
396    }
397}
398
399impl<K, V> TaskInput for FrozenMap<K, V>
400where
401    K: TaskInput + Ord + 'static,
402    V: TaskInput + 'static,
403{
404    async fn resolve_input(&self) -> Result<Self> {
405        let mut new_entries = Vec::with_capacity(self.len());
406        for (k, v) in self {
407            new_entries.push((
408                TaskInput::resolve_input(k).await?,
409                TaskInput::resolve_input(v).await?,
410            ));
411        }
412        // note: resolving might deduplicate `Vc`s in keys
413        Ok(Self::from(new_entries))
414    }
415
416    fn is_resolved(&self) -> bool {
417        self.iter()
418            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
419    }
420
421    fn is_transient(&self) -> bool {
422        self.iter()
423            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
424    }
425}
426
427impl<T> TaskInput for FrozenSet<T>
428where
429    T: TaskInput + Ord + 'static,
430{
431    async fn resolve_input(&self) -> Result<Self> {
432        let mut new_set = Vec::with_capacity(self.len());
433        for value in self {
434            new_set.push(TaskInput::resolve_input(value).await?);
435        }
436        Ok(Self::from_iter(new_set))
437    }
438
439    fn is_resolved(&self) -> bool {
440        self.iter().all(TaskInput::is_resolved)
441    }
442
443    fn is_transient(&self) -> bool {
444        self.iter().any(TaskInput::is_transient)
445    }
446}
447
448/// A thin wrapper around [`Either`] that implements the traits required by [`TaskInput`], notably
449/// [`Encode`] and [`Decode`].
450#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
451pub struct EitherTaskInput<L, R>(pub Either<L, R>);
452
453impl<L, R> Deref for EitherTaskInput<L, R> {
454    type Target = Either<L, R>;
455
456    fn deref(&self) -> &Self::Target {
457        &self.0
458    }
459}
460
461impl<L, R> DerefMut for EitherTaskInput<L, R> {
462    fn deref_mut(&mut self) -> &mut Self::Target {
463        &mut self.0
464    }
465}
466
467impl<L, R> Encode for EitherTaskInput<L, R>
468where
469    L: Encode,
470    R: Encode,
471{
472    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
473        turbo_bincode::either::encode(self, encoder)
474    }
475}
476
477impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
478where
479    L: Decode<Context>,
480    R: Decode<Context>,
481{
482    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
483        turbo_bincode::either::decode(decoder).map(Self)
484    }
485}
486
487impl<L, R> TaskInput for EitherTaskInput<L, R>
488where
489    L: TaskInput,
490    R: TaskInput,
491{
492    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
493        self.as_ref().map_either(
494            |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
495            |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
496        )
497    }
498
499    fn is_resolved(&self) -> bool {
500        self.as_ref()
501            .either(TaskInput::is_resolved, TaskInput::is_resolved)
502    }
503
504    fn is_transient(&self) -> bool {
505        self.as_ref()
506            .either(TaskInput::is_transient, TaskInput::is_transient)
507    }
508}
509
510macro_rules! tuple_impls {
511    ( $( $name:ident )+ ) => {
512        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
513        where $($name: TaskInput),+
514        {
515            #[allow(non_snake_case)]
516            fn is_resolved(&self) -> bool {
517                let ($($name,)+) = self;
518                $($name.is_resolved() &&)+ true
519            }
520
521            #[allow(non_snake_case)]
522            fn is_transient(&self) -> bool {
523                let ($($name,)+) = self;
524                $($name.is_transient() ||)+ false
525            }
526
527            #[allow(non_snake_case)]
528            async fn resolve_input(&self) -> Result<Self> {
529                let ($($name,)+) = self;
530                Ok(($($name.resolve_input().await?,)+))
531            }
532        }
533    };
534}
535
536// Implement `TaskInput` for all tuples of 1 to 12 elements.
537tuple_impls! { A }
538tuple_impls! { A B }
539tuple_impls! { A B C }
540tuple_impls! { A B C D }
541tuple_impls! { A B C D E }
542tuple_impls! { A B C D E F }
543tuple_impls! { A B C D E F G }
544tuple_impls! { A B C D E F G H }
545tuple_impls! { A B C D E F G H I }
546tuple_impls! { A B C D E F G H I J }
547tuple_impls! { A B C D E F G H I J K }
548tuple_impls! { A B C D E F G H I J K L }
549
550#[cfg(test)]
551mod tests {
552    use turbo_rcstr::rcstr;
553
554    use super::*;
555
556    fn assert_task_input<T>(_: T)
557    where
558        T: TaskInput,
559    {
560    }
561
562    #[test]
563    fn test_no_fields() -> Result<()> {
564        #[turbo_tasks::task_input]
565        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
566        struct NoFields;
567
568        assert_task_input(NoFields);
569        Ok(())
570    }
571
572    #[test]
573    fn test_one_unnamed_field() -> Result<()> {
574        #[turbo_tasks::task_input]
575        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
576        struct OneUnnamedField(u32);
577
578        assert_task_input(OneUnnamedField(42));
579        Ok(())
580    }
581
582    #[test]
583    fn test_multiple_unnamed_fields() -> Result<()> {
584        #[turbo_tasks::task_input]
585        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
586        struct MultipleUnnamedFields(u32, RcStr);
587
588        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
589        Ok(())
590    }
591
592    #[test]
593    fn test_one_named_field() -> Result<()> {
594        #[turbo_tasks::task_input]
595        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
596        struct OneNamedField {
597            named: u32,
598        }
599
600        assert_task_input(OneNamedField { named: 42 });
601        Ok(())
602    }
603
604    #[test]
605    fn test_multiple_named_fields() -> Result<()> {
606        #[turbo_tasks::task_input]
607        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
608        struct MultipleNamedFields {
609            named: u32,
610            other: RcStr,
611        }
612
613        assert_task_input(MultipleNamedFields {
614            named: 42,
615            other: rcstr!("42"),
616        });
617        Ok(())
618    }
619
620    #[test]
621    fn test_generic_field() -> Result<()> {
622        #[turbo_tasks::task_input]
623        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
624        struct GenericField<T>(T);
625
626        assert_task_input(GenericField(42));
627        assert_task_input(GenericField(rcstr!("42")));
628        Ok(())
629    }
630
631    #[turbo_tasks::task_input]
632    #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
633    enum OneVariant {
634        Variant,
635    }
636
637    #[test]
638    fn test_one_variant() -> Result<()> {
639        assert_task_input(OneVariant::Variant);
640        Ok(())
641    }
642
643    #[test]
644    fn test_multiple_variants() -> Result<()> {
645        #[turbo_tasks::task_input]
646        #[derive(Clone, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
647        enum MultipleVariants {
648            Variant1,
649            Variant2,
650        }
651
652        assert_task_input(MultipleVariants::Variant2);
653        Ok(())
654    }
655
656    #[turbo_tasks::task_input]
657    #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
658    enum MultipleVariantsAndHeterogeneousFields {
659        Variant1,
660        Variant2(u32),
661        Variant3 { named: u32 },
662        Variant4(u32, RcStr),
663        Variant5 { named: u32, other: RcStr },
664    }
665
666    #[test]
667    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
668        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
669            named: 42,
670            other: rcstr!("42"),
671        });
672        Ok(())
673    }
674
675    #[test]
676    fn test_nested_variants() -> Result<()> {
677        #[turbo_tasks::task_input]
678        #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
679        enum NestedVariants {
680            Variant1,
681            Variant2(MultipleVariantsAndHeterogeneousFields),
682            Variant3 { named: OneVariant },
683            Variant4(OneVariant, RcStr),
684            Variant5 { named: OneVariant, other: RcStr },
685        }
686
687        assert_task_input(NestedVariants::Variant5 {
688            named: OneVariant::Variant,
689            other: rcstr!("42"),
690        });
691        assert_task_input(NestedVariants::Variant2(
692            MultipleVariantsAndHeterogeneousFields::Variant5 {
693                named: 42,
694                other: rcstr!("42"),
695            },
696        ));
697        Ok(())
698    }
699}