1use std::{
2 collections::BTreeMap, fmt::Debug, future::Future, hash::Hash, sync::Arc, time::Duration,
3};
4
5use anyhow::Result;
6use either::Either;
7use serde::{Deserialize, Serialize};
8use turbo_rcstr::RcStr;
9
10use crate::{
11 MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, ValueTypeId, Vc,
12 trace::TraceRawVcs,
13};
14
15pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
18 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
19 async { Ok(self.clone()) }
20 }
21 fn is_resolved(&self) -> bool {
22 true
23 }
24 fn is_transient(&self) -> bool;
25}
26
27macro_rules! impl_task_input {
28 ($($t:ty),*) => {
29 $(
30 impl TaskInput for $t {
31 fn is_transient(&self) -> bool {
32 false
33 }
34 }
35 )*
36 };
37}
38
39impl_task_input! {
40 (),
41 bool,
42 u8,
43 u16,
44 u32,
45 i32,
46 u64,
47 usize,
48 RcStr,
49 TaskId,
50 ValueTypeId,
51 Duration,
52 String
53}
54
55impl<T> TaskInput for Vec<T>
56where
57 T: TaskInput,
58{
59 fn is_resolved(&self) -> bool {
60 self.iter().all(TaskInput::is_resolved)
61 }
62
63 fn is_transient(&self) -> bool {
64 self.iter().any(TaskInput::is_transient)
65 }
66
67 async fn resolve_input(&self) -> Result<Self> {
68 let mut resolved = Vec::with_capacity(self.len());
69 for value in self {
70 resolved.push(value.resolve_input().await?);
71 }
72 Ok(resolved)
73 }
74}
75
76impl<T> TaskInput for Box<T>
77where
78 T: TaskInput,
79{
80 fn is_resolved(&self) -> bool {
81 self.as_ref().is_resolved()
82 }
83
84 fn is_transient(&self) -> bool {
85 self.as_ref().is_transient()
86 }
87
88 async fn resolve_input(&self) -> Result<Self> {
89 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
90 }
91}
92
93impl<T> TaskInput for Arc<T>
94where
95 T: TaskInput,
96{
97 fn is_resolved(&self) -> bool {
98 self.as_ref().is_resolved()
99 }
100
101 fn is_transient(&self) -> bool {
102 self.as_ref().is_transient()
103 }
104
105 async fn resolve_input(&self) -> Result<Self> {
106 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
107 }
108}
109
110impl<T> TaskInput for Option<T>
111where
112 T: TaskInput,
113{
114 fn is_resolved(&self) -> bool {
115 match self {
116 Some(value) => value.is_resolved(),
117 None => true,
118 }
119 }
120
121 fn is_transient(&self) -> bool {
122 match self {
123 Some(value) => value.is_transient(),
124 None => false,
125 }
126 }
127
128 async fn resolve_input(&self) -> Result<Self> {
129 match self {
130 Some(value) => Ok(Some(value.resolve_input().await?)),
131 None => Ok(None),
132 }
133 }
134}
135
136impl<T> TaskInput for Vc<T>
137where
138 T: Send + Sync + ?Sized,
139{
140 fn is_resolved(&self) -> bool {
141 Vc::is_resolved(*self)
142 }
143
144 fn is_transient(&self) -> bool {
145 self.node.is_transient()
146 }
147
148 async fn resolve_input(&self) -> Result<Self> {
149 Vc::resolve(*self).await
150 }
151}
152
153impl<T> TaskInput for ResolvedVc<T>
156where
157 T: Send + Sync + ?Sized,
158{
159 fn is_resolved(&self) -> bool {
160 true
161 }
162
163 fn is_transient(&self) -> bool {
164 self.node.is_transient()
165 }
166
167 async fn resolve_input(&self) -> Result<Self> {
168 Ok(*self)
169 }
170}
171
172impl<T> TaskInput for TransientValue<T>
173where
174 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
175{
176 fn is_transient(&self) -> bool {
177 true
178 }
179}
180
181impl<T> Serialize for TransientValue<T>
182where
183 T: MagicAny + Clone + 'static,
184{
185 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
186 where
187 S: serde::Serializer,
188 {
189 Err(serde::ser::Error::custom(
190 "cannot serialize transient task inputs",
191 ))
192 }
193}
194
195impl<'de, T> Deserialize<'de> for TransientValue<T>
196where
197 T: MagicAny + Clone + 'static,
198{
199 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
200 where
201 D: serde::Deserializer<'de>,
202 {
203 Err(serde::de::Error::custom(
204 "cannot deserialize transient task inputs",
205 ))
206 }
207}
208
209impl<T> TaskInput for TransientInstance<T>
210where
211 T: Sync + Send + TraceRawVcs + 'static,
212{
213 fn is_transient(&self) -> bool {
214 true
215 }
216}
217
218impl<T> Serialize for TransientInstance<T> {
219 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
220 where
221 S: serde::Serializer,
222 {
223 Err(serde::ser::Error::custom(
224 "cannot serialize transient task inputs",
225 ))
226 }
227}
228
229impl<'de, T> Deserialize<'de> for TransientInstance<T> {
230 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
231 where
232 D: serde::Deserializer<'de>,
233 {
234 Err(serde::de::Error::custom(
235 "cannot deserialize transient task inputs",
236 ))
237 }
238}
239
240impl<L, R> TaskInput for Either<L, R>
241where
242 L: TaskInput,
243 R: TaskInput,
244{
245 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
246 self.as_ref().map_either(
247 |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
248 |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
249 )
250 }
251
252 fn is_resolved(&self) -> bool {
253 self.as_ref()
254 .either(TaskInput::is_resolved, TaskInput::is_resolved)
255 }
256
257 fn is_transient(&self) -> bool {
258 self.as_ref()
259 .either(TaskInput::is_transient, TaskInput::is_transient)
260 }
261}
262
263impl<K, V> TaskInput for BTreeMap<K, V>
264where
265 K: TaskInput + Ord,
266 V: TaskInput,
267{
268 async fn resolve_input(&self) -> Result<Self> {
269 let mut new_map = BTreeMap::new();
270 for (k, v) in self {
271 new_map.insert(
272 TaskInput::resolve_input(k).await?,
273 TaskInput::resolve_input(v).await?,
274 );
275 }
276 Ok(new_map)
277 }
278
279 fn is_resolved(&self) -> bool {
280 self.iter()
281 .all(|(k, v)| TaskInput::is_resolved(k) || TaskInput::is_resolved(v))
282 }
283
284 fn is_transient(&self) -> bool {
285 self.iter()
286 .any(|(k, v)| TaskInput::is_transient(k) || TaskInput::is_transient(v))
287 }
288}
289macro_rules! tuple_impls {
290 ( $( $name:ident )+ ) => {
291 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
292 where $($name: TaskInput),+
293 {
294 #[allow(non_snake_case)]
295 fn is_resolved(&self) -> bool {
296 let ($($name,)+) = self;
297 $($name.is_resolved() &&)+ true
298 }
299
300 #[allow(non_snake_case)]
301 fn is_transient(&self) -> bool {
302 let ($($name,)+) = self;
303 $($name.is_transient() ||)+ false
304 }
305
306 #[allow(non_snake_case)]
307 async fn resolve_input(&self) -> Result<Self> {
308 let ($($name,)+) = self;
309 Ok(($($name.resolve_input().await?,)+))
310 }
311 }
312 };
313}
314
315tuple_impls! { A }
317tuple_impls! { A B }
318tuple_impls! { A B C }
319tuple_impls! { A B C D }
320tuple_impls! { A B C D E }
321tuple_impls! { A B C D E F }
322tuple_impls! { A B C D E F G }
323tuple_impls! { A B C D E F G H }
324tuple_impls! { A B C D E F G H I }
325tuple_impls! { A B C D E F G H I J }
326tuple_impls! { A B C D E F G H I J K }
327tuple_impls! { A B C D E F G H I J K L }
328
329#[cfg(test)]
330mod tests {
331 use turbo_rcstr::rcstr;
332 use turbo_tasks_macros::TaskInput;
333
334 use super::*;
335 use crate as turbo_tasks;
338
339 fn assert_task_input<T>(_: T)
340 where
341 T: TaskInput,
342 {
343 }
344
345 #[test]
346 fn test_no_fields() -> Result<()> {
347 #[derive(
348 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
349 )]
350 struct NoFields;
351
352 assert_task_input(NoFields);
353 Ok(())
354 }
355
356 #[test]
357 fn test_one_unnamed_field() -> Result<()> {
358 #[derive(
359 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
360 )]
361 struct OneUnnamedField(u32);
362
363 assert_task_input(OneUnnamedField(42));
364 Ok(())
365 }
366
367 #[test]
368 fn test_multiple_unnamed_fields() -> Result<()> {
369 #[derive(
370 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
371 )]
372 struct MultipleUnnamedFields(u32, RcStr);
373
374 assert_task_input(MultipleUnnamedFields(42, rcstr!("42")));
375 Ok(())
376 }
377
378 #[test]
379 fn test_one_named_field() -> Result<()> {
380 #[derive(
381 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
382 )]
383 struct OneNamedField {
384 named: u32,
385 }
386
387 assert_task_input(OneNamedField { named: 42 });
388 Ok(())
389 }
390
391 #[test]
392 fn test_multiple_named_fields() -> Result<()> {
393 #[derive(
394 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
395 )]
396 struct MultipleNamedFields {
397 named: u32,
398 other: RcStr,
399 }
400
401 assert_task_input(MultipleNamedFields {
402 named: 42,
403 other: rcstr!("42"),
404 });
405 Ok(())
406 }
407
408 #[test]
409 fn test_generic_field() -> Result<()> {
410 #[derive(
411 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
412 )]
413 struct GenericField<T>(T);
414
415 assert_task_input(GenericField(42));
416 assert_task_input(GenericField(rcstr!("42")));
417 Ok(())
418 }
419
420 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
421 enum OneVariant {
422 Variant,
423 }
424
425 #[test]
426 fn test_one_variant() -> Result<()> {
427 assert_task_input(OneVariant::Variant);
428 Ok(())
429 }
430
431 #[test]
432 fn test_multiple_variants() -> Result<()> {
433 #[derive(
434 Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
435 )]
436 enum MultipleVariants {
437 Variant1,
438 Variant2,
439 }
440
441 assert_task_input(MultipleVariants::Variant2);
442 Ok(())
443 }
444
445 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
446 enum MultipleVariantsAndHeterogeneousFields {
447 Variant1,
448 Variant2(u32),
449 Variant3 { named: u32 },
450 Variant4(u32, RcStr),
451 Variant5 { named: u32, other: RcStr },
452 }
453
454 #[test]
455 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
456 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
457 named: 42,
458 other: rcstr!("42"),
459 });
460 Ok(())
461 }
462
463 #[test]
464 fn test_nested_variants() -> Result<()> {
465 #[derive(
466 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
467 )]
468 enum NestedVariants {
469 Variant1,
470 Variant2(MultipleVariantsAndHeterogeneousFields),
471 Variant3 { named: OneVariant },
472 Variant4(OneVariant, RcStr),
473 Variant5 { named: OneVariant, other: RcStr },
474 }
475
476 assert_task_input(NestedVariants::Variant5 {
477 named: OneVariant::Variant,
478 other: rcstr!("42"),
479 });
480 assert_task_input(NestedVariants::Variant2(
481 MultipleVariantsAndHeterogeneousFields::Variant5 {
482 named: 42,
483 other: rcstr!("42"),
484 },
485 ));
486 Ok(())
487 }
488}