1use std::{
2 any::TypeId,
3 cell::SyncUnsafeCell,
4 fmt::{self, Debug, Display, Formatter},
5 hash::Hash,
6};
7
8use bincode::{Decode, Encode};
9use tracing::Span;
10use turbo_bincode::{AnyDecodeFn, AnyEncodeFn};
11
12use crate::{
13 RawVc, SharedReference, TaskPriority, VcValueType,
14 dyn_task_inputs::any_as_encode,
15 id::TraitTypeId,
16 macro_helpers::{NativeFunction, TRAIT_IMPLS_SLICE},
17 registry::{
18 RegistryType, get_trait_type_id, get_value_type_id_unchecked, impl_ptr_identity,
19 trait_type_count,
20 },
21 task::shared_reference::TypedSharedReference,
22 vc::VcCellMode,
23};
24
25type RawCellFactoryFn = fn(TypedSharedReference) -> RawVc;
26type Vtable = &'static [&'static NativeFunction];
27
28pub enum ValueTypePersistence {
44 Persistable(AnyEncodeFn, AnyDecodeFn<SharedReference>),
47 Skip,
51 HashOnly,
56}
57
58pub enum Evictability {
64 Always,
71 Expensive,
75 Never,
81}
82
83pub struct ValueType {
87 pub ty: RegistryType,
88
89 pub persistence: ValueTypePersistence,
91
92 pub evictability: Evictability,
95
96 pub(crate) raw_cell: RawCellFactoryFn,
106
107 traits: SyncUnsafeCell<ValueTypeTraits>,
108}
109impl_ptr_identity!(ValueType);
110
111impl Debug for ValueType {
112 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
113 f.debug_struct("ValueType")
114 .field("name", &self.ty.name)
115 .finish()
116 }
117}
118
119impl Display for ValueType {
120 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
121 f.write_str(self.ty.name)
122 }
123}
124
125struct ValueTypeTraits {
126 traits: Option<Box<[Option<Vtable>]>>,
130}
131
132pub trait ManualEncodeWrapper: Encode {
133 type Value;
134
135 fn new<'a>(value: &'a Self::Value) -> impl Encode + 'a;
137}
138
139pub trait ManualDecodeWrapper: Decode<()> {
140 type Value;
141
142 fn inner(self) -> Self::Value;
143}
144
145impl ValueType {
146 pub const fn new<T: VcValueType>(
154 global_name: &'static str,
155 persistence: ValueTypePersistence,
156 evictability: Evictability,
157 ) -> Self {
158 Self::new_inner::<T>(global_name, persistence, evictability)
159 }
160
161 pub const fn persistable<T: VcValueType + Encode + Decode<()>>(
167 global_name: &'static str,
168 evictability: Evictability,
169 ) -> Self {
170 Self::new_inner::<T>(
171 global_name,
172 ValueTypePersistence::Persistable(
173 |this, enc| {
174 T::encode(any_as_encode::<T>(this), enc)?;
175 Ok(())
176 },
177 |dec| {
178 let val = T::decode(dec)?;
179 Ok(SharedReference::new(triomphe::Arc::new(val)))
180 },
181 ),
182 evictability,
183 )
184 }
185
186 pub const fn new_with_bincode_wrappers<
194 T: VcValueType,
195 E: ManualEncodeWrapper<Value = T>,
196 D: ManualDecodeWrapper<Value = T>,
197 >(
198 global_name: &'static str,
199 evictability: Evictability,
200 ) -> Self {
201 Self::new_inner::<T>(
202 global_name,
203 ValueTypePersistence::Persistable(
204 |this, enc| {
205 E::new(any_as_encode::<T>(this)).encode(enc)?;
206 Ok(())
207 },
208 |dec| {
209 let val = D::inner(D::decode(dec)?);
210 Ok(SharedReference::new(triomphe::Arc::new(val)))
211 },
212 ),
213 evictability,
214 )
215 }
216
217 const fn new_inner<T: VcValueType>(
219 global_name: &'static str,
220 persistence: ValueTypePersistence,
221 evictability: Evictability,
222 ) -> Self {
223 Self {
224 ty: RegistryType::new::<T>(std::any::type_name::<T>(), global_name),
225 persistence,
226 evictability,
227 raw_cell: <T::CellMode as VcCellMode<T>>::raw_cell,
228 traits: SyncUnsafeCell::new(ValueTypeTraits { traits: None }),
229 }
230 }
231
232 pub fn type_id(&self) -> TypeId {
234 self.ty.type_id
235 }
236
237 #[inline]
238 fn trait_info(&self) -> &ValueTypeTraits {
239 unsafe { &*self.traits.get() }
241 }
242
243 #[inline]
244 pub fn get_trait_method(
245 &self,
246 trait_method: &'static TraitMethod,
247 ) -> Option<&'static NativeFunction> {
248 let trait_type_id = trait_method.trait_type_id();
249 let vtable = self.trait_info().traits.as_ref()?[*trait_type_id as usize - 1]?;
250 Some(vtable[trait_method.index as usize])
251 }
252
253 fn register_trait(&self, trait_type: &'static TraitType, trait_methods: Vtable) {
254 let traits = unsafe { &mut *self.traits.get() };
256 let trait_type_id = get_trait_type_id(trait_type);
257 let array = traits
258 .traits
259 .get_or_insert_with(|| vec![None; trait_type_count()].into_boxed_slice());
260 array[*trait_type_id as usize - 1] = Some(trait_methods);
261 }
262
263 #[inline]
264 pub fn has_trait(&self, trait_type: &TraitTypeId) -> bool {
265 self.trait_info()
266 .traits
267 .as_ref()
268 .is_some_and(|t| t[**trait_type as usize - 1].is_some())
269 }
270}
271
272pub(crate) fn register_all_trait_methods() {
275 for entry in TRAIT_IMPLS_SLICE.iter() {
276 for (i, impl_method) in entry.methods.iter().enumerate() {
277 let trait_method = &entry.trait_type.methods[i];
278 if trait_method.is_root != impl_method.is_root {
279 let attr = if trait_method.is_root {
280 "the trait method has `#[turbo_tasks::function(root)]` but the impl does not"
281 } else {
282 "the impl has `#[turbo_tasks::function(root)]` but the trait method does not"
283 };
284 panic!(
285 "`root` attribute mismatch on `{}::{}` for `{}`: {}. The `root` attribute \
286 must match between trait and impl methods.",
287 trait_method.trait_name,
288 trait_method.method_name,
289 entry.value_type.ty.name,
290 attr,
291 );
292 }
293 }
294 let value_type = entry.value_type;
295 value_type.register_trait(entry.trait_type, entry.methods);
296 let id = unsafe { get_value_type_id_unchecked(value_type) };
300 (entry.install_vtable)(id);
302 }
303}
304
305pub struct TraitMethod {
306 pub trait_type: &'static TraitType,
307 pub index: u8,
308 pub trait_name: &'static str,
309 pub method_name: &'static str,
310 pub default_method: Option<&'static NativeFunction>,
311 pub is_root: bool,
314}
315impl Hash for TraitMethod {
316 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
317 (self as *const TraitMethod).hash(state);
318 }
319}
320
321impl Eq for TraitMethod {}
322
323impl PartialEq for TraitMethod {
324 fn eq(&self, other: &Self) -> bool {
325 std::ptr::eq(self, other)
326 }
327}
328impl Debug for TraitMethod {
329 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
330 f.debug_struct("TraitMethod")
331 .field("trait_name", &self.trait_name)
332 .field("name", &self.method_name)
333 .field("default_method", &self.default_method)
334 .finish()
335 }
336}
337impl TraitMethod {
338 #[inline]
341 fn trait_type_id(&self) -> TraitTypeId {
342 let raw = unsafe { std::ptr::read(self.trait_type.ty.id.get()) };
344 debug_assert!(raw != 0, "TraitMethod::trait_type_id not initialized");
345 unsafe { TraitTypeId::new_unchecked(raw) }
346 }
347
348 pub(crate) fn resolve_span(&self, priority: TaskPriority) -> Span {
349 tracing::trace_span!(
350 "turbo_tasks::resolve_trait_call",
351 name = format_args!("{}::{}", &self.trait_name, &self.method_name),
352 priority = %priority,
353 )
354 }
355}
356
357pub struct TraitType {
358 pub ty: RegistryType,
359 pub methods: &'static [TraitMethod],
360 pub default_methods: &'static [Option<&'static NativeFunction>],
361}
362impl_ptr_identity!(TraitType);
363
364impl Debug for TraitType {
365 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
366 let mut d = f.debug_struct("TraitType");
367 d.field("name", &self.ty.name);
368 for method in self.methods.iter() {
369 d.field(method.method_name, method);
370 }
371 d.finish()
372 }
373}
374
375impl Display for TraitType {
376 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
377 write!(f, "trait {}", self.ty.name)
378 }
379}
380
381impl TraitType {
382 pub const fn new<T: 'static>(
383 name: &'static str,
384 global_name: &'static str,
385 methods: &'static [TraitMethod],
386 default_methods: &'static [Option<&'static NativeFunction>],
387 ) -> Self {
388 Self {
389 ty: RegistryType::new::<T>(name, global_name),
390 methods,
391 default_methods,
392 }
393 }
394
395 #[cfg(test)]
396 pub fn get(&self, name: &str) -> &TraitMethod {
397 self.methods
398 .iter()
399 .find(|method| method.method_name == name)
400 .expect("Method not found!")
401 }
402}
403
404pub trait TraitVtablePrototype {
405 const LEN: usize;
406 const DEFAULTS: &'static [Option<&'static NativeFunction>];
407}
408
409pub const fn index_of_method_name(methods: &'static [TraitMethod], name: &'static str) -> usize {
413 let mut i = 0;
414 'outer: while i < methods.len() {
415 let entry = methods[i].method_name;
416 if entry.len() == name.len() {
417 let mut j = 0;
418 while j < name.len() {
419 if entry.as_bytes()[j] != name.as_bytes()[j] {
420 i += 1;
421 continue 'outer;
422 }
423 j += 1;
424 }
425 return i;
426 }
427 i += 1;
428 }
429 panic!("Method not found!")
430}
431
432pub const fn build_trait_vtable<
433 B: TraitVtablePrototype + crate::registry::RegistryDef<TraitType>,
434 const LEN: usize,
435>(
436 overrides: &[(&'static str, &'static NativeFunction)],
437) -> [&'static NativeFunction; LEN] {
438 let mut methods = [&crate::native_function::VTABLE_DEFAULT; LEN];
439 let mut i = 0;
440 while i < LEN {
441 if let Some(default) = B::DEFAULTS[i] {
442 methods[i] = default;
443 }
444 i += 1;
445 }
446 let mut i = 0;
448 while i < overrides.len() {
449 let (name, f) = overrides[i];
450 methods[index_of_method_name(
451 <B as crate::registry::RegistryDef<TraitType>>::DEF.methods,
452 name,
453 )] = f;
454 i += 1;
455 }
456 methods
457}
458
459#[cfg(test)]
460mod tests {
461 use super::{Evictability, ValueTypePersistence};
467 use crate::{self as turbo_tasks, VcValueType, registry};
468
469 #[turbo_tasks::value(serialization = "skip")]
470 struct SkipValue(#[turbo_tasks(trace_ignore)] u32);
471
472 #[turbo_tasks::value(serialization = "hash")]
473 struct HashValue(u32);
474
475 #[turbo_tasks::value(serialization = "skip", evict = "last")]
476 struct SkipExpensiveValue(#[turbo_tasks(trace_ignore)] u32);
477
478 #[turbo_tasks::value(serialization = "skip", evict = "never", cell = "new", eq = "manual")]
479 struct SessionStatefulValue;
480
481 #[turbo_tasks::value]
482 struct PersistableValue(u32);
483
484 #[turbo_tasks::value(evict = "never")]
485 struct PersistableNeverValue(u32);
486
487 #[test]
488 fn skip_maps_to_skip_always() {
489 let vt = registry::get_value_type(SkipValue::get_value_type_id());
490 assert!(
491 matches!(vt.persistence, ValueTypePersistence::Skip),
492 "`serialization = \"skip\"` must map to ValueTypePersistence::Skip"
493 );
494 assert!(matches!(vt.evictability, Evictability::Always));
495 assert!(!SkipValue::has_serialization());
496 }
497
498 #[test]
499 fn hash_maps_to_hash_only_always() {
500 let vt = registry::get_value_type(HashValue::get_value_type_id());
501 assert!(
502 matches!(vt.persistence, ValueTypePersistence::HashOnly),
503 "`serialization = \"hash\"` must map to ValueTypePersistence::HashOnly"
504 );
505 assert!(matches!(vt.evictability, Evictability::Always));
506 assert!(!HashValue::has_serialization());
507 }
508
509 #[test]
510 fn skip_expensive_maps_to_skip_expensive() {
511 let vt = registry::get_value_type(SkipExpensiveValue::get_value_type_id());
512 assert!(matches!(vt.persistence, ValueTypePersistence::Skip));
513 assert!(
514 matches!(vt.evictability, Evictability::Expensive),
515 "`serialization = \"skip\", evict = \"last\"` must map to Evictability::Expensive"
516 );
517 assert!(!SkipExpensiveValue::has_serialization());
518 }
519
520 #[test]
521 fn session_stateful_maps_to_skip_never() {
522 let vt = registry::get_value_type(SessionStatefulValue::get_value_type_id());
523 assert!(matches!(vt.persistence, ValueTypePersistence::Skip));
524 assert!(
525 matches!(vt.evictability, Evictability::Never),
526 "`serialization = \"skip\", evict = \"never\"` must map to Evictability::Never"
527 );
528 assert!(!SessionStatefulValue::has_serialization());
529 }
530
531 #[test]
532 fn default_maps_to_persistable_always() {
533 let vt = registry::get_value_type(PersistableValue::get_value_type_id());
534 assert!(
535 matches!(vt.persistence, ValueTypePersistence::Persistable(_, _)),
536 "default (auto) serialization must map to ValueTypePersistence::Persistable"
537 );
538 assert!(matches!(vt.evictability, Evictability::Always));
539 assert!(PersistableValue::has_serialization());
540 }
541
542 #[test]
543 fn persistable_never_maps_to_persistable_never() {
544 let vt = registry::get_value_type(PersistableNeverValue::get_value_type_id());
545 assert!(
546 matches!(vt.persistence, ValueTypePersistence::Persistable(_, _)),
547 "`evict = \"never\"` (default serialization) must keep \
548 ValueTypePersistence::Persistable"
549 );
550 assert!(
551 matches!(vt.evictability, Evictability::Never),
552 "`evict = \"never\"` must map to Evictability::Never"
553 );
554 assert!(PersistableNeverValue::has_serialization());
555 }
556}