1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 ops::{Deref, DerefMut},
7 sync::Arc,
8 time::Duration,
9};
10
11use anyhow::Result;
12use bincode::{
13 Decode, Encode,
14 de::Decoder,
15 enc::Encoder,
16 error::{DecodeError, EncodeError},
17};
18use either::Either;
19use turbo_frozenmap::{FrozenMap, FrozenSet};
20use turbo_rcstr::RcStr;
21
22use crate as turbo_tasks;
25use crate::{
26 MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
27 trace::TraceRawVcs,
28};
29
30pub trait TaskInput:
39 Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs + Encode + Decode<()>
40{
41 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
42 async { Ok(self.clone()) }
43 }
44 fn is_resolved(&self) -> bool {
45 true
46 }
47 fn is_transient(&self) -> bool;
48}
49
50macro_rules! impl_task_input {
51 ($($t:ty),*) => {
52 $(
53 impl TaskInput for $t {
54 fn is_transient(&self) -> bool {
55 false
56 }
57 }
58 )*
59 };
60}
61
62impl_task_input! {
63 (),
64 bool,
65 u8,
66 u16,
67 u32,
68 i32,
69 u64,
70 usize,
71 RcStr,
72 TaskId,
73 ValueTypeId,
74 Duration,
75 String
76}
77
78impl<T> TaskInput for Vec<T>
79where
80 T: TaskInput,
81{
82 fn is_resolved(&self) -> bool {
83 self.iter().all(TaskInput::is_resolved)
84 }
85
86 fn is_transient(&self) -> bool {
87 self.iter().any(TaskInput::is_transient)
88 }
89
90 async fn resolve_input(&self) -> Result<Self> {
91 let mut resolved = Vec::with_capacity(self.len());
92 for value in self {
93 resolved.push(value.resolve_input().await?);
94 }
95 Ok(resolved)
96 }
97}
98
99impl<T> TaskInput for Box<T>
100where
101 T: TaskInput,
102{
103 fn is_resolved(&self) -> bool {
104 self.as_ref().is_resolved()
105 }
106
107 fn is_transient(&self) -> bool {
108 self.as_ref().is_transient()
109 }
110
111 async fn resolve_input(&self) -> Result<Self> {
112 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
113 }
114}
115
116impl<T> TaskInput for Arc<T>
117where
118 T: TaskInput,
119{
120 fn is_resolved(&self) -> bool {
121 self.as_ref().is_resolved()
122 }
123
124 fn is_transient(&self) -> bool {
125 self.as_ref().is_transient()
126 }
127
128 async fn resolve_input(&self) -> Result<Self> {
129 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
130 }
131}
132
133impl<T> TaskInput for ReadRef<T>
134where
135 T: TaskInput,
136{
137 fn is_resolved(&self) -> bool {
138 Self::as_raw_ref(self).is_resolved()
139 }
140
141 fn is_transient(&self) -> bool {
142 Self::as_raw_ref(self).is_transient()
143 }
144
145 async fn resolve_input(&self) -> Result<Self> {
146 Ok(ReadRef::new_owned(
147 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
148 ))
149 }
150}
151
152impl<T> TaskInput for Option<T>
153where
154 T: TaskInput,
155{
156 fn is_resolved(&self) -> bool {
157 match self {
158 Some(value) => value.is_resolved(),
159 None => true,
160 }
161 }
162
163 fn is_transient(&self) -> bool {
164 match self {
165 Some(value) => value.is_transient(),
166 None => false,
167 }
168 }
169
170 async fn resolve_input(&self) -> Result<Self> {
171 match self {
172 Some(value) => Ok(Some(value.resolve_input().await?)),
173 None => Ok(None),
174 }
175 }
176}
177
178impl<T> TaskInput for Vc<T>
179where
180 T: Send + Sync + ?Sized,
181{
182 fn is_resolved(&self) -> bool {
183 Vc::is_resolved(*self)
184 }
185
186 fn is_transient(&self) -> bool {
187 self.node.is_transient()
188 }
189
190 async fn resolve_input(&self) -> Result<Self> {
191 Vc::resolve(*self).await
192 }
193}
194
195impl<T> TaskInput for ResolvedVc<T>
198where
199 T: Send + Sync + ?Sized,
200{
201 fn is_resolved(&self) -> bool {
202 true
203 }
204
205 fn is_transient(&self) -> bool {
206 self.node.is_transient()
207 }
208
209 async fn resolve_input(&self) -> Result<Self> {
210 Ok(*self)
211 }
212}
213
214impl<T> TaskInput for TransientValue<T>
215where
216 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
217{
218 fn is_transient(&self) -> bool {
219 true
220 }
221}
222
223impl<T> Encode for TransientValue<T> {
224 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
225 Err(EncodeError::Other("cannot encode transient task inputs"))
226 }
227}
228
229impl<Context, T> Decode<Context> for TransientValue<T> {
230 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
231 Err(DecodeError::Other("cannot decode transient task inputs"))
232 }
233}
234
235impl<T> TaskInput for TransientInstance<T>
236where
237 T: Sync + Send + TraceRawVcs + 'static,
238{
239 fn is_transient(&self) -> bool {
240 true
241 }
242}
243
244impl<T> Encode for TransientInstance<T> {
245 fn encode<E: Encoder>(&self, _encoder: &mut E) -> Result<(), EncodeError> {
246 Err(EncodeError::Other("cannot encode transient task inputs"))
247 }
248}
249
250impl<Context, T> Decode<Context> for TransientInstance<T> {
251 fn decode<D: Decoder<Context = Context>>(_decoder: &mut D) -> Result<Self, DecodeError> {
252 Err(DecodeError::Other("cannot decode transient task inputs"))
253 }
254}
255
256impl<K, V> TaskInput for BTreeMap<K, V>
257where
258 K: TaskInput + Ord,
259 V: TaskInput,
260{
261 async fn resolve_input(&self) -> Result<Self> {
262 let mut new_map = BTreeMap::new();
263 for (k, v) in self {
264 new_map.insert(
265 TaskInput::resolve_input(k).await?,
266 TaskInput::resolve_input(v).await?,
267 );
268 }
269 Ok(new_map)
270 }
271
272 fn is_resolved(&self) -> bool {
273 self.iter()
274 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
275 }
276
277 fn is_transient(&self) -> bool {
278 self.iter()
279 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
280 }
281}
282
283impl<T> TaskInput for BTreeSet<T>
284where
285 T: TaskInput + Ord,
286{
287 async fn resolve_input(&self) -> Result<Self> {
288 let mut new_set = BTreeSet::new();
289 for value in self {
290 new_set.insert(TaskInput::resolve_input(value).await?);
291 }
292 Ok(new_set)
293 }
294
295 fn is_resolved(&self) -> bool {
296 self.iter().all(TaskInput::is_resolved)
297 }
298
299 fn is_transient(&self) -> bool {
300 self.iter().any(TaskInput::is_transient)
301 }
302}
303
304impl<K, V> TaskInput for FrozenMap<K, V>
305where
306 K: TaskInput + Ord + 'static,
307 V: TaskInput + 'static,
308{
309 async fn resolve_input(&self) -> Result<Self> {
310 let mut new_entries = Vec::with_capacity(self.len());
311 for (k, v) in self {
312 new_entries.push((
313 TaskInput::resolve_input(k).await?,
314 TaskInput::resolve_input(v).await?,
315 ));
316 }
317 Ok(Self::from(new_entries))
319 }
320
321 fn is_resolved(&self) -> bool {
322 self.iter()
323 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
324 }
325
326 fn is_transient(&self) -> bool {
327 self.iter()
328 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
329 }
330}
331
332impl<T> TaskInput for FrozenSet<T>
333where
334 T: TaskInput + Ord + 'static,
335{
336 async fn resolve_input(&self) -> Result<Self> {
337 let mut new_set = Vec::with_capacity(self.len());
338 for value in self {
339 new_set.push(TaskInput::resolve_input(value).await?);
340 }
341 Ok(Self::from_iter(new_set))
342 }
343
344 fn is_resolved(&self) -> bool {
345 self.iter().all(TaskInput::is_resolved)
346 }
347
348 fn is_transient(&self) -> bool {
349 self.iter().any(TaskInput::is_transient)
350 }
351}
352
353#[derive(Clone, Debug, PartialEq, Eq, Hash, TraceRawVcs)]
356pub struct EitherTaskInput<L, R>(pub Either<L, R>);
357
358impl<L, R> Deref for EitherTaskInput<L, R> {
359 type Target = Either<L, R>;
360
361 fn deref(&self) -> &Self::Target {
362 &self.0
363 }
364}
365
366impl<L, R> DerefMut for EitherTaskInput<L, R> {
367 fn deref_mut(&mut self) -> &mut Self::Target {
368 &mut self.0
369 }
370}
371
372impl<L, R> Encode for EitherTaskInput<L, R>
373where
374 L: Encode,
375 R: Encode,
376{
377 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
378 turbo_bincode::either::encode(self, encoder)
379 }
380}
381
382impl<Context, L, R> Decode<Context> for EitherTaskInput<L, R>
383where
384 L: Decode<Context>,
385 R: Decode<Context>,
386{
387 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
388 turbo_bincode::either::decode(decoder).map(Self)
389 }
390}
391
392impl<L, R> TaskInput for EitherTaskInput<L, R>
393where
394 L: TaskInput,
395 R: TaskInput,
396{
397 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
398 self.as_ref().map_either(
399 |l| async move { anyhow::Ok(Self(Either::Left(l.resolve_input().await?))) },
400 |r| async move { anyhow::Ok(Self(Either::Right(r.resolve_input().await?))) },
401 )
402 }
403
404 fn is_resolved(&self) -> bool {
405 self.as_ref()
406 .either(TaskInput::is_resolved, TaskInput::is_resolved)
407 }
408
409 fn is_transient(&self) -> bool {
410 self.as_ref()
411 .either(TaskInput::is_transient, TaskInput::is_transient)
412 }
413}
414
415macro_rules! tuple_impls {
416 ( $( $name:ident )+ ) => {
417 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
418 where $($name: TaskInput),+
419 {
420 #[allow(non_snake_case)]
421 fn is_resolved(&self) -> bool {
422 let ($($name,)+) = self;
423 $($name.is_resolved() &&)+ true
424 }
425
426 #[allow(non_snake_case)]
427 fn is_transient(&self) -> bool {
428 let ($($name,)+) = self;
429 $($name.is_transient() ||)+ false
430 }
431
432 #[allow(non_snake_case)]
433 async fn resolve_input(&self) -> Result<Self> {
434 let ($($name,)+) = self;
435 Ok(($($name.resolve_input().await?,)+))
436 }
437 }
438 };
439}
440
441tuple_impls! { A }
443tuple_impls! { A B }
444tuple_impls! { A B C }
445tuple_impls! { A B C D }
446tuple_impls! { A B C D E }
447tuple_impls! { A B C D E F }
448tuple_impls! { A B C D E F G }
449tuple_impls! { A B C D E F G H }
450tuple_impls! { A B C D E F G H I }
451tuple_impls! { A B C D E F G H I J }
452tuple_impls! { A B C D E F G H I J K }
453tuple_impls! { A B C D E F G H I J K L }
454
455#[cfg(test)]
456mod tests {
457 use turbo_rcstr::rcstr;
458 use turbo_tasks_macros::TaskInput;
459
460 use super::*;
461
462 fn assert_task_input<T>(_: T)
463 where
464 T: TaskInput,
465 {
466 }
467
468 #[test]
469 fn test_no_fields() -> Result<()> {
470 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
471 struct NoFields;
472
473 assert_task_input(NoFields);
474 Ok(())
475 }
476
477 #[test]
478 fn test_one_unnamed_field() -> Result<()> {
479 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
480 struct OneUnnamedField(u32);
481
482 assert_task_input(OneUnnamedField(42));
483 Ok(())
484 }
485
486 #[test]
487 fn test_multiple_unnamed_fields() -> Result<()> {
488 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
489 struct MultipleUnnamedFields(u32, RcStr);
490
491 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
492 Ok(())
493 }
494
495 #[test]
496 fn test_one_named_field() -> Result<()> {
497 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
498 struct OneNamedField {
499 named: u32,
500 }
501
502 assert_task_input(OneNamedField { named: 42 });
503 Ok(())
504 }
505
506 #[test]
507 fn test_multiple_named_fields() -> Result<()> {
508 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
509 struct MultipleNamedFields {
510 named: u32,
511 other: RcStr,
512 }
513
514 assert_task_input(MultipleNamedFields {
515 named: 42,
516 other: rcstr!("42"),
517 });
518 Ok(())
519 }
520
521 #[test]
522 fn test_generic_field() -> Result<()> {
523 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
524 struct GenericField<T>(T);
525
526 assert_task_input(GenericField(42));
527 assert_task_input(GenericField(rcstr!("42")));
528 Ok(())
529 }
530
531 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
532 enum OneVariant {
533 Variant,
534 }
535
536 #[test]
537 fn test_one_variant() -> Result<()> {
538 assert_task_input(OneVariant::Variant);
539 Ok(())
540 }
541
542 #[test]
543 fn test_multiple_variants() -> Result<()> {
544 #[derive(Clone, TaskInput, PartialEq, Eq, Hash, Debug, Encode, Decode, TraceRawVcs)]
545 enum MultipleVariants {
546 Variant1,
547 Variant2,
548 }
549
550 assert_task_input(MultipleVariants::Variant2);
551 Ok(())
552 }
553
554 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
555 enum MultipleVariantsAndHeterogeneousFields {
556 Variant1,
557 Variant2(u32),
558 Variant3 { named: u32 },
559 Variant4(u32, RcStr),
560 Variant5 { named: u32, other: RcStr },
561 }
562
563 #[test]
564 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
565 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
566 named: 42,
567 other: rcstr!("42"),
568 });
569 Ok(())
570 }
571
572 #[test]
573 fn test_nested_variants() -> Result<()> {
574 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Encode, Decode, TraceRawVcs)]
575 enum NestedVariants {
576 Variant1,
577 Variant2(MultipleVariantsAndHeterogeneousFields),
578 Variant3 { named: OneVariant },
579 Variant4(OneVariant, RcStr),
580 Variant5 { named: OneVariant, other: RcStr },
581 }
582
583 assert_task_input(NestedVariants::Variant5 {
584 named: OneVariant::Variant,
585 other: rcstr!("42"),
586 });
587 assert_task_input(NestedVariants::Variant2(
588 MultipleVariantsAndHeterogeneousFields::Variant5 {
589 named: 42,
590 other: rcstr!("42"),
591 },
592 ));
593 Ok(())
594 }
595}