turbo_tasks/
id.rs

1use std::{
2    fmt::{Debug, Display},
3    mem::transmute_copy,
4    num::{NonZero, NonZeroU64, TryFromIntError},
5    ops::Deref,
6};
7
8use serde::{Deserialize, Serialize, de::Visitor};
9
10use crate::{
11    TaskPersistence, registry,
12    trace::{TraceRawVcs, TraceRawVcsContext},
13};
14
15macro_rules! define_id {
16    (
17        $name:ident : $primitive:ty
18        $(,derive($($derive:ty),*))?
19        $(,serde($serde:tt))?
20        $(,doc = $doc:literal)*
21        $(,)?
22    ) => {
23        $(#[doc = $doc])*
24        #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
25        $(#[serde($serde)])?
26        pub struct $name {
27            id: NonZero<$primitive>,
28        }
29
30        impl $name {
31            pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
32            pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
33
34            /// Constructs a wrapper type from the numeric identifier.
35            ///
36            /// # Safety
37            ///
38            /// The passed `id` must not be zero.
39            pub const unsafe fn new_unchecked(id: $primitive) -> Self {
40                Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
41            }
42            /// Constructs a wrapper type from the numeric identifier.
43            ///
44            /// Returns `None` if the provided `id` is zero, otherwise returns
45            /// `Some(Self)` containing the wrapped non-zero identifier.
46            pub fn new(id: $primitive) -> Option<Self> {
47                NonZero::<$primitive>::new(id).map(|id| Self{id})
48            }
49            /// Allows `const` conversion to a [`NonZeroU64`], useful with
50            /// [`crate::id_factory::IdFactory::new_const`].
51            pub const fn to_non_zero_u64(self) -> NonZeroU64 {
52                const {
53                    assert!(<$primitive>::BITS <= u64::BITS);
54                }
55                unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
56            }
57        }
58
59        impl Display for $name {
60            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61                write!(f, concat!(stringify!($name), " {}"), self.id)
62            }
63        }
64
65        impl Deref for $name {
66            type Target = $primitive;
67
68            fn deref(&self) -> &Self::Target {
69                unsafe { transmute_copy(&&self.id) }
70            }
71        }
72
73        define_id!(@impl_try_from_primitive_conversion $name $primitive);
74
75        impl From<NonZero<$primitive>> for $name {
76            fn from(id: NonZero::<$primitive>) -> Self {
77                Self {
78                    id,
79                }
80            }
81        }
82
83        impl From<$name> for NonZeroU64 {
84            fn from(id: $name) -> Self {
85                id.to_non_zero_u64()
86            }
87        }
88
89        impl TraceRawVcs for $name {
90            fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
91        }
92    };
93    (
94        @impl_try_from_primitive_conversion $name:ident u64
95    ) => {
96        // we get a `TryFrom` blanket impl for free via the `From` impl
97    };
98    (
99        @impl_try_from_primitive_conversion $name:ident $primitive:ty
100    ) => {
101        impl TryFrom<$primitive> for $name {
102            type Error = TryFromIntError;
103
104            fn try_from(id: $primitive) -> Result<Self, Self::Error> {
105                Ok(Self {
106                    id: NonZero::try_from(id)?
107                })
108            }
109        }
110
111        impl TryFrom<NonZeroU64> for $name {
112            type Error = TryFromIntError;
113
114            fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
115                Ok(Self { id: NonZero::try_from(id)? })
116            }
117        }
118    };
119}
120
121define_id!(TaskId: u32, derive(Serialize, Deserialize), serde(transparent));
122define_id!(FunctionId: u16);
123define_id!(ValueTypeId: u16);
124define_id!(TraitTypeId: u16);
125define_id!(SessionId: u32, derive(Debug, Serialize, Deserialize), serde(transparent));
126define_id!(
127    LocalTaskId: u32,
128    derive(Debug, Serialize, Deserialize),
129    serde(transparent),
130    doc = "Represents the nth `local` function call inside a task.",
131);
132define_id!(
133    ExecutionId: u16,
134    derive(Debug, Serialize, Deserialize),
135    serde(transparent),
136    doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
137        leak. This value may overflow and re-use old values.",
138);
139
140impl Debug for TaskId {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.debug_struct("TaskId").field("id", &self.id).finish()
143    }
144}
145
146pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
147
148impl TaskId {
149    pub fn is_transient(&self) -> bool {
150        **self & TRANSIENT_TASK_BIT != 0
151    }
152    pub fn persistence(&self) -> TaskPersistence {
153        // tasks with `TaskPersistence::LocalCells` have no `TaskId`, so we can ignore that case
154        if self.is_transient() {
155            TaskPersistence::Transient
156        } else {
157            TaskPersistence::Persistent
158        }
159    }
160}
161
162macro_rules! make_serializable {
163    ($ty:ty, $get_object:path, $validate_type_id:path, $visitor_name:ident) => {
164        impl Serialize for $ty {
165            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
166            where
167                S: serde::Serializer,
168            {
169                serializer.serialize_u16(self.id.into())
170            }
171        }
172
173        impl<'de> Deserialize<'de> for $ty {
174            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
175            where
176                D: serde::Deserializer<'de>,
177            {
178                deserializer.deserialize_u16($visitor_name)
179            }
180        }
181
182        struct $visitor_name;
183
184        impl<'de> Visitor<'de> for $visitor_name {
185            type Value = $ty;
186
187            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
188                formatter.write_str(concat!("an id of a registered ", stringify!($ty)))
189            }
190
191            fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
192            where
193                E: serde::de::Error,
194            {
195                match Self::Value::new(v) {
196                    Some(value) => {
197                        if let Some(error) = $validate_type_id(value) {
198                            Err(E::custom(error))
199                        } else {
200                            Ok(value)
201                        }
202                    }
203                    None => Err(E::unknown_variant(&format!("{v}"), &["a non zero u16"])),
204                }
205            }
206        }
207
208        impl Debug for $ty {
209            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210                f.debug_struct(stringify!($ty))
211                    .field("id", &self.id)
212                    .field("name", &$get_object(*self))
213                    .finish()
214            }
215        }
216    };
217}
218
219make_serializable!(
220    ValueTypeId,
221    registry::get_value_type,
222    registry::validate_value_type_id,
223    ValueTypeVisitor
224);
225make_serializable!(
226    TraitTypeId,
227    registry::get_trait,
228    registry::validate_trait_type_id,
229    TraitTypeVisitor
230);
231make_serializable!(
232    FunctionId,
233    registry::get_native_function,
234    registry::validate_function_id,
235    FunctionTypeVisitor
236);