1use std::{
2 fmt::{Debug, Display},
3 mem::transmute_copy,
4 num::{NonZero, NonZeroU64, TryFromIntError},
5 ops::Deref,
6};
7
8use serde::{Deserialize, Serialize, de::Visitor};
9
10use crate::{
11 TaskPersistence, registry,
12 trace::{TraceRawVcs, TraceRawVcsContext},
13};
14
15macro_rules! define_id {
16 (
17 $name:ident : $primitive:ty
18 $(,derive($($derive:ty),*))?
19 $(,serde($serde:tt))?
20 $(,doc = $doc:literal)*
21 $(,)?
22 ) => {
23 $(#[doc = $doc])*
24 #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
25 $(#[serde($serde)])?
26 pub struct $name {
27 id: NonZero<$primitive>,
28 }
29
30 impl $name {
31 pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
32 pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
33
34 pub const unsafe fn new_unchecked(id: $primitive) -> Self {
40 Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
41 }
42
43 pub const fn to_non_zero_u64(self) -> NonZeroU64 {
46 const {
47 assert!(<$primitive>::BITS <= u64::BITS);
48 }
49 unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
50 }
51 }
52
53 impl Display for $name {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, concat!(stringify!($name), " {}"), self.id)
56 }
57 }
58
59 impl Deref for $name {
60 type Target = $primitive;
61
62 fn deref(&self) -> &Self::Target {
63 unsafe { transmute_copy(&&self.id) }
64 }
65 }
66
67 define_id!(@impl_try_from_primitive_conversion $name $primitive);
68
69 impl From<NonZero<$primitive>> for $name {
70 fn from(id: NonZero::<$primitive>) -> Self {
71 Self {
72 id,
73 }
74 }
75 }
76
77 impl From<$name> for NonZeroU64 {
78 fn from(id: $name) -> Self {
79 id.to_non_zero_u64()
80 }
81 }
82
83 impl TraceRawVcs for $name {
84 fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
85 }
86 };
87 (
88 @impl_try_from_primitive_conversion $name:ident u64
89 ) => {
90 };
92 (
93 @impl_try_from_primitive_conversion $name:ident $primitive:ty
94 ) => {
95 impl TryFrom<$primitive> for $name {
96 type Error = TryFromIntError;
97
98 fn try_from(id: $primitive) -> Result<Self, Self::Error> {
99 Ok(Self {
100 id: NonZero::try_from(id)?
101 })
102 }
103 }
104
105 impl TryFrom<NonZeroU64> for $name {
106 type Error = TryFromIntError;
107
108 fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
109 Ok(Self { id: NonZero::try_from(id)? })
110 }
111 }
112 };
113}
114
115define_id!(TaskId: u32, derive(Serialize, Deserialize), serde(transparent));
116define_id!(FunctionId: u32);
117define_id!(ValueTypeId: u32);
118define_id!(TraitTypeId: u32);
119define_id!(BackendJobId: u32);
120define_id!(SessionId: u32, derive(Debug, Serialize, Deserialize), serde(transparent));
121define_id!(
122 LocalTaskId: u32,
123 derive(Debug, Serialize, Deserialize),
124 serde(transparent),
125 doc = "Represents the nth `local` function call inside a task.",
126);
127define_id!(
128 ExecutionId: u16,
129 derive(Debug, Serialize, Deserialize),
130 serde(transparent),
131 doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
132 leak. This value may overflow and re-use old values.",
133);
134
135impl Debug for TaskId {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("TaskId").field("id", &self.id).finish()
138 }
139}
140
141pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
142
143impl TaskId {
144 pub fn is_transient(&self) -> bool {
145 **self & TRANSIENT_TASK_BIT != 0
146 }
147 pub fn persistence(&self) -> TaskPersistence {
148 if self.is_transient() {
150 TaskPersistence::Transient
151 } else {
152 TaskPersistence::Persistent
153 }
154 }
155}
156
157macro_rules! make_serializable {
158 ($ty:ty, $get_global_name:path, $get_id:path, $visitor_name:ident) => {
159 impl Serialize for $ty {
160 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161 where
162 S: serde::Serializer,
163 {
164 serializer.serialize_str($get_global_name(*self))
165 }
166 }
167
168 impl<'de> Deserialize<'de> for $ty {
169 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170 where
171 D: serde::Deserializer<'de>,
172 {
173 deserializer.deserialize_str($visitor_name)
174 }
175 }
176
177 struct $visitor_name;
178
179 impl<'de> Visitor<'de> for $visitor_name {
180 type Value = $ty;
181
182 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
183 formatter.write_str(concat!("a name of a registered ", stringify!($ty)))
184 }
185
186 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
187 where
188 E: serde::de::Error,
189 {
190 $get_id(v).ok_or_else(|| E::unknown_variant(v, &[]))
191 }
192 }
193
194 impl Debug for $ty {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct(stringify!($ty))
197 .field("id", &self.id)
198 .field("name", &$get_global_name(*self))
199 .finish()
200 }
201 }
202 };
203}
204
205make_serializable!(
206 FunctionId,
207 registry::get_function_global_name,
208 registry::get_function_id_by_global_name,
209 FunctionIdVisitor
210);
211make_serializable!(
212 ValueTypeId,
213 registry::get_value_type_global_name,
214 registry::get_value_type_id_by_global_name,
215 ValueTypeVisitor
216);
217make_serializable!(
218 TraitTypeId,
219 registry::get_trait_type_global_name,
220 registry::get_trait_type_id_by_global_name,
221 TraitTypeVisitor
222);