1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 ops::{Deref, DerefMut},
7 sync::Arc,
8 time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13 Decode, Encode,
14 de::Decoder,
15 enc::Encoder,
16 error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21use turbo_tasks_hash::HashAlgorithm;
22
23use crate as turbo_tasks;
26use crate::{
27 MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
28 trace::TraceRawVcs,
29};
30
31pub trait TaskInput:
40 Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
41{
42 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
43 async { Ok(self.clone()) }
44 }
45 fn is_resolved(&self) -> bool {
46 true
47 }
48 fn is_transient(&self) -> bool;
49}
50
51macro_rules! impl_task_input {
52 ($($t:ty),*) => {
53 $(
54 impl TaskInput for $t {
55 fn is_transient(&self) -> bool {
56 false
57 }
58 }
59 )*
60 };
61}
62
63impl_task_input! {
64 (),
65 bool,
66 u8,
67 u16,
68 u32,
69 i32,
70 u64,
71 usize,
72 RcStr,
73 TaskId,
74 ValueTypeId,
75 Duration,
76 String,
77 HashAlgorithm
78}
79
80impl<T> TaskInput for Vec<T>
81where
82 T: TaskInput,
83{
84 fn is_resolved(&self) -> bool {
85 self.iter().all(TaskInput::is_resolved)
86 }
87
88 fn is_transient(&self) -> bool {
89 self.iter().any(TaskInput::is_transient)
90 }
91
92 async fn resolve_input(&self) -> Result<Self> {
93 let mut resolved = Vec::with_capacity(self.len());
94 for value in self {
95 resolved.push(value.resolve_input().await?);
96 }
97 Ok(resolved)
98 }
99}
100
101impl<T> TaskInput for Box<T>
102where
103 T: TaskInput,
104{
105 fn is_resolved(&self) -> bool {
106 self.as_ref().is_resolved()
107 }
108
109 fn is_transient(&self) -> bool {
110 self.as_ref().is_transient()
111 }
112
113 async fn resolve_input(&self) -> Result<Self> {
114 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
115 }
116}
117
118impl<T> TaskInput for Arc<T>
119where
120 T: TaskInput,
121{
122 fn is_resolved(&self) -> bool {
123 self.as_ref().is_resolved()
124 }
125
126 fn is_transient(&self) -> bool {
127 self.as_ref().is_transient()
128 }
129
130 async fn resolve_input(&self) -> Result<Self> {
131 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
132 }
133}
134
135impl<T> TaskInput for ReadRef<T>
136where
137 T: TaskInput,
138{
139 fn is_resolved(&self) -> bool {
140 Self::as_raw_ref(self).is_resolved()
141 }
142
143 fn is_transient(&self) -> bool {
144 Self::as_raw_ref(self).is_transient()
145 }
146
147 async fn resolve_input(&self) -> Result<Self> {
148 Ok(ReadRef::new_owned(
149 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
150 ))
151 }
152}
153
154impl<T> TaskInput for Option<T>
155where
156 T: TaskInput,
157{
158 fn is_resolved(&self) -> bool {
159 match self {
160 Some(value) => value.is_resolved(),
161 None => true,
162 }
163 }
164
165 fn is_transient(&self) -> bool {
166 match self {
167 Some(value) => value.is_transient(),
168 None => false,
169 }
170 }
171
172 async fn resolve_input(&self) -> Result<Self> {
173 match self {
174 Some(value) => Ok(Some(value.resolve_input().await?)),
175 None => Ok(None),
176 }
177 }
178}
179
180impl<T> TaskInput for Vc<T>
181where
182 T: Send + Sync + ?Sized,
183{
184 fn is_resolved(&self) -> bool {
185 Vc::is_resolved(*self)
186 }
187
188 fn is_transient(&self) -> bool {
189 self.node.is_transient()
190 }
191
192 async fn resolve_input(&self) -> Result<Self> {
193 Vc::resolve(*self).await
194 }
195}
196
197impl<T> TaskInput for ResolvedVc<T>
200where
201 T: Send + Sync + ?Sized,
202{
203 fn is_resolved(&self) -> bool {
204 true
205 }
206
207 fn is_transient(&self) -> bool {
208 self.node.is_transient()
209 }
210
211 async fn resolve_input(&self) -> Result<Self> {
212 Ok(*self)
213 }
214}
215
216impl<T> TaskInput for TransientValue<T>
217where
218 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
219{
220 fn is_transient(&self) -> bool {
221 true
222 }
223}
224
225impl<T> Encode for TransientValue<T> {
226 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
227 Err(EncodeError::Other("cannot encode transient task inputs"))
228 }
229}
230
231impl<Context, T> Decode<Context> for TransientValue<T> {
232 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
233 Err(DecodeError::Other("cannot decode transient task inputs"))
234 }
235}
236
237impl<T> TaskInput for TransientInstance<T>
238where
239 T: Sync + Send + TraceRawVcs + 'static,
240{
241 fn is_transient(&self) -> bool {
242 true
243 }
244}
245
246impl<T> Encode for TransientInstance<T> {
247 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
248 Err(EncodeError::Other("cannot encode transient task inputs"))
249 }
250}
251
252impl<Context, T> Decode<Context> for TransientInstance<T> {
253 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
254 Err(DecodeError::Other("cannot decode transient task inputs"))
255 }
256}
257
258impl<K, V> TaskInput for BTreeMap<K, V>
259where
260 K: TaskInput + Ord,
261 V: TaskInput,
262{
263 async fn resolve_input(&self) -> Result<Self> {
264 let mut new_map = BTreeMap::new();
265 for (k, v) in self {
266 new_map.insert(
267 TaskInput::resolve_input(k).await?,
268 TaskInput::resolve_input(v).await?,
269 );
270 }
271 Ok(new_map)
272 }
273
274 fn is_resolved(&self) -> bool {
275 self.iter()
276 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
277 }
278
279 fn is_transient(&self) -> bool {
280 self.iter()
281 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
282 }
283}
284
285impl<T> TaskInput for BTreeSet<T>
286where
287 T: TaskInput + Ord,
288{
289 async fn resolve_input(&self) -> Result<Self> {
290 let mut new_set = BTreeSet::new();
291 for value in self {
292 new_set.insert(TaskInput::resolve_input(value).await?);
293 }
294 Ok(new_set)
295 }
296
297 fn is_resolved(&self) -> bool {
298 self.iter().all(TaskInput::is_resolved)
299 }
300
301 fn is_transient(&self) -> bool {
302 self.iter().any(TaskInput::is_transient)
303 }
304}
305
306impl<K, V> TaskInput for FrozenMap<K, V>
307where
308 K: TaskInput + Ord + 'static,
309 V: TaskInput + 'static,
310{
311 async fn resolve_input(&self) -> Result<Self> {
312 let mut new_entries = Vec::with_capacity(self.len());
313 for (k, v) in self {
314 new_entries.push((
315 TaskInput::resolve_input(k).await?,
316 TaskInput::resolve_input(v).await?,
317 ));
318 }
319 Ok(Self::from(new_entries))
321 }
322
323 fn is_resolved(&self) -> bool {
324 self.iter()
325 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
326 }
327
328 fn is_transient(&self) -> bool {
329 self.iter()
330 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
331 }
332}
333
334impl<T> TaskInput for FrozenSet<T>
335where
336 T: TaskInput + Ord + 'static,
337{
338 async fn resolve_input(&self) -> Result<Self> {
339 let mut new_set = Vec::with_capacity(self.len());
340 for value in self {
341 new_set.push(TaskInput::resolve_input(value).await?);
342 }
343 Ok(Self::from_iter(new_set))
344 }
345
346 fn is_resolved(&self) -> bool {
347 self.iter().all(TaskInput::is_resolved)
348 }
349
350 fn is_transient(&self) -> bool {
351 self.iter().any(TaskInput::is_transient)
352 }
353}
354
355#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
358pub struct EitherTaskInput<L, R>(pub Either<L, R>);
359
360impl<L, R> Deref for EitherTaskInput<L, R> {
361 type Target = Either<L, R>;
362
363 fn deref(&self) -> &Self::Target {
364 &self.0
365 }
366}
367
368impl<L, R> DerefMut for EitherTaskInput<L, R> {
369 fn deref_mut(&mut self) -> &mut Self::Target {
370 &mut self.0
371 }
372}
373
374impl<L, R> Encode for EitherTaskInput<L, R>
375where
376 L: Encode,
377 R: Encode,
378{
379 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
380 turbo_bincode::either::encode(self, encoder)
381 }
382}
383
384impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
385where
386 L: Decode<Context>,
387 R: Decode<Context>,
388{
389 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
390 turbo_bincode::either::decode(decoder).map(Self)
391 }
392}
393
394impl<L, R> TaskInput for EitherTaskInput<L, R>
395where
396 L: TaskInput,
397 R: TaskInput,
398{
399 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
400 self.as_ref().map_either(
401 |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
402 |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
403 )
404 }
405
406 fn is_resolved(&self) -> bool {
407 self.as_ref()
408 .either(TaskInput::is_resolved, TaskInput::is_resolved)
409 }
410
411 fn is_transient(&self) -> bool {
412 self.as_ref()
413 .either(TaskInput::is_transient, TaskInput::is_transient)
414 }
415}
416
417macro_rules! tuple_impls {
418 ( $( $name:ident )+ ) => {
419 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
420 where $($name: TaskInput),+
421 {
422 #[allow(non_snake_case)]
423 fn is_resolved(&self) -> bool {
424 let ($($name,)+) = self;
425 $($name.is_resolved() &&)+ true
426 }
427
428 #[allow(non_snake_case)]
429 fn is_transient(&self) -> bool {
430 let ($($name,)+) = self;
431 $($name.is_transient() ||)+ false
432 }
433
434 #[allow(non_snake_case)]
435 async fn resolve_input(&self) -> Result<Self> {
436 let ($($name,)+) = self;
437 Ok(($($name.resolve_input().await?,)+))
438 }
439 }
440 };
441}
442
443tuple_impls! { A }
445tuple_impls! { A B }
446tuple_impls! { A B C }
447tuple_impls! { A B C D }
448tuple_impls! { A B C D E }
449tuple_impls! { A B C D E F }
450tuple_impls! { A B C D E F G }
451tuple_impls! { A B C D E F G H }
452tuple_impls! { A B C D E F G H I }
453tuple_impls! { A B C D E F G H I J }
454tuple_impls! { A B C D E F G H I J K }
455tuple_impls! { A B C D E F G H I J K L }
456
457#[cfg(test)]
458mod tests {
459 use turbo_rcstr::rcstr;
460 use turbo_tasks_macros::TaskInput;
461
462 use super::*;
463
464 fn assert_task_input<T>(_: T)
465 where
466 T: TaskInput,
467 {
468 }
469
470 #[test]
471 fn test_no_fields() -> Result<()> {
472 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
473 struct NoFields;
474
475 assert_task_input(NoFields);
476 Ok(())
477 }
478
479 #[test]
480 fn test_one_unnamed_field() -> Result<()> {
481 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
482 struct OneUnnamedField(u32);
483
484 assert_task_input(OneUnnamedField(42));
485 Ok(())
486 }
487
488 #[test]
489 fn test_multiple_unnamed_fields() -> Result<()> {
490 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
491 struct MultipleUnnamedFields(u32, RcStr);
492
493 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
494 Ok(())
495 }
496
497 #[test]
498 fn test_one_named_field() -> Result<()> {
499 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
500 struct OneNamedField {
501 named: u32,
502 }
503
504 assert_task_input(OneNamedField { named: 42 });
505 Ok(())
506 }
507
508 #[test]
509 fn test_multiple_named_fields() -> Result<()> {
510 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
511 struct MultipleNamedFields {
512 named: u32,
513 other: RcStr,
514 }
515
516 assert_task_input(MultipleNamedFields {
517 named: 42,
518 other: rcstr!("42"),
519 });
520 Ok(())
521 }
522
523 #[test]
524 fn test_generic_field() -> Result<()> {
525 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
526 struct GenericField<T>(T);
527
528 assert_task_input(GenericField(42));
529 assert_task_input(GenericField(rcstr!("42")));
530 Ok(())
531 }
532
533 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
534 enum OneVariant {
535 Variant,
536 }
537
538 #[test]
539 fn test_one_variant() -> Result<()> {
540 assert_task_input(OneVariant::Variant);
541 Ok(())
542 }
543
544 #[test]
545 fn test_multiple_variants() -> Result<()> {
546 #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
547 enum MultipleVariants {
548 Variant1,
549 Variant2,
550 }
551
552 assert_task_input(MultipleVariants::Variant2);
553 Ok(())
554 }
555
556 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
557 enum MultipleVariantsAndHeterogeneousFields {
558 Variant1,
559 Variant2(u32),
560 Variant3 { named: u32 },
561 Variant4(u32, RcStr),
562 Variant5 { named: u32, other: RcStr },
563 }
564
565 #[test]
566 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
567 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
568 named: 42,
569 other: rcstr!("42"),
570 });
571 Ok(())
572 }
573
574 #[test]
575 fn test_nested_variants() -> Result<()> {
576 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
577 enum NestedVariants {
578 Variant1,
579 Variant2(MultipleVariantsAndHeterogeneousFields),
580 Variant3 { named: OneVariant },
581 Variant4(OneVariant, RcStr),
582 Variant5 { named: OneVariant, other: RcStr },
583 }
584
585 assert_task_input(NestedVariants::Variant5 {
586 named: OneVariant::Variant,
587 other: rcstr!("42"),
588 });
589 assert_task_input(NestedVariants::Variant2(
590 MultipleVariantsAndHeterogeneousFields::Variant5 {
591 named: 42,
592 other: rcstr!("42"),
593 },
594 ));
595 Ok(())
596 }
597}