turbo_tasks/
state.rs

1use std::{
2    any::type_name,
3    fmt::Debug,
4    mem::take,
5    ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use parking_lot::{Mutex, MutexGuard};
10use serde::{Deserialize, Serialize};
11use tracing::trace_span;
12
13use crate::{
14    Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
15    mark_stateful, trace::TraceRawVcs,
16};
17
18#[derive(Serialize, Deserialize)]
19struct StateInner<T> {
20    value: T,
21    invalidators: AutoSet<Invalidator>,
22}
23
24impl<T> StateInner<T> {
25    pub fn new(value: T) -> Self {
26        Self {
27            value,
28            invalidators: AutoSet::new(),
29        }
30    }
31
32    pub fn add_invalidator(&mut self, invalidator: Invalidator) {
33        self.invalidators.insert(invalidator);
34    }
35
36    pub fn set_unconditionally(&mut self, value: T) {
37        self.value = value;
38        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
39        for invalidator in take(&mut self.invalidators) {
40            invalidator.invalidate();
41        }
42    }
43
44    pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
45        if !update(&mut self.value) {
46            return false;
47        }
48        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
49        for invalidator in take(&mut self.invalidators) {
50            invalidator.invalidate();
51        }
52        true
53    }
54}
55
56impl<T: PartialEq> StateInner<T> {
57    pub fn set(&mut self, value: T) -> bool {
58        if self.value == value {
59            return false;
60        }
61        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
62        self.value = value;
63        for invalidator in take(&mut self.invalidators) {
64            invalidator.invalidate();
65        }
66        true
67    }
68}
69
70pub struct StateRef<'a, T> {
71    serialization_invalidator: Option<&'a SerializationInvalidator>,
72    inner: MutexGuard<'a, StateInner<T>>,
73    mutated: bool,
74}
75
76impl<T> Deref for StateRef<'_, T> {
77    type Target = T;
78
79    fn deref(&self) -> &Self::Target {
80        &self.inner.value
81    }
82}
83
84impl<T> DerefMut for StateRef<'_, T> {
85    fn deref_mut(&mut self) -> &mut Self::Target {
86        self.mutated = true;
87        &mut self.inner.value
88    }
89}
90
91impl<T> Drop for StateRef<'_, T> {
92    fn drop(&mut self) {
93        if self.mutated {
94            let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
95            for invalidator in take(&mut self.inner.invalidators) {
96                invalidator.invalidate();
97            }
98            if let Some(serialization_invalidator) = self.serialization_invalidator {
99                serialization_invalidator.invalidate();
100            }
101        }
102    }
103}
104
105/// **This API violates core assumption of turbo-tasks, is believed to be unsound, and there's no
106/// plan fix it.** You should prefer to use [collectibles][crate::CollectiblesSource] instead of
107/// state where at all possible. This API may be removed in the future.
108///
109/// An [internally-mutable] type, similar to [`RefCell`][std::cell::RefCell] or [`Mutex`] that can
110/// be stored inside a [`VcValueType`].
111///
112/// **[`State`] should only be used with [`OperationVc`] and types that implement
113/// [`OperationValue`]**.
114///
115/// Setting values inside a [`State`] bypasses the normal argument and return value tracking
116/// that's tracks child function calls and re-runs tasks until their values settled. That system is
117/// needed for [strong consistency]. [`OperationVc`] ensures that function calls are reconnected
118/// with the parent/child call graph.
119///
120/// When reading a `State` with [`State::get`], the state itself (though not any values inside of
121/// it) is marked as a dependency of the current task.
122///
123/// [internally-mutable]: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
124/// [`VcValueType`]: crate::VcValueType
125/// [strong consistency]: crate::OperationVc::read_strongly_consistent
126/// [`OperationVc`]: crate::OperationVc
127/// [`OperationValue`]: crate::OperationValue
128#[derive(Serialize, Deserialize)]
129pub struct State<T> {
130    serialization_invalidator: SerializationInvalidator,
131    inner: Mutex<StateInner<T>>,
132}
133
134impl<T: Debug> Debug for State<T> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("State")
137            .field("value", &self.inner.lock().value)
138            .finish()
139    }
140}
141
142impl<T: TraceRawVcs> TraceRawVcs for State<T> {
143    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
144        self.inner.lock().value.trace_raw_vcs(trace_context);
145    }
146}
147
148impl<T: Default + OperationValue> Default for State<T> {
149    fn default() -> Self {
150        // Need to be explicit to ensure marking as stateful.
151        Self::new(Default::default())
152    }
153}
154
155impl<T> PartialEq for State<T> {
156    fn eq(&self, _other: &Self) -> bool {
157        false
158    }
159}
160impl<T> Eq for State<T> {}
161
162impl<T> State<T> {
163    pub fn new(value: T) -> Self
164    where
165        T: OperationValue,
166    {
167        Self {
168            serialization_invalidator: mark_stateful(),
169            inner: Mutex::new(StateInner::new(value)),
170        }
171    }
172
173    /// Gets the current value of the state. The current task will be registered
174    /// as dependency of the state and will be invalidated when the state
175    /// changes.
176    pub fn get(&self) -> StateRef<'_, T> {
177        let invalidator = get_invalidator();
178        let mut inner = self.inner.lock();
179        if let Some(invalidator) = invalidator {
180            inner.add_invalidator(invalidator);
181        }
182        StateRef {
183            serialization_invalidator: Some(&self.serialization_invalidator),
184            inner,
185            mutated: false,
186        }
187    }
188
189    /// Gets the current value of the state. Untracked.
190    pub fn get_untracked(&self) -> StateRef<'_, T> {
191        let inner = self.inner.lock();
192        StateRef {
193            serialization_invalidator: Some(&self.serialization_invalidator),
194            inner,
195            mutated: false,
196        }
197    }
198
199    /// Sets the current state without comparing it with the old value. This
200    /// should only be used if one is sure that the value has changed.
201    pub fn set_unconditionally(&self, value: T) {
202        {
203            let mut inner = self.inner.lock();
204            inner.set_unconditionally(value);
205        }
206        self.serialization_invalidator.invalidate();
207    }
208
209    /// Updates the current state with the `update` function. The `update`
210    /// function need to return `true` when the value was modified. Exposing
211    /// the current value from the `update` function is not allowed and will
212    /// result in incorrect cache invalidation.
213    pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
214        {
215            let mut inner = self.inner.lock();
216            if !inner.update_conditionally(update) {
217                return;
218            }
219        }
220        self.serialization_invalidator.invalidate();
221    }
222}
223
224impl<T: PartialEq> State<T> {
225    /// Update the current state when the `value` is different from the current
226    /// value. `T` must implement [PartialEq] for this to work.
227    pub fn set(&self, value: T) {
228        {
229            let mut inner = self.inner.lock();
230            if !inner.set(value) {
231                return;
232            }
233        }
234        self.serialization_invalidator.invalidate();
235    }
236}
237
238pub struct TransientState<T> {
239    inner: Mutex<StateInner<Option<T>>>,
240}
241
242impl<T> Serialize for TransientState<T> {
243    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244    where
245        S: serde::Serializer,
246    {
247        Serialize::serialize(&(), serializer)
248    }
249}
250
251impl<'de, T> Deserialize<'de> for TransientState<T> {
252    fn deserialize<D>(deserializer: D) -> Result<TransientState<T>, D::Error>
253    where
254        D: serde::Deserializer<'de>,
255    {
256        let () = Deserialize::deserialize(deserializer)?;
257        Ok(TransientState {
258            inner: Mutex::new(StateInner::new(Default::default())),
259        })
260    }
261}
262
263impl<T: Debug> Debug for TransientState<T> {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        f.debug_struct("TransientState")
266            .field("value", &self.inner.lock().value)
267            .finish()
268    }
269}
270
271impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
272    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
273        self.inner.lock().value.trace_raw_vcs(trace_context);
274    }
275}
276
277impl<T> Default for TransientState<T> {
278    fn default() -> Self {
279        // Need to be explicit to ensure marking as stateful.
280        Self::new()
281    }
282}
283
284impl<T> PartialEq for TransientState<T> {
285    fn eq(&self, _other: &Self) -> bool {
286        false
287    }
288}
289impl<T> Eq for TransientState<T> {}
290
291impl<T> TransientState<T> {
292    pub fn new() -> Self {
293        mark_stateful();
294        Self {
295            inner: Mutex::new(StateInner::new(None)),
296        }
297    }
298
299    /// Gets the current value of the state. The current task will be registered
300    /// as dependency of the state and will be invalidated when the state
301    /// changes.
302    pub fn get(&self) -> StateRef<'_, Option<T>> {
303        mark_session_dependent();
304        let invalidator = get_invalidator();
305        let mut inner = self.inner.lock();
306        if let Some(invalidator) = invalidator {
307            inner.add_invalidator(invalidator);
308        }
309        StateRef {
310            serialization_invalidator: None,
311            inner,
312            mutated: false,
313        }
314    }
315
316    /// Gets the current value of the state. Untracked.
317    pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
318        let inner = self.inner.lock();
319        StateRef {
320            serialization_invalidator: None,
321            inner,
322            mutated: false,
323        }
324    }
325
326    /// Sets the current state without comparing it with the old value. This
327    /// should only be used if one is sure that the value has changed.
328    pub fn set_unconditionally(&self, value: T) {
329        let mut inner = self.inner.lock();
330        inner.set_unconditionally(Some(value));
331    }
332
333    /// Unset the current value without comparing it with the old value. This
334    /// should only be used if one is sure that the value has changed.
335    pub fn unset_unconditionally(&self) {
336        let mut inner = self.inner.lock();
337        inner.set_unconditionally(None);
338    }
339
340    /// Updates the current state with the `update` function. The `update`
341    /// function need to return `true` when the value was modified. Exposing
342    /// the current value from the `update` function is not allowed and will
343    /// result in incorrect cache invalidation.
344    pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
345        let mut inner = self.inner.lock();
346        inner.update_conditionally(update);
347    }
348}
349
350impl<T: PartialEq> TransientState<T> {
351    /// Update the current state when the `value` is different from the current
352    /// value. `T` must implement [PartialEq] for this to work.
353    pub fn set(&self, value: T) {
354        let mut inner = self.inner.lock();
355        inner.set(Some(value));
356    }
357
358    /// Unset the current value.
359    pub fn unset(&self) {
360        let mut inner = self.inner.lock();
361        inner.set(None);
362    }
363}