1use std::{
2 fmt::Display,
3 hash::Hash,
4 mem::replace,
5 sync::{Arc, Weak},
6};
7
8use anyhow::Result;
9use indexmap::map::Entry;
10use serde::{Deserialize, Serialize, de::Visitor};
11use tokio::runtime::Handle;
12use turbo_dyn_eq_hash::{
13 DynEq, DynHash, impl_eq_for_dyn, impl_hash_for_dyn, impl_partial_eq_for_dyn,
14};
15
16use crate::{
17 FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
18 manager::{current_task, mark_invalidator, with_turbo_tasks},
19 trace::TraceRawVcs,
20 util::StaticOrArc,
21};
22
23pub fn get_invalidator() -> Invalidator {
26 mark_invalidator();
27
28 let handle = Handle::current();
29 Invalidator {
30 task: current_task("turbo_tasks::get_invalidator()"),
31 turbo_tasks: with_turbo_tasks(Arc::downgrade),
32 handle,
33 }
34}
35
36pub struct Invalidator {
37 task: TaskId,
38 turbo_tasks: Weak<dyn TurboTasksApi>,
39 handle: Handle,
40}
41
42impl Invalidator {
43 pub fn invalidate(self) {
44 let Invalidator {
45 task,
46 turbo_tasks,
47 handle,
48 } = self;
49 let _guard = handle.enter();
50 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
51 turbo_tasks.invalidate(task);
52 }
53 }
54
55 pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
56 let Invalidator {
57 task,
58 turbo_tasks,
59 handle,
60 } = self;
61 let _guard = handle.enter();
62 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
63 turbo_tasks.invalidate_with_reason(
64 task,
65 (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
66 );
67 }
68 }
69
70 pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
71 let Invalidator {
72 task,
73 turbo_tasks,
74 handle,
75 } = self;
76 let _guard = handle.enter();
77 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
78 turbo_tasks
79 .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
80 }
81 }
82}
83
84impl Hash for Invalidator {
85 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86 self.task.hash(state);
87 }
88}
89
90impl PartialEq for Invalidator {
91 fn eq(&self, other: &Self) -> bool {
92 self.task == other.task
93 }
94}
95
96impl Eq for Invalidator {}
97
98impl TraceRawVcs for Invalidator {
99 fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
100 }
102}
103
104impl Serialize for Invalidator {
105 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106 where
107 S: serde::Serializer,
108 {
109 serializer.serialize_newtype_struct("Invalidator", &self.task)
110 }
111}
112
113impl<'de> Deserialize<'de> for Invalidator {
114 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115 where
116 D: serde::Deserializer<'de>,
117 {
118 struct V;
119
120 impl<'de> Visitor<'de> for V {
121 type Value = Invalidator;
122
123 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
124 write!(f, "an Invalidator")
125 }
126
127 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
128 where
129 D: serde::Deserializer<'de>,
130 {
131 Ok(Invalidator {
132 task: TaskId::deserialize(deserializer)?,
133 turbo_tasks: with_turbo_tasks(Arc::downgrade),
134 handle: tokio::runtime::Handle::current(),
135 })
136 }
137 }
138 deserializer.deserialize_newtype_struct("Invalidator", V)
139 }
140}
141
142pub trait InvalidationReason: DynEq + DynHash + Display + Send + Sync + 'static {
147 fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
148 None
149 }
150}
151
152pub trait InvalidationReasonKind: DynEq + DynHash + Send + Sync + 'static {
158 fn fmt(
161 &self,
162 data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
163 f: &mut std::fmt::Formatter<'_>,
164 ) -> std::fmt::Result;
165}
166
167impl_partial_eq_for_dyn!(dyn InvalidationReason);
168impl_eq_for_dyn!(dyn InvalidationReason);
169impl_hash_for_dyn!(dyn InvalidationReason);
170
171impl_partial_eq_for_dyn!(dyn InvalidationReasonKind);
172impl_eq_for_dyn!(dyn InvalidationReasonKind);
173impl_hash_for_dyn!(dyn InvalidationReasonKind);
174
175#[derive(PartialEq, Eq, Hash)]
176enum MapKey {
177 Untyped {
178 unique_tag: usize,
179 },
180 Typed {
181 kind: StaticOrArc<dyn InvalidationReasonKind>,
182 },
183}
184
185enum MapEntry {
186 Single {
187 reason: StaticOrArc<dyn InvalidationReason>,
188 },
189 Multiple {
190 reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
191 },
192}
193
194#[derive(Default)]
198pub struct InvalidationReasonSet {
199 next_unique_tag: usize,
200 map: FxIndexMap<MapKey, MapEntry>,
202}
203
204impl InvalidationReasonSet {
205 pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
206 if let Some(kind) = reason.kind() {
207 let key = MapKey::Typed { kind };
208 match self.map.entry(key) {
209 Entry::Occupied(mut entry) => {
210 let entry = &mut *entry.get_mut();
211 match replace(
212 entry,
213 MapEntry::Multiple {
214 reasons: FxIndexSet::default(),
215 },
216 ) {
217 MapEntry::Single {
218 reason: existing_reason,
219 } => {
220 if reason == existing_reason {
221 *entry = MapEntry::Single {
222 reason: existing_reason,
223 };
224 return;
225 }
226 let mut reasons = FxIndexSet::default();
227 reasons.insert(existing_reason);
228 reasons.insert(reason);
229 *entry = MapEntry::Multiple { reasons };
230 }
231 MapEntry::Multiple { mut reasons } => {
232 reasons.insert(reason);
233 *entry = MapEntry::Multiple { reasons };
234 }
235 }
236 }
237 Entry::Vacant(entry) => {
238 entry.insert(MapEntry::Single { reason });
239 }
240 }
241 } else {
242 let key = MapKey::Untyped {
243 unique_tag: self.next_unique_tag,
244 };
245 self.next_unique_tag += 1;
246 self.map.insert(key, MapEntry::Single { reason });
247 }
248 }
249
250 pub fn is_empty(&self) -> bool {
251 self.map.is_empty()
252 }
253
254 pub fn len(&self) -> usize {
255 self.map.len()
256 }
257}
258
259impl Display for InvalidationReasonSet {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 let count = self.map.len();
262 for (i, (key, entry)) in self.map.iter().enumerate() {
263 if i > 0 {
264 write!(f, ", ")?;
265 if i == count - 1 {
266 write!(f, "and ")?;
267 }
268 }
269 match entry {
270 MapEntry::Single { reason } => {
271 write!(f, "{reason}")?;
272 }
273 MapEntry::Multiple { reasons } => {
274 let MapKey::Typed { kind } = key else {
275 unreachable!("An untyped reason can't collect more than one reason");
276 };
277 kind.fmt(reasons, f)?
278 }
279 }
280 }
281 Ok(())
282 }
283}