1use std::{
2 fmt::{Debug, Display},
3 mem::transmute_copy,
4 num::{NonZero, NonZeroU64, TryFromIntError},
5 ops::Deref,
6};
7
8use bincode::{
9 Decode, Encode,
10 de::Decoder,
11 enc::Encoder,
12 error::{DecodeError, EncodeError},
13 impl_borrow_decode,
14};
15use serde::{Deserialize, Serialize, de::Visitor};
16
17use crate::{
18 TaskPersistence, registry,
19 trace::{TraceRawVcs, TraceRawVcsContext},
20};
21
22macro_rules! define_id {
23 (
24 $name:ident : $primitive:ty
25 $(,derive($($derive:ty),*))?
26 $(,serde($serde:tt))?
27 $(,doc = $doc:literal)*
28 $(,)?
29 ) => {
30 $(#[doc = $doc])*
31 #[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord $($(,$derive)*)? )]
32 $(#[serde($serde)])?
33 pub struct $name {
34 id: NonZero<$primitive>,
35 }
36
37 impl $name {
38 pub const MIN: Self = Self { id: NonZero::<$primitive>::MIN };
39 pub const MAX: Self = Self { id: NonZero::<$primitive>::MAX };
40
41 pub const unsafe fn new_unchecked(id: $primitive) -> Self {
47 Self { id: unsafe { NonZero::<$primitive>::new_unchecked(id) } }
48 }
49 pub fn new(id: $primitive) -> Option<Self> {
54 NonZero::<$primitive>::new(id).map(|id| Self{id})
55 }
56 pub const fn to_non_zero_u64(self) -> NonZeroU64 {
59 const {
60 assert!(<$primitive>::BITS <= u64::BITS);
61 }
62 unsafe { NonZeroU64::new_unchecked(self.id.get() as u64) }
63 }
64 }
65
66 impl Display for $name {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(f, concat!(stringify!($name), " {}"), self.id)
69 }
70 }
71
72 impl Deref for $name {
73 type Target = $primitive;
74
75 fn deref(&self) -> &Self::Target {
76 unsafe { transmute_copy(&&self.id) }
78 }
79 }
80
81 define_id!(@impl_try_from_primitive_conversion $name $primitive);
82
83 impl From<NonZero<$primitive>> for $name {
84 fn from(id: NonZero::<$primitive>) -> Self {
85 Self {
86 id,
87 }
88 }
89 }
90
91 impl From<$name> for NonZeroU64 {
92 fn from(id: $name) -> Self {
93 id.to_non_zero_u64()
94 }
95 }
96
97 impl TraceRawVcs for $name {
98 fn trace_raw_vcs(&self, _trace_context: &mut TraceRawVcsContext) {}
99 }
100 };
101 (
102 @impl_try_from_primitive_conversion $name:ident u64
103 ) => {
104 };
106 (
107 @impl_try_from_primitive_conversion $name:ident $primitive:ty
108 ) => {
109 impl TryFrom<$primitive> for $name {
110 type Error = TryFromIntError;
111
112 fn try_from(id: $primitive) -> Result<Self, Self::Error> {
113 Ok(Self {
114 id: NonZero::try_from(id)?
115 })
116 }
117 }
118
119 impl TryFrom<NonZeroU64> for $name {
120 type Error = TryFromIntError;
121
122 fn try_from(id: NonZeroU64) -> Result<Self, Self::Error> {
123 Ok(Self { id: NonZero::try_from(id)? })
124 }
125 }
126 };
127}
128
129define_id!(TaskId: u32, derive(Serialize, Deserialize, Encode, Decode), serde(transparent));
130define_id!(FunctionId: u16);
131define_id!(ValueTypeId: u16);
132define_id!(TraitTypeId: u16);
133define_id!(
134 LocalTaskId: u32,
135 derive(Debug, Serialize, Deserialize, Encode, Decode),
136 serde(transparent),
137 doc = "Represents the nth `local` function call inside a task.",
138);
139define_id!(
140 ExecutionId: u16,
141 derive(Debug, Serialize, Deserialize, Encode, Decode),
142 serde(transparent),
143 doc = "An identifier for a specific task execution. Used to assert that local `Vc`s don't \
144 leak. This value may overflow and re-use old values.",
145);
146
147impl Debug for TaskId {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("TaskId").field("id", &self.id).finish()
150 }
151}
152
153pub const TRANSIENT_TASK_BIT: u32 = 0x8000_0000;
154
155impl TaskId {
156 pub fn is_transient(&self) -> bool {
157 **self & TRANSIENT_TASK_BIT != 0
158 }
159 pub fn persistence(&self) -> TaskPersistence {
160 if self.is_transient() {
162 TaskPersistence::Transient
163 } else {
164 TaskPersistence::Persistent
165 }
166 }
167}
168
169macro_rules! make_registered_serializable {
170 ($ty:ty, $primitive:ty, $get_object:path, $validate_type_id:path $(,)?) => {
171 impl Serialize for $ty {
172 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
173 where
174 S: serde::Serializer,
175 {
176 serializer.serialize_u16(self.id.into())
177 }
178 }
179
180 impl<'de> Deserialize<'de> for $ty {
181 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
182 where
183 D: serde::Deserializer<'de>,
184 {
185 struct DeserializeVisitor;
186 impl<'de> Visitor<'de> for DeserializeVisitor {
187 type Value = $ty;
188
189 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
190 formatter.write_str(concat!("an id of a registered ", stringify!($ty)))
191 }
192
193 fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
194 where
195 E: serde::de::Error,
196 {
197 match Self::Value::new(v) {
198 Some(value) => {
199 if let Some(error) = $validate_type_id(value) {
200 Err(E::custom(error))
201 } else {
202 Ok(value)
203 }
204 }
205 None => Err(E::unknown_variant(&format!("{v}"), &["a non zero u16"])),
206 }
207 }
208 }
209
210 deserializer.deserialize_u16(DeserializeVisitor)
211 }
212 }
213
214 impl Debug for $ty {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct(stringify!($ty))
217 .field("id", &self.id)
218 .field("name", &$get_object(*self))
219 .finish()
220 }
221 }
222
223 impl Encode for $ty {
224 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
225 <NonZero<$primitive> as Encode>::encode(&self.id, encoder)
226 }
227 }
228
229 impl<Context> Decode<Context> for $ty {
230 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
231 let value = Self {
232 id: NonZero::<$primitive>::decode(decoder)?,
233 };
234 if let Some(error) = $validate_type_id(value) {
235 Err(DecodeError::OtherString(error.to_string()))
236 } else {
237 Ok(value)
238 }
239 }
240 }
241
242 impl_borrow_decode!($ty);
243 };
244}
245
246make_registered_serializable!(
247 ValueTypeId,
248 u16,
249 registry::get_value_type,
250 registry::validate_value_type_id,
251);
252make_registered_serializable!(
253 TraitTypeId,
254 u16,
255 registry::get_trait,
256 registry::validate_trait_type_id,
257);
258make_registered_serializable!(
259 FunctionId,
260 u16,
261 registry::get_native_function,
262 registry::validate_function_id,
263);