1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 future::Future,
5 hash::Hash,
6 sync::Arc,
7 time::Duration,
8};
9
10use anyhow::Result;
11use either::Either;
12use serde::{Deserialize, Serialize};
13use turbo_rcstr::RcStr;
14
15use crate::{
16 MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
17 trace::TraceRawVcs,
18};
19
20pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
23 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
24 async { Ok(self.clone()) }
25 }
26 fn is_resolved(&self) -> bool {
27 true
28 }
29 fn is_transient(&self) -> bool;
30}
31
32macro_rules! impl_task_input {
33 ($($t:ty),*) => {
34 $(
35 impl TaskInput for $t {
36 fn is_transient(&self) -> bool {
37 false
38 }
39 }
40 )*
41 };
42}
43
44impl_task_input! {
45 (),
46 bool,
47 u8,
48 u16,
49 u32,
50 i32,
51 u64,
52 usize,
53 RcStr,
54 TaskId,
55 ValueTypeId,
56 Duration,
57 String
58}
59
60impl<T> TaskInput for Vec<T>
61where
62 T: TaskInput,
63{
64 fn is_resolved(&self) -> bool {
65 self.iter().all(TaskInput::is_resolved)
66 }
67
68 fn is_transient(&self) -> bool {
69 self.iter().any(TaskInput::is_transient)
70 }
71
72 async fn resolve_input(&self) -> Result<Self> {
73 let mut resolved = Vec::with_capacity(self.len());
74 for value in self {
75 resolved.push(value.resolve_input().await?);
76 }
77 Ok(resolved)
78 }
79}
80
81impl<T> TaskInput for Box<T>
82where
83 T: TaskInput,
84{
85 fn is_resolved(&self) -> bool {
86 self.as_ref().is_resolved()
87 }
88
89 fn is_transient(&self) -> bool {
90 self.as_ref().is_transient()
91 }
92
93 async fn resolve_input(&self) -> Result<Self> {
94 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
95 }
96}
97
98impl<T> TaskInput for Arc<T>
99where
100 T: TaskInput,
101{
102 fn is_resolved(&self) -> bool {
103 self.as_ref().is_resolved()
104 }
105
106 fn is_transient(&self) -> bool {
107 self.as_ref().is_transient()
108 }
109
110 async fn resolve_input(&self) -> Result<Self> {
111 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
112 }
113}
114
115impl<T> TaskInput for Option<T>
116where
117 T: TaskInput,
118{
119 fn is_resolved(&self) -> bool {
120 match self {
121 Some(value) => value.is_resolved(),
122 None => true,
123 }
124 }
125
126 fn is_transient(&self) -> bool {
127 match self {
128 Some(value) => value.is_transient(),
129 None => false,
130 }
131 }
132
133 async fn resolve_input(&self) -> Result<Self> {
134 match self {
135 Some(value) => Ok(Some(value.resolve_input().await?)),
136 None => Ok(None),
137 }
138 }
139}
140
141impl<T> TaskInput for Vc<T>
142where
143 T: Send + Sync + ?Sized,
144{
145 fn is_resolved(&self) -> bool {
146 Vc::is_resolved(*self)
147 }
148
149 fn is_transient(&self) -> bool {
150 self.node.is_transient()
151 }
152
153 async fn resolve_input(&self) -> Result<Self> {
154 Vc::resolve(*self).await
155 }
156}
157
158impl<T> TaskInput for ResolvedVc<T>
161where
162 T: Send + Sync + ?Sized,
163{
164 fn is_resolved(&self) -> bool {
165 true
166 }
167
168 fn is_transient(&self) -> bool {
169 self.node.is_transient()
170 }
171
172 async fn resolve_input(&self) -> Result<Self> {
173 Ok(*self)
174 }
175}
176
177impl<T> TaskInput for TransientValue<T>
178where
179 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
180{
181 fn is_transient(&self) -> bool {
182 true
183 }
184}
185
186impl<T> Serialize for TransientValue<T>
187where
188 T: MagicAny + Clone + 'static,
189{
190 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
191 where
192 S: serde::Serializer,
193 {
194 Err(serde::ser::Error::custom(
195 "cannot serialize transient task inputs",
196 ))
197 }
198}
199
200impl<'de, T> Deserialize<'de> for TransientValue<T>
201where
202 T: MagicAny + Clone + 'static,
203{
204 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
205 where
206 D: serde::Deserializer<'de>,
207 {
208 Err(serde::de::Error::custom(
209 "cannot deserialize transient task inputs",
210 ))
211 }
212}
213
214impl<T> TaskInput for TransientInstance<T>
215where
216 T: Sync + Send + TraceRawVcs + 'static,
217{
218 fn is_transient(&self) -> bool {
219 true
220 }
221}
222
223impl<T> Serialize for TransientInstance<T> {
224 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
225 where
226 S: serde::Serializer,
227 {
228 Err(serde::ser::Error::custom(
229 "cannot serialize transient task inputs",
230 ))
231 }
232}
233
234impl<'de, T> Deserialize<'de> for TransientInstance<T> {
235 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
236 where
237 D: serde::Deserializer<'de>,
238 {
239 Err(serde::de::Error::custom(
240 "cannot deserialize transient task inputs",
241 ))
242 }
243}
244
245impl<L, R> TaskInput for Either<L, R>
246where
247 L: TaskInput,
248 R: TaskInput,
249{
250 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
251 self.as_ref().map_either(
252 |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
253 |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
254 )
255 }
256
257 fn is_resolved(&self) -> bool {
258 self.as_ref()
259 .either(TaskInput::is_resolved, TaskInput::is_resolved)
260 }
261
262 fn is_transient(&self) -> bool {
263 self.as_ref()
264 .either(TaskInput::is_transient, TaskInput::is_transient)
265 }
266}
267
268impl<K, V> TaskInput for BTreeMap<K, V>
269where
270 K: TaskInput + Ord,
271 V: TaskInput,
272{
273 async fn resolve_input(&self) -> Result<Self> {
274 let mut new_map = BTreeMap::new();
275 for (k, v) in self {
276 new_map.insert(
277 TaskInput::resolve_input(k).await?,
278 TaskInput::resolve_input(v).await?,
279 );
280 }
281 Ok(new_map)
282 }
283
284 fn is_resolved(&self) -> bool {
285 self.iter()
286 .all(|(k, v)| TaskInput::is_resolved(k) && TaskInput::is_resolved(v))
287 }
288
289 fn is_transient(&self) -> bool {
290 self.iter()
291 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
292 }
293}
294
295impl<T> TaskInput for BTreeSet<T>
296where
297 T: TaskInput + Ord,
298{
299 async fn resolve_input(&self) -> Result<Self> {
300 let mut new_map = BTreeSet::new();
301 for value in self {
302 new_map.insert(TaskInput::resolve_input(value).await?);
303 }
304 Ok(new_map)
305 }
306
307 fn is_resolved(&self) -> bool {
308 self.iter().all(TaskInput::is_resolved)
309 }
310
311 fn is_transient(&self) -> bool {
312 self.iter().any(TaskInput::is_transient)
313 }
314}
315
316macro_rules! tuple_impls {
317 ( $( $name:ident )+ ) => {
318 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
319 where $($name: TaskInput),+
320 {
321 #[allow(non_snake_case)]
322 fn is_resolved(&self) -> bool {
323 let ($($name,)+) = self;
324 $($name.is_resolved() &&)+ true
325 }
326
327 #[allow(non_snake_case)]
328 fn is_transient(&self) -> bool {
329 let ($($name,)+) = self;
330 $($name.is_transient() ||)+ false
331 }
332
333 #[allow(non_snake_case)]
334 async fn resolve_input(&self) -> Result<Self> {
335 let ($($name,)+) = self;
336 Ok(($($name.resolve_input().await?,)+))
337 }
338 }
339 };
340}
341
342tuple_impls! { A }
344tuple_impls! { A B }
345tuple_impls! { A B C }
346tuple_impls! { A B C D }
347tuple_impls! { A B C D E }
348tuple_impls! { A B C D E F }
349tuple_impls! { A B C D E F G }
350tuple_impls! { A B C D E F G H }
351tuple_impls! { A B C D E F G H I }
352tuple_impls! { A B C D E F G H I J }
353tuple_impls! { A B C D E F G H I J K }
354tuple_impls! { A B C D E F G H I J K L }
355
356#[cfg(test)]
357mod tests {
358 use turbo_rcstr::rcstr;
359 use turbo_tasks_macros::TaskInput;
360
361 use super::*;
362 use crate as turbo_tasks;
365
366 fn assert_task_input<T>(_: T)
367 where
368 T: TaskInput,
369 {
370 }
371
372 #[test]
373 fn test_no_fields() -> Result<()> {
374 #[derive(
375 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
376 )]
377 struct NoFields;
378
379 assert_task_input(NoFields);
380 Ok(())
381 }
382
383 #[test]
384 fn test_one_unnamed_field() -> Result<()> {
385 #[derive(
386 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
387 )]
388 struct OneUnnamedField(u32);
389
390 assert_task_input(OneUnnamedField(42));
391 Ok(())
392 }
393
394 #[test]
395 fn test_multiple_unnamed_fields() -> Result<()> {
396 #[derive(
397 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
398 )]
399 struct MultipleUnnamedFields(u32, RcStr);
400
401 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
402 Ok(())
403 }
404
405 #[test]
406 fn test_one_named_field() -> Result<()> {
407 #[derive(
408 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
409 )]
410 struct OneNamedField {
411 named: u32,
412 }
413
414 assert_task_input(OneNamedField { named: 42 });
415 Ok(())
416 }
417
418 #[test]
419 fn test_multiple_named_fields() -> Result<()> {
420 #[derive(
421 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
422 )]
423 struct MultipleNamedFields {
424 named: u32,
425 other: RcStr,
426 }
427
428 assert_task_input(MultipleNamedFields {
429 named: 42,
430 other: rcstr!("42"),
431 });
432 Ok(())
433 }
434
435 #[test]
436 fn test_generic_field() -> Result<()> {
437 #[derive(
438 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
439 )]
440 struct GenericField<T>(T);
441
442 assert_task_input(GenericField(42));
443 assert_task_input(GenericField(rcstr!("42")));
444 Ok(())
445 }
446
447 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
448 enum OneVariant {
449 Variant,
450 }
451
452 #[test]
453 fn test_one_variant() -> Result<()> {
454 assert_task_input(OneVariant::Variant);
455 Ok(())
456 }
457
458 #[test]
459 fn test_multiple_variants() -> Result<()> {
460 #[derive(
461 Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
462 )]
463 enum MultipleVariants {
464 Variant1,
465 Variant2,
466 }
467
468 assert_task_input(MultipleVariants::Variant2);
469 Ok(())
470 }
471
472 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
473 enum MultipleVariantsAndHeterogeneousFields {
474 Variant1,
475 Variant2(u32),
476 Variant3 { named: u32 },
477 Variant4(u32, RcStr),
478 Variant5 { named: u32, other: RcStr },
479 }
480
481 #[test]
482 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
483 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
484 named: 42,
485 other: rcstr!("42"),
486 });
487 Ok(())
488 }
489
490 #[test]
491 fn test_nested_variants() -> Result<()> {
492 #[derive(
493 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
494 )]
495 enum NestedVariants {
496 Variant1,
497 Variant2(MultipleVariantsAndHeterogeneousFields),
498 Variant3 { named: OneVariant },
499 Variant4(OneVariant, RcStr),
500 Variant5 { named: OneVariant, other: RcStr },
501 }
502
503 assert_task_input(NestedVariants::Variant5 {
504 named: OneVariant::Variant,
505 other: rcstr!("42"),
506 });
507 assert_task_input(NestedVariants::Variant2(
508 MultipleVariantsAndHeterogeneousFields::Variant5 {
509 named: 42,
510 other: rcstr!("42"),
511 },
512 ));
513 Ok(())
514 }
515}