1use std::{
2 any::{Any, TypeId},
3 fmt::Display,
4 hash::{Hash, Hasher},
5 mem::replace,
6 sync::{Arc, Weak},
7};
8
9use anyhow::Result;
10use indexmap::map::Entry;
11use serde::{Deserialize, Serialize, de::Visitor};
12use tokio::runtime::Handle;
13
14use crate::{
15 FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
16 magic_any::HasherMut,
17 manager::{current_task, with_turbo_tasks},
18 trace::TraceRawVcs,
19 util::StaticOrArc,
20};
21
22pub fn get_invalidator() -> Invalidator {
25 let handle = Handle::current();
26 Invalidator {
27 task: current_task("turbo_tasks::get_invalidator()"),
28 turbo_tasks: with_turbo_tasks(Arc::downgrade),
29 handle,
30 }
31}
32
33pub struct Invalidator {
34 task: TaskId,
35 turbo_tasks: Weak<dyn TurboTasksApi>,
36 handle: Handle,
37}
38
39impl Invalidator {
40 pub fn invalidate(self) {
41 let Invalidator {
42 task,
43 turbo_tasks,
44 handle,
45 } = self;
46 let _guard = handle.enter();
47 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
48 turbo_tasks.invalidate(task);
49 }
50 }
51
52 pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
53 let Invalidator {
54 task,
55 turbo_tasks,
56 handle,
57 } = self;
58 let _guard = handle.enter();
59 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
60 turbo_tasks.invalidate_with_reason(
61 task,
62 (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
63 );
64 }
65 }
66
67 pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
68 let Invalidator {
69 task,
70 turbo_tasks,
71 handle,
72 } = self;
73 let _guard = handle.enter();
74 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
75 turbo_tasks
76 .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
77 }
78 }
79}
80
81impl Hash for Invalidator {
82 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83 self.task.hash(state);
84 }
85}
86
87impl PartialEq for Invalidator {
88 fn eq(&self, other: &Self) -> bool {
89 self.task == other.task
90 }
91}
92
93impl Eq for Invalidator {}
94
95impl TraceRawVcs for Invalidator {
96 fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
97 }
99}
100
101impl Serialize for Invalidator {
102 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103 where
104 S: serde::Serializer,
105 {
106 serializer.serialize_newtype_struct("Invalidator", &self.task)
107 }
108}
109
110impl<'de> Deserialize<'de> for Invalidator {
111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112 where
113 D: serde::Deserializer<'de>,
114 {
115 struct V;
116
117 impl<'de> Visitor<'de> for V {
118 type Value = Invalidator;
119
120 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
121 write!(f, "an Invalidator")
122 }
123
124 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
125 where
126 D: serde::Deserializer<'de>,
127 {
128 Ok(Invalidator {
129 task: TaskId::deserialize(deserializer)?,
130 turbo_tasks: with_turbo_tasks(Arc::downgrade),
131 handle: tokio::runtime::Handle::current(),
132 })
133 }
134 }
135 deserializer.deserialize_newtype_struct("Invalidator", V)
136 }
137}
138
139pub trait DynamicEqHash {
140 fn as_any(&self) -> &dyn Any;
141 fn dyn_eq(&self, other: &dyn Any) -> bool;
142 fn dyn_hash(&self, state: &mut dyn Hasher);
143}
144
145impl<T: Any + PartialEq + Eq + Hash> DynamicEqHash for T {
146 fn as_any(&self) -> &dyn Any {
147 self
148 }
149
150 fn dyn_eq(&self, other: &dyn Any) -> bool {
151 other
152 .downcast_ref::<Self>()
153 .map(|other| self.eq(other))
154 .unwrap_or(false)
155 }
156
157 fn dyn_hash(&self, state: &mut dyn Hasher) {
158 Hash::hash(&(TypeId::of::<Self>(), self), &mut HasherMut(state));
159 }
160}
161
162pub trait InvalidationReason: DynamicEqHash + Display + Send + Sync + 'static {
167 fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
168 None
169 }
170}
171
172pub trait InvalidationReasonKind: DynamicEqHash + Send + Sync + 'static {
178 fn fmt(
181 &self,
182 data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
183 f: &mut std::fmt::Formatter<'_>,
184 ) -> std::fmt::Result;
185}
186
187macro_rules! impl_eq_hash {
188 ($ty:ty) => {
189 impl PartialEq for $ty {
190 fn eq(&self, other: &Self) -> bool {
191 DynamicEqHash::dyn_eq(self, other.as_any())
192 }
193 }
194
195 impl Eq for $ty {}
196
197 impl Hash for $ty {
198 fn hash<H: Hasher>(&self, state: &mut H) {
199 self.as_any().type_id().hash(state);
200 DynamicEqHash::dyn_hash(self, state as &mut dyn Hasher)
201 }
202 }
203 };
204}
205
206impl_eq_hash!(dyn InvalidationReason);
207impl_eq_hash!(dyn InvalidationReasonKind);
208
209#[derive(PartialEq, Eq, Hash)]
210enum MapKey {
211 Untyped {
212 unique_tag: usize,
213 },
214 Typed {
215 kind: StaticOrArc<dyn InvalidationReasonKind>,
216 },
217}
218
219enum MapEntry {
220 Single {
221 reason: StaticOrArc<dyn InvalidationReason>,
222 },
223 Multiple {
224 reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
225 },
226}
227
228#[derive(Default)]
232pub struct InvalidationReasonSet {
233 next_unique_tag: usize,
234 map: FxIndexMap<MapKey, MapEntry>,
236}
237
238impl InvalidationReasonSet {
239 pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
240 if let Some(kind) = reason.kind() {
241 let key = MapKey::Typed { kind };
242 match self.map.entry(key) {
243 Entry::Occupied(mut entry) => {
244 let entry = &mut *entry.get_mut();
245 match replace(
246 entry,
247 MapEntry::Multiple {
248 reasons: FxIndexSet::default(),
249 },
250 ) {
251 MapEntry::Single {
252 reason: existing_reason,
253 } => {
254 if reason == existing_reason {
255 *entry = MapEntry::Single {
256 reason: existing_reason,
257 };
258 return;
259 }
260 let mut reasons = FxIndexSet::default();
261 reasons.insert(existing_reason);
262 reasons.insert(reason);
263 *entry = MapEntry::Multiple { reasons };
264 }
265 MapEntry::Multiple { mut reasons } => {
266 reasons.insert(reason);
267 *entry = MapEntry::Multiple { reasons };
268 }
269 }
270 }
271 Entry::Vacant(entry) => {
272 entry.insert(MapEntry::Single { reason });
273 }
274 }
275 } else {
276 let key = MapKey::Untyped {
277 unique_tag: self.next_unique_tag,
278 };
279 self.next_unique_tag += 1;
280 self.map.insert(key, MapEntry::Single { reason });
281 }
282 }
283
284 pub fn is_empty(&self) -> bool {
285 self.map.is_empty()
286 }
287
288 pub fn len(&self) -> usize {
289 self.map.len()
290 }
291}
292
293impl Display for InvalidationReasonSet {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 let count = self.map.len();
296 for (i, (key, entry)) in self.map.iter().enumerate() {
297 if i > 0 {
298 write!(f, ", ")?;
299 if i == count - 1 {
300 write!(f, "and ")?;
301 }
302 }
303 match entry {
304 MapEntry::Single { reason } => {
305 write!(f, "{reason}")?;
306 }
307 MapEntry::Multiple { reasons } => {
308 let MapKey::Typed { kind } = key else {
309 unreachable!("An untyped reason can't collect more than one reason");
310 };
311 kind.fmt(reasons, f)?
312 }
313 }
314 }
315 Ok(())
316 }
317}