turbo_tasks/
invalidation.rs

1use std::{
2    any::{Any, TypeId},
3    fmt::Display,
4    hash::{Hash, Hasher},
5    mem::replace,
6    sync::{Arc, Weak},
7};
8
9use anyhow::Result;
10use indexmap::map::Entry;
11use serde::{Deserialize, Serialize, de::Visitor};
12use tokio::{runtime::Handle, task_local};
13
14use crate::{
15    FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
16    magic_any::HasherMut,
17    manager::{current_task, with_turbo_tasks},
18    trace::TraceRawVcs,
19    util::StaticOrArc,
20};
21
22task_local! {
23    static DISALLOW_INVALIDATOR: ();
24}
25
26pub fn disallow_invalidator<R>(f: impl Future<Output = R>) -> impl Future<Output = R> {
27    DISALLOW_INVALIDATOR.scope((), f)
28}
29
30/// Get an [`Invalidator`] that can be used to invalidate the current task
31/// based on external events.
32pub fn get_invalidator() -> Invalidator {
33    if DISALLOW_INVALIDATOR.try_with(|_| {}).is_ok() {
34        panic!(
35            "Invalidator can only be used in the turbo-tasks function that has \
36             #[turbo_tasks::function(invalidator)] attribute"
37        );
38    }
39
40    let handle = Handle::current();
41    Invalidator {
42        task: current_task("turbo_tasks::get_invalidator()"),
43        turbo_tasks: with_turbo_tasks(Arc::downgrade),
44        handle,
45    }
46}
47
48pub struct Invalidator {
49    task: TaskId,
50    turbo_tasks: Weak<dyn TurboTasksApi>,
51    handle: Handle,
52}
53
54impl Invalidator {
55    pub fn invalidate(self) {
56        let Invalidator {
57            task,
58            turbo_tasks,
59            handle,
60        } = self;
61        let _guard = handle.enter();
62        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
63            turbo_tasks.invalidate(task);
64        }
65    }
66
67    pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
68        let Invalidator {
69            task,
70            turbo_tasks,
71            handle,
72        } = self;
73        let _guard = handle.enter();
74        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
75            turbo_tasks.invalidate_with_reason(
76                task,
77                (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
78            );
79        }
80    }
81
82    pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
83        let Invalidator {
84            task,
85            turbo_tasks,
86            handle,
87        } = self;
88        let _guard = handle.enter();
89        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
90            turbo_tasks
91                .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
92        }
93    }
94}
95
96impl Hash for Invalidator {
97    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98        self.task.hash(state);
99    }
100}
101
102impl PartialEq for Invalidator {
103    fn eq(&self, other: &Self) -> bool {
104        self.task == other.task
105    }
106}
107
108impl Eq for Invalidator {}
109
110impl TraceRawVcs for Invalidator {
111    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
112        // nothing here
113    }
114}
115
116impl Serialize for Invalidator {
117    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
118    where
119        S: serde::Serializer,
120    {
121        serializer.serialize_newtype_struct("Invalidator", &self.task)
122    }
123}
124
125impl<'de> Deserialize<'de> for Invalidator {
126    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
127    where
128        D: serde::Deserializer<'de>,
129    {
130        struct V;
131
132        impl<'de> Visitor<'de> for V {
133            type Value = Invalidator;
134
135            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136                write!(f, "an Invalidator")
137            }
138
139            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
140            where
141                D: serde::Deserializer<'de>,
142            {
143                Ok(Invalidator {
144                    task: TaskId::deserialize(deserializer)?,
145                    turbo_tasks: with_turbo_tasks(Arc::downgrade),
146                    handle: tokio::runtime::Handle::current(),
147                })
148            }
149        }
150        deserializer.deserialize_newtype_struct("Invalidator", V)
151    }
152}
153
154pub trait DynamicEqHash {
155    fn as_any(&self) -> &dyn Any;
156    fn dyn_eq(&self, other: &dyn Any) -> bool;
157    fn dyn_hash(&self, state: &mut dyn Hasher);
158}
159
160impl<T: Any + PartialEq + Eq + Hash> DynamicEqHash for T {
161    fn as_any(&self) -> &dyn Any {
162        self
163    }
164
165    fn dyn_eq(&self, other: &dyn Any) -> bool {
166        other
167            .downcast_ref::<Self>()
168            .map(|other| self.eq(other))
169            .unwrap_or(false)
170    }
171
172    fn dyn_hash(&self, state: &mut dyn Hasher) {
173        Hash::hash(&(TypeId::of::<Self>(), self), &mut HasherMut(state));
174    }
175}
176
177/// A user-facing reason why a task was invalidated. This should only be used
178/// for invalidation that were triggered by the user.
179///
180/// Reasons are deduplicated, so this need to implement [Eq] and [Hash]
181pub trait InvalidationReason: DynamicEqHash + Display + Send + Sync + 'static {
182    fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
183        None
184    }
185}
186
187/// Invalidation reason kind. This is used to merge multiple reasons of the same
188/// kind into a combined description.
189///
190/// Reason kinds are used a hash map key, so this need to implement [Eq] and
191/// [Hash]
192pub trait InvalidationReasonKind: DynamicEqHash + Send + Sync + 'static {
193    /// Displays a description of multiple invalidation reasons of the same
194    /// kind. It is only called with two or more reasons.
195    fn fmt(
196        &self,
197        data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
198        f: &mut std::fmt::Formatter<'_>,
199    ) -> std::fmt::Result;
200}
201
202macro_rules! impl_eq_hash {
203    ($ty:ty) => {
204        impl PartialEq for $ty {
205            fn eq(&self, other: &Self) -> bool {
206                DynamicEqHash::dyn_eq(self, other.as_any())
207            }
208        }
209
210        impl Eq for $ty {}
211
212        impl Hash for $ty {
213            fn hash<H: Hasher>(&self, state: &mut H) {
214                self.as_any().type_id().hash(state);
215                DynamicEqHash::dyn_hash(self, state as &mut dyn Hasher)
216            }
217        }
218    };
219}
220
221impl_eq_hash!(dyn InvalidationReason);
222impl_eq_hash!(dyn InvalidationReasonKind);
223
224#[derive(PartialEq, Eq, Hash)]
225enum MapKey {
226    Untyped {
227        unique_tag: usize,
228    },
229    Typed {
230        kind: StaticOrArc<dyn InvalidationReasonKind>,
231    },
232}
233
234enum MapEntry {
235    Single {
236        reason: StaticOrArc<dyn InvalidationReason>,
237    },
238    Multiple {
239        reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
240    },
241}
242
243/// A set of [InvalidationReason]s. They are automatically deduplicated and
244/// merged by kind during insertion. It implements [Display] to get a readable
245/// representation.
246#[derive(Default)]
247pub struct InvalidationReasonSet {
248    next_unique_tag: usize,
249    // We track typed and untyped entries in the same map to keep the occurrence order of entries.
250    map: FxIndexMap<MapKey, MapEntry>,
251}
252
253impl InvalidationReasonSet {
254    pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
255        if let Some(kind) = reason.kind() {
256            let key = MapKey::Typed { kind };
257            match self.map.entry(key) {
258                Entry::Occupied(mut entry) => {
259                    let entry = &mut *entry.get_mut();
260                    match replace(
261                        entry,
262                        MapEntry::Multiple {
263                            reasons: FxIndexSet::default(),
264                        },
265                    ) {
266                        MapEntry::Single {
267                            reason: existing_reason,
268                        } => {
269                            if reason == existing_reason {
270                                *entry = MapEntry::Single {
271                                    reason: existing_reason,
272                                };
273                                return;
274                            }
275                            let mut reasons = FxIndexSet::default();
276                            reasons.insert(existing_reason);
277                            reasons.insert(reason);
278                            *entry = MapEntry::Multiple { reasons };
279                        }
280                        MapEntry::Multiple { mut reasons } => {
281                            reasons.insert(reason);
282                            *entry = MapEntry::Multiple { reasons };
283                        }
284                    }
285                }
286                Entry::Vacant(entry) => {
287                    entry.insert(MapEntry::Single { reason });
288                }
289            }
290        } else {
291            let key = MapKey::Untyped {
292                unique_tag: self.next_unique_tag,
293            };
294            self.next_unique_tag += 1;
295            self.map.insert(key, MapEntry::Single { reason });
296        }
297    }
298
299    pub fn is_empty(&self) -> bool {
300        self.map.is_empty()
301    }
302
303    pub fn len(&self) -> usize {
304        self.map.len()
305    }
306}
307
308impl Display for InvalidationReasonSet {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        let count = self.map.len();
311        for (i, (key, entry)) in self.map.iter().enumerate() {
312            if i > 0 {
313                write!(f, ", ")?;
314                if i == count - 1 {
315                    write!(f, "and ")?;
316                }
317            }
318            match entry {
319                MapEntry::Single { reason } => {
320                    write!(f, "{reason}")?;
321                }
322                MapEntry::Multiple { reasons } => {
323                    let MapKey::Typed { kind } = key else {
324                        unreachable!("An untyped reason can't collect more than one reason");
325                    };
326                    kind.fmt(reasons, f)?
327                }
328            }
329        }
330        Ok(())
331    }
332}