1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 ops::{Deref, DerefMut},
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use anyhow::Result;
14use bincode::{
15 Decode, Encode,
16 de::Decoder,
17 enc::Encoder,
18 error::{DecodeError, EncodeError},
19};
20use either::Either;
21use turbo_frozenmap::{FrozenMap, FrozenSet};
22use turbo_rcstr::RcStr;
23use turbo_tasks_hash::HashAlgorithm;
24
25use crate::{self as turbo_tasks, ReadRef};
28use crate::{
29 DynTaskInputs, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
30 trace::TraceRawVcs,
31};
32
33struct CloneReady<'a, T> {
38 pub inner: Option<&'a T>,
39}
40
41impl<'a, T: Clone> Future for CloneReady<'a, T> {
42 type Output = Result<T>;
43
44 fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
45 Poll::Ready(Ok(self
46 .inner
47 .take()
48 .expect("future already polled to completion")
49 .clone()))
50 }
51}
52
53impl<'a, T> Unpin for CloneReady<'a, T> {}
55
56pub trait TaskInput:
109 Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
110{
111 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
115 CloneReady { inner: Some(self) }
116 }
117
118 fn is_resolved(&self) -> bool {
130 true
131 }
132
133 fn is_transient(&self) -> bool;
143}
144
145macro_rules! impl_task_input {
146 ($($t:ty),*) => {
147 $(
148 impl TaskInput for $t {
149 fn is_transient(&self) -> bool {
150 false
151 }
152 }
153 )*
154 };
155}
156
157impl_task_input! {
158 (),
159 bool,
160 u8,
161 u16,
162 u32,
163 i32,
164 u64,
165 u128,
166 usize,
167 RcStr,
168 TaskId,
169 ValueTypeId,
170 Duration,
171 String,
172 HashAlgorithm
173}
174
175impl<T> TaskInput for Vec<T>
176where
177 T: TaskInput,
178{
179 fn is_resolved(&self) -> bool {
180 self.iter().all(TaskInput::is_resolved)
181 }
182
183 fn is_transient(&self) -> bool {
184 self.iter().any(TaskInput::is_transient)
185 }
186
187 async fn resolve_input(&self) -> Result<Self> {
188 let mut resolved = Vec::with_capacity(self.len());
189 for value in self {
190 resolved.push(value.resolve_input().await?);
191 }
192 Ok(resolved)
193 }
194}
195
196impl<T> TaskInput for Box<T>
197where
198 T: TaskInput,
199{
200 fn is_resolved(&self) -> bool {
201 self.as_ref().is_resolved()
202 }
203
204 fn is_transient(&self) -> bool {
205 self.as_ref().is_transient()
206 }
207
208 async fn resolve_input(&self) -> Result<Self> {
209 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
210 }
211}
212
213impl<T> TaskInput for Arc<T>
214where
215 T: TaskInput,
216{
217 fn is_resolved(&self) -> bool {
218 self.as_ref().is_resolved()
219 }
220
221 fn is_transient(&self) -> bool {
222 self.as_ref().is_transient()
223 }
224
225 async fn resolve_input(&self) -> Result<Self> {
226 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
227 }
228}
229
230impl<T> TaskInput for ReadRef<T>
231where
232 T: TaskInput,
233{
234 fn is_resolved(&self) -> bool {
235 Self::as_raw_ref(self).is_resolved()
236 }
237
238 fn is_transient(&self) -> bool {
239 Self::as_raw_ref(self).is_transient()
240 }
241
242 async fn resolve_input(&self) -> Result<Self> {
243 Ok(ReadRef::new_owned(
244 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
245 ))
246 }
247}
248
249impl<T> TaskInput for Option<T>
250where
251 T: TaskInput,
252{
253 fn is_resolved(&self) -> bool {
254 match self {
255 Some(value) => value.is_resolved(),
256 None => true,
257 }
258 }
259
260 fn is_transient(&self) -> bool {
261 match self {
262 Some(value) => value.is_transient(),
263 None => false,
264 }
265 }
266
267 async fn resolve_input(&self) -> Result<Self> {
268 match self {
269 Some(value) => Ok(Some(value.resolve_input().await?)),
270 None => Ok(None),
271 }
272 }
273}
274
275impl<T> TaskInput for Vc<T>
276where
277 T: Send + Sync + ?Sized,
278{
279 fn is_resolved(&self) -> bool {
280 Vc::is_resolved(*self)
281 }
282
283 fn is_transient(&self) -> bool {
284 self.node.is_transient()
285 }
286
287 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
288 (*self).resolve()
291 }
292}
293
294impl<T> TaskInput for ResolvedVc<T>
297where
298 T: Send + Sync + ?Sized,
299{
300 fn is_resolved(&self) -> bool {
301 true
302 }
303
304 fn is_transient(&self) -> bool {
305 self.node.is_transient()
306 }
307}
308
309impl<T> TaskInput for TransientValue<T>
310where
311 T: DynTaskInputs + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
312{
313 fn is_transient(&self) -> bool {
314 true
315 }
316}
317
318impl<T> Encode for TransientValue<T> {
319 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
320 Err(EncodeError::Other("cannot encode transient task inputs"))
321 }
322}
323
324impl<Context, T> Decode<Context> for TransientValue<T> {
325 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
326 Err(DecodeError::Other("cannot decode transient task inputs"))
327 }
328}
329
330impl<T> TaskInput for TransientInstance<T>
331where
332 T: Sync + Send + TraceRawVcs + 'static,
333{
334 fn is_transient(&self) -> bool {
335 true
336 }
337}
338
339impl<T> Encode for TransientInstance<T> {
340 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
341 Err(EncodeError::Other("cannot encode transient task inputs"))
342 }
343}
344
345impl<Context, T> Decode<Context> for TransientInstance<T> {
346 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
347 Err(DecodeError::Other("cannot decode transient task inputs"))
348 }
349}
350
351impl<K, V> TaskInput for BTreeMap<K, V>
352where
353 K: TaskInput + Ord,
354 V: TaskInput,
355{
356 async fn resolve_input(&self) -> Result<Self> {
357 let mut new_map = BTreeMap::new();
358 for (k, v) in self {
359 new_map.insert(
360 TaskInput::resolve_input(k).await?,
361 TaskInput::resolve_input(v).await?,
362 );
363 }
364 Ok(new_map)
365 }
366
367 fn is_resolved(&self) -> bool {
368 self.iter()
369 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
370 }
371
372 fn is_transient(&self) -> bool {
373 self.iter()
374 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
375 }
376}
377
378impl<T> TaskInput for BTreeSet<T>
379where
380 T: TaskInput + Ord,
381{
382 async fn resolve_input(&self) -> Result<Self> {
383 let mut new_set = BTreeSet::new();
384 for value in self {
385 new_set.insert(TaskInput::resolve_input(value).await?);
386 }
387 Ok(new_set)
388 }
389
390 fn is_resolved(&self) -> bool {
391 self.iter().all(TaskInput::is_resolved)
392 }
393
394 fn is_transient(&self) -> bool {
395 self.iter().any(TaskInput::is_transient)
396 }
397}
398
399impl<K, V> TaskInput for FrozenMap<K, V>
400where
401 K: TaskInput + Ord + 'static,
402 V: TaskInput + 'static,
403{
404 async fn resolve_input(&self) -> Result<Self> {
405 let mut new_entries = Vec::with_capacity(self.len());
406 for (k, v) in self {
407 new_entries.push((
408 TaskInput::resolve_input(k).await?,
409 TaskInput::resolve_input(v).await?,
410 ));
411 }
412 Ok(Self::from(new_entries))
414 }
415
416 fn is_resolved(&self) -> bool {
417 self.iter()
418 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
419 }
420
421 fn is_transient(&self) -> bool {
422 self.iter()
423 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
424 }
425}
426
427impl<T> TaskInput for FrozenSet<T>
428where
429 T: TaskInput + Ord + 'static,
430{
431 async fn resolve_input(&self) -> Result<Self> {
432 let mut new_set = Vec::with_capacity(self.len());
433 for value in self {
434 new_set.push(TaskInput::resolve_input(value).await?);
435 }
436 Ok(Self::from_iter(new_set))
437 }
438
439 fn is_resolved(&self) -> bool {
440 self.iter().all(TaskInput::is_resolved)
441 }
442
443 fn is_transient(&self) -> bool {
444 self.iter().any(TaskInput::is_transient)
445 }
446}
447
448#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
451pub struct EitherTaskInput<L, R>(pub Either<L, R>);
452
453impl<L, R> Deref for EitherTaskInput<L, R> {
454 type Target = Either<L, R>;
455
456 fn deref(&self) -> &Self::Target {
457 &self.0
458 }
459}
460
461impl<L, R> DerefMut for EitherTaskInput<L, R> {
462 fn deref_mut(&mut self) -> &mut Self::Target {
463 &mut self.0
464 }
465}
466
467impl<L, R> Encode for EitherTaskInput<L, R>
468where
469 L: Encode,
470 R: Encode,
471{
472 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
473 turbo_bincode::either::encode(self, encoder)
474 }
475}
476
477impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
478where
479 L: Decode<Context>,
480 R: Decode<Context>,
481{
482 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
483 turbo_bincode::either::decode(decoder).map(Self)
484 }
485}
486
487impl<L, R> TaskInput for EitherTaskInput<L, R>
488where
489 L: TaskInput,
490 R: TaskInput,
491{
492 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
493 self.as_ref().map_either(
494 |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
495 |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
496 )
497 }
498
499 fn is_resolved(&self) -> bool {
500 self.as_ref()
501 .either(TaskInput::is_resolved, TaskInput::is_resolved)
502 }
503
504 fn is_transient(&self) -> bool {
505 self.as_ref()
506 .either(TaskInput::is_transient, TaskInput::is_transient)
507 }
508}
509
510macro_rules! tuple_impls {
511 ( $( $name:ident )+ ) => {
512 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
513 where $($name: TaskInput),+
514 {
515 #[allow(non_snake_case)]
516 fn is_resolved(&self) -> bool {
517 let ($($name,)+) = self;
518 $($name.is_resolved() &&)+ true
519 }
520
521 #[allow(non_snake_case)]
522 fn is_transient(&self) -> bool {
523 let ($($name,)+) = self;
524 $($name.is_transient() ||)+ false
525 }
526
527 #[allow(non_snake_case)]
528 async fn resolve_input(&self) -> Result<Self> {
529 let ($($name,)+) = self;
530 Ok(($($name.resolve_input().await?,)+))
531 }
532 }
533 };
534}
535
536tuple_impls! { A }
538tuple_impls! { A B }
539tuple_impls! { A B C }
540tuple_impls! { A B C D }
541tuple_impls! { A B C D E }
542tuple_impls! { A B C D E F }
543tuple_impls! { A B C D E F G }
544tuple_impls! { A B C D E F G H }
545tuple_impls! { A B C D E F G H I }
546tuple_impls! { A B C D E F G H I J }
547tuple_impls! { A B C D E F G H I J K }
548tuple_impls! { A B C D E F G H I J K L }
549
550#[cfg(test)]
551mod tests {
552 use turbo_rcstr::rcstr;
553
554 use super::*;
555
556 fn assert_task_input<T>(_: T)
557 where
558 T: TaskInput,
559 {
560 }
561
562 #[test]
563 fn test_no_fields() -> Result<()> {
564 #[turbo_tasks::task_input]
565 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
566 struct NoFields;
567
568 assert_task_input(NoFields);
569 Ok(())
570 }
571
572 #[test]
573 fn test_one_unnamed_field() -> Result<()> {
574 #[turbo_tasks::task_input]
575 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
576 struct OneUnnamedField(u32);
577
578 assert_task_input(OneUnnamedField(42));
579 Ok(())
580 }
581
582 #[test]
583 fn test_multiple_unnamed_fields() -> Result<()> {
584 #[turbo_tasks::task_input]
585 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
586 struct MultipleUnnamedFields(u32, RcStr);
587
588 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
589 Ok(())
590 }
591
592 #[test]
593 fn test_one_named_field() -> Result<()> {
594 #[turbo_tasks::task_input]
595 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
596 struct OneNamedField {
597 named: u32,
598 }
599
600 assert_task_input(OneNamedField { named: 42 });
601 Ok(())
602 }
603
604 #[test]
605 fn test_multiple_named_fields() -> Result<()> {
606 #[turbo_tasks::task_input]
607 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
608 struct MultipleNamedFields {
609 named: u32,
610 other: RcStr,
611 }
612
613 assert_task_input(MultipleNamedFields {
614 named: 42,
615 other: rcstr!("42"),
616 });
617 Ok(())
618 }
619
620 #[test]
621 fn test_generic_field() -> Result<()> {
622 #[turbo_tasks::task_input]
623 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
624 struct GenericField<T>(T);
625
626 assert_task_input(GenericField(42));
627 assert_task_input(GenericField(rcstr!("42")));
628 Ok(())
629 }
630
631 #[turbo_tasks::task_input]
632 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
633 enum OneVariant {
634 Variant,
635 }
636
637 #[test]
638 fn test_one_variant() -> Result<()> {
639 assert_task_input(OneVariant::Variant);
640 Ok(())
641 }
642
643 #[test]
644 fn test_multiple_variants() -> Result<()> {
645 #[turbo_tasks::task_input]
646 #[derive(Clone, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
647 enum MultipleVariants {
648 Variant1,
649 Variant2,
650 }
651
652 assert_task_input(MultipleVariants::Variant2);
653 Ok(())
654 }
655
656 #[turbo_tasks::task_input]
657 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
658 enum MultipleVariantsAndHeterogeneousFields {
659 Variant1,
660 Variant2(u32),
661 Variant3 { named: u32 },
662 Variant4(u32, RcStr),
663 Variant5 { named: u32, other: RcStr },
664 }
665
666 #[test]
667 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
668 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
669 named: 42,
670 other: rcstr!("42"),
671 });
672 Ok(())
673 }
674
675 #[test]
676 fn test_nested_variants() -> Result<()> {
677 #[turbo_tasks::task_input]
678 #[derive(Clone, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
679 enum NestedVariants {
680 Variant1,
681 Variant2(MultipleVariantsAndHeterogeneousFields),
682 Variant3 { named: OneVariant },
683 Variant4(OneVariant, RcStr),
684 Variant5 { named: OneVariant, other: RcStr },
685 }
686
687 assert_task_input(NestedVariants::Variant5 {
688 named: OneVariant::Variant,
689 other: rcstr!("42"),
690 });
691 assert_task_input(NestedVariants::Variant2(
692 MultipleVariantsAndHeterogeneousFields::Variant5 {
693 named: 42,
694 other: rcstr!("42"),
695 },
696 ));
697 Ok(())
698 }
699}