turbo_tasks/
id.rs

1use std::{
2    fmt::{Debug, Display},
3    mem::transmute_copy,
4    num::{NonZero, NonZeroU64, TryFromIntError},
5    ops::Deref,
6};
7
8use serde::{Deserialize, Serialize, de::Visitor};
9
10use crate::{
11    TaskPersistence, registry,
12    trace::{TraceRawVcs, TraceRawVcsContext},
13};
14
15macro_rules! define_id {
16    (
17        $name:ident : $primitive:ty
18        $(,derive($($derive:ty),*))?
19        $(,serde($serde:tt))?
20        $(,doc = $doc:literal)*
21        $(,)?
22    ) => {
23        $(#[doc = $doc])*
24        #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
25        $(#[serde($serde)])?
26        pub struct $name {
27            id: NonZero<$primitive>,
28        }
29
30        impl $name {
31            pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
32            pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
33
34            /// Constructs a wrapper type from the numeric identifier.
35            ///
36            /// # Safety
37            ///
38            /// The passed `id` must not be zero.
39            pub const unsafe fn new_unchecked(id: $primitive) -> Self {
40                Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
41            }
42
43            /// Allows `const` conversion to a [`NonZeroU64`], useful with
44            /// [`crate::id_factory::IdFactory::new_const`].
45            pub const fn to_non_zero_u64(self) -> NonZeroU64 {
46                const {
47                    assert!(<$primitive>::BITS <= u64::BITS);
48                }
49                unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
50            }
51        }
52
53        impl Display for $name {
54            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55                write!(f, concat!(stringify!($name), " {}"), self.id)
56            }
57        }
58
59        impl Deref for $name {
60            type Target = $primitive;
61
62            fn deref(&self) -> &Self::Target {
63                unsafe { transmute_copy(&&self.id) }
64            }
65        }
66
67        define_id!(@impl_try_from_primitive_conversion $name $primitive);
68
69        impl From<NonZero<$primitive>> for $name {
70            fn from(id: NonZero::<$primitive>) -> Self {
71                Self {
72                    id,
73                }
74            }
75        }
76
77        impl From<$name> for NonZeroU64 {
78            fn from(id: $name) -> Self {
79                id.to_non_zero_u64()
80            }
81        }
82
83        impl TraceRawVcs for $name {
84            fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
85        }
86    };
87    (
88        @impl_try_from_primitive_conversion $name:ident u64
89    ) => {
90        // we get a `TryFrom` blanket impl for free via the `From` impl
91    };
92    (
93        @impl_try_from_primitive_conversion $name:ident $primitive:ty
94    ) => {
95        impl TryFrom<$primitive> for $name {
96            type Error = TryFromIntError;
97
98            fn try_from(id: $primitive) -> Result<Self, Self::Error> {
99                Ok(Self {
100                    id: NonZero::try_from(id)?
101                })
102            }
103        }
104
105        impl TryFrom<NonZeroU64> for $name {
106            type Error = TryFromIntError;
107
108            fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
109                Ok(Self { id: NonZero::try_from(id)? })
110            }
111        }
112    };
113}
114
115define_id!(TaskId: u32, derive(Serialize, Deserialize), serde(transparent));
116define_id!(FunctionId: u32);
117define_id!(ValueTypeId: u32);
118define_id!(TraitTypeId: u32);
119define_id!(BackendJobId: u32);
120define_id!(SessionId: u32, derive(Debug, Serialize, Deserialize), serde(transparent));
121define_id!(
122    LocalTaskId: u32,
123    derive(Debug, Serialize, Deserialize),
124    serde(transparent),
125    doc = "Represents the nth `local` function call inside a task.",
126);
127define_id!(
128    ExecutionId: u16,
129    derive(Debug, Serialize, Deserialize),
130    serde(transparent),
131    doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
132        leak. This value may overflow and re-use old values.",
133);
134
135impl Debug for TaskId {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        f.debug_struct("TaskId").field("id", &self.id).finish()
138    }
139}
140
141pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
142
143impl TaskId {
144    pub fn is_transient(&self) -> bool {
145        **self & TRANSIENT_TASK_BIT != 0
146    }
147    pub fn persistence(&self) -> TaskPersistence {
148        // tasks with `TaskPersistence::LocalCells` have no `TaskId`, so we can ignore that case
149        if self.is_transient() {
150            TaskPersistence::Transient
151        } else {
152            TaskPersistence::Persistent
153        }
154    }
155}
156
157macro_rules! make_serializable {
158    ($ty:ty, $get_global_name:path, $get_id:path, $visitor_name:ident) => {
159        impl Serialize for $ty {
160            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161            where
162                S: serde::Serializer,
163            {
164                serializer.serialize_str($get_global_name(*self))
165            }
166        }
167
168        impl<'de> Deserialize<'de> for $ty {
169            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170            where
171                D: serde::Deserializer<'de>,
172            {
173                deserializer.deserialize_str($visitor_name)
174            }
175        }
176
177        struct $visitor_name;
178
179        impl<'de> Visitor<'de> for $visitor_name {
180            type Value = $ty;
181
182            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
183                formatter.write_str(concat!("a name of a registered ", stringify!($ty)))
184            }
185
186            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
187            where
188                E: serde::de::Error,
189            {
190                $get_id(v).ok_or_else(|| E::unknown_variant(v, &[]))
191            }
192        }
193
194        impl Debug for $ty {
195            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196                f.debug_struct(stringify!($ty))
197                    .field("id", &self.id)
198                    .field("name", &$get_global_name(*self))
199                    .finish()
200            }
201        }
202    };
203}
204
205make_serializable!(
206    FunctionId,
207    registry::get_function_global_name,
208    registry::get_function_id_by_global_name,
209    FunctionIdVisitor
210);
211make_serializable!(
212    ValueTypeId,
213    registry::get_value_type_global_name,
214    registry::get_value_type_id_by_global_name,
215    ValueTypeVisitor
216);
217make_serializable!(
218    TraitTypeId,
219    registry::get_trait_type_global_name,
220    registry::get_trait_type_id_by_global_name,
221    TraitTypeVisitor
222);