turbo_tasks/
invalidation.rs

1use std::{
2    any::{Any, TypeId},
3    fmt::Display,
4    hash::{Hash, Hasher},
5    mem::replace,
6    sync::{Arc, Weak},
7};
8
9use anyhow::Result;
10use indexmap::map::Entry;
11use serde::{Deserialize, Serialize, de::Visitor};
12use tokio::runtime::Handle;
13
14use crate::{
15    FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
16    magic_any::HasherMut,
17    manager::{current_task, with_turbo_tasks},
18    trace::TraceRawVcs,
19    util::StaticOrArc,
20};
21
22/// Get an [`Invalidator`] that can be used to invalidate the current task
23/// based on external events.
24pub fn get_invalidator() -> Invalidator {
25    let handle = Handle::current();
26    Invalidator {
27        task: current_task("turbo_tasks::get_invalidator()"),
28        turbo_tasks: with_turbo_tasks(Arc::downgrade),
29        handle,
30    }
31}
32
33pub struct Invalidator {
34    task: TaskId,
35    turbo_tasks: Weak<dyn TurboTasksApi>,
36    handle: Handle,
37}
38
39impl Invalidator {
40    pub fn invalidate(self) {
41        let Invalidator {
42            task,
43            turbo_tasks,
44            handle,
45        } = self;
46        let _guard = handle.enter();
47        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
48            turbo_tasks.invalidate(task);
49        }
50    }
51
52    pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
53        let Invalidator {
54            task,
55            turbo_tasks,
56            handle,
57        } = self;
58        let _guard = handle.enter();
59        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
60            turbo_tasks.invalidate_with_reason(
61                task,
62                (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
63            );
64        }
65    }
66
67    pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
68        let Invalidator {
69            task,
70            turbo_tasks,
71            handle,
72        } = self;
73        let _guard = handle.enter();
74        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
75            turbo_tasks
76                .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
77        }
78    }
79}
80
81impl Hash for Invalidator {
82    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
83        self.task.hash(state);
84    }
85}
86
87impl PartialEq for Invalidator {
88    fn eq(&self, other: &Self) -> bool {
89        self.task == other.task
90    }
91}
92
93impl Eq for Invalidator {}
94
95impl TraceRawVcs for Invalidator {
96    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
97        // nothing here
98    }
99}
100
101impl Serialize for Invalidator {
102    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103    where
104        S: serde::Serializer,
105    {
106        serializer.serialize_newtype_struct("Invalidator", &self.task)
107    }
108}
109
110impl<'de> Deserialize<'de> for Invalidator {
111    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112    where
113        D: serde::Deserializer<'de>,
114    {
115        struct V;
116
117        impl<'de> Visitor<'de> for V {
118            type Value = Invalidator;
119
120            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
121                write!(f, "an Invalidator")
122            }
123
124            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
125            where
126                D: serde::Deserializer<'de>,
127            {
128                Ok(Invalidator {
129                    task: TaskId::deserialize(deserializer)?,
130                    turbo_tasks: with_turbo_tasks(Arc::downgrade),
131                    handle: tokio::runtime::Handle::current(),
132                })
133            }
134        }
135        deserializer.deserialize_newtype_struct("Invalidator", V)
136    }
137}
138
139pub trait DynamicEqHash {
140    fn as_any(&self) -> &dyn Any;
141    fn dyn_eq(&self, other: &dyn Any) -> bool;
142    fn dyn_hash(&self, state: &mut dyn Hasher);
143}
144
145impl<T: Any + PartialEq + Eq + Hash> DynamicEqHash for T {
146    fn as_any(&self) -> &dyn Any {
147        self
148    }
149
150    fn dyn_eq(&self, other: &dyn Any) -> bool {
151        other
152            .downcast_ref::<Self>()
153            .map(|other| self.eq(other))
154            .unwrap_or(false)
155    }
156
157    fn dyn_hash(&self, state: &mut dyn Hasher) {
158        Hash::hash(&(TypeId::of::<Self>(), self), &mut HasherMut(state));
159    }
160}
161
162/// A user-facing reason why a task was invalidated. This should only be used
163/// for invalidation that were triggered by the user.
164///
165/// Reasons are deduplicated, so this need to implement [Eq] and [Hash]
166pub trait InvalidationReason: DynamicEqHash + Display + Send + Sync + 'static {
167    fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
168        None
169    }
170}
171
172/// Invalidation reason kind. This is used to merge multiple reasons of the same
173/// kind into a combined description.
174///
175/// Reason kinds are used a hash map key, so this need to implement [Eq] and
176/// [Hash]
177pub trait InvalidationReasonKind: DynamicEqHash + Send + Sync + 'static {
178    /// Displays a description of multiple invalidation reasons of the same
179    /// kind. It is only called with two or more reasons.
180    fn fmt(
181        &self,
182        data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
183        f: &mut std::fmt::Formatter<'_>,
184    ) -> std::fmt::Result;
185}
186
187macro_rules! impl_eq_hash {
188    ($ty:ty) => {
189        impl PartialEq for $ty {
190            fn eq(&self, other: &Self) -> bool {
191                DynamicEqHash::dyn_eq(self, other.as_any())
192            }
193        }
194
195        impl Eq for $ty {}
196
197        impl Hash for $ty {
198            fn hash<H: Hasher>(&self, state: &mut H) {
199                self.as_any().type_id().hash(state);
200                DynamicEqHash::dyn_hash(self, state as &mut dyn Hasher)
201            }
202        }
203    };
204}
205
206impl_eq_hash!(dyn InvalidationReason);
207impl_eq_hash!(dyn InvalidationReasonKind);
208
209#[derive(PartialEq, Eq, Hash)]
210enum MapKey {
211    Untyped {
212        unique_tag: usize,
213    },
214    Typed {
215        kind: StaticOrArc<dyn InvalidationReasonKind>,
216    },
217}
218
219enum MapEntry {
220    Single {
221        reason: StaticOrArc<dyn InvalidationReason>,
222    },
223    Multiple {
224        reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
225    },
226}
227
228/// A set of [InvalidationReason]s. They are automatically deduplicated and
229/// merged by kind during insertion. It implements [Display] to get a readable
230/// representation.
231#[derive(Default)]
232pub struct InvalidationReasonSet {
233    next_unique_tag: usize,
234    // We track typed and untyped entries in the same map to keep the occurence order of entries.
235    map: FxIndexMap<MapKey, MapEntry>,
236}
237
238impl InvalidationReasonSet {
239    pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
240        if let Some(kind) = reason.kind() {
241            let key = MapKey::Typed { kind };
242            match self.map.entry(key) {
243                Entry::Occupied(mut entry) => {
244                    let entry = &mut *entry.get_mut();
245                    match replace(
246                        entry,
247                        MapEntry::Multiple {
248                            reasons: FxIndexSet::default(),
249                        },
250                    ) {
251                        MapEntry::Single {
252                            reason: existing_reason,
253                        } => {
254                            if reason == existing_reason {
255                                *entry = MapEntry::Single {
256                                    reason: existing_reason,
257                                };
258                                return;
259                            }
260                            let mut reasons = FxIndexSet::default();
261                            reasons.insert(existing_reason);
262                            reasons.insert(reason);
263                            *entry = MapEntry::Multiple { reasons };
264                        }
265                        MapEntry::Multiple { mut reasons } => {
266                            reasons.insert(reason);
267                            *entry = MapEntry::Multiple { reasons };
268                        }
269                    }
270                }
271                Entry::Vacant(entry) => {
272                    entry.insert(MapEntry::Single { reason });
273                }
274            }
275        } else {
276            let key = MapKey::Untyped {
277                unique_tag: self.next_unique_tag,
278            };
279            self.next_unique_tag += 1;
280            self.map.insert(key, MapEntry::Single { reason });
281        }
282    }
283
284    pub fn is_empty(&self) -> bool {
285        self.map.is_empty()
286    }
287
288    pub fn len(&self) -> usize {
289        self.map.len()
290    }
291}
292
293impl Display for InvalidationReasonSet {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        let count = self.map.len();
296        for (i, (key, entry)) in self.map.iter().enumerate() {
297            if i > 0 {
298                write!(f, ", ")?;
299                if i == count - 1 {
300                    write!(f, "and ")?;
301                }
302            }
303            match entry {
304                MapEntry::Single { reason } => {
305                    write!(f, "{reason}")?;
306                }
307                MapEntry::Multiple { reasons } => {
308                    let MapKey::Typed { kind } = key else {
309                        unreachable!("An untyped reason can't collect more than one reason");
310                    };
311                    kind.fmt(reasons, f)?
312                }
313            }
314        }
315        Ok(())
316    }
317}