1use std::{any::Any, fmt::Debug, future::Future, hash::Hash, sync::Arc, time::Duration};
2
3use anyhow::Result;
4use either::Either;
5use serde::{Deserialize, Serialize};
6use turbo_rcstr::RcStr;
7
8use crate::{
9 MagicAny, ResolvedVc, TaskId, TransientInstance, TransientValue, Value, ValueTypeId, Vc,
10 trace::TraceRawVcs,
11};
12
13pub trait TaskInput: Send + Sync + Clone + Debug + PartialEq + Eq + Hash + TraceRawVcs {
16 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
17 async { Ok(self.clone()) }
18 }
19 fn is_resolved(&self) -> bool {
20 true
21 }
22 fn is_transient(&self) -> bool;
23}
24
25macro_rules! impl_task_input {
26 ($($t:ty),*) => {
27 $(
28 impl TaskInput for $t {
29 fn is_transient(&self) -> bool {
30 false
31 }
32 }
33 )*
34 };
35}
36
37impl_task_input! {
38 (),
39 bool,
40 u8,
41 u16,
42 u32,
43 i32,
44 u64,
45 usize,
46 RcStr,
47 TaskId,
48 ValueTypeId,
49 Duration
50}
51
52impl<T> TaskInput for Vec<T>
53where
54 T: TaskInput,
55{
56 fn is_resolved(&self) -> bool {
57 self.iter().all(TaskInput::is_resolved)
58 }
59
60 fn is_transient(&self) -> bool {
61 self.iter().any(TaskInput::is_transient)
62 }
63
64 async fn resolve_input(&self) -> Result<Self> {
65 let mut resolved = Vec::with_capacity(self.len());
66 for value in self {
67 resolved.push(value.resolve_input().await?);
68 }
69 Ok(resolved)
70 }
71}
72
73impl<T> TaskInput for Box<T>
74where
75 T: TaskInput,
76{
77 fn is_resolved(&self) -> bool {
78 self.as_ref().is_resolved()
79 }
80
81 fn is_transient(&self) -> bool {
82 self.as_ref().is_transient()
83 }
84
85 async fn resolve_input(&self) -> Result<Self> {
86 Ok(Box::new(Box::pin(self.as_ref().resolve_input()).await?))
87 }
88}
89
90impl<T> TaskInput for Arc<T>
91where
92 T: TaskInput,
93{
94 fn is_resolved(&self) -> bool {
95 self.as_ref().is_resolved()
96 }
97
98 fn is_transient(&self) -> bool {
99 self.as_ref().is_transient()
100 }
101
102 async fn resolve_input(&self) -> Result<Self> {
103 Ok(Arc::new(Box::pin(self.as_ref().resolve_input()).await?))
104 }
105}
106
107impl<T> TaskInput for Option<T>
108where
109 T: TaskInput,
110{
111 fn is_resolved(&self) -> bool {
112 match self {
113 Some(value) => value.is_resolved(),
114 None => true,
115 }
116 }
117
118 fn is_transient(&self) -> bool {
119 match self {
120 Some(value) => value.is_transient(),
121 None => false,
122 }
123 }
124
125 async fn resolve_input(&self) -> Result<Self> {
126 match self {
127 Some(value) => Ok(Some(value.resolve_input().await?)),
128 None => Ok(None),
129 }
130 }
131}
132
133impl<T> TaskInput for Vc<T>
134where
135 T: Send + Sync + ?Sized,
136{
137 fn is_resolved(&self) -> bool {
138 Vc::is_resolved(*self)
139 }
140
141 fn is_transient(&self) -> bool {
142 self.node.is_transient()
143 }
144
145 async fn resolve_input(&self) -> Result<Self> {
146 Vc::resolve(*self).await
147 }
148}
149
150impl<T> TaskInput for ResolvedVc<T>
153where
154 T: Send + Sync + ?Sized,
155{
156 fn is_resolved(&self) -> bool {
157 true
158 }
159
160 fn is_transient(&self) -> bool {
161 self.node.is_transient()
162 }
163
164 async fn resolve_input(&self) -> Result<Self> {
165 Ok(*self)
166 }
167}
168
169impl<T> TaskInput for Value<T>
170where
171 T: Any
172 + std::fmt::Debug
173 + Clone
174 + std::hash::Hash
175 + Eq
176 + Send
177 + Sync
178 + Serialize
179 + for<'de> Deserialize<'de>
180 + TraceRawVcs
181 + 'static,
182{
183 fn is_resolved(&self) -> bool {
184 true
185 }
186
187 fn is_transient(&self) -> bool {
188 false
189 }
190}
191
192impl<T> TaskInput for TransientValue<T>
193where
194 T: MagicAny + Clone + Debug + Hash + Eq + TraceRawVcs + 'static,
195{
196 fn is_transient(&self) -> bool {
197 true
198 }
199}
200
201impl<T> Serialize for TransientValue<T>
202where
203 T: MagicAny + Clone + 'static,
204{
205 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: serde::Serializer,
208 {
209 Err(serde::ser::Error::custom(
210 "cannot serialize transient task inputs",
211 ))
212 }
213}
214
215impl<'de, T> Deserialize<'de> for TransientValue<T>
216where
217 T: MagicAny + Clone + 'static,
218{
219 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
220 where
221 D: serde::Deserializer<'de>,
222 {
223 Err(serde::de::Error::custom(
224 "cannot deserialize transient task inputs",
225 ))
226 }
227}
228
229impl<T> TaskInput for TransientInstance<T>
230where
231 T: Sync + Send + TraceRawVcs + 'static,
232{
233 fn is_transient(&self) -> bool {
234 true
235 }
236}
237
238impl<T> Serialize for TransientInstance<T> {
239 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
240 where
241 S: serde::Serializer,
242 {
243 Err(serde::ser::Error::custom(
244 "cannot serialize transient task inputs",
245 ))
246 }
247}
248
249impl<'de, T> Deserialize<'de> for TransientInstance<T> {
250 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 Err(serde::de::Error::custom(
255 "cannot deserialize transient task inputs",
256 ))
257 }
258}
259
260impl<L, R> TaskInput for Either<L, R>
261where
262 L: TaskInput,
263 R: TaskInput,
264{
265 fn resolve_input(&self) -> impl Future<Output = Result<Self>> + Send + '_ {
266 self.as_ref().map_either(
267 |l| async move { anyhow::Ok(Either::Left(l.resolve_input().await?)) },
268 |r| async move { anyhow::Ok(Either::Right(r.resolve_input().await?)) },
269 )
270 }
271
272 fn is_resolved(&self) -> bool {
273 self.as_ref()
274 .either(TaskInput::is_resolved, TaskInput::is_resolved)
275 }
276
277 fn is_transient(&self) -> bool {
278 self.as_ref()
279 .either(TaskInput::is_transient, TaskInput::is_transient)
280 }
281}
282
283macro_rules! tuple_impls {
284 ( $( $name:ident )+ ) => {
285 impl<$($name: TaskInput),+> TaskInput for ($($name,)+)
286 where $($name: TaskInput),+
287 {
288 #[allow(non_snake_case)]
289 fn is_resolved(&self) -> bool {
290 let ($($name,)+) = self;
291 $($name.is_resolved() &&)+ true
292 }
293
294 #[allow(non_snake_case)]
295 fn is_transient(&self) -> bool {
296 let ($($name,)+) = self;
297 $($name.is_transient() ||)+ false
298 }
299
300 #[allow(non_snake_case)]
301 async fn resolve_input(&self) -> Result<Self> {
302 let ($($name,)+) = self;
303 Ok(($($name.resolve_input().await?,)+))
304 }
305 }
306 };
307}
308
309tuple_impls! { A }
311tuple_impls! { A B }
312tuple_impls! { A B C }
313tuple_impls! { A B C D }
314tuple_impls! { A B C D E }
315tuple_impls! { A B C D E F }
316tuple_impls! { A B C D E F G }
317tuple_impls! { A B C D E F G H }
318tuple_impls! { A B C D E F G H I }
319tuple_impls! { A B C D E F G H I J }
320tuple_impls! { A B C D E F G H I J K }
321tuple_impls! { A B C D E F G H I J K L }
322
323#[cfg(test)]
324mod tests {
325 use turbo_tasks_macros::TaskInput;
326
327 use super::*;
328 use crate as turbo_tasks;
331
332 fn assert_task_input<T>(_: T)
333 where
334 T: TaskInput,
335 {
336 }
337
338 #[test]
339 fn test_no_fields() -> Result<()> {
340 #[derive(
341 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
342 )]
343 struct NoFields;
344
345 assert_task_input(NoFields);
346 Ok(())
347 }
348
349 #[test]
350 fn test_one_unnamed_field() -> Result<()> {
351 #[derive(
352 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
353 )]
354 struct OneUnnamedField(u32);
355
356 assert_task_input(OneUnnamedField(42));
357 Ok(())
358 }
359
360 #[test]
361 fn test_multiple_unnamed_fields() -> Result<()> {
362 #[derive(
363 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
364 )]
365 struct MultipleUnnamedFields(u32, RcStr);
366
367 assert_task_input(MultipleUnnamedFields(42, "42".into()));
368 Ok(())
369 }
370
371 #[test]
372 fn test_one_named_field() -> Result<()> {
373 #[derive(
374 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
375 )]
376 struct OneNamedField {
377 named: u32,
378 }
379
380 assert_task_input(OneNamedField { named: 42 });
381 Ok(())
382 }
383
384 #[test]
385 fn test_multiple_named_fields() -> Result<()> {
386 #[derive(
387 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
388 )]
389 struct MultipleNamedFields {
390 named: u32,
391 other: RcStr,
392 }
393
394 assert_task_input(MultipleNamedFields {
395 named: 42,
396 other: "42".into(),
397 });
398 Ok(())
399 }
400
401 #[test]
402 fn test_generic_field() -> Result<()> {
403 #[derive(
404 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
405 )]
406 struct GenericField<T>(T);
407
408 assert_task_input(GenericField(42));
409 assert_task_input(GenericField(RcStr::from("42")));
410 Ok(())
411 }
412
413 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
414 enum OneVariant {
415 Variant,
416 }
417
418 #[test]
419 fn test_one_variant() -> Result<()> {
420 assert_task_input(OneVariant::Variant);
421 Ok(())
422 }
423
424 #[test]
425 fn test_multiple_variants() -> Result<()> {
426 #[derive(
427 Clone, TaskInput, PartialEq, Eq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
428 )]
429 enum MultipleVariants {
430 Variant1,
431 Variant2,
432 }
433
434 assert_task_input(MultipleVariants::Variant2);
435 Ok(())
436 }
437
438 #[derive(Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs)]
439 enum MultipleVariantsAndHeterogeneousFields {
440 Variant1,
441 Variant2(u32),
442 Variant3 { named: u32 },
443 Variant4(u32, RcStr),
444 Variant5 { named: u32, other: RcStr },
445 }
446
447 #[test]
448 fn test_multiple_variants_and_heterogeneous_fields() -> Result<()> {
449 assert_task_input(MultipleVariantsAndHeterogeneousFields::Variant5 {
450 named: 42,
451 other: "42".into(),
452 });
453 Ok(())
454 }
455
456 #[test]
457 fn test_nested_variants() -> Result<()> {
458 #[derive(
459 Clone, TaskInput, Eq, PartialEq, Hash, Debug, Serialize, Deserialize, TraceRawVcs,
460 )]
461 enum NestedVariants {
462 Variant1,
463 Variant2(MultipleVariantsAndHeterogeneousFields),
464 Variant3 { named: OneVariant },
465 Variant4(OneVariant, RcStr),
466 Variant5 { named: OneVariant, other: RcStr },
467 }
468
469 assert_task_input(NestedVariants::Variant5 {
470 named: OneVariant::Variant,
471 other: "42".into(),
472 });
473 assert_task_input(NestedVariants::Variant2(
474 MultipleVariantsAndHeterogeneousFields::Variant5 {
475 named: 42,
476 other: "42".into(),
477 },
478 ));
479 Ok(())
480 }
481}