1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 sync::Arc,
7 time::Duration,
8};
9
10use anyhow::Result;
11use either::Either;
12use serde::{Deserialize, Serialize};
13use turbo_rcstr::RcStr;
14
15use crate::{
16 MagicAny, ReadRef, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
17 trace::TraceRawVcs,
18};
19
20pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
26 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
27 async { Ok(self.clone()) }
28 }
29 fn is_resolved(&self) -> bool {
30 true
31 }
32 fn is_transient(&self) -> bool;
33}
34
35macro_rules! impl_task_input {
36 ($($t:ty),*) => {
37 $(
38 impl TaskInput for $t {
39 fn is_transient(&self) -> bool {
40 false
41 }
42 }
43 )*
44 };
45}
46
47impl_task_input! {
48 (),
49 bool,
50 u8,
51 u16,
52 u32,
53 i32,
54 u64,
55 usize,
56 RcStr,
57 TaskId,
58 ValueTypeId,
59 Duration,
60 String
61}
62
63impl<T> TaskInput for Vec<T>
64where
65 T: TaskInput,
66{
67 fn is_resolved(&self) -> bool {
68 self.iter().all(TaskInput::is_resolved)
69 }
70
71 fn is_transient(&self) -> bool {
72 self.iter().any(TaskInput::is_transient)
73 }
74
75 async fn resolve_input(&self) -> Result<Self> {
76 let mut resolved = Vec::with_capacity(self.len());
77 for value in self {
78 resolved.push(value.resolve_input().await?);
79 }
80 Ok(resolved)
81 }
82}
83
84impl<T> TaskInput for Box<T>
85where
86 T: TaskInput,
87{
88 fn is_resolved(&self) -> bool {
89 self.as_ref().is_resolved()
90 }
91
92 fn is_transient(&self) -> bool {
93 self.as_ref().is_transient()
94 }
95
96 async fn resolve_input(&self) -> Result<Self> {
97 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
98 }
99}
100
101impl<T> TaskInput for Arc<T>
102where
103 T: TaskInput,
104{
105 fn is_resolved(&self) -> bool {
106 self.as_ref().is_resolved()
107 }
108
109 fn is_transient(&self) -> bool {
110 self.as_ref().is_transient()
111 }
112
113 async fn resolve_input(&self) -> Result<Self> {
114 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
115 }
116}
117
118impl<T> TaskInput for ReadRef<T>
119where
120 T: TaskInput,
121{
122 fn is_resolved(&self) -> bool {
123 Self::as_raw_ref(self).is_resolved()
124 }
125
126 fn is_transient(&self) -> bool {
127 Self::as_raw_ref(self).is_transient()
128 }
129
130 async fn resolve_input(&self) -> Result<Self> {
131 Ok(ReadRef::new_owned(
132 Box::pin(Self::as_raw_ref(self).resolve_input()).await?,
133 ))
134 }
135}
136
137impl<T> TaskInput for Option<T>
138where
139 T: TaskInput,
140{
141 fn is_resolved(&self) -> bool {
142 match self {
143 Some(value) => value.is_resolved(),
144 None => true,
145 }
146 }
147
148 fn is_transient(&self) -> bool {
149 match self {
150 Some(value) => value.is_transient(),
151 None => false,
152 }
153 }
154
155 async fn resolve_input(&self) -> Result<Self> {
156 match self {
157 Some(value) => Ok(Some(value.resolve_input().await?)),
158 None => Ok(None),
159 }
160 }
161}
162
163impl<T> TaskInput for Vc<T>
164where
165 T: Send + Sync + ?Sized,
166{
167 fn is_resolved(&self) -> bool {
168 Vc::is_resolved(*self)
169 }
170
171 fn is_transient(&self) -> bool {
172 self.node.is_transient()
173 }
174
175 async fn resolve_input(&self) -> Result<Self> {
176 Vc::resolve(*self).await
177 }
178}
179
180impl<T> TaskInput for ResolvedVc<T>
183where
184 T: Send + Sync + ?Sized,
185{
186 fn is_resolved(&self) -> bool {
187 true
188 }
189
190 fn is_transient(&self) -> bool {
191 self.node.is_transient()
192 }
193
194 async fn resolve_input(&self) -> Result<Self> {
195 Ok(*self)
196 }
197}
198
199impl<T> TaskInput for TransientValue<T>
200where
201 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
202{
203 fn is_transient(&self) -> bool {
204 true
205 }
206}
207
208impl<T> Serialize for TransientValue<T>
209where
210 T: MagicAny + Clone + 'static,
211{
212 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
213 where
214 S: serde::Serializer,
215 {
216 Err(serde::ser::Error::custom(
217 "cannot serialize transient task inputs",
218 ))
219 }
220}
221
222impl<'de, T> Deserialize<'de> for TransientValue<T>
223where
224 T: MagicAny + Clone + 'static,
225{
226 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
227 where
228 D: serde::Deserializer<'de>,
229 {
230 Err(serde::de::Error::custom(
231 "cannot deserialize transient task inputs",
232 ))
233 }
234}
235
236impl<T> TaskInput for TransientInstance<T>
237where
238 T: Sync + Send + TraceRawVcs + 'static,
239{
240 fn is_transient(&self) -> bool {
241 true
242 }
243}
244
245impl<T> Serialize for TransientInstance<T> {
246 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
247 where
248 S: serde::Serializer,
249 {
250 Err(serde::ser::Error::custom(
251 "cannot serialize transient task inputs",
252 ))
253 }
254}
255
256impl<'de, T> Deserialize<'de> for TransientInstance<T> {
257 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
258 where
259 D: serde::Deserializer<'de>,
260 {
261 Err(serde::de::Error::custom(
262 "cannot deserialize transient task inputs",
263 ))
264 }
265}
266
267impl<L, R> TaskInput for Either<L, R>
268where
269 L: TaskInput,
270 R: TaskInput,
271{
272 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
273 self.as_ref().map_either(
274 |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
275 |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
276 )
277 }
278
279 fn is_resolved(&self) -> bool {
280 self.as_ref()
281 .either(TaskInput::is_resolved, TaskInput::is_resolved)
282 }
283
284 fn is_transient(&self) -> bool {
285 self.as_ref()
286 .either(TaskInput::is_transient, TaskInput::is_transient)
287 }
288}
289
290impl<K, V> TaskInput for BTreeMap<K, V>
291where
292 K: TaskInput + Ord,
293 V: TaskInput,
294{
295 async fn resolve_input(&self) -> Result<Self> {
296 let mut new_map = BTreeMap::new();
297 for (k, v) in self {
298 new_map.insert(
299 TaskInput::resolve_input(k).await?,
300 TaskInput::resolve_input(v).await?,
301 );
302 }
303 Ok(new_map)
304 }
305
306 fn is_resolved(&self) -> bool {
307 self.iter()
308 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
309 }
310
311 fn is_transient(&self) -> bool {
312 self.iter()
313 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
314 }
315}
316
317impl<T> TaskInput for BTreeSet<T>
318where
319 T: TaskInput + Ord,
320{
321 async fn resolve_input(&self) -> Result<Self> {
322 let mut new_map = BTreeSet::new();
323 for value in self {
324 new_map.insert(TaskInput::resolve_input(value).await?);
325 }
326 Ok(new_map)
327 }
328
329 fn is_resolved(&self) -> bool {
330 self.iter().all(TaskInput::is_resolved)
331 }
332
333 fn is_transient(&self) -> bool {
334 self.iter().any(TaskInput::is_transient)
335 }
336}
337
338macro_rules! tuple_impls {
339 ( $( $name:ident )+ ) => {
340 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
341 where $($name: TaskInput),+
342 {
343 #[allow(non_snake_case)]
344 fn is_resolved(&self) -> bool {
345 let ($($name,)+) = self;
346 $($name.is_resolved() &&)+ true
347 }
348
349 #[allow(non_snake_case)]
350 fn is_transient(&self) -> bool {
351 let ($($name,)+) = self;
352 $($name.is_transient() ||)+ false
353 }
354
355 #[allow(non_snake_case)]
356 async fn resolve_input(&self) -> Result<Self> {
357 let ($($name,)+) = self;
358 Ok(($($name.resolve_input().await?,)+))
359 }
360 }
361 };
362}
363
364tuple_impls! { A }
366tuple_impls! { A B }
367tuple_impls! { A B C }
368tuple_impls! { A B C D }
369tuple_impls! { A B C D E }
370tuple_impls! { A B C D E F }
371tuple_impls! { A B C D E F G }
372tuple_impls! { A B C D E F G H }
373tuple_impls! { A B C D E F G H I }
374tuple_impls! { A B C D E F G H I J }
375tuple_impls! { A B C D E F G H I J K }
376tuple_impls! { A B C D E F G H I J K L }
377
378#[cfg(test)]
379mod tests {
380 use turbo_rcstr::rcstr;
381 use turbo_tasks_macros::TaskInput;
382
383 use super::*;
384 use crate as turbo_tasks;
387
388 fn assert_task_input<T>(_: T)
389 where
390 T: TaskInput,
391 {
392 }
393
394 #[test]
395 fn test_no_fields() -> Result<()> {
396 #[derive(
397 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
398 )]
399 struct NoFields;
400
401 assert_task_input(NoFields);
402 Ok(())
403 }
404
405 #[test]
406 fn test_one_unnamed_field() -> Result<()> {
407 #[derive(
408 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
409 )]
410 struct OneUnnamedField(u32);
411
412 assert_task_input(OneUnnamedField(42));
413 Ok(())
414 }
415
416 #[test]
417 fn test_multiple_unnamed_fields() -> Result<()> {
418 #[derive(
419 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
420 )]
421 struct MultipleUnnamedFields(u32, RcStr);
422
423 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
424 Ok(())
425 }
426
427 #[test]
428 fn test_one_named_field() -> Result<()> {
429 #[derive(
430 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
431 )]
432 struct OneNamedField {
433 named: u32,
434 }
435
436 assert_task_input(OneNamedField { named: 42 });
437 Ok(())
438 }
439
440 #[test]
441 fn test_multiple_named_fields() -> Result<()> {
442 #[derive(
443 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
444 )]
445 struct MultipleNamedFields {
446 named: u32,
447 other: RcStr,
448 }
449
450 assert_task_input(MultipleNamedFields {
451 named: 42,
452 other: rcstr!("42"),
453 });
454 Ok(())
455 }
456
457 #[test]
458 fn test_generic_field() -> Result<()> {
459 #[derive(
460 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
461 )]
462 struct GenericField<T>(T);
463
464 assert_task_input(GenericField(42));
465 assert_task_input(GenericField(rcstr!("42")));
466 Ok(())
467 }
468
469 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
470 enum OneVariant {
471 Variant,
472 }
473
474 #[test]
475 fn test_one_variant() -> Result<()> {
476 assert_task_input(OneVariant::Variant);
477 Ok(())
478 }
479
480 #[test]
481 fn test_multiple_variants() -> Result<()> {
482 #[derive(
483 Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
484 )]
485 enum MultipleVariants {
486 Variant1,
487 Variant2,
488 }
489
490 assert_task_input(MultipleVariants::Variant2);
491 Ok(())
492 }
493
494 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
495 enum MultipleVariantsAndHeterogeneousFields {
496 Variant1,
497 Variant2(u32),
498 Variant3 { named: u32 },
499 Variant4(u32, RcStr),
500 Variant5 { named: u32, other: RcStr },
501 }
502
503 #[test]
504 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
505 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
506 named: 42,
507 other: rcstr!("42"),
508 });
509 Ok(())
510 }
511
512 #[test]
513 fn test_nested_variants() -> Result<()> {
514 #[derive(
515 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
516 )]
517 enum NestedVariants {
518 Variant1,
519 Variant2(MultipleVariantsAndHeterogeneousFields),
520 Variant3 { named: OneVariant },
521 Variant4(OneVariant, RcStr),
522 Variant5 { named: OneVariant, other: RcStr },
523 }
524
525 assert_task_input(NestedVariants::Variant5 {
526 named: OneVariant::Variant,
527 other: rcstr!("42"),
528 });
529 assert_task_input(NestedVariants::Variant2(
530 MultipleVariantsAndHeterogeneousFields::Variant5 {
531 named: 42,
532 other: rcstr!("42"),
533 },
534 ));
535 Ok(())
536 }
537}