turbo_tasks/id_factory.rs
1use std::{
2 any::type_name,
3 marker::PhantomData,
4 num::NonZeroU64,
5 sync::atomic::{AtomicU64, Ordering},
6};
7
8use concurrent_queue::ConcurrentQueue;
9
10/// A helper for constructing id types like [`FunctionId`][crate::FunctionId].
11///
12/// For ids that may be re-used, see [`IdFactoryWithReuse`].
13pub struct IdFactory<T> {
14 /// A value starting at 0 and incremented each time a new id is allocated. Regardless of the
15 /// underlying type, a u64 is used to cheaply detect overflows.
16 counter: AtomicU64,
17 /// We've overflowed if the `counter > max_count`.
18 max_count: u64,
19 id_offset: u64, // added to the value received from `counter`
20 _phantom_data: PhantomData<T>,
21}
22
23impl<T> IdFactory<T> {
24 /// Create a factory for ids in the range `start..=max`.
25 pub fn new(start: T, max: T) -> Self
26 where
27 T: Into<NonZeroU64> + Ord,
28 {
29 Self::new_const(start.into(), max.into())
30 }
31
32 /// Create a factory for ids in the range `start..=max`.
33 ///
34 /// Provides a less convenient API than [`IdFactory::new`], but skips a type conversion that
35 /// would make the function non-const.
36 pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
37 assert!(start.get() < max.get());
38 Self {
39 // Always start `counter` at 0, don't use the value of `start` because `start` could be
40 // close to `u64::MAX`.
41 counter: AtomicU64::new(0),
42 max_count: max.get() - start.get(),
43 id_offset: start.get(),
44 _phantom_data: PhantomData,
45 }
46 }
47}
48
49impl<T> IdFactory<T>
50where
51 T: TryFrom<NonZeroU64>,
52{
53 /// Return a unique new id.
54 ///
55 /// Panics if the id type overflows.
56 pub fn get(&self) -> T {
57 let count = self.counter.fetch_add(1, Ordering::Relaxed);
58
59 #[cfg(debug_assertions)]
60 {
61 if count == u64::MAX {
62 // u64 counter is about to overflow -- this should never happen! A `u64` counter
63 // starting at 0 should take decades to overflow on a single machine.
64 //
65 // This is unrecoverable because other threads may have already read the overflowed
66 // value, so abort the entire process.
67 std::process::abort()
68 }
69 }
70
71 // `max_count` might be something like `u32::MAX`. The extra bits of `u64` are useful to
72 // detect overflows in that case. We assume the u64 counter is large enough to never
73 // overflow.
74 if count > self.max_count {
75 panic!(
76 "Max id limit (overflow) hit while attempting to generate a unique {}",
77 type_name::<T>(),
78 )
79 }
80
81 let new_id_u64 = count + self.id_offset;
82 // Safety:
83 // - `count` is assumed not to overflow.
84 // - `id_offset` is a non-zero value.
85 // - `id_offset + count < u64::MAX`.
86 let new_id = unsafe { NonZeroU64::new_unchecked(new_id_u64) };
87
88 match new_id.try_into() {
89 Ok(id) => id,
90 // With any sane implementation of `TryFrom`, this shouldn't happen, as we've already
91 // checked the `max_count` bound. (Could happen with the `new_const` constructor)
92 Err(_) => panic!(
93 "Failed to convert NonZeroU64 value of {} into {}",
94 new_id,
95 type_name::<T>()
96 ),
97 }
98 }
99
100 /// Returns an id, potentially allowing an overflow. This may cause ids to be silently re-used.
101 /// Used for [`crate::id::ExecutionId`].
102 ///
103 /// If id re-use is desired only for "freed" ids, use [`IdFactoryWithReuse`] instead.
104 pub fn wrapping_get(&self) -> T {
105 let count = self.counter.fetch_add(1, Ordering::Relaxed);
106
107 let new_id_u64 = (count % self.max_count) + self.id_offset;
108 // Safety:
109 // - `id_offset` is a non-zero value.
110 // - `id_offset + max_count < u64::MAX`.
111 let new_id = unsafe { NonZeroU64::new_unchecked(new_id_u64) };
112
113 match new_id.try_into() {
114 Ok(id) => id,
115 Err(_) => panic!(
116 "Failed to convert NonZeroU64 value of {} into {}",
117 new_id,
118 type_name::<T>()
119 ),
120 }
121 }
122}
123
124/// An [`IdFactory`], but extended with a free list to allow for id reuse.
125///
126/// If silent untracked re-use of ids is okay, consider using the cheaper
127/// [`IdFactory::wrapping_get`] method.
128pub struct IdFactoryWithReuse<T> {
129 factory: IdFactory<T>,
130 free_ids: ConcurrentQueue<T>,
131}
132
133impl<T> IdFactoryWithReuse<T>
134where
135 T: Into<NonZeroU64> + Ord,
136{
137 /// Create a factory for ids in the range `start..=max`.
138 pub fn new(start: T, max: T) -> Self {
139 Self {
140 factory: IdFactory::new(start, max),
141 free_ids: ConcurrentQueue::unbounded(),
142 }
143 }
144
145 /// Create a factory for ids in the range `start..=max`. Provides a less convenient API than
146 /// [`IdFactoryWithReuse::new`], but skips a type conversion that would make the function
147 /// non-const.
148 pub const fn new_const(start: NonZeroU64, max: NonZeroU64) -> Self {
149 Self {
150 factory: IdFactory::new_const(start, max),
151 free_ids: ConcurrentQueue::unbounded(),
152 }
153 }
154}
155
156impl<T> IdFactoryWithReuse<T>
157where
158 T: TryFrom<NonZeroU64>,
159{
160 /// Return a new or potentially reused id.
161 ///
162 /// Panics if the id type overflows.
163 pub fn get(&self) -> T {
164 self.free_ids.pop().unwrap_or_else(|_| self.factory.get())
165 }
166
167 /// Add an id to the free list, allowing it to be re-used on a subsequent call to
168 /// [`IdFactoryWithReuse::get`].
169 ///
170 /// # Safety
171 ///
172 /// The id must no longer be used. Must be a valid id that was previously returned by
173 /// [`IdFactoryWithReuse::get`].
174 pub unsafe fn reuse(&self, id: T) {
175 let _ = self.free_ids.push(id);
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::num::NonZeroU8;
182
183 use super::*;
184
185 #[test]
186 #[should_panic(expected = "Max id limit (overflow)")]
187 fn test_overflow_detection() {
188 let factory = IdFactory::new(NonZeroU8::MIN, NonZeroU8::MAX);
189 assert_eq!(factory.get(), NonZeroU8::new(1).unwrap());
190 assert_eq!(factory.get(), NonZeroU8::new(2).unwrap());
191 for _ in 2..256 {
192 factory.get();
193 }
194 }
195
196 #[test]
197 #[should_panic(expected = "Max id limit (overflow)")]
198 fn test_overflow_detection_near_u64_max() {
199 let factory = IdFactory::new(NonZeroU64::try_from(u64::MAX - 5).unwrap(), NonZeroU64::MAX);
200 for _ in 0..=6 {
201 factory.get();
202 }
203 }
204}