turbo_tasks/task/
task_input.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    fmt::Debug,
4    future::Future,
5    hash::Hash,
6    ops::{Deref, DerefMut},
7    sync::Arc,
8    time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13    Decode, Encode,
14    de::Decoder,
15    enc::Encoder,
16    error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21
22// This import is necessary for derive macros to work, as their expansion refers to the crate
23// name directly.
24use crate as turbo_tasks;
25use crate::{
26    MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
27    trace::TraceRawVcs,
28};
29
30/// Trait to implement in order for a type to be accepted as a
31/// [`#[turbo_tasks::function]`][crate::function] argument.
32///
33/// Transient task inputs are required to implement [`Encode`] and [`Decode`], but are allowed to
34/// panic at runtime. This requirement could be lifted in the future.
35///
36/// Bincode encoding must be deterministic and compatible with [`Eq`] comparisons. If two
37/// `TaskInput`s compare equal they must also encode to the same bytes.
38pub trait TaskInput:
39    Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
40{
41    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
42        async { Ok(self.clone()) }
43    }
44    fn is_resolved(&self) -> bool {
45        true
46    }
47    fn is_transient(&self) -> bool;
48}
49
50macro_rules! impl_task_input {
51    ($($t:ty),*) => {
52        $(
53            impl TaskInput for $t {
54                fn is_transient(&self) -> bool {
55                    false
56                }
57            }
58        )*
59    };
60}
61
62impl_task_input! {
63    (),
64    bool,
65    u8,
66    u16,
67    u32,
68    i32,
69    u64,
70    usize,
71    RcStr,
72    TaskId,
73    ValueTypeId,
74    Duration,
75    String
76}
77
78impl<T> TaskInput for Vec<T>
79where
80    T: TaskInput,
81{
82    fn is_resolved(&self) -> bool {
83        self.iter().all(TaskInput::is_resolved)
84    }
85
86    fn is_transient(&self) -> bool {
87        self.iter().any(TaskInput::is_transient)
88    }
89
90    async fn resolve_input(&self) -> Result<Self> {
91        let mut resolved = Vec::with_capacity(self.len());
92        for value in self {
93            resolved.push(value.resolve_input().await?);
94        }
95        Ok(resolved)
96    }
97}
98
99impl<T> TaskInput for Box<T>
100where
101    T: TaskInput,
102{
103    fn is_resolved(&self) -> bool {
104        self.as_ref().is_resolved()
105    }
106
107    fn is_transient(&self) -> bool {
108        self.as_ref().is_transient()
109    }
110
111    async fn resolve_input(&self) -> Result<Self> {
112        Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
113    }
114}
115
116impl<T> TaskInput for Arc<T>
117where
118    T: TaskInput,
119{
120    fn is_resolved(&self) -> bool {
121        self.as_ref().is_resolved()
122    }
123
124    fn is_transient(&self) -> bool {
125        self.as_ref().is_transient()
126    }
127
128    async fn resolve_input(&self) -> Result<Self> {
129        Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
130    }
131}
132
133impl<T> TaskInput for ReadRef<T>
134where
135    T: TaskInput,
136{
137    fn is_resolved(&self) -> bool {
138        Self::as_raw_ref(self).is_resolved()
139    }
140
141    fn is_transient(&self) -> bool {
142        Self::as_raw_ref(self).is_transient()
143    }
144
145    async fn resolve_input(&self) -> Result<Self> {
146        Ok(ReadRef::new_owned(
147            Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
148        ))
149    }
150}
151
152impl<T> TaskInput for Option<T>
153where
154    T: TaskInput,
155{
156    fn is_resolved(&self) -> bool {
157        match self {
158            Some(value) => value.is_resolved(),
159            None => true,
160        }
161    }
162
163    fn is_transient(&self) -> bool {
164        match self {
165            Some(value) => value.is_transient(),
166            None => false,
167        }
168    }
169
170    async fn resolve_input(&self) -> Result<Self> {
171        match self {
172            Some(value) => Ok(Some(value.resolve_input().await?)),
173            None => Ok(None),
174        }
175    }
176}
177
178impl<T> TaskInput for Vc<T>
179where
180    T: Send + Sync + ?Sized,
181{
182    fn is_resolved(&self) -> bool {
183        Vc::is_resolved(*self)
184    }
185
186    fn is_transient(&self) -> bool {
187        self.node.is_transient()
188    }
189
190    async fn resolve_input(&self) -> Result<Self> {
191        Vc::resolve(*self).await
192    }
193}
194
195// `TaskInput` isn't needed/used for a bare `ResolvedVc`, as we'll expose `ResolvedVc` arguments as
196// `Vc`, but it is useful for structs that contain `ResolvedVc` and want to derive `TaskInput`.
197impl<T> TaskInput for ResolvedVc<T>
198where
199    T: Send + Sync + ?Sized,
200{
201    fn is_resolved(&self) -> bool {
202        true
203    }
204
205    fn is_transient(&self) -> bool {
206        self.node.is_transient()
207    }
208
209    async fn resolve_input(&self) -> Result<Self> {
210        Ok(*self)
211    }
212}
213
214impl<T> TaskInput for TransientValue<T>
215where
216    T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
217{
218    fn is_transient(&self) -> bool {
219        true
220    }
221}
222
223impl<T> Encode for TransientValue<T> {
224    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
225        Err(EncodeError::Other("cannot encode transient task inputs"))
226    }
227}
228
229impl<Context, T> Decode<Context> for TransientValue<T> {
230    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
231        Err(DecodeError::Other("cannot decode transient task inputs"))
232    }
233}
234
235impl<T> TaskInput for TransientInstance<T>
236where
237    T: Sync + Send + TraceRawVcs + 'static,
238{
239    fn is_transient(&self) -> bool {
240        true
241    }
242}
243
244impl<T> Encode for TransientInstance<T> {
245    fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
246        Err(EncodeError::Other("cannot encode transient task inputs"))
247    }
248}
249
250impl<Context, T> Decode<Context> for TransientInstance<T> {
251    fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
252        Err(DecodeError::Other("cannot decode transient task inputs"))
253    }
254}
255
256impl<K, V> TaskInput for BTreeMap<K, V>
257where
258    K: TaskInput + Ord,
259    V: TaskInput,
260{
261    async fn resolve_input(&self) -> Result<Self> {
262        let mut new_map = BTreeMap::new();
263        for (k, v) in self {
264            new_map.insert(
265                TaskInput::resolve_input(k).await?,
266                TaskInput::resolve_input(v).await?,
267            );
268        }
269        Ok(new_map)
270    }
271
272    fn is_resolved(&self) -> bool {
273        self.iter()
274            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
275    }
276
277    fn is_transient(&self) -> bool {
278        self.iter()
279            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
280    }
281}
282
283impl<T> TaskInput for BTreeSet<T>
284where
285    T: TaskInput + Ord,
286{
287    async fn resolve_input(&self) -> Result<Self> {
288        let mut new_set = BTreeSet::new();
289        for value in self {
290            new_set.insert(TaskInput::resolve_input(value).await?);
291        }
292        Ok(new_set)
293    }
294
295    fn is_resolved(&self) -> bool {
296        self.iter().all(TaskInput::is_resolved)
297    }
298
299    fn is_transient(&self) -> bool {
300        self.iter().any(TaskInput::is_transient)
301    }
302}
303
304impl<K, V> TaskInput for FrozenMap<K, V>
305where
306    K: TaskInput + Ord + 'static,
307    V: TaskInput + 'static,
308{
309    async fn resolve_input(&self) -> Result<Self> {
310        let mut new_entries = Vec::with_capacity(self.len());
311        for (k, v) in self {
312            new_entries.push((
313                TaskInput::resolve_input(k).await?,
314                TaskInput::resolve_input(v).await?,
315            ));
316        }
317        // note: resolving might deduplicate `Vc`s in keys
318        Ok(Self::from(new_entries))
319    }
320
321    fn is_resolved(&self) -> bool {
322        self.iter()
323            .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
324    }
325
326    fn is_transient(&self) -> bool {
327        self.iter()
328            .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
329    }
330}
331
332impl<T> TaskInput for FrozenSet<T>
333where
334    T: TaskInput + Ord + 'static,
335{
336    async fn resolve_input(&self) -> Result<Self> {
337        let mut new_set = Vec::with_capacity(self.len());
338        for value in self {
339            new_set.push(TaskInput::resolve_input(value).await?);
340        }
341        Ok(Self::from_iter(new_set))
342    }
343
344    fn is_resolved(&self) -> bool {
345        self.iter().all(TaskInput::is_resolved)
346    }
347
348    fn is_transient(&self) -> bool {
349        self.iter().any(TaskInput::is_transient)
350    }
351}
352
353/// A thin wrapper around [`Either`] that implements the traits required by [`TaskInput`], notably
354/// [`Encode`] and [`Decode`].
355#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
356pub struct EitherTaskInput<L, R>(pub Either<L, R>);
357
358impl<L, R> Deref for EitherTaskInput<L, R> {
359    type Target = Either<L, R>;
360
361    fn deref(&self) -> &Self::Target {
362        &self.0
363    }
364}
365
366impl<L, R> DerefMut for EitherTaskInput<L, R> {
367    fn deref_mut(&mut self) -> &mut Self::Target {
368        &mut self.0
369    }
370}
371
372impl<L, R> Encode for EitherTaskInput<L, R>
373where
374    L: Encode,
375    R: Encode,
376{
377    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
378        turbo_bincode::either::encode(self, encoder)
379    }
380}
381
382impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
383where
384    L: Decode<Context>,
385    R: Decode<Context>,
386{
387    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
388        turbo_bincode::either::decode(decoder).map(Self)
389    }
390}
391
392impl<L, R> TaskInput for EitherTaskInput<L, R>
393where
394    L: TaskInput,
395    R: TaskInput,
396{
397    fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
398        self.as_ref().map_either(
399            |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
400            |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
401        )
402    }
403
404    fn is_resolved(&self) -> bool {
405        self.as_ref()
406            .either(TaskInput::is_resolved, TaskInput::is_resolved)
407    }
408
409    fn is_transient(&self) -> bool {
410        self.as_ref()
411            .either(TaskInput::is_transient, TaskInput::is_transient)
412    }
413}
414
415macro_rules! tuple_impls {
416    ( $( $name:ident )+ ) => {
417        impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
418        where $($name: TaskInput),+
419        {
420            #[allow(non_snake_case)]
421            fn is_resolved(&self) -> bool {
422                let ($($name,)+) = self;
423                $($name.is_resolved() &&)+ true
424            }
425
426            #[allow(non_snake_case)]
427            fn is_transient(&self) -> bool {
428                let ($($name,)+) = self;
429                $($name.is_transient() ||)+ false
430            }
431
432            #[allow(non_snake_case)]
433            async fn resolve_input(&self) -> Result<Self> {
434                let ($($name,)+) = self;
435                Ok(($($name.resolve_input().await?,)+))
436            }
437        }
438    };
439}
440
441// Implement `TaskInput` for all tuples of 1 to 12 elements.
442tuple_impls! { A }
443tuple_impls! { A B }
444tuple_impls! { A B C }
445tuple_impls! { A B C D }
446tuple_impls! { A B C D E }
447tuple_impls! { A B C D E F }
448tuple_impls! { A B C D E F G }
449tuple_impls! { A B C D E F G H }
450tuple_impls! { A B C D E F G H I }
451tuple_impls! { A B C D E F G H I J }
452tuple_impls! { A B C D E F G H I J K }
453tuple_impls! { A B C D E F G H I J K L }
454
455#[cfg(test)]
456mod tests {
457    use turbo_rcstr::rcstr;
458    use turbo_tasks_macros::TaskInput;
459
460    use super::*;
461
462    fn assert_task_input<T>(_: T)
463    where
464        T: TaskInput,
465    {
466    }
467
468    #[test]
469    fn test_no_fields() -> Result<()> {
470        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
471        struct NoFields;
472
473        assert_task_input(NoFields);
474        Ok(())
475    }
476
477    #[test]
478    fn test_one_unnamed_field() -> Result<()> {
479        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
480        struct OneUnnamedField(u32);
481
482        assert_task_input(OneUnnamedField(42));
483        Ok(())
484    }
485
486    #[test]
487    fn test_multiple_unnamed_fields() -> Result<()> {
488        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
489        struct MultipleUnnamedFields(u32, RcStr);
490
491        assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
492        Ok(())
493    }
494
495    #[test]
496    fn test_one_named_field() -> Result<()> {
497        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
498        struct OneNamedField {
499            named: u32,
500        }
501
502        assert_task_input(OneNamedField { named: 42 });
503        Ok(())
504    }
505
506    #[test]
507    fn test_multiple_named_fields() -> Result<()> {
508        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
509        struct MultipleNamedFields {
510            named: u32,
511            other: RcStr,
512        }
513
514        assert_task_input(MultipleNamedFields {
515            named: 42,
516            other: rcstr!("42"),
517        });
518        Ok(())
519    }
520
521    #[test]
522    fn test_generic_field() -> Result<()> {
523        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
524        struct GenericField<T>(T);
525
526        assert_task_input(GenericField(42));
527        assert_task_input(GenericField(rcstr!("42")));
528        Ok(())
529    }
530
531    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
532    enum OneVariant {
533        Variant,
534    }
535
536    #[test]
537    fn test_one_variant() -> Result<()> {
538        assert_task_input(OneVariant::Variant);
539        Ok(())
540    }
541
542    #[test]
543    fn test_multiple_variants() -> Result<()> {
544        #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
545        enum MultipleVariants {
546            Variant1,
547            Variant2,
548        }
549
550        assert_task_input(MultipleVariants::Variant2);
551        Ok(())
552    }
553
554    #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
555    enum MultipleVariantsAndHeterogeneousFields {
556        Variant1,
557        Variant2(u32),
558        Variant3 { named: u32 },
559        Variant4(u32, RcStr),
560        Variant5 { named: u32, other: RcStr },
561    }
562
563    #[test]
564    fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
565        assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
566            named: 42,
567            other: rcstr!("42"),
568        });
569        Ok(())
570    }
571
572    #[test]
573    fn test_nested_variants() -> Result<()> {
574        #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
575        enum NestedVariants {
576            Variant1,
577            Variant2(MultipleVariantsAndHeterogeneousFields),
578            Variant3 { named: OneVariant },
579            Variant4(OneVariant, RcStr),
580            Variant5 { named: OneVariant, other: RcStr },
581        }
582
583        assert_task_input(NestedVariants::Variant5 {
584            named: OneVariant::Variant,
585            other: rcstr!("42"),
586        });
587        assert_task_input(NestedVariants::Variant2(
588            MultipleVariantsAndHeterogeneousFields::Variant5 {
589                named: 42,
590                other: rcstr!("42"),
591            },
592        ));
593        Ok(())
594    }
595}