1use std::{
2 fmt::{Debug, Display},
3 mem::transmute_copy,
4 num::{NonZero, NonZeroU64, TryFromIntError},
5 ops::Deref,
6};
7
8use serde::{Deserialize, Serialize, de::Visitor};
9
10use crate::{
11 TaskPersistence, registry,
12 trace::{TraceRawVcs, TraceRawVcsContext},
13};
14
15macro_rules! define_id {
16 (
17 $name:ident : $primitive:ty
18 $(,derive($($derive:ty),*))?
19 $(,serde($serde:tt))?
20 $(,doc = $doc:literal)*
21 $(,)?
22 ) => {
23 $(#[doc = $doc])*
24 #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
25 $(#[serde($serde)])?
26 pub struct $name {
27 id: NonZero<$primitive>,
28 }
29
30 impl $name {
31 pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
32 pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
33
34 pub const unsafe fn new_unchecked(id: $primitive) -> Self {
40 Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
41 }
42 pub fn new(id: $primitive) -> Option<Self> {
47 NonZero::<$primitive>::new(id).map(|id| Self{id})
48 }
49 pub const fn to_non_zero_u64(self) -> NonZeroU64 {
52 const {
53 assert!(<$primitive>::BITS <= u64::BITS);
54 }
55 unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
56 }
57 }
58
59 impl Display for $name {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, concat!(stringify!($name), " {}"), self.id)
62 }
63 }
64
65 impl Deref for $name {
66 type Target = $primitive;
67
68 fn deref(&self) -> &Self::Target {
69 unsafe { transmute_copy(&&self.id) }
70 }
71 }
72
73 define_id!(@impl_try_from_primitive_conversion $name $primitive);
74
75 impl From<NonZero<$primitive>> for $name {
76 fn from(id: NonZero::<$primitive>) -> Self {
77 Self {
78 id,
79 }
80 }
81 }
82
83 impl From<$name> for NonZeroU64 {
84 fn from(id: $name) -> Self {
85 id.to_non_zero_u64()
86 }
87 }
88
89 impl TraceRawVcs for $name {
90 fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
91 }
92 };
93 (
94 @impl_try_from_primitive_conversion $name:ident u64
95 ) => {
96 };
98 (
99 @impl_try_from_primitive_conversion $name:ident $primitive:ty
100 ) => {
101 impl TryFrom<$primitive> for $name {
102 type Error = TryFromIntError;
103
104 fn try_from(id: $primitive) -> Result<Self, Self::Error> {
105 Ok(Self {
106 id: NonZero::try_from(id)?
107 })
108 }
109 }
110
111 impl TryFrom<NonZeroU64> for $name {
112 type Error = TryFromIntError;
113
114 fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
115 Ok(Self { id: NonZero::try_from(id)? })
116 }
117 }
118 };
119}
120
121define_id!(TaskId: u32, derive(Serialize, Deserialize), serde(transparent));
122define_id!(FunctionId: u16);
123define_id!(ValueTypeId: u16);
124define_id!(TraitTypeId: u16);
125define_id!(SessionId: u32, derive(Debug, Serialize, Deserialize), serde(transparent));
126define_id!(
127 LocalTaskId: u32,
128 derive(Debug, Serialize, Deserialize),
129 serde(transparent),
130 doc = "Represents the nth `local` function call inside a task.",
131);
132define_id!(
133 ExecutionId: u16,
134 derive(Debug, Serialize, Deserialize),
135 serde(transparent),
136 doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
137 leak. This value may overflow and re-use old values.",
138);
139
140impl Debug for TaskId {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 f.debug_struct("TaskId").field("id", &self.id).finish()
143 }
144}
145
146pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
147
148impl TaskId {
149 pub fn is_transient(&self) -> bool {
150 **self & TRANSIENT_TASK_BIT != 0
151 }
152 pub fn persistence(&self) -> TaskPersistence {
153 if self.is_transient() {
155 TaskPersistence::Transient
156 } else {
157 TaskPersistence::Persistent
158 }
159 }
160}
161
162macro_rules! make_serializable {
163 ($ty:ty, $get_object:path, $validate_type_id:path, $visitor_name:ident) => {
164 impl Serialize for $ty {
165 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
166 where
167 S: serde::Serializer,
168 {
169 serializer.serialize_u16(self.id.into())
170 }
171 }
172
173 impl<'de> Deserialize<'de> for $ty {
174 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
175 where
176 D: serde::Deserializer<'de>,
177 {
178 deserializer.deserialize_u16($visitor_name)
179 }
180 }
181
182 struct $visitor_name;
183
184 impl<'de> Visitor<'de> for $visitor_name {
185 type Value = $ty;
186
187 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
188 formatter.write_str(concat!("an id of a registered ", stringify!($ty)))
189 }
190
191 fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
192 where
193 E: serde::de::Error,
194 {
195 match Self::Value::new(v) {
196 Some(value) => {
197 if let Some(error) = $validate_type_id(value) {
198 Err(E::custom(error))
199 } else {
200 Ok(value)
201 }
202 }
203 None => Err(E::unknown_variant(&format!("{v}"), &["a non zero u16"])),
204 }
205 }
206 }
207
208 impl Debug for $ty {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct(stringify!($ty))
211 .field("id", &self.id)
212 .field("name", &$get_object(*self))
213 .finish()
214 }
215 }
216 };
217}
218
219make_serializable!(
220 ValueTypeId,
221 registry::get_value_type,
222 registry::validate_value_type_id,
223 ValueTypeVisitor
224);
225make_serializable!(
226 TraitTypeId,
227 registry::get_trait,
228 registry::validate_trait_type_id,
229 TraitTypeVisitor
230);
231make_serializable!(
232 FunctionId,
233 registry::get_native_function,
234 registry::validate_function_id,
235 FunctionTypeVisitor
236);