1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 ops::{Deref, DerefMut},
7 sync::Arc,
8 time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13 Decode, Encode,
14 de::Decoder,
15 enc::Encoder,
16 error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21use turbo_tasks_hash::HashAlgorithm;
22
23use crate as turbo_tasks;
26use crate::{
27 DynTaskInputs, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
28 trace::TraceRawVcs,
29};
30
31pub trait TaskInput:
84 Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
85{
86 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
90 async { Ok(self.clone()) }
91 }
92
93 fn is_resolved(&self) -> bool {
105 true
106 }
107
108 fn is_transient(&self) -> bool;
118}
119
120macro_rules! impl_task_input {
121 ($($t:ty),*) => {
122 $(
123 impl TaskInput for $t {
124 fn is_transient(&self) -> bool {
125 false
126 }
127 }
128 )*
129 };
130}
131
132impl_task_input! {
133 (),
134 bool,
135 u8,
136 u16,
137 u32,
138 i32,
139 u64,
140 usize,
141 RcStr,
142 TaskId,
143 ValueTypeId,
144 Duration,
145 String,
146 HashAlgorithm
147}
148
149impl<T> TaskInput for Vec<T>
150where
151 T: TaskInput,
152{
153 fn is_resolved(&self) -> bool {
154 self.iter().all(TaskInput::is_resolved)
155 }
156
157 fn is_transient(&self) -> bool {
158 self.iter().any(TaskInput::is_transient)
159 }
160
161 async fn resolve_input(&self) -> Result<Self> {
162 let mut resolved = Vec::with_capacity(self.len());
163 for value in self {
164 resolved.push(value.resolve_input().await?);
165 }
166 Ok(resolved)
167 }
168}
169
170impl<T> TaskInput for Box<T>
171where
172 T: TaskInput,
173{
174 fn is_resolved(&self) -> bool {
175 self.as_ref().is_resolved()
176 }
177
178 fn is_transient(&self) -> bool {
179 self.as_ref().is_transient()
180 }
181
182 async fn resolve_input(&self) -> Result<Self> {
183 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
184 }
185}
186
187impl<T> TaskInput for Arc<T>
188where
189 T: TaskInput,
190{
191 fn is_resolved(&self) -> bool {
192 self.as_ref().is_resolved()
193 }
194
195 fn is_transient(&self) -> bool {
196 self.as_ref().is_transient()
197 }
198
199 async fn resolve_input(&self) -> Result<Self> {
200 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
201 }
202}
203
204impl<T> TaskInput for ReadRef<T>
205where
206 T: TaskInput,
207{
208 fn is_resolved(&self) -> bool {
209 Self::as_raw_ref(self).is_resolved()
210 }
211
212 fn is_transient(&self) -> bool {
213 Self::as_raw_ref(self).is_transient()
214 }
215
216 async fn resolve_input(&self) -> Result<Self> {
217 Ok(ReadRef::new_owned(
218 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
219 ))
220 }
221}
222
223impl<T> TaskInput for Option<T>
224where
225 T: TaskInput,
226{
227 fn is_resolved(&self) -> bool {
228 match self {
229 Some(value) => value.is_resolved(),
230 None => true,
231 }
232 }
233
234 fn is_transient(&self) -> bool {
235 match self {
236 Some(value) => value.is_transient(),
237 None => false,
238 }
239 }
240
241 async fn resolve_input(&self) -> Result<Self> {
242 match self {
243 Some(value) => Ok(Some(value.resolve_input().await?)),
244 None => Ok(None),
245 }
246 }
247}
248
249impl<T> TaskInput for Vc<T>
250where
251 T: Send + Sync + ?Sized,
252{
253 fn is_resolved(&self) -> bool {
254 Vc::is_resolved(*self)
255 }
256
257 fn is_transient(&self) -> bool {
258 self.node.is_transient()
259 }
260
261 async fn resolve_input(&self) -> Result<Self> {
262 Ok(*(*self).to_resolved().await?)
263 }
264}
265
266impl<T> TaskInput for ResolvedVc<T>
269where
270 T: Send + Sync + ?Sized,
271{
272 fn is_resolved(&self) -> bool {
273 true
274 }
275
276 fn is_transient(&self) -> bool {
277 self.node.is_transient()
278 }
279
280 async fn resolve_input(&self) -> Result<Self> {
281 Ok(*self)
282 }
283}
284
285impl<T> TaskInput for TransientValue<T>
286where
287 T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
288{
289 fn is_transient(&self) -> bool {
290 true
291 }
292}
293
294impl<T> Encode for TransientValue<T> {
295 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
296 Err(EncodeError::Other("cannot encode transient task inputs"))
297 }
298}
299
300impl<Context, T> Decode<Context> for TransientValue<T> {
301 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
302 Err(DecodeError::Other("cannot decode transient task inputs"))
303 }
304}
305
306impl<T> TaskInput for TransientInstance<T>
307where
308 T: Sync + Send + TraceRawVcs + 'static,
309{
310 fn is_transient(&self) -> bool {
311 true
312 }
313}
314
315impl<T> Encode for TransientInstance<T> {
316 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
317 Err(EncodeError::Other("cannot encode transient task inputs"))
318 }
319}
320
321impl<Context, T> Decode<Context> for TransientInstance<T> {
322 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
323 Err(DecodeError::Other("cannot decode transient task inputs"))
324 }
325}
326
327impl<K, V> TaskInput for BTreeMap<K, V>
328where
329 K: TaskInput + Ord,
330 V: TaskInput,
331{
332 async fn resolve_input(&self) -> Result<Self> {
333 let mut new_map = BTreeMap::new();
334 for (k, v) in self {
335 new_map.insert(
336 TaskInput::resolve_input(k).await?,
337 TaskInput::resolve_input(v).await?,
338 );
339 }
340 Ok(new_map)
341 }
342
343 fn is_resolved(&self) -> bool {
344 self.iter()
345 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
346 }
347
348 fn is_transient(&self) -> bool {
349 self.iter()
350 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
351 }
352}
353
354impl<T> TaskInput for BTreeSet<T>
355where
356 T: TaskInput + Ord,
357{
358 async fn resolve_input(&self) -> Result<Self> {
359 let mut new_set = BTreeSet::new();
360 for value in self {
361 new_set.insert(TaskInput::resolve_input(value).await?);
362 }
363 Ok(new_set)
364 }
365
366 fn is_resolved(&self) -> bool {
367 self.iter().all(TaskInput::is_resolved)
368 }
369
370 fn is_transient(&self) -> bool {
371 self.iter().any(TaskInput::is_transient)
372 }
373}
374
375impl<K, V> TaskInput for FrozenMap<K, V>
376where
377 K: TaskInput + Ord + 'static,
378 V: TaskInput + 'static,
379{
380 async fn resolve_input(&self) -> Result<Self> {
381 let mut new_entries = Vec::with_capacity(self.len());
382 for (k, v) in self {
383 new_entries.push((
384 TaskInput::resolve_input(k).await?,
385 TaskInput::resolve_input(v).await?,
386 ));
387 }
388 Ok(Self::from(new_entries))
390 }
391
392 fn is_resolved(&self) -> bool {
393 self.iter()
394 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
395 }
396
397 fn is_transient(&self) -> bool {
398 self.iter()
399 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
400 }
401}
402
403impl<T> TaskInput for FrozenSet<T>
404where
405 T: TaskInput + Ord + 'static,
406{
407 async fn resolve_input(&self) -> Result<Self> {
408 let mut new_set = Vec::with_capacity(self.len());
409 for value in self {
410 new_set.push(TaskInput::resolve_input(value).await?);
411 }
412 Ok(Self::from_iter(new_set))
413 }
414
415 fn is_resolved(&self) -> bool {
416 self.iter().all(TaskInput::is_resolved)
417 }
418
419 fn is_transient(&self) -> bool {
420 self.iter().any(TaskInput::is_transient)
421 }
422}
423
424#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
427pub struct EitherTaskInput<L, R>(pub Either<L, R>);
428
429impl<L, R> Deref for EitherTaskInput<L, R> {
430 type Target = Either<L, R>;
431
432 fn deref(&self) -> &Self::Target {
433 &self.0
434 }
435}
436
437impl<L, R> DerefMut for EitherTaskInput<L, R> {
438 fn deref_mut(&mut self) -> &mut Self::Target {
439 &mut self.0
440 }
441}
442
443impl<L, R> Encode for EitherTaskInput<L, R>
444where
445 L: Encode,
446 R: Encode,
447{
448 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
449 turbo_bincode::either::encode(self, encoder)
450 }
451}
452
453impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
454where
455 L: Decode<Context>,
456 R: Decode<Context>,
457{
458 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
459 turbo_bincode::either::decode(decoder).map(Self)
460 }
461}
462
463impl<L, R> TaskInput for EitherTaskInput<L, R>
464where
465 L: TaskInput,
466 R: TaskInput,
467{
468 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
469 self.as_ref().map_either(
470 |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
471 |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
472 )
473 }
474
475 fn is_resolved(&self) -> bool {
476 self.as_ref()
477 .either(TaskInput::is_resolved, TaskInput::is_resolved)
478 }
479
480 fn is_transient(&self) -> bool {
481 self.as_ref()
482 .either(TaskInput::is_transient, TaskInput::is_transient)
483 }
484}
485
486macro_rules! tuple_impls {
487 ( $( $name:ident )+ ) => {
488 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
489 where $($name: TaskInput),+
490 {
491 #[allow(non_snake_case)]
492 fn is_resolved(&self) -> bool {
493 let ($($name,)+) = self;
494 $($name.is_resolved() &&)+ true
495 }
496
497 #[allow(non_snake_case)]
498 fn is_transient(&self) -> bool {
499 let ($($name,)+) = self;
500 $($name.is_transient() ||)+ false
501 }
502
503 #[allow(non_snake_case)]
504 async fn resolve_input(&self) -> Result<Self> {
505 let ($($name,)+) = self;
506 Ok(($($name.resolve_input().await?,)+))
507 }
508 }
509 };
510}
511
512tuple_impls! { A }
514tuple_impls! { A B }
515tuple_impls! { A B C }
516tuple_impls! { A B C D }
517tuple_impls! { A B C D E }
518tuple_impls! { A B C D E F }
519tuple_impls! { A B C D E F G }
520tuple_impls! { A B C D E F G H }
521tuple_impls! { A B C D E F G H I }
522tuple_impls! { A B C D E F G H I J }
523tuple_impls! { A B C D E F G H I J K }
524tuple_impls! { A B C D E F G H I J K L }
525
526#[cfg(test)]
527mod tests {
528 use turbo_rcstr::rcstr;
529 use turbo_tasks_macros::TaskInput;
530
531 use super::*;
532
533 fn assert_task_input<T>(_: T)
534 where
535 T: TaskInput,
536 {
537 }
538
539 #[test]
540 fn test_no_fields() -> Result<()> {
541 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
542 struct NoFields;
543
544 assert_task_input(NoFields);
545 Ok(())
546 }
547
548 #[test]
549 fn test_one_unnamed_field() -> Result<()> {
550 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
551 struct OneUnnamedField(u32);
552
553 assert_task_input(OneUnnamedField(42));
554 Ok(())
555 }
556
557 #[test]
558 fn test_multiple_unnamed_fields() -> Result<()> {
559 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
560 struct MultipleUnnamedFields(u32, RcStr);
561
562 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
563 Ok(())
564 }
565
566 #[test]
567 fn test_one_named_field() -> Result<()> {
568 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
569 struct OneNamedField {
570 named: u32,
571 }
572
573 assert_task_input(OneNamedField { named: 42 });
574 Ok(())
575 }
576
577 #[test]
578 fn test_multiple_named_fields() -> Result<()> {
579 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
580 struct MultipleNamedFields {
581 named: u32,
582 other: RcStr,
583 }
584
585 assert_task_input(MultipleNamedFields {
586 named: 42,
587 other: rcstr!("42"),
588 });
589 Ok(())
590 }
591
592 #[test]
593 fn test_generic_field() -> Result<()> {
594 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
595 struct GenericField<T>(T);
596
597 assert_task_input(GenericField(42));
598 assert_task_input(GenericField(rcstr!("42")));
599 Ok(())
600 }
601
602 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
603 enum OneVariant {
604 Variant,
605 }
606
607 #[test]
608 fn test_one_variant() -> Result<()> {
609 assert_task_input(OneVariant::Variant);
610 Ok(())
611 }
612
613 #[test]
614 fn test_multiple_variants() -> Result<()> {
615 #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
616 enum MultipleVariants {
617 Variant1,
618 Variant2,
619 }
620
621 assert_task_input(MultipleVariants::Variant2);
622 Ok(())
623 }
624
625 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
626 enum MultipleVariantsAndHeterogeneousFields {
627 Variant1,
628 Variant2(u32),
629 Variant3 { named: u32 },
630 Variant4(u32, RcStr),
631 Variant5 { named: u32, other: RcStr },
632 }
633
634 #[test]
635 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
636 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
637 named: 42,
638 other: rcstr!("42"),
639 });
640 Ok(())
641 }
642
643 #[test]
644 fn test_nested_variants() -> Result<()> {
645 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
646 enum NestedVariants {
647 Variant1,
648 Variant2(MultipleVariantsAndHeterogeneousFields),
649 Variant3 { named: OneVariant },
650 Variant4(OneVariant, RcStr),
651 Variant5 { named: OneVariant, other: RcStr },
652 }
653
654 assert_task_input(NestedVariants::Variant5 {
655 named: OneVariant::Variant,
656 other: rcstr!("42"),
657 });
658 assert_task_input(NestedVariants::Variant2(
659 MultipleVariantsAndHeterogeneousFields::Variant5 {
660 named: 42,
661 other: rcstr!("42"),
662 },
663 ));
664 Ok(())
665 }
666}