1use std::{
2 fmt::Display,
3 hash::Hash,
4 mem::replace,
5 sync::{Arc, Weak},
6};
7
8use anyhow::Result;
9use bincode::{
10 Decode, Encode,
11 de::Decoder,
12 enc::Encoder,
13 error::{DecodeError, EncodeError},
14 impl_borrow_decode,
15};
16use indexmap::map::Entry;
17use tokio::runtime::Handle;
18use turbo_dyn_eq_hash::{
19 DynEq, DynHash, impl_eq_for_dyn, impl_hash_for_dyn, impl_partial_eq_for_dyn,
20};
21
22use crate::{
23 FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
24 manager::{current_task_if_available, mark_invalidator, with_turbo_tasks},
25 trace::TraceRawVcs,
26 util::StaticOrArc,
27};
28
29pub fn get_invalidator() -> Option<Invalidator> {
33 if let Some(task) = current_task_if_available("turbo_tasks::get_invalidator()") {
34 mark_invalidator();
35
36 let handle = Handle::current();
37 Some(Invalidator {
38 task,
39 turbo_tasks: with_turbo_tasks(Arc::downgrade),
40 handle,
41 })
42 } else {
43 None
44 }
45}
46
47pub struct Invalidator {
48 task: TaskId,
49 turbo_tasks: Weak<dyn TurboTasksApi>,
50 handle: Handle,
51}
52
53impl Invalidator {
54 pub fn invalidate(self) {
55 let Invalidator {
56 task,
57 turbo_tasks,
58 handle,
59 } = self;
60 let _guard = handle.enter();
61 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
62 turbo_tasks.invalidate(task);
63 }
64 }
65
66 pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
67 let Invalidator {
68 task,
69 turbo_tasks,
70 handle,
71 } = self;
72 let _guard = handle.enter();
73 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
74 turbo_tasks.invalidate_with_reason(
75 task,
76 (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
77 );
78 }
79 }
80
81 pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
82 let Invalidator {
83 task,
84 turbo_tasks,
85 handle,
86 } = self;
87 let _guard = handle.enter();
88 if let Some(turbo_tasks) = turbo_tasks.upgrade() {
89 turbo_tasks
90 .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
91 }
92 }
93}
94
95impl Hash for Invalidator {
96 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
97 self.task.hash(state);
98 }
99}
100
101impl PartialEq for Invalidator {
102 fn eq(&self, other: &Self) -> bool {
103 self.task == other.task
104 }
105}
106
107impl Eq for Invalidator {}
108
109impl TraceRawVcs for Invalidator {
110 fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
111 }
113}
114
115impl Encode for Invalidator {
116 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
117 Encode::encode(&self.task, encoder)
118 }
119}
120
121impl<Context> Decode<Context> for Invalidator {
122 fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
123 Ok(Invalidator {
124 task: Decode::decode(decoder)?,
125 turbo_tasks: with_turbo_tasks(Arc::downgrade),
126 handle: tokio::runtime::Handle::current(),
127 })
128 }
129}
130
131impl_borrow_decode!(Invalidator);
132
133pub trait InvalidationReason: DynEq + DynHash + Display + Send + Sync + 'static {
138 fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
139 None
140 }
141}
142
143pub trait InvalidationReasonKind: DynEq + DynHash + Send + Sync + 'static {
149 fn fmt(
152 &self,
153 data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
154 f: &mut std::fmt::Formatter<'_>,
155 ) -> std::fmt::Result;
156}
157
158impl_partial_eq_for_dyn!(dyn InvalidationReason);
159impl_eq_for_dyn!(dyn InvalidationReason);
160impl_hash_for_dyn!(dyn InvalidationReason);
161
162impl_partial_eq_for_dyn!(dyn InvalidationReasonKind);
163impl_eq_for_dyn!(dyn InvalidationReasonKind);
164impl_hash_for_dyn!(dyn InvalidationReasonKind);
165
166#[derive(PartialEq, Eq, Hash)]
167enum MapKey {
168 Untyped {
169 unique_tag: usize,
170 },
171 Typed {
172 kind: StaticOrArc<dyn InvalidationReasonKind>,
173 },
174}
175
176enum MapEntry {
177 Single {
178 reason: StaticOrArc<dyn InvalidationReason>,
179 },
180 Multiple {
181 reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
182 },
183}
184
185#[derive(Default)]
189pub struct InvalidationReasonSet {
190 next_unique_tag: usize,
191 map: FxIndexMap<MapKey, MapEntry>,
193}
194
195impl InvalidationReasonSet {
196 pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
197 if let Some(kind) = reason.kind() {
198 let key = MapKey::Typed { kind };
199 match self.map.entry(key) {
200 Entry::Occupied(mut entry) => {
201 let entry = &mut *entry.get_mut();
202 match replace(
203 entry,
204 MapEntry::Multiple {
205 reasons: FxIndexSet::default(),
206 },
207 ) {
208 MapEntry::Single {
209 reason: existing_reason,
210 } => {
211 if reason == existing_reason {
212 *entry = MapEntry::Single {
213 reason: existing_reason,
214 };
215 return;
216 }
217 let mut reasons = FxIndexSet::default();
218 reasons.insert(existing_reason);
219 reasons.insert(reason);
220 *entry = MapEntry::Multiple { reasons };
221 }
222 MapEntry::Multiple { mut reasons } => {
223 reasons.insert(reason);
224 *entry = MapEntry::Multiple { reasons };
225 }
226 }
227 }
228 Entry::Vacant(entry) => {
229 entry.insert(MapEntry::Single { reason });
230 }
231 }
232 } else {
233 let key = MapKey::Untyped {
234 unique_tag: self.next_unique_tag,
235 };
236 self.next_unique_tag += 1;
237 self.map.insert(key, MapEntry::Single { reason });
238 }
239 }
240
241 pub fn is_empty(&self) -> bool {
242 self.map.is_empty()
243 }
244
245 pub fn len(&self) -> usize {
246 self.map.len()
247 }
248}
249
250impl Display for InvalidationReasonSet {
251 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252 let count = self.map.len();
253 for (i, (key, entry)) in self.map.iter().enumerate() {
254 if i > 0 {
255 write!(f, ", ")?;
256 if i == count - 1 {
257 write!(f, "and ")?;
258 }
259 }
260 match entry {
261 MapEntry::Single { reason } => {
262 write!(f, "{reason}")?;
263 }
264 MapEntry::Multiple { reasons } => {
265 let MapKey::Typed { kind } = key else {
266 unreachable!("An untyped reason can't collect more than one reason");
267 };
268 kind.fmt(reasons, f)?
269 }
270 }
271 }
272 Ok(())
273 }
274}