turbo_tasks/
invalidation.rs

1use std::{
2    fmt::Display,
3    hash::Hash,
4    mem::replace,
5    sync::{Arc, Weak},
6};
7
8use anyhow::Result;
9use indexmap::map::Entry;
10use serde::{Deserialize, Serialize, de::Visitor};
11use tokio::runtime::Handle;
12use turbo_dyn_eq_hash::{
13    DynEq, DynHash, impl_eq_for_dyn, impl_hash_for_dyn, impl_partial_eq_for_dyn,
14};
15
16use crate::{
17    FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
18    manager::{current_task, mark_invalidator, with_turbo_tasks},
19    trace::TraceRawVcs,
20    util::StaticOrArc,
21};
22
23/// Get an [`Invalidator`] that can be used to invalidate the current task
24/// based on external events.
25pub fn get_invalidator() -> Invalidator {
26    mark_invalidator();
27
28    let handle = Handle::current();
29    Invalidator {
30        task: current_task("turbo_tasks::get_invalidator()"),
31        turbo_tasks: with_turbo_tasks(Arc::downgrade),
32        handle,
33    }
34}
35
36pub struct Invalidator {
37    task: TaskId,
38    turbo_tasks: Weak<dyn TurboTasksApi>,
39    handle: Handle,
40}
41
42impl Invalidator {
43    pub fn invalidate(self) {
44        let Invalidator {
45            task,
46            turbo_tasks,
47            handle,
48        } = self;
49        let _guard = handle.enter();
50        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
51            turbo_tasks.invalidate(task);
52        }
53    }
54
55    pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
56        let Invalidator {
57            task,
58            turbo_tasks,
59            handle,
60        } = self;
61        let _guard = handle.enter();
62        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
63            turbo_tasks.invalidate_with_reason(
64                task,
65                (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
66            );
67        }
68    }
69
70    pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
71        let Invalidator {
72            task,
73            turbo_tasks,
74            handle,
75        } = self;
76        let _guard = handle.enter();
77        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
78            turbo_tasks
79                .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
80        }
81    }
82}
83
84impl Hash for Invalidator {
85    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
86        self.task.hash(state);
87    }
88}
89
90impl PartialEq for Invalidator {
91    fn eq(&self, other: &Self) -> bool {
92        self.task == other.task
93    }
94}
95
96impl Eq for Invalidator {}
97
98impl TraceRawVcs for Invalidator {
99    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
100        // nothing here
101    }
102}
103
104impl Serialize for Invalidator {
105    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: serde::Serializer,
108    {
109        serializer.serialize_newtype_struct("Invalidator", &self.task)
110    }
111}
112
113impl<'de> Deserialize<'de> for Invalidator {
114    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115    where
116        D: serde::Deserializer<'de>,
117    {
118        struct V;
119
120        impl<'de> Visitor<'de> for V {
121            type Value = Invalidator;
122
123            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
124                write!(f, "an Invalidator")
125            }
126
127            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
128            where
129                D: serde::Deserializer<'de>,
130            {
131                Ok(Invalidator {
132                    task: TaskId::deserialize(deserializer)?,
133                    turbo_tasks: with_turbo_tasks(Arc::downgrade),
134                    handle: tokio::runtime::Handle::current(),
135                })
136            }
137        }
138        deserializer.deserialize_newtype_struct("Invalidator", V)
139    }
140}
141
142/// A user-facing reason why a task was invalidated. This should only be used
143/// for invalidation that were triggered by the user.
144///
145/// Reasons are deduplicated, so this need to implement [Eq] and [Hash]
146pub trait InvalidationReason: DynEq + DynHash + Display + Send + Sync + 'static {
147    fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
148        None
149    }
150}
151
152/// Invalidation reason kind. This is used to merge multiple reasons of the same
153/// kind into a combined description.
154///
155/// Reason kinds are used a hash map key, so this need to implement [Eq] and
156/// [Hash]
157pub trait InvalidationReasonKind: DynEq + DynHash + Send + Sync + 'static {
158    /// Displays a description of multiple invalidation reasons of the same
159    /// kind. It is only called with two or more reasons.
160    fn fmt(
161        &self,
162        data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
163        f: &mut std::fmt::Formatter<'_>,
164    ) -> std::fmt::Result;
165}
166
167impl_partial_eq_for_dyn!(dyn InvalidationReason);
168impl_eq_for_dyn!(dyn InvalidationReason);
169impl_hash_for_dyn!(dyn InvalidationReason);
170
171impl_partial_eq_for_dyn!(dyn InvalidationReasonKind);
172impl_eq_for_dyn!(dyn InvalidationReasonKind);
173impl_hash_for_dyn!(dyn InvalidationReasonKind);
174
175#[derive(PartialEq, Eq, Hash)]
176enum MapKey {
177    Untyped {
178        unique_tag: usize,
179    },
180    Typed {
181        kind: StaticOrArc<dyn InvalidationReasonKind>,
182    },
183}
184
185enum MapEntry {
186    Single {
187        reason: StaticOrArc<dyn InvalidationReason>,
188    },
189    Multiple {
190        reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
191    },
192}
193
194/// A set of [InvalidationReason]s. They are automatically deduplicated and
195/// merged by kind during insertion. It implements [Display] to get a readable
196/// representation.
197#[derive(Default)]
198pub struct InvalidationReasonSet {
199    next_unique_tag: usize,
200    // We track typed and untyped entries in the same map to keep the occurrence order of entries.
201    map: FxIndexMap<MapKey, MapEntry>,
202}
203
204impl InvalidationReasonSet {
205    pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
206        if let Some(kind) = reason.kind() {
207            let key = MapKey::Typed { kind };
208            match self.map.entry(key) {
209                Entry::Occupied(mut entry) => {
210                    let entry = &mut *entry.get_mut();
211                    match replace(
212                        entry,
213                        MapEntry::Multiple {
214                            reasons: FxIndexSet::default(),
215                        },
216                    ) {
217                        MapEntry::Single {
218                            reason: existing_reason,
219                        } => {
220                            if reason == existing_reason {
221                                *entry = MapEntry::Single {
222                                    reason: existing_reason,
223                                };
224                                return;
225                            }
226                            let mut reasons = FxIndexSet::default();
227                            reasons.insert(existing_reason);
228                            reasons.insert(reason);
229                            *entry = MapEntry::Multiple { reasons };
230                        }
231                        MapEntry::Multiple { mut reasons } => {
232                            reasons.insert(reason);
233                            *entry = MapEntry::Multiple { reasons };
234                        }
235                    }
236                }
237                Entry::Vacant(entry) => {
238                    entry.insert(MapEntry::Single { reason });
239                }
240            }
241        } else {
242            let key = MapKey::Untyped {
243                unique_tag: self.next_unique_tag,
244            };
245            self.next_unique_tag += 1;
246            self.map.insert(key, MapEntry::Single { reason });
247        }
248    }
249
250    pub fn is_empty(&self) -> bool {
251        self.map.is_empty()
252    }
253
254    pub fn len(&self) -> usize {
255        self.map.len()
256    }
257}
258
259impl Display for InvalidationReasonSet {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        let count = self.map.len();
262        for (i, (key, entry)) in self.map.iter().enumerate() {
263            if i > 0 {
264                write!(f, ", ")?;
265                if i == count - 1 {
266                    write!(f, "and ")?;
267                }
268            }
269            match entry {
270                MapEntry::Single { reason } => {
271                    write!(f, "{reason}")?;
272                }
273                MapEntry::Multiple { reasons } => {
274                    let MapKey::Typed { kind } = key else {
275                        unreachable!("An untyped reason can't collect more than one reason");
276                    };
277                    kind.fmt(reasons, f)?
278                }
279            }
280        }
281        Ok(())
282    }
283}