turbo_tasks/id_factory.rs
1use std::{
2 any::type_name,
3 marker::PhantomData,
4 num::NonZeroU64,
5 sync::atomic::{AtomicU64, Ordering},
6};
7
8use concurrent_queue::ConcurrentQueue;
9
10/// A helper for constructing id types like [`FunctionId`][crate::FunctionId].
11///
12/// For ids that may be re-used, see [`IdFactoryWithReuse`].
13pub struct IdFactory<T> {
14 /// A value starting at 0 and incremented each time a new id is allocated. Regardless of the
15 /// underlying type, a u64 is used to cheaply detect overflows.
16 counter: AtomicU64,
17 /// We've overflowed if the `counter > max_count`.
18 max_count: u64,
19 id_offset: u64, // added to the value received from `counter`
20 _phantom_data: PhantomData<T>,
21}
22
23impl<T> IdFactory<T> {
24 /// Create a factory for ids in the range `start..=max`.
25 pub fn new(start: T, max: T) -> Self
26 where
27 T: Into<NonZeroU64> + Ord,
28 {
29 Self::new_const(start.into(), max.into())
30 }
31
32 /// Create a factory for ids in the range `start..=max`.
33 ///
34 /// Provides a less convenient API than [`IdFactory::new`], but skips a type conversion that
35 /// would make the function non-const.
36 pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
37 assert!(start.get() < max.get());
38 Self {
39 // Always start `counter` at 0, don't use the value of `start` because `start` could be
40 // close to `u64::MAX`.
41 counter: AtomicU64::new(0),
42 max_count: max.get() - start.get(),
43 id_offset: start.get(),
44 _phantom_data: PhantomData,
45 }
46 }
47}
48
49impl<T> IdFactory<T>
50where
51 T: TryFrom<NonZeroU64>,
52{
53 /// Return a unique new id.
54 ///
55 /// Panics if the id type overflows.
56 pub fn get(&self) -> T {
57 let count = self.counter.fetch_add(1, Ordering::Relaxed);
58
59 #[cfg(debug_assertions)]
60 {
61 if count == u64::MAX {
62 // u64 counter is about to overflow -- this should never happen! A `u64` counter
63 // starting at 0 should take decades to overflow on a single machine.
64 //
65 // This is unrecoverable because other threads may have already read the overflowed
66 // value, so abort the entire process.
67 std::process::abort()
68 }
69 }
70
71 // `max_count` might be something like `u32::MAX`. The extra bits of `u64` are useful to
72 // detect overflows in that case. We assume the u64 counter is large enough to never
73 // overflow.
74 if count > self.max_count {
75 panic!(
76 "Max id limit (overflow) hit while attempting to generate a unique {}",
77 type_name::<T>(),
78 )
79 }
80
81 let new_id_u64 = count + self.id_offset;
82 // Safety:
83 // - `count` is assumed not to overflow.
84 // - `id_offset` is a non-zero value.
85 // - `id_offset + count < u64::MAX`.
86 let new_id = unsafe { NonZeroU64::new_unchecked(new_id_u64) };
87
88 match new_id.try_into() {
89 Ok(id) => id,
90 // With any sane implementation of `TryFrom`, this shouldn't happen, as we've already
91 // checked the `max_count` bound. (Could happen with the `new_const` constructor)
92 Err(_) => panic!(
93 "Failed to convert NonZeroU64 value of {} into {}",
94 new_id,
95 type_name::<T>()
96 ),
97 }
98 }
99
100 /// Returns an id, potentially allowing an overflow. This may cause ids to be silently re-used.
101 /// Used for [`crate::id::ExecutionId`].
102 ///
103 /// If id re-use is desired only for "freed" ids, use [`IdFactoryWithReuse`] instead.
104 pub fn wrapping_get(&self) -> T {
105 let count = self.counter.fetch_add(1, Ordering::Relaxed);
106
107 let new_id_u64 = (count % self.max_count) + self.id_offset;
108 // Safety:
109 // - `id_offset` is a non-zero value.
110 // - `id_offset + max_count < u64::MAX`.
111 let new_id = unsafe { NonZeroU64::new_unchecked(new_id_u64) };
112
113 match new_id.try_into() {
114 Ok(id) => id,
115 Err(_) => panic!(
116 "Failed to convert NonZeroU64 value of {} into {}",
117 new_id,
118 type_name::<T>()
119 ),
120 }
121 }
122}
123
124/// An [`IdFactory`], but extended with a free list to allow for id reuse for ids such as
125/// [`BackendJobId`][crate::backend::BackendJobId].
126///
127/// If silent untracked re-use of ids is okay, consider using the cheaper
128/// [`IdFactory::wrapping_get`] method.
129pub struct IdFactoryWithReuse<T> {
130 factory: IdFactory<T>,
131 free_ids: ConcurrentQueue<T>,
132}
133
134impl<T> IdFactoryWithReuse<T>
135where
136 T: Into<NonZeroU64> + Ord,
137{
138 /// Create a factory for ids in the range `start..=max`.
139 pub fn new(start: T, max: T) -> Self {
140 Self {
141 factory: IdFactory::new(start, max),
142 free_ids: ConcurrentQueue::unbounded(),
143 }
144 }
145
146 /// Create a factory for ids in the range `start..=max`. Provides a less convenient API than
147 /// [`IdFactoryWithReuse::new`], but skips a type conversion that would make the function
148 /// non-const.
149 pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
150 Self {
151 factory: IdFactory::new_const(start, max),
152 free_ids: ConcurrentQueue::unbounded(),
153 }
154 }
155}
156
157impl<T> IdFactoryWithReuse<T>
158where
159 T: TryFrom<NonZeroU64>,
160{
161 /// Return a new or potentially reused id.
162 ///
163 /// Panics if the id type overflows.
164 pub fn get(&self) -> T {
165 self.free_ids.pop().unwrap_or_else(|_| self.factory.get())
166 }
167
168 /// Add an id to the free list, allowing it to be re-used on a subsequent call to
169 /// [`IdFactoryWithReuse::get`].
170 ///
171 /// # Safety
172 ///
173 /// The id must no longer be used. Must be a valid id that was previously returned by
174 /// [`IdFactoryWithReuse::get`].
175 pub unsafe fn reuse(&self, id: T) {
176 let _ = self.free_ids.push(id);
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use std::num::NonZeroU8;
183
184 use super::*;
185
186 #[test]
187 #[should_panic(expected = "Max id limit (overflow)")]
188 fn test_overflow_detection() {
189 let factory = IdFactory::new(NonZeroU8::MIN, NonZeroU8::MAX);
190 assert_eq!(factory.get(), NonZeroU8::new(1).unwrap());
191 assert_eq!(factory.get(), NonZeroU8::new(2).unwrap());
192 for _ in 2..256 {
193 factory.get();
194 }
195 }
196
197 #[test]
198 #[should_panic(expected = "Max id limit (overflow)")]
199 fn test_overflow_detection_near_u64_max() {
200 let factory = IdFactory::new(NonZeroU64::try_from(u64::MAX - 5).unwrap(), NonZeroU64::MAX);
201 for _ in 0..=6 {
202 factory.get();
203 }
204 }
205}