Skip to main content

turbo_tasks/
id.rs

1use std::{
2    fmt::{Debug, Display},
3    mem::transmute_copy,
4    num::{NonZero, NonZeroU64, TryFromIntError},
5    ops::Deref,
6};
7
8use bincode::{
9    Decode, Encode,
10    de::Decoder,
11    enc::Encoder,
12    error::{DecodeError, EncodeError},
13    impl_borrow_decode,
14};
15use serde::{Deserialize, Serialize, de::Visitor};
16
17use crate::{
18    TaskPersistence, registry,
19    trace::{TraceRawVcs, TraceRawVcsContext},
20};
21
22macro_rules! define_id {
23    (
24        $name:ident : $primitive:ty
25        $(,derive($($derive:ty),*))?
26        $(,serde($serde:tt))?
27        $(,doc = $doc:literal)*
28        $(,)?
29    ) => {
30        $(#[doc = $doc])*
31        #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
32        $(#[serde($serde)])?
33        pub struct $name {
34            id: NonZero<$primitive>,
35        }
36
37        impl $name {
38            pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
39            pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
40
41            /// Constructs a wrapper type from the numeric identifier.
42            ///
43            /// # Safety
44            ///
45            /// The passed `id` must not be zero.
46            pub const unsafe fn new_unchecked(id: $primitive) -> Self {
47                Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
48            }
49            /// Constructs a wrapper type from the numeric identifier.
50            ///
51            /// Returns `None` if the provided `id` is zero, otherwise returns
52            /// `Some(Self)` containing the wrapped non-zero identifier.
53            pub fn new(id: $primitive) -> Option<Self> {
54                NonZero::<$primitive>::new(id).map(|id| Self{id})
55            }
56            /// Allows `const` conversion to a [`NonZeroU64`], useful with
57            /// [`crate::id_factory::IdFactory::new_const`].
58            pub const fn to_non_zero_u64(self) -> NonZeroU64 {
59                const {
60                    assert!(<$primitive>::BITS <= u64::BITS);
61                }
62                unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
63            }
64        }
65
66        impl Display for $name {
67            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68                write!(f, concat!(stringify!($name), " {}"), self.id)
69            }
70        }
71
72        impl Deref for $name {
73            type Target = $primitive;
74
75            fn deref(&self) -> &Self::Target {
76                // SAFETY: `NonZero<T>` is guaranteed to have the same layout as `T`
77                unsafe { transmute_copy(&&self.id) }
78            }
79        }
80
81        define_id!(@impl_try_from_primitive_conversion $name $primitive);
82
83        impl From<NonZero<$primitive>> for $name {
84            fn from(id: NonZero::<$primitive>) -> Self {
85                Self {
86                    id,
87                }
88            }
89        }
90
91        impl From<$name> for NonZeroU64 {
92            fn from(id: $name) -> Self {
93                id.to_non_zero_u64()
94            }
95        }
96
97        impl TraceRawVcs for $name {
98            fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
99        }
100    };
101    (
102        @impl_try_from_primitive_conversion $name:ident u64
103    ) => {
104        // we get a `TryFrom` blanket impl for free via the `From` impl
105    };
106    (
107        @impl_try_from_primitive_conversion $name:ident $primitive:ty
108    ) => {
109        impl TryFrom<$primitive> for $name {
110            type Error = TryFromIntError;
111
112            fn try_from(id: $primitive) -> Result<Self, Self::Error> {
113                Ok(Self {
114                    id: NonZero::try_from(id)?
115                })
116            }
117        }
118
119        impl TryFrom<NonZeroU64> for $name {
120            type Error = TryFromIntError;
121
122            fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
123                Ok(Self { id: NonZero::try_from(id)? })
124            }
125        }
126    };
127}
128
129define_id!(TaskId: u32, derive(Serialize, Deserialize, Encode, Decode), serde(transparent));
130define_id!(FunctionId: u16);
131define_id!(ValueTypeId: u16);
132define_id!(TraitTypeId: u16);
133define_id!(
134    LocalTaskId: u32,
135    derive(Debug, Serialize, Deserialize, Encode, Decode),
136    serde(transparent),
137    doc = "Represents the nth `local` function call inside a task.",
138);
139define_id!(
140    ExecutionId: u16,
141    derive(Debug, Serialize, Deserialize, Encode, Decode),
142    serde(transparent),
143    doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
144        leak. This value may overflow and re-use old values.",
145);
146
147impl Debug for TaskId {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("TaskId").field("id", &self.id).finish()
150    }
151}
152
153unsafe impl crate::NonLocalValue for TaskId {}
154
155pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
156
157impl TaskId {
158    pub fn is_transient(&self) -> bool {
159        **self & TRANSIENT_TASK_BIT != 0
160    }
161    pub fn persistence(&self) -> TaskPersistence {
162        // tasks with `TaskPersistence::LocalCells` have no `TaskId`, so we can ignore that case
163        if self.is_transient() {
164            TaskPersistence::Transient
165        } else {
166            TaskPersistence::Persistent
167        }
168    }
169}
170
171macro_rules! make_registered_serializable {
172    ($ty:ty, $primitive:ty, $get_object:path, $validate_type_id:path $(,)?) => {
173        impl Serialize for $ty {
174            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175            where
176                S: serde::Serializer,
177            {
178                serializer.serialize_u16(self.id.into())
179            }
180        }
181
182        impl<'de> Deserialize<'de> for $ty {
183            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
184            where
185                D: serde::Deserializer<'de>,
186            {
187                struct DeserializeVisitor;
188                impl<'de> Visitor<'de> for DeserializeVisitor {
189                    type Value = $ty;
190
191                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
192                        formatter.write_str(concat!("an id of a registered ", stringify!($ty)))
193                    }
194
195                    fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
196                    where
197                        E: serde::de::Error,
198                    {
199                        match Self::Value::new(v) {
200                            Some(value) => {
201                                if let Some(error) = $validate_type_id(value) {
202                                    Err(E::custom(error))
203                                } else {
204                                    Ok(value)
205                                }
206                            }
207                            None => Err(E::unknown_variant(&format!("{v}"), &["a non zero u16"])),
208                        }
209                    }
210                }
211
212                deserializer.deserialize_u16(DeserializeVisitor)
213            }
214        }
215
216        impl Debug for $ty {
217            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218                f.debug_struct(stringify!($ty))
219                    .field("id", &self.id)
220                    .field("name", &$get_object(*self))
221                    .finish()
222            }
223        }
224
225        impl Encode for $ty {
226            fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
227                <NonZero<$primitive> as Encode>::encode(&self.id, encoder)
228            }
229        }
230
231        impl<Context> Decode<Context> for $ty {
232            fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
233                let value = Self {
234                    id: NonZero::<$primitive>::decode(decoder)?,
235                };
236                if let Some(error) = $validate_type_id(value) {
237                    Err(DecodeError::OtherString(error.to_string()))
238                } else {
239                    Ok(value)
240                }
241            }
242        }
243
244        impl_borrow_decode!($ty);
245    };
246}
247
248make_registered_serializable!(
249    ValueTypeId,
250    u16,
251    registry::get_value_type,
252    registry::validate_value_type_id,
253);
254make_registered_serializable!(
255    TraitTypeId,
256    u16,
257    registry::get_trait,
258    registry::validate_trait_type_id,
259);
260make_registered_serializable!(
261    FunctionId,
262    u16,
263    registry::get_native_function,
264    registry::validate_function_id,
265);