Skip to main content

turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    ops::{Deref, DerefMut},
7    sync::Arc,
8    time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13    Decode, Encode,
14    de::Decoder,
15    enc::Encoder,
16    error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21use turbo_tasks_hash::HashAlgorithm;
22
23// This import is necessary for derive macros to work, as their expansion refers to the crate
24// name directly.
25use crate as turbo_tasks;
26use crate::{
27    MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
28    trace::TraceRawVcs,
29};
30
31/// Trait to implement in order for a type to be accepted as a
32/// [`#[turbo_tasks::function]`][crate::function] argument.
33///
34/// Transient task inputs are required to implement [`Encode`] and [`Decode`], but are allowed to
35/// panic at runtime. This requirement could be lifted in the future.
36///
37/// Bincode encoding must be deterministic and compatible with [`Eq`] comparisons. If two
38/// `TaskInput`s compare equal they must also encode to the same bytes.
39pub trait TaskInput:
40    Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
41{
42    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
43        async { Ok(self.clone()) }
44    }
45    fn is_resolved(&self) -> bool {
46        true
47    }
48    fn is_transient(&self) -> bool;
49}
50
51macro_rules! impl_task_input {
52    ($($t:ty),*) => {
53        $(
54            impl TaskInput for $t {
55                fn is_transient(&self) -> bool {
56                    false
57                }
58            }
59        )*
60    };
61}
62
63impl_task_input! {
64    (),
65    bool,
66    u8,
67    u16,
68    u32,
69    i32,
70    u64,
71    usize,
72    RcStr,
73    TaskId,
74    ValueTypeId,
75    Duration,
76    String,
77    HashAlgorithm
78}
79
80impl<T> TaskInput for Vec<T>
81where
82    T: TaskInput,
83{
84    fn is_resolved(&self) -> bool {
85        self.iter().all(TaskInput::is_resolved)
86    }
87
88    fn is_transient(&self) -> bool {
89        self.iter().any(TaskInput::is_transient)
90    }
91
92    async fn resolve_input(&self) -> Result<Self> {
93        let mut resolved = Vec::with_capacity(self.len());
94        for value in self {
95            resolved.push(value.resolve_input().await?);
96        }
97        Ok(resolved)
98    }
99}
100
101impl<T> TaskInput for Box<T>
102where
103    T: TaskInput,
104{
105    fn is_resolved(&self) -> bool {
106        self.as_ref().is_resolved()
107    }
108
109    fn is_transient(&self) -> bool {
110        self.as_ref().is_transient()
111    }
112
113    async fn resolve_input(&self) -> Result<Self> {
114        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
115    }
116}
117
118impl<T> TaskInput for Arc<T>
119where
120    T: TaskInput,
121{
122    fn is_resolved(&self) -> bool {
123        self.as_ref().is_resolved()
124    }
125
126    fn is_transient(&self) -> bool {
127        self.as_ref().is_transient()
128    }
129
130    async fn resolve_input(&self) -> Result<Self> {
131        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
132    }
133}
134
135impl<T> TaskInput for ReadRef<T>
136where
137    T: TaskInput,
138{
139    fn is_resolved(&self) -> bool {
140        Self::as_raw_ref(self).is_resolved()
141    }
142
143    fn is_transient(&self) -> bool {
144        Self::as_raw_ref(self).is_transient()
145    }
146
147    async fn resolve_input(&self) -> Result<Self> {
148        Ok(ReadRef::new_owned(
149            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
150        ))
151    }
152}
153
154impl<T> TaskInput for Option<T>
155where
156    T: TaskInput,
157{
158    fn is_resolved(&self) -> bool {
159        match self {
160            Some(value) => value.is_resolved(),
161            None => true,
162        }
163    }
164
165    fn is_transient(&self) -> bool {
166        match self {
167            Some(value) => value.is_transient(),
168            None => false,
169        }
170    }
171
172    async fn resolve_input(&self) -> Result<Self> {
173        match self {
174            Some(value) => Ok(Some(value.resolve_input().await?)),
175            None => Ok(None),
176        }
177    }
178}
179
180impl<T> TaskInput for Vc<T>
181where
182    T: Send + Sync + ?Sized,
183{
184    fn is_resolved(&self) -> bool {
185        Vc::is_resolved(*self)
186    }
187
188    fn is_transient(&self) -> bool {
189        self.node.is_transient()
190    }
191
192    async fn resolve_input(&self) -> Result<Self> {
193        Vc::resolve(*self).await
194    }
195}
196
197// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
198// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
199impl<T> TaskInput for ResolvedVc<T>
200where
201    T: Send + Sync + ?Sized,
202{
203    fn is_resolved(&self) -> bool {
204        true
205    }
206
207    fn is_transient(&self) -> bool {
208        self.node.is_transient()
209    }
210
211    async fn resolve_input(&self) -> Result<Self> {
212        Ok(*self)
213    }
214}
215
216impl<T> TaskInput for TransientValue<T>
217where
218    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
219{
220    fn is_transient(&self) -> bool {
221        true
222    }
223}
224
225impl<T> Encode for TransientValue<T> {
226    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
227        Err(EncodeError::Other("cannot encode transient task inputs"))
228    }
229}
230
231impl<Context, T> Decode<Context> for TransientValue<T> {
232    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
233        Err(DecodeError::Other("cannot decode transient task inputs"))
234    }
235}
236
237impl<T> TaskInput for TransientInstance<T>
238where
239    T: Sync + Send + TraceRawVcs + 'static,
240{
241    fn is_transient(&self) -> bool {
242        true
243    }
244}
245
246impl<T> Encode for TransientInstance<T> {
247    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
248        Err(EncodeError::Other("cannot encode transient task inputs"))
249    }
250}
251
252impl<Context, T> Decode<Context> for TransientInstance<T> {
253    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
254        Err(DecodeError::Other("cannot decode transient task inputs"))
255    }
256}
257
258impl<K, V> TaskInput for BTreeMap<K, V>
259where
260    K: TaskInput + Ord,
261    V: TaskInput,
262{
263    async fn resolve_input(&self) -> Result<Self> {
264        let mut new_map = BTreeMap::new();
265        for (k, v) in self {
266            new_map.insert(
267                TaskInput::resolve_input(k).await?,
268                TaskInput::resolve_input(v).await?,
269            );
270        }
271        Ok(new_map)
272    }
273
274    fn is_resolved(&self) -> bool {
275        self.iter()
276            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
277    }
278
279    fn is_transient(&self) -> bool {
280        self.iter()
281            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
282    }
283}
284
285impl<T> TaskInput for BTreeSet<T>
286where
287    T: TaskInput + Ord,
288{
289    async fn resolve_input(&self) -> Result<Self> {
290        let mut new_set = BTreeSet::new();
291        for value in self {
292            new_set.insert(TaskInput::resolve_input(value).await?);
293        }
294        Ok(new_set)
295    }
296
297    fn is_resolved(&self) -> bool {
298        self.iter().all(TaskInput::is_resolved)
299    }
300
301    fn is_transient(&self) -> bool {
302        self.iter().any(TaskInput::is_transient)
303    }
304}
305
306impl<K, V> TaskInput for FrozenMap<K, V>
307where
308    K: TaskInput + Ord + 'static,
309    V: TaskInput + 'static,
310{
311    async fn resolve_input(&self) -> Result<Self> {
312        let mut new_entries = Vec::with_capacity(self.len());
313        for (k, v) in self {
314            new_entries.push((
315                TaskInput::resolve_input(k).await?,
316                TaskInput::resolve_input(v).await?,
317            ));
318        }
319        // note: resolving might deduplicate `Vc`s in keys
320        Ok(Self::from(new_entries))
321    }
322
323    fn is_resolved(&self) -> bool {
324        self.iter()
325            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
326    }
327
328    fn is_transient(&self) -> bool {
329        self.iter()
330            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
331    }
332}
333
334impl<T> TaskInput for FrozenSet<T>
335where
336    T: TaskInput + Ord + 'static,
337{
338    async fn resolve_input(&self) -> Result<Self> {
339        let mut new_set = Vec::with_capacity(self.len());
340        for value in self {
341            new_set.push(TaskInput::resolve_input(value).await?);
342        }
343        Ok(Self::from_iter(new_set))
344    }
345
346    fn is_resolved(&self) -> bool {
347        self.iter().all(TaskInput::is_resolved)
348    }
349
350    fn is_transient(&self) -> bool {
351        self.iter().any(TaskInput::is_transient)
352    }
353}
354
355/// A thin wrapper around [`Either`] that implements the traits required by [`TaskInput`], notably
356/// [`Encode`] and [`Decode`].
357#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
358pub struct EitherTaskInput<L, R>(pub Either<L, R>);
359
360impl<L, R> Deref for EitherTaskInput<L, R> {
361    type Target = Either<L, R>;
362
363    fn deref(&self) -> &Self::Target {
364        &self.0
365    }
366}
367
368impl<L, R> DerefMut for EitherTaskInput<L, R> {
369    fn deref_mut(&mut self) -> &mut Self::Target {
370        &mut self.0
371    }
372}
373
374impl<L, R> Encode for EitherTaskInput<L, R>
375where
376    L: Encode,
377    R: Encode,
378{
379    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
380        turbo_bincode::either::encode(self, encoder)
381    }
382}
383
384impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
385where
386    L: Decode<Context>,
387    R: Decode<Context>,
388{
389    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
390        turbo_bincode::either::decode(decoder).map(Self)
391    }
392}
393
394impl<L, R> TaskInput for EitherTaskInput<L, R>
395where
396    L: TaskInput,
397    R: TaskInput,
398{
399    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
400        self.as_ref().map_either(
401            |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
402            |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
403        )
404    }
405
406    fn is_resolved(&self) -> bool {
407        self.as_ref()
408            .either(TaskInput::is_resolved, TaskInput::is_resolved)
409    }
410
411    fn is_transient(&self) -> bool {
412        self.as_ref()
413            .either(TaskInput::is_transient, TaskInput::is_transient)
414    }
415}
416
417macro_rules! tuple_impls {
418    ( $( $name:ident )+ ) => {
419        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
420        where $($name: TaskInput),+
421        {
422            #[allow(non_snake_case)]
423            fn is_resolved(&self) -> bool {
424                let ($($name,)+) = self;
425                $($name.is_resolved() &&)+ true
426            }
427
428            #[allow(non_snake_case)]
429            fn is_transient(&self) -> bool {
430                let ($($name,)+) = self;
431                $($name.is_transient() ||)+ false
432            }
433
434            #[allow(non_snake_case)]
435            async fn resolve_input(&self) -> Result<Self> {
436                let ($($name,)+) = self;
437                Ok(($($name.resolve_input().await?,)+))
438            }
439        }
440    };
441}
442
443// Implement `TaskInput` for all tuples of 1 to 12 elements.
444tuple_impls! { A }
445tuple_impls! { A B }
446tuple_impls! { A B C }
447tuple_impls! { A B C D }
448tuple_impls! { A B C D E }
449tuple_impls! { A B C D E F }
450tuple_impls! { A B C D E F G }
451tuple_impls! { A B C D E F G H }
452tuple_impls! { A B C D E F G H I }
453tuple_impls! { A B C D E F G H I J }
454tuple_impls! { A B C D E F G H I J K }
455tuple_impls! { A B C D E F G H I J K L }
456
457#[cfg(test)]
458mod tests {
459    use turbo_rcstr::rcstr;
460    use turbo_tasks_macros::TaskInput;
461
462    use super::*;
463
464    fn assert_task_input<T>(_: T)
465    where
466        T: TaskInput,
467    {
468    }
469
470    #[test]
471    fn test_no_fields() -> Result<()> {
472        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
473        struct NoFields;
474
475        assert_task_input(NoFields);
476        Ok(())
477    }
478
479    #[test]
480    fn test_one_unnamed_field() -> Result<()> {
481        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
482        struct OneUnnamedField(u32);
483
484        assert_task_input(OneUnnamedField(42));
485        Ok(())
486    }
487
488    #[test]
489    fn test_multiple_unnamed_fields() -> Result<()> {
490        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
491        struct MultipleUnnamedFields(u32, RcStr);
492
493        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
494        Ok(())
495    }
496
497    #[test]
498    fn test_one_named_field() -> Result<()> {
499        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
500        struct OneNamedField {
501            named: u32,
502        }
503
504        assert_task_input(OneNamedField { named: 42 });
505        Ok(())
506    }
507
508    #[test]
509    fn test_multiple_named_fields() -> Result<()> {
510        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
511        struct MultipleNamedFields {
512            named: u32,
513            other: RcStr,
514        }
515
516        assert_task_input(MultipleNamedFields {
517            named: 42,
518            other: rcstr!("42"),
519        });
520        Ok(())
521    }
522
523    #[test]
524    fn test_generic_field() -> Result<()> {
525        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
526        struct GenericField<T>(T);
527
528        assert_task_input(GenericField(42));
529        assert_task_input(GenericField(rcstr!("42")));
530        Ok(())
531    }
532
533    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
534    enum OneVariant {
535        Variant,
536    }
537
538    #[test]
539    fn test_one_variant() -> Result<()> {
540        assert_task_input(OneVariant::Variant);
541        Ok(())
542    }
543
544    #[test]
545    fn test_multiple_variants() -> Result<()> {
546        #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
547        enum MultipleVariants {
548            Variant1,
549            Variant2,
550        }
551
552        assert_task_input(MultipleVariants::Variant2);
553        Ok(())
554    }
555
556    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
557    enum MultipleVariantsAndHeterogeneousFields {
558        Variant1,
559        Variant2(u32),
560        Variant3 { named: u32 },
561        Variant4(u32, RcStr),
562        Variant5 { named: u32, other: RcStr },
563    }
564
565    #[test]
566    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
567        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
568            named: 42,
569            other: rcstr!("42"),
570        });
571        Ok(())
572    }
573
574    #[test]
575    fn test_nested_variants() -> Result<()> {
576        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
577        enum NestedVariants {
578            Variant1,
579            Variant2(MultipleVariantsAndHeterogeneousFields),
580            Variant3 { named: OneVariant },
581            Variant4(OneVariant, RcStr),
582            Variant5 { named: OneVariant, other: RcStr },
583        }
584
585        assert_task_input(NestedVariants::Variant5 {
586            named: OneVariant::Variant,
587            other: rcstr!("42"),
588        });
589        assert_task_input(NestedVariants::Variant2(
590            MultipleVariantsAndHeterogeneousFields::Variant5 {
591                named: 42,
592                other: rcstr!("42"),
593            },
594        ));
595        Ok(())
596    }
597}