1use std::{
2    fmt::Display,
3    hash::Hash,
4    mem::replace,
5    sync::{Arc, Weak},
6};
7
8use anyhow::Result;
9use indexmap::map::Entry;
10use serde::{Deserialize, Serialize, de::Visitor};
11use tokio::runtime::Handle;
12use turbo_dyn_eq_hash::{
13    DynEq, DynHash, impl_eq_for_dyn, impl_hash_for_dyn, impl_partial_eq_for_dyn,
14};
15
16use crate::{
17    FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
18    manager::{current_task_if_available, mark_invalidator, with_turbo_tasks},
19    trace::TraceRawVcs,
20    util::StaticOrArc,
21};
22
23pub fn get_invalidator() -> Option<Invalidator> {
27    if let Some(task) = current_task_if_available("turbo_tasks::get_invalidator()") {
28        mark_invalidator();
29
30        let handle = Handle::current();
31        Some(Invalidator {
32            task,
33            turbo_tasks: with_turbo_tasks(Arc::downgrade),
34            handle,
35        })
36    } else {
37        None
38    }
39}
40
41pub struct Invalidator {
42    task: TaskId,
43    turbo_tasks: Weak<dyn TurboTasksApi>,
44    handle: Handle,
45}
46
47impl Invalidator {
48    pub fn invalidate(self) {
49        let Invalidator {
50            task,
51            turbo_tasks,
52            handle,
53        } = self;
54        let _guard = handle.enter();
55        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
56            turbo_tasks.invalidate(task);
57        }
58    }
59
60    pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
61        let Invalidator {
62            task,
63            turbo_tasks,
64            handle,
65        } = self;
66        let _guard = handle.enter();
67        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
68            turbo_tasks.invalidate_with_reason(
69                task,
70                (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
71            );
72        }
73    }
74
75    pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
76        let Invalidator {
77            task,
78            turbo_tasks,
79            handle,
80        } = self;
81        let _guard = handle.enter();
82        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
83            turbo_tasks
84                .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
85        }
86    }
87}
88
89impl Hash for Invalidator {
90    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91        self.task.hash(state);
92    }
93}
94
95impl PartialEq for Invalidator {
96    fn eq(&self, other: &Self) -> bool {
97        self.task == other.task
98    }
99}
100
101impl Eq for Invalidator {}
102
103impl TraceRawVcs for Invalidator {
104    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
105        }
107}
108
109impl Serialize for Invalidator {
110    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: serde::Serializer,
113    {
114        serializer.serialize_newtype_struct("Invalidator", &self.task)
115    }
116}
117
118impl<'de> Deserialize<'de> for Invalidator {
119    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120    where
121        D: serde::Deserializer<'de>,
122    {
123        struct V;
124
125        impl<'de> Visitor<'de> for V {
126            type Value = Invalidator;
127
128            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
129                write!(f, "an Invalidator")
130            }
131
132            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
133            where
134                D: serde::Deserializer<'de>,
135            {
136                Ok(Invalidator {
137                    task: TaskId::deserialize(deserializer)?,
138                    turbo_tasks: with_turbo_tasks(Arc::downgrade),
139                    handle: tokio::runtime::Handle::current(),
140                })
141            }
142        }
143        deserializer.deserialize_newtype_struct("Invalidator", V)
144    }
145}
146
147pub trait InvalidationReason: DynEq + DynHash + Display + Send + Sync + 'static {
152    fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
153        None
154    }
155}
156
157pub trait InvalidationReasonKind: DynEq + DynHash + Send + Sync + 'static {
163    fn fmt(
166        &self,
167        data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
168        f: &mut std::fmt::Formatter<'_>,
169    ) -> std::fmt::Result;
170}
171
172impl_partial_eq_for_dyn!(dyn InvalidationReason);
173impl_eq_for_dyn!(dyn InvalidationReason);
174impl_hash_for_dyn!(dyn InvalidationReason);
175
176impl_partial_eq_for_dyn!(dyn InvalidationReasonKind);
177impl_eq_for_dyn!(dyn InvalidationReasonKind);
178impl_hash_for_dyn!(dyn InvalidationReasonKind);
179
180#[derive(PartialEq, Eq, Hash)]
181enum MapKey {
182    Untyped {
183        unique_tag: usize,
184    },
185    Typed {
186        kind: StaticOrArc<dyn InvalidationReasonKind>,
187    },
188}
189
190enum MapEntry {
191    Single {
192        reason: StaticOrArc<dyn InvalidationReason>,
193    },
194    Multiple {
195        reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
196    },
197}
198
199#[derive(Default)]
203pub struct InvalidationReasonSet {
204    next_unique_tag: usize,
205    map: FxIndexMap<MapKey, MapEntry>,
207}
208
209impl InvalidationReasonSet {
210    pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
211        if let Some(kind) = reason.kind() {
212            let key = MapKey::Typed { kind };
213            match self.map.entry(key) {
214                Entry::Occupied(mut entry) => {
215                    let entry = &mut *entry.get_mut();
216                    match replace(
217                        entry,
218                        MapEntry::Multiple {
219                            reasons: FxIndexSet::default(),
220                        },
221                    ) {
222                        MapEntry::Single {
223                            reason: existing_reason,
224                        } => {
225                            if reason == existing_reason {
226                                *entry = MapEntry::Single {
227                                    reason: existing_reason,
228                                };
229                                return;
230                            }
231                            let mut reasons = FxIndexSet::default();
232                            reasons.insert(existing_reason);
233                            reasons.insert(reason);
234                            *entry = MapEntry::Multiple { reasons };
235                        }
236                        MapEntry::Multiple { mut reasons } => {
237                            reasons.insert(reason);
238                            *entry = MapEntry::Multiple { reasons };
239                        }
240                    }
241                }
242                Entry::Vacant(entry) => {
243                    entry.insert(MapEntry::Single { reason });
244                }
245            }
246        } else {
247            let key = MapKey::Untyped {
248                unique_tag: self.next_unique_tag,
249            };
250            self.next_unique_tag += 1;
251            self.map.insert(key, MapEntry::Single { reason });
252        }
253    }
254
255    pub fn is_empty(&self) -> bool {
256        self.map.is_empty()
257    }
258
259    pub fn len(&self) -> usize {
260        self.map.len()
261    }
262}
263
264impl Display for InvalidationReasonSet {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        let count = self.map.len();
267        for (i, (key, entry)) in self.map.iter().enumerate() {
268            if i > 0 {
269                write!(f, ", ")?;
270                if i == count - 1 {
271                    write!(f, "and ")?;
272                }
273            }
274            match entry {
275                MapEntry::Single { reason } => {
276                    write!(f, "{reason}")?;
277                }
278                MapEntry::Multiple { reasons } => {
279                    let MapKey::Typed { kind } = key else {
280                        unreachable!("An untyped reason can't collect more than one reason");
281                    };
282                    kind.fmt(reasons, f)?
283                }
284            }
285        }
286        Ok(())
287    }
288}