turbo_tasks/
invalidation.rs

1use std::{
2    fmt::Display,
3    hash::Hash,
4    mem::replace,
5    sync::{Arc, Weak},
6};
7
8use anyhow::Result;
9use bincode::{
10    Decode, Encode,
11    de::Decoder,
12    enc::Encoder,
13    error::{DecodeError, EncodeError},
14    impl_borrow_decode,
15};
16use indexmap::map::Entry;
17use tokio::runtime::Handle;
18use turbo_dyn_eq_hash::{
19    DynEq, DynHash, impl_eq_for_dyn, impl_hash_for_dyn, impl_partial_eq_for_dyn,
20};
21
22use crate::{
23    FxIndexMap, FxIndexSet, TaskId, TurboTasksApi,
24    manager::{current_task_if_available, mark_invalidator, with_turbo_tasks},
25    trace::TraceRawVcs,
26    util::StaticOrArc,
27};
28
29/// Get an [`Invalidator`] that can be used to invalidate the current task
30/// based on external events.
31/// Returns `None` if called outside of a task context.
32pub fn get_invalidator() -> Option<Invalidator> {
33    if let Some(task) = current_task_if_available("turbo_tasks::get_invalidator()") {
34        mark_invalidator();
35
36        let handle = Handle::current();
37        Some(Invalidator {
38            task,
39            turbo_tasks: with_turbo_tasks(Arc::downgrade),
40            handle,
41        })
42    } else {
43        None
44    }
45}
46
47pub struct Invalidator {
48    task: TaskId,
49    turbo_tasks: Weak<dyn TurboTasksApi>,
50    handle: Handle,
51}
52
53impl Invalidator {
54    pub fn invalidate(self) {
55        let Invalidator {
56            task,
57            turbo_tasks,
58            handle,
59        } = self;
60        let _guard = handle.enter();
61        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
62            turbo_tasks.invalidate(task);
63        }
64    }
65
66    pub fn invalidate_with_reason<T: InvalidationReason>(self, reason: T) {
67        let Invalidator {
68            task,
69            turbo_tasks,
70            handle,
71        } = self;
72        let _guard = handle.enter();
73        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
74            turbo_tasks.invalidate_with_reason(
75                task,
76                (Arc::new(reason) as Arc<dyn InvalidationReason>).into(),
77            );
78        }
79    }
80
81    pub fn invalidate_with_static_reason<T: InvalidationReason>(self, reason: &'static T) {
82        let Invalidator {
83            task,
84            turbo_tasks,
85            handle,
86        } = self;
87        let _guard = handle.enter();
88        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
89            turbo_tasks
90                .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into());
91        }
92    }
93}
94
95impl Hash for Invalidator {
96    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
97        self.task.hash(state);
98    }
99}
100
101impl PartialEq for Invalidator {
102    fn eq(&self, other: &Self) -> bool {
103        self.task == other.task
104    }
105}
106
107impl Eq for Invalidator {}
108
109impl TraceRawVcs for Invalidator {
110    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
111        // nothing here
112    }
113}
114
115impl Encode for Invalidator {
116    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
117        Encode::encode(&self.task, encoder)
118    }
119}
120
121impl<Context> Decode<Context> for Invalidator {
122    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
123        Ok(Invalidator {
124            task: Decode::decode(decoder)?,
125            turbo_tasks: with_turbo_tasks(Arc::downgrade),
126            handle: tokio::runtime::Handle::current(),
127        })
128    }
129}
130
131impl_borrow_decode!(Invalidator);
132
133/// A user-facing reason why a task was invalidated. This should only be used
134/// for invalidation that were triggered by the user.
135///
136/// Reasons are deduplicated, so this need to implement [Eq] and [Hash]
137pub trait InvalidationReason: DynEq + DynHash + Display + Send + Sync + 'static {
138    fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
139        None
140    }
141}
142
143/// Invalidation reason kind. This is used to merge multiple reasons of the same
144/// kind into a combined description.
145///
146/// Reason kinds are used a hash map key, so this need to implement [Eq] and
147/// [Hash]
148pub trait InvalidationReasonKind: DynEq + DynHash + Send + Sync + 'static {
149    /// Displays a description of multiple invalidation reasons of the same
150    /// kind. It is only called with two or more reasons.
151    fn fmt(
152        &self,
153        data: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
154        f: &mut std::fmt::Formatter<'_>,
155    ) -> std::fmt::Result;
156}
157
158impl_partial_eq_for_dyn!(dyn InvalidationReason);
159impl_eq_for_dyn!(dyn InvalidationReason);
160impl_hash_for_dyn!(dyn InvalidationReason);
161
162impl_partial_eq_for_dyn!(dyn InvalidationReasonKind);
163impl_eq_for_dyn!(dyn InvalidationReasonKind);
164impl_hash_for_dyn!(dyn InvalidationReasonKind);
165
166#[derive(PartialEq, Eq, Hash)]
167enum MapKey {
168    Untyped {
169        unique_tag: usize,
170    },
171    Typed {
172        kind: StaticOrArc<dyn InvalidationReasonKind>,
173    },
174}
175
176enum MapEntry {
177    Single {
178        reason: StaticOrArc<dyn InvalidationReason>,
179    },
180    Multiple {
181        reasons: FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
182    },
183}
184
185/// A set of [InvalidationReason]s. They are automatically deduplicated and
186/// merged by kind during insertion. It implements [Display] to get a readable
187/// representation.
188#[derive(Default)]
189pub struct InvalidationReasonSet {
190    next_unique_tag: usize,
191    // We track typed and untyped entries in the same map to keep the occurrence order of entries.
192    map: FxIndexMap<MapKey, MapEntry>,
193}
194
195impl InvalidationReasonSet {
196    pub(crate) fn insert(&mut self, reason: StaticOrArc<dyn InvalidationReason>) {
197        if let Some(kind) = reason.kind() {
198            let key = MapKey::Typed { kind };
199            match self.map.entry(key) {
200                Entry::Occupied(mut entry) => {
201                    let entry = &mut *entry.get_mut();
202                    match replace(
203                        entry,
204                        MapEntry::Multiple {
205                            reasons: FxIndexSet::default(),
206                        },
207                    ) {
208                        MapEntry::Single {
209                            reason: existing_reason,
210                        } => {
211                            if reason == existing_reason {
212                                *entry = MapEntry::Single {
213                                    reason: existing_reason,
214                                };
215                                return;
216                            }
217                            let mut reasons = FxIndexSet::default();
218                            reasons.insert(existing_reason);
219                            reasons.insert(reason);
220                            *entry = MapEntry::Multiple { reasons };
221                        }
222                        MapEntry::Multiple { mut reasons } => {
223                            reasons.insert(reason);
224                            *entry = MapEntry::Multiple { reasons };
225                        }
226                    }
227                }
228                Entry::Vacant(entry) => {
229                    entry.insert(MapEntry::Single { reason });
230                }
231            }
232        } else {
233            let key = MapKey::Untyped {
234                unique_tag: self.next_unique_tag,
235            };
236            self.next_unique_tag += 1;
237            self.map.insert(key, MapEntry::Single { reason });
238        }
239    }
240
241    pub fn is_empty(&self) -> bool {
242        self.map.is_empty()
243    }
244
245    pub fn len(&self) -> usize {
246        self.map.len()
247    }
248}
249
250impl Display for InvalidationReasonSet {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        let count = self.map.len();
253        for (i, (key, entry)) in self.map.iter().enumerate() {
254            if i > 0 {
255                write!(f, ", ")?;
256                if i == count - 1 {
257                    write!(f, "and ")?;
258                }
259            }
260            match entry {
261                MapEntry::Single { reason } => {
262                    write!(f, "{reason}")?;
263                }
264                MapEntry::Multiple { reasons } => {
265                    let MapKey::Typed { kind } = key else {
266                        unreachable!("An untyped reason can't collect more than one reason");
267                    };
268                    kind.fmt(reasons, f)?
269                }
270            }
271        }
272        Ok(())
273    }
274}