1use std::{
2 any::{Any, TypeId},
3 fmt::Display,
4 hash::{Hash, Hasher},
5 mem::replace,
6 sync::{Arc, Weak},
7};
8
9use anyhow::Result;
10use indexmap::map::Entry;
11use serde::{Deserialize, Serialize, de::Visitor};
12use tokio::{runtime::Handle, task_local};
13
14use crate::{
15 FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
16 magic_any::HasherMut,
17 manager::{current_task, with_turbo_tasks},
18 trace::TraceRawVcs,
19 util::StaticOrArc,
20};
21
22task_local! {
23 static DISALLOW_INVALIDATOR: ();
24}
25
26pub fn disallow_invalidator<R>(f: impl Future<Output = R>) -> impl Future<Output = R> {
27 DISALLOW_INVALIDATOR.scope((), f)
28}
29
30pub fn get_invalidator() -> Invalidator {
33 if DISALLOW_INVALIDATOR.try_with(|_| {}).is_ok() {
34 panic!(
35 "Invalidator can only be used in the turbo-tasks function that has \
36 #[turbo_tasks::function(invalidator)] attribute"
37 );
38 }
39
40 let handle = Handle::current();
41 Invalidator {
42 task: current_task("turbo_tasks::get_invalidator()"),
43 turbo_tasks: with_turbo_tasks(Arc::downgrade),
44 handle,
45 }
46}
47
48pub struct Invalidator {
49 task: TaskId,
50 turbo_tasks: Weak<dyn TurboTasksApi>,
51 handle: Handle,
52}
53
54impl Invalidator {
55 pub fn invalidate(self) {
56 let Invalidator {
57 task,
58 turbo_tasks,
59 handle,
60 } = self;
61 let _guard = handle.enter();
62 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
63 turbo_tasks.invalidate(task);
64 }
65 }
66
67 pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
68 let Invalidator {
69 task,
70 turbo_tasks,
71 handle,
72 } = self;
73 let _guard = handle.enter();
74 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
75 turbo_tasks.invalidate_with_reason(
76 task,
77 (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
78 );
79 }
80 }
81
82 pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
83 let Invalidator {
84 task,
85 turbo_tasks,
86 handle,
87 } = self;
88 let _guard = handle.enter();
89 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
90 turbo_tasks
91 .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
92 }
93 }
94}
95
96impl Hash for Invalidator {
97 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98 self.task.hash(state);
99 }
100}
101
102impl PartialEq for Invalidator {
103 fn eq(&self, other: &Self) -> bool {
104 self.task == other.task
105 }
106}
107
108impl Eq for Invalidator {}
109
110impl TraceRawVcs for Invalidator {
111 fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
112 }
114}
115
116impl Serialize for Invalidator {
117 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
118 where
119 S: serde::Serializer,
120 {
121 serializer.serialize_newtype_struct("Invalidator", &self.task)
122 }
123}
124
125impl<'de> Deserialize<'de> for Invalidator {
126 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 struct V;
131
132 impl<'de> Visitor<'de> for V {
133 type Value = Invalidator;
134
135 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136 write!(f, "an Invalidator")
137 }
138
139 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
140 where
141 D: serde::Deserializer<'de>,
142 {
143 Ok(Invalidator {
144 task: TaskId::deserialize(deserializer)?,
145 turbo_tasks: with_turbo_tasks(Arc::downgrade),
146 handle: tokio::runtime::Handle::current(),
147 })
148 }
149 }
150 deserializer.deserialize_newtype_struct("Invalidator", V)
151 }
152}
153
154pub trait DynamicEqHash {
155 fn as_any(&self) -> &dyn Any;
156 fn dyn_eq(&self, other: &dyn Any) -> bool;
157 fn dyn_hash(&self, state: &mut dyn Hasher);
158}
159
160impl<T: Any + PartialEq + Eq + Hash> DynamicEqHash for T {
161 fn as_any(&self) -> &dyn Any {
162 self
163 }
164
165 fn dyn_eq(&self, other: &dyn Any) -> bool {
166 other
167 .downcast_ref::<Self>()
168 .map(|other| self.eq(other))
169 .unwrap_or(false)
170 }
171
172 fn dyn_hash(&self, state: &mut dyn Hasher) {
173 Hash::hash(&(TypeId::of::<Self>(), self), &mut HasherMut(state));
174 }
175}
176
177pub trait InvalidationReason: DynamicEqHash + Display + Send + Sync + 'static {
182 fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
183 None
184 }
185}
186
187pub trait InvalidationReasonKind: DynamicEqHash + Send + Sync + 'static {
193 fn fmt(
196 &self,
197 data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
198 f: &mut std::fmt::Formatter<'_>,
199 ) -> std::fmt::Result;
200}
201
202macro_rules! impl_eq_hash {
203 ($ty:ty) => {
204 impl PartialEq for $ty {
205 fn eq(&self, other: &Self) -> bool {
206 DynamicEqHash::dyn_eq(self, other.as_any())
207 }
208 }
209
210 impl Eq for $ty {}
211
212 impl Hash for $ty {
213 fn hash<H: Hasher>(&self, state: &mut H) {
214 self.as_any().type_id().hash(state);
215 DynamicEqHash::dyn_hash(self, state as &mut dyn Hasher)
216 }
217 }
218 };
219}
220
221impl_eq_hash!(dyn InvalidationReason);
222impl_eq_hash!(dyn InvalidationReasonKind);
223
224#[derive(PartialEq, Eq, Hash)]
225enum MapKey {
226 Untyped {
227 unique_tag: usize,
228 },
229 Typed {
230 kind: StaticOrArc<dyn InvalidationReasonKind>,
231 },
232}
233
234enum MapEntry {
235 Single {
236 reason: StaticOrArc<dyn InvalidationReason>,
237 },
238 Multiple {
239 reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
240 },
241}
242
243#[derive(Default)]
247pub struct InvalidationReasonSet {
248 next_unique_tag: usize,
249 map: FxIndexMap<MapKey, MapEntry>,
251}
252
253impl InvalidationReasonSet {
254 pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
255 if let Some(kind) = reason.kind() {
256 let key = MapKey::Typed { kind };
257 match self.map.entry(key) {
258 Entry::Occupied(mut entry) => {
259 let entry = &mut *entry.get_mut();
260 match replace(
261 entry,
262 MapEntry::Multiple {
263 reasons: FxIndexSet::default(),
264 },
265 ) {
266 MapEntry::Single {
267 reason: existing_reason,
268 } => {
269 if reason == existing_reason {
270 *entry = MapEntry::Single {
271 reason: existing_reason,
272 };
273 return;
274 }
275 let mut reasons = FxIndexSet::default();
276 reasons.insert(existing_reason);
277 reasons.insert(reason);
278 *entry = MapEntry::Multiple { reasons };
279 }
280 MapEntry::Multiple { mut reasons } => {
281 reasons.insert(reason);
282 *entry = MapEntry::Multiple { reasons };
283 }
284 }
285 }
286 Entry::Vacant(entry) => {
287 entry.insert(MapEntry::Single { reason });
288 }
289 }
290 } else {
291 let key = MapKey::Untyped {
292 unique_tag: self.next_unique_tag,
293 };
294 self.next_unique_tag += 1;
295 self.map.insert(key, MapEntry::Single { reason });
296 }
297 }
298
299 pub fn is_empty(&self) -> bool {
300 self.map.is_empty()
301 }
302
303 pub fn len(&self) -> usize {
304 self.map.len()
305 }
306}
307
308impl Display for InvalidationReasonSet {
309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310 let count = self.map.len();
311 for (i, (key, entry)) in self.map.iter().enumerate() {
312 if i > 0 {
313 write!(f, ", ")?;
314 if i == count - 1 {
315 write!(f, "and ")?;
316 }
317 }
318 match entry {
319 MapEntry::Single { reason } => {
320 write!(f, "{reason}")?;
321 }
322 MapEntry::Multiple { reasons } => {
323 let MapKey::Typed { kind } = key else {
324 unreachable!("An untyped reason can't collect more than one reason");
325 };
326 kind.fmt(reasons, f)?
327 }
328 }
329 }
330 Ok(())
331 }
332}