1use std::{
2 any::TypeId,
3 cell::SyncUnsafeCell,
4 fmt::{self, Debug, Display, Formatter},
5 hash::Hash,
6};
7
8use bincode::{Decode, Encode};
9use tracing::Span;
10use turbo_bincode::{AnyDecodeFn, AnyEncodeFn};
11
12use crate::{
13 RawVc, SharedReference, TaskPriority, VcValueType,
14 dyn_task_inputs::any_as_encode,
15 id::TraitTypeId,
16 macro_helpers::{CollectableTraitMethods, NativeFunction},
17 registry::{RegistryType, get_trait_type_id, trait_type_count, turbo_registry},
18 task::shared_reference::TypedSharedReference,
19 vc::VcCellMode,
20};
21
22type RawCellFactoryFn = fn(TypedSharedReference) -> RawVc;
23type Vtable = &'static [&'static NativeFunction];
24
25pub enum ValueTypePersistence {
41 Persistable(AnyEncodeFn, AnyDecodeFn<SharedReference>),
44 Skip,
48 HashOnly,
53}
54
55pub enum Evictability {
61 Always,
68 Expensive,
72 Never,
78}
79
80pub struct ValueType {
84 pub ty: RegistryType,
85
86 pub persistence: ValueTypePersistence,
88
89 pub evictability: Evictability,
92
93 pub(crate) raw_cell: RawCellFactoryFn,
103
104 traits: SyncUnsafeCell<ValueTypeTraits>,
105}
106
107impl Debug for ValueType {
108 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
109 f.debug_struct("ValueType")
110 .field("name", &self.ty.name)
111 .finish()
112 }
113}
114
115impl Display for ValueType {
116 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
117 f.write_str(self.ty.name)
118 }
119}
120
121struct ValueTypeTraits {
122 traits: Option<Box<[Option<Vtable>]>>,
126}
127
128pub trait ManualEncodeWrapper: Encode {
129 type Value;
130
131 fn new<'a>(value: &'a Self::Value) -> impl Encode + 'a;
133}
134
135pub trait ManualDecodeWrapper: Decode<()> {
136 type Value;
137
138 fn inner(self) -> Self::Value;
139}
140
141impl ValueType {
142 pub const fn new<T: VcValueType>(
150 global_name: &'static str,
151 persistence: ValueTypePersistence,
152 evictability: Evictability,
153 ) -> Self {
154 Self::new_inner::<T>(global_name, persistence, evictability)
155 }
156
157 pub const fn persistable<T: VcValueType + Encode + Decode<()>>(
163 global_name: &'static str,
164 evictability: Evictability,
165 ) -> Self {
166 Self::new_inner::<T>(
167 global_name,
168 ValueTypePersistence::Persistable(
169 |this, enc| {
170 T::encode(any_as_encode::<T>(this), enc)?;
171 Ok(())
172 },
173 |dec| {
174 let val = T::decode(dec)?;
175 Ok(SharedReference::new(triomphe::Arc::new(val)))
176 },
177 ),
178 evictability,
179 )
180 }
181
182 pub const fn new_with_bincode_wrappers<
190 T: VcValueType,
191 E: ManualEncodeWrapper<Value = T>,
192 D: ManualDecodeWrapper<Value = T>,
193 >(
194 global_name: &'static str,
195 evictability: Evictability,
196 ) -> Self {
197 Self::new_inner::<T>(
198 global_name,
199 ValueTypePersistence::Persistable(
200 |this, enc| {
201 E::new(any_as_encode::<T>(this)).encode(enc)?;
202 Ok(())
203 },
204 |dec| {
205 let val = D::inner(D::decode(dec)?);
206 Ok(SharedReference::new(triomphe::Arc::new(val)))
207 },
208 ),
209 evictability,
210 )
211 }
212
213 const fn new_inner<T: VcValueType>(
215 global_name: &'static str,
216 persistence: ValueTypePersistence,
217 evictability: Evictability,
218 ) -> Self {
219 Self {
220 ty: RegistryType::new::<T>(std::any::type_name::<T>(), global_name),
221 persistence,
222 evictability,
223 raw_cell: <T::CellMode as VcCellMode<T>>::raw_cell,
224 traits: SyncUnsafeCell::new(ValueTypeTraits { traits: None }),
225 }
226 }
227
228 pub fn type_id(&self) -> TypeId {
230 self.ty.type_id
231 }
232
233 #[inline]
234 fn trait_info(&self) -> &ValueTypeTraits {
235 unsafe { &*self.traits.get() }
237 }
238
239 #[inline]
240 pub fn get_trait_method(
241 &self,
242 trait_method: &'static TraitMethod,
243 ) -> Option<&'static NativeFunction> {
244 let trait_type_id = trait_method.trait_type_id();
245 let vtable = self.trait_info().traits.as_ref()?[*trait_type_id as usize - 1]?;
246 Some(vtable[trait_method.index as usize])
247 }
248
249 fn register_trait(&self, trait_type: &'static TraitType, trait_methods: Vtable) {
250 let traits = unsafe { &mut *self.traits.get() };
252 let trait_type_id = get_trait_type_id(trait_type);
253 let array = traits
254 .traits
255 .get_or_insert_with(|| vec![None; trait_type_count()].into_boxed_slice());
256 array[*trait_type_id as usize - 1] = Some(trait_methods);
257 }
258
259 #[inline]
260 pub fn has_trait(&self, trait_type: &TraitTypeId) -> bool {
261 self.trait_info()
262 .traits
263 .as_ref()
264 .is_some_and(|t| t[**trait_type as usize - 1].is_some())
265 }
266}
267
268turbo_registry!("Value", ValueType);
269
270pub(crate) fn register_all_trait_methods(_: &[&'static ValueType]) {
273 for entry in inventory::iter::<CollectableTraitMethods> {
274 for (i, impl_method) in entry.methods.iter().enumerate() {
275 let trait_method = &entry.trait_type.methods[i];
276 if trait_method.is_root != impl_method.is_root {
277 let attr = if trait_method.is_root {
278 "the trait method has `#[turbo_tasks::function(root)]` but the impl does not"
279 } else {
280 "the impl has `#[turbo_tasks::function(root)]` but the trait method does not"
281 };
282 panic!(
283 "`root` attribute mismatch on `{}::{}` for `{}`: {}. The `root` attribute \
284 must match between trait and impl methods.",
285 trait_method.trait_name,
286 trait_method.method_name,
287 entry.value_type.ty.name,
288 attr,
289 );
290 }
291 }
292 entry
293 .value_type
294 .register_trait(entry.trait_type, entry.methods);
295 (entry.finalize_vtable_registry)();
297 }
298}
299
300pub struct TraitMethod {
301 pub trait_type: &'static TraitType,
302 pub index: u8,
303 pub trait_name: &'static str,
304 pub method_name: &'static str,
305 pub default_method: Option<&'static NativeFunction>,
306 pub is_root: bool,
309}
310impl Hash for TraitMethod {
311 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
312 (self as *const TraitMethod).hash(state);
313 }
314}
315
316impl Eq for TraitMethod {}
317
318impl PartialEq for TraitMethod {
319 fn eq(&self, other: &Self) -> bool {
320 std::ptr::eq(self, other)
321 }
322}
323impl Debug for TraitMethod {
324 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
325 f.debug_struct("TraitMethod")
326 .field("trait_name", &self.trait_name)
327 .field("name", &self.method_name)
328 .field("default_method", &self.default_method)
329 .finish()
330 }
331}
332impl TraitMethod {
333 #[inline]
336 fn trait_type_id(&self) -> TraitTypeId {
337 let raw = unsafe { std::ptr::read(self.trait_type.ty.id.get()) };
339 debug_assert!(raw != 0, "TraitMethod::trait_type_id not initialized");
340 unsafe { TraitTypeId::new_unchecked(raw) }
341 }
342
343 pub(crate) fn resolve_span(&self, priority: TaskPriority) -> Span {
344 tracing::trace_span!(
345 "turbo_tasks::resolve_trait_call",
346 name = format_args!("{}::{}", &self.trait_name, &self.method_name),
347 priority = %priority,
348 )
349 }
350}
351
352pub struct TraitType {
353 pub ty: RegistryType,
354 pub methods: &'static [TraitMethod],
355 pub default_methods: &'static [Option<&'static NativeFunction>],
356}
357
358impl Debug for TraitType {
359 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
360 let mut d = f.debug_struct("TraitType");
361 d.field("name", &self.ty.name);
362 for method in self.methods.iter() {
363 d.field(method.method_name, method);
364 }
365 d.finish()
366 }
367}
368
369impl Display for TraitType {
370 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
371 write!(f, "trait {}", self.ty.name)
372 }
373}
374
375impl TraitType {
376 pub const fn new<T: 'static>(
377 name: &'static str,
378 global_name: &'static str,
379 methods: &'static [TraitMethod],
380 default_methods: &'static [Option<&'static NativeFunction>],
381 ) -> Self {
382 Self {
383 ty: RegistryType::new::<T>(name, global_name),
384 methods,
385 default_methods,
386 }
387 }
388
389 #[cfg(test)]
390 pub fn get(&self, name: &str) -> &TraitMethod {
391 self.methods
392 .iter()
393 .find(|method| method.method_name == name)
394 .expect("Method not found!")
395 }
396}
397
398turbo_registry!("Trait", TraitType);
399
400pub trait TraitVtablePrototype {
401 const LEN: usize;
402 const DEFAULTS: &'static [Option<&'static NativeFunction>];
403}
404
405pub const fn index_of_method_name(methods: &'static [TraitMethod], name: &'static str) -> usize {
409 let mut i = 0;
410 'outer: while i < methods.len() {
411 let entry = methods[i].method_name;
412 if entry.len() == name.len() {
413 let mut j = 0;
414 while j < name.len() {
415 if entry.as_bytes()[j] != name.as_bytes()[j] {
416 i += 1;
417 continue 'outer;
418 }
419 j += 1;
420 }
421 return i;
422 }
423 i += 1;
424 }
425 panic!("Method not found!")
426}
427
428pub const fn build_trait_vtable<
429 B: TraitVtablePrototype + crate::registry::RegistryDef<TraitType>,
430 const LEN: usize,
431>(
432 overrides: &[(&'static str, &'static NativeFunction)],
433) -> [&'static NativeFunction; LEN] {
434 let mut methods = [&crate::native_function::VTABLE_DEFAULT; LEN];
435 let mut i = 0;
436 while i < LEN {
437 if let Some(default) = B::DEFAULTS[i] {
438 methods[i] = default;
439 }
440 i += 1;
441 }
442 let mut i = 0;
444 while i < overrides.len() {
445 let (name, f) = overrides[i];
446 methods[index_of_method_name(
447 <B as crate::registry::RegistryDef<TraitType>>::DEF.methods,
448 name,
449 )] = f;
450 i += 1;
451 }
452 methods
453}
454
455#[cfg(test)]
456mod tests {
457 use super::{Evictability, ValueTypePersistence};
463 use crate::{self as turbo_tasks, VcValueType, registry};
464
465 #[turbo_tasks::value(serialization = "skip")]
466 struct SkipValue(#[turbo_tasks(trace_ignore)] u32);
467
468 #[turbo_tasks::value(serialization = "hash")]
469 struct HashValue(u32);
470
471 #[turbo_tasks::value(serialization = "skip", evict = "last")]
472 struct SkipExpensiveValue(#[turbo_tasks(trace_ignore)] u32);
473
474 #[turbo_tasks::value(serialization = "skip", evict = "never", cell = "new", eq = "manual")]
475 struct SessionStatefulValue;
476
477 #[turbo_tasks::value]
478 struct PersistableValue(u32);
479
480 #[turbo_tasks::value(evict = "never")]
481 struct PersistableNeverValue(u32);
482
483 #[test]
484 fn skip_maps_to_skip_always() {
485 let vt = registry::get_value_type(SkipValue::get_value_type_id());
486 assert!(
487 matches!(vt.persistence, ValueTypePersistence::Skip),
488 "`serialization = \"skip\"` must map to ValueTypePersistence::Skip"
489 );
490 assert!(matches!(vt.evictability, Evictability::Always));
491 assert!(!SkipValue::has_serialization());
492 }
493
494 #[test]
495 fn hash_maps_to_hash_only_always() {
496 let vt = registry::get_value_type(HashValue::get_value_type_id());
497 assert!(
498 matches!(vt.persistence, ValueTypePersistence::HashOnly),
499 "`serialization = \"hash\"` must map to ValueTypePersistence::HashOnly"
500 );
501 assert!(matches!(vt.evictability, Evictability::Always));
502 assert!(!HashValue::has_serialization());
503 }
504
505 #[test]
506 fn skip_expensive_maps_to_skip_expensive() {
507 let vt = registry::get_value_type(SkipExpensiveValue::get_value_type_id());
508 assert!(matches!(vt.persistence, ValueTypePersistence::Skip));
509 assert!(
510 matches!(vt.evictability, Evictability::Expensive),
511 "`serialization = \"skip\", evict = \"last\"` must map to Evictability::Expensive"
512 );
513 assert!(!SkipExpensiveValue::has_serialization());
514 }
515
516 #[test]
517 fn session_stateful_maps_to_skip_never() {
518 let vt = registry::get_value_type(SessionStatefulValue::get_value_type_id());
519 assert!(matches!(vt.persistence, ValueTypePersistence::Skip));
520 assert!(
521 matches!(vt.evictability, Evictability::Never),
522 "`serialization = \"skip\", evict = \"never\"` must map to Evictability::Never"
523 );
524 assert!(!SessionStatefulValue::has_serialization());
525 }
526
527 #[test]
528 fn default_maps_to_persistable_always() {
529 let vt = registry::get_value_type(PersistableValue::get_value_type_id());
530 assert!(
531 matches!(vt.persistence, ValueTypePersistence::Persistable(_, _)),
532 "default (auto) serialization must map to ValueTypePersistence::Persistable"
533 );
534 assert!(matches!(vt.evictability, Evictability::Always));
535 assert!(PersistableValue::has_serialization());
536 }
537
538 #[test]
539 fn persistable_never_maps_to_persistable_never() {
540 let vt = registry::get_value_type(PersistableNeverValue::get_value_type_id());
541 assert!(
542 matches!(vt.persistence, ValueTypePersistence::Persistable(_, _)),
543 "`evict = \"never\"` (default serialization) must keep \
544 ValueTypePersistence::Persistable"
545 );
546 assert!(
547 matches!(vt.evictability, Evictability::Never),
548 "`evict = \"never\"` must map to Evictability::Never"
549 );
550 assert!(PersistableNeverValue::has_serialization());
551 }
552}