1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 ops::{Deref, DerefMut},
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use anyhow::Result;
14use bincode::{
15 Decode, Encode,
16 de::Decoder,
17 enc::Encoder,
18 error::{DecodeError, EncodeError},
19};
20use either::Either;
21use turbo_frozenmap::{FrozenMap, FrozenSet};
22use turbo_rcstr::RcStr;
23use turbo_tasks_hash::HashAlgorithm;
24
25use crate::{self as turbo_tasks, ReadRef};
28use crate::{
29 DynTaskInputs, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
30 trace::TraceRawVcs,
31};
32
33struct CloneReady<'a, T> {
38 pub inner: Option<&'a T>,
39}
40
41impl<'a, T: Clone> Future for CloneReady<'a, T> {
42 type Output = Result<T>;
43
44 fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
45 Poll::Ready(Ok(self
46 .inner
47 .take()
48 .expect("future already polled to completion")
49 .clone()))
50 }
51}
52
53impl<'a, T> Unpin for CloneReady<'a, T> {}
55
56pub trait TaskInput:
109 Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
110{
111 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
115 CloneReady { inner: Some(self) }
116 }
117
118 fn is_resolved(&self) -> bool {
130 true
131 }
132
133 fn is_transient(&self) -> bool;
143}
144
145macro_rules! impl_task_input {
146 ($($t:ty),*) => {
147 $(
148 impl TaskInput for $t {
149 fn is_transient(&self) -> bool {
150 false
151 }
152 }
153 )*
154 };
155}
156
157impl_task_input! {
158 (),
159 bool,
160 u8,
161 u16,
162 u32,
163 i32,
164 u64,
165 usize,
166 RcStr,
167 TaskId,
168 ValueTypeId,
169 Duration,
170 String,
171 HashAlgorithm
172}
173
174impl<T> TaskInput for Vec<T>
175where
176 T: TaskInput,
177{
178 fn is_resolved(&self) -> bool {
179 self.iter().all(TaskInput::is_resolved)
180 }
181
182 fn is_transient(&self) -> bool {
183 self.iter().any(TaskInput::is_transient)
184 }
185
186 async fn resolve_input(&self) -> Result<Self> {
187 let mut resolved = Vec::with_capacity(self.len());
188 for value in self {
189 resolved.push(value.resolve_input().await?);
190 }
191 Ok(resolved)
192 }
193}
194
195impl<T> TaskInput for Box<T>
196where
197 T: TaskInput,
198{
199 fn is_resolved(&self) -> bool {
200 self.as_ref().is_resolved()
201 }
202
203 fn is_transient(&self) -> bool {
204 self.as_ref().is_transient()
205 }
206
207 async fn resolve_input(&self) -> Result<Self> {
208 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
209 }
210}
211
212impl<T> TaskInput for Arc<T>
213where
214 T: TaskInput,
215{
216 fn is_resolved(&self) -> bool {
217 self.as_ref().is_resolved()
218 }
219
220 fn is_transient(&self) -> bool {
221 self.as_ref().is_transient()
222 }
223
224 async fn resolve_input(&self) -> Result<Self> {
225 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
226 }
227}
228
229impl<T> TaskInput for ReadRef<T>
230where
231 T: TaskInput,
232{
233 fn is_resolved(&self) -> bool {
234 Self::as_raw_ref(self).is_resolved()
235 }
236
237 fn is_transient(&self) -> bool {
238 Self::as_raw_ref(self).is_transient()
239 }
240
241 async fn resolve_input(&self) -> Result<Self> {
242 Ok(ReadRef::new_owned(
243 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
244 ))
245 }
246}
247
248impl<T> TaskInput for Option<T>
249where
250 T: TaskInput,
251{
252 fn is_resolved(&self) -> bool {
253 match self {
254 Some(value) => value.is_resolved(),
255 None => true,
256 }
257 }
258
259 fn is_transient(&self) -> bool {
260 match self {
261 Some(value) => value.is_transient(),
262 None => false,
263 }
264 }
265
266 async fn resolve_input(&self) -> Result<Self> {
267 match self {
268 Some(value) => Ok(Some(value.resolve_input().await?)),
269 None => Ok(None),
270 }
271 }
272}
273
274impl<T> TaskInput for Vc<T>
275where
276 T: Send + Sync + ?Sized,
277{
278 fn is_resolved(&self) -> bool {
279 Vc::is_resolved(*self)
280 }
281
282 fn is_transient(&self) -> bool {
283 self.node.is_transient()
284 }
285
286 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
287 (*self).resolve()
290 }
291}
292
293impl<T> TaskInput for ResolvedVc<T>
296where
297 T: Send + Sync + ?Sized,
298{
299 fn is_resolved(&self) -> bool {
300 true
301 }
302
303 fn is_transient(&self) -> bool {
304 self.node.is_transient()
305 }
306}
307
308impl<T> TaskInput for TransientValue<T>
309where
310 T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
311{
312 fn is_transient(&self) -> bool {
313 true
314 }
315}
316
317impl<T> Encode for TransientValue<T> {
318 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
319 Err(EncodeError::Other("cannot encode transient task inputs"))
320 }
321}
322
323impl<Context, T> Decode<Context> for TransientValue<T> {
324 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
325 Err(DecodeError::Other("cannot decode transient task inputs"))
326 }
327}
328
329impl<T> TaskInput for TransientInstance<T>
330where
331 T: Sync + Send + TraceRawVcs + 'static,
332{
333 fn is_transient(&self) -> bool {
334 true
335 }
336}
337
338impl<T> Encode for TransientInstance<T> {
339 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
340 Err(EncodeError::Other("cannot encode transient task inputs"))
341 }
342}
343
344impl<Context, T> Decode<Context> for TransientInstance<T> {
345 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
346 Err(DecodeError::Other("cannot decode transient task inputs"))
347 }
348}
349
350impl<K, V> TaskInput for BTreeMap<K, V>
351where
352 K: TaskInput + Ord,
353 V: TaskInput,
354{
355 async fn resolve_input(&self) -> Result<Self> {
356 let mut new_map = BTreeMap::new();
357 for (k, v) in self {
358 new_map.insert(
359 TaskInput::resolve_input(k).await?,
360 TaskInput::resolve_input(v).await?,
361 );
362 }
363 Ok(new_map)
364 }
365
366 fn is_resolved(&self) -> bool {
367 self.iter()
368 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
369 }
370
371 fn is_transient(&self) -> bool {
372 self.iter()
373 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
374 }
375}
376
377impl<T> TaskInput for BTreeSet<T>
378where
379 T: TaskInput + Ord,
380{
381 async fn resolve_input(&self) -> Result<Self> {
382 let mut new_set = BTreeSet::new();
383 for value in self {
384 new_set.insert(TaskInput::resolve_input(value).await?);
385 }
386 Ok(new_set)
387 }
388
389 fn is_resolved(&self) -> bool {
390 self.iter().all(TaskInput::is_resolved)
391 }
392
393 fn is_transient(&self) -> bool {
394 self.iter().any(TaskInput::is_transient)
395 }
396}
397
398impl<K, V> TaskInput for FrozenMap<K, V>
399where
400 K: TaskInput + Ord + 'static,
401 V: TaskInput + 'static,
402{
403 async fn resolve_input(&self) -> Result<Self> {
404 let mut new_entries = Vec::with_capacity(self.len());
405 for (k, v) in self {
406 new_entries.push((
407 TaskInput::resolve_input(k).await?,
408 TaskInput::resolve_input(v).await?,
409 ));
410 }
411 Ok(Self::from(new_entries))
413 }
414
415 fn is_resolved(&self) -> bool {
416 self.iter()
417 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
418 }
419
420 fn is_transient(&self) -> bool {
421 self.iter()
422 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
423 }
424}
425
426impl<T> TaskInput for FrozenSet<T>
427where
428 T: TaskInput + Ord + 'static,
429{
430 async fn resolve_input(&self) -> Result<Self> {
431 let mut new_set = Vec::with_capacity(self.len());
432 for value in self {
433 new_set.push(TaskInput::resolve_input(value).await?);
434 }
435 Ok(Self::from_iter(new_set))
436 }
437
438 fn is_resolved(&self) -> bool {
439 self.iter().all(TaskInput::is_resolved)
440 }
441
442 fn is_transient(&self) -> bool {
443 self.iter().any(TaskInput::is_transient)
444 }
445}
446
447#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
450pub struct EitherTaskInput<L, R>(pub Either<L, R>);
451
452impl<L, R> Deref for EitherTaskInput<L, R> {
453 type Target = Either<L, R>;
454
455 fn deref(&self) -> &Self::Target {
456 &self.0
457 }
458}
459
460impl<L, R> DerefMut for EitherTaskInput<L, R> {
461 fn deref_mut(&mut self) -> &mut Self::Target {
462 &mut self.0
463 }
464}
465
466impl<L, R> Encode for EitherTaskInput<L, R>
467where
468 L: Encode,
469 R: Encode,
470{
471 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
472 turbo_bincode::either::encode(self, encoder)
473 }
474}
475
476impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
477where
478 L: Decode<Context>,
479 R: Decode<Context>,
480{
481 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
482 turbo_bincode::either::decode(decoder).map(Self)
483 }
484}
485
486impl<L, R> TaskInput for EitherTaskInput<L, R>
487where
488 L: TaskInput,
489 R: TaskInput,
490{
491 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
492 self.as_ref().map_either(
493 |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
494 |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
495 )
496 }
497
498 fn is_resolved(&self) -> bool {
499 self.as_ref()
500 .either(TaskInput::is_resolved, TaskInput::is_resolved)
501 }
502
503 fn is_transient(&self) -> bool {
504 self.as_ref()
505 .either(TaskInput::is_transient, TaskInput::is_transient)
506 }
507}
508
509macro_rules! tuple_impls {
510 ( $( $name:ident )+ ) => {
511 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
512 where $($name: TaskInput),+
513 {
514 #[allow(non_snake_case)]
515 fn is_resolved(&self) -> bool {
516 let ($($name,)+) = self;
517 $($name.is_resolved() &&)+ true
518 }
519
520 #[allow(non_snake_case)]
521 fn is_transient(&self) -> bool {
522 let ($($name,)+) = self;
523 $($name.is_transient() ||)+ false
524 }
525
526 #[allow(non_snake_case)]
527 async fn resolve_input(&self) -> Result<Self> {
528 let ($($name,)+) = self;
529 Ok(($($name.resolve_input().await?,)+))
530 }
531 }
532 };
533}
534
535tuple_impls! { A }
537tuple_impls! { A B }
538tuple_impls! { A B C }
539tuple_impls! { A B C D }
540tuple_impls! { A B C D E }
541tuple_impls! { A B C D E F }
542tuple_impls! { A B C D E F G }
543tuple_impls! { A B C D E F G H }
544tuple_impls! { A B C D E F G H I }
545tuple_impls! { A B C D E F G H I J }
546tuple_impls! { A B C D E F G H I J K }
547tuple_impls! { A B C D E F G H I J K L }
548
549#[cfg(test)]
550mod tests {
551 use turbo_rcstr::rcstr;
552
553 use super::*;
554
555 fn assert_task_input<T>(_: T)
556 where
557 T: TaskInput,
558 {
559 }
560
561 #[test]
562 fn test_no_fields() -> Result<()> {
563 #[turbo_tasks::task_input]
564 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
565 struct NoFields;
566
567 assert_task_input(NoFields);
568 Ok(())
569 }
570
571 #[test]
572 fn test_one_unnamed_field() -> Result<()> {
573 #[turbo_tasks::task_input]
574 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
575 struct OneUnnamedField(u32);
576
577 assert_task_input(OneUnnamedField(42));
578 Ok(())
579 }
580
581 #[test]
582 fn test_multiple_unnamed_fields() -> Result<()> {
583 #[turbo_tasks::task_input]
584 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
585 struct MultipleUnnamedFields(u32, RcStr);
586
587 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
588 Ok(())
589 }
590
591 #[test]
592 fn test_one_named_field() -> Result<()> {
593 #[turbo_tasks::task_input]
594 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
595 struct OneNamedField {
596 named: u32,
597 }
598
599 assert_task_input(OneNamedField { named: 42 });
600 Ok(())
601 }
602
603 #[test]
604 fn test_multiple_named_fields() -> Result<()> {
605 #[turbo_tasks::task_input]
606 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
607 struct MultipleNamedFields {
608 named: u32,
609 other: RcStr,
610 }
611
612 assert_task_input(MultipleNamedFields {
613 named: 42,
614 other: rcstr!("42"),
615 });
616 Ok(())
617 }
618
619 #[test]
620 fn test_generic_field() -> Result<()> {
621 #[turbo_tasks::task_input]
622 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
623 struct GenericField<T>(T);
624
625 assert_task_input(GenericField(42));
626 assert_task_input(GenericField(rcstr!("42")));
627 Ok(())
628 }
629
630 #[turbo_tasks::task_input]
631 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
632 enum OneVariant {
633 Variant,
634 }
635
636 #[test]
637 fn test_one_variant() -> Result<()> {
638 assert_task_input(OneVariant::Variant);
639 Ok(())
640 }
641
642 #[test]
643 fn test_multiple_variants() -> Result<()> {
644 #[turbo_tasks::task_input]
645 #[derive(Clone, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
646 enum MultipleVariants {
647 Variant1,
648 Variant2,
649 }
650
651 assert_task_input(MultipleVariants::Variant2);
652 Ok(())
653 }
654
655 #[turbo_tasks::task_input]
656 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
657 enum MultipleVariantsAndHeterogeneousFields {
658 Variant1,
659 Variant2(u32),
660 Variant3 { named: u32 },
661 Variant4(u32, RcStr),
662 Variant5 { named: u32, other: RcStr },
663 }
664
665 #[test]
666 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
667 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
668 named: 42,
669 other: rcstr!("42"),
670 });
671 Ok(())
672 }
673
674 #[test]
675 fn test_nested_variants() -> Result<()> {
676 #[turbo_tasks::task_input]
677 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
678 enum NestedVariants {
679 Variant1,
680 Variant2(MultipleVariantsAndHeterogeneousFields),
681 Variant3 { named: OneVariant },
682 Variant4(OneVariant, RcStr),
683 Variant5 { named: OneVariant, other: RcStr },
684 }
685
686 assert_task_input(NestedVariants::Variant5 {
687 named: OneVariant::Variant,
688 other: rcstr!("42"),
689 });
690 assert_task_input(NestedVariants::Variant2(
691 MultipleVariantsAndHeterogeneousFields::Variant5 {
692 named: 42,
693 other: rcstr!("42"),
694 },
695 ));
696 Ok(())
697 }
698}