turbo_tasks/
state.rs

1use std::{
2    fmt::Debug,
3    mem::take,
4    ops::{Deref, DerefMut},
5};
6
7use auto_hash_map::AutoSet;
8use parking_lot::{Mutex, MutexGuard};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
13    mark_stateful, trace::TraceRawVcs,
14};
15
16#[derive(Serialize, Deserialize)]
17struct StateInner<T> {
18    value: T,
19    invalidators: AutoSet<Invalidator>,
20}
21
22impl<T> StateInner<T> {
23    pub fn new(value: T) -> Self {
24        Self {
25            value,
26            invalidators: AutoSet::new(),
27        }
28    }
29
30    pub fn add_invalidator(&mut self, invalidator: Invalidator) {
31        self.invalidators.insert(invalidator);
32    }
33
34    pub fn set_unconditionally(&mut self, value: T) {
35        self.value = value;
36        for invalidator in take(&mut self.invalidators) {
37            invalidator.invalidate();
38        }
39    }
40
41    pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
42        if !update(&mut self.value) {
43            return false;
44        }
45        for invalidator in take(&mut self.invalidators) {
46            invalidator.invalidate();
47        }
48        true
49    }
50}
51
52impl<T: PartialEq> StateInner<T> {
53    pub fn set(&mut self, value: T) -> bool {
54        if self.value == value {
55            return false;
56        }
57        self.value = value;
58        for invalidator in take(&mut self.invalidators) {
59            invalidator.invalidate();
60        }
61        true
62    }
63}
64
65pub struct StateRef<'a, T> {
66    serialization_invalidator: Option<&'a SerializationInvalidator>,
67    inner: MutexGuard<'a, StateInner<T>>,
68    mutated: bool,
69}
70
71impl<T> Deref for StateRef<'_, T> {
72    type Target = T;
73
74    fn deref(&self) -> &Self::Target {
75        &self.inner.value
76    }
77}
78
79impl<T> DerefMut for StateRef<'_, T> {
80    fn deref_mut(&mut self) -> &mut Self::Target {
81        self.mutated = true;
82        &mut self.inner.value
83    }
84}
85
86impl<T> Drop for StateRef<'_, T> {
87    fn drop(&mut self) {
88        if self.mutated {
89            for invalidator in take(&mut self.inner.invalidators) {
90                invalidator.invalidate();
91            }
92            if let Some(serialization_invalidator) = self.serialization_invalidator {
93                serialization_invalidator.invalidate();
94            }
95        }
96    }
97}
98
99/// **This API violates core assumption of turbo-tasks, is believed to be unsound, and there's no
100/// plan fix it.** You should prefer to use [collectibles][crate::CollectiblesSource] instead of
101/// state where at all possible. This API may be removed in the future.
102///
103/// An [internally-mutable] type, similar to [`RefCell`][std::cell::RefCell] or [`Mutex`] that can
104/// be stored inside a [`VcValueType`].
105///
106/// **[`State`] should only be used with [`OperationVc`] and types that implement
107/// [`OperationValue`]**.
108///
109/// Setting values inside a [`State`] bypasses the normal argument and return value tracking
110/// that's tracks child function calls and re-runs tasks until their values settled. That system is
111/// needed for [strong consistency]. [`OperationVc`] ensures that function calls are reconnected
112/// with the parent/child call graph.
113///
114/// When reading a `State` with [`State::get`], the state itself (though not any values inside of
115/// it) is marked as a dependency of the current task.
116///
117/// [internally-mutable]: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
118/// [`VcValueType`]: crate::VcValueType
119/// [strong consistency]: crate::OperationVc::read_strongly_consistent
120/// [`OperationVc`]: crate::OperationVc
121/// [`OperationValue`]: crate::OperationValue
122#[derive(Serialize, Deserialize)]
123pub struct State<T> {
124    serialization_invalidator: SerializationInvalidator,
125    inner: Mutex<StateInner<T>>,
126}
127
128impl<T: Debug> Debug for State<T> {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("State")
131            .field("value", &self.inner.lock().value)
132            .finish()
133    }
134}
135
136impl<T: TraceRawVcs> TraceRawVcs for State<T> {
137    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
138        self.inner.lock().value.trace_raw_vcs(trace_context);
139    }
140}
141
142impl<T: Default + OperationValue> Default for State<T> {
143    fn default() -> Self {
144        // Need to be explicit to ensure marking as stateful.
145        Self::new(Default::default())
146    }
147}
148
149impl<T> PartialEq for State<T> {
150    fn eq(&self, _other: &Self) -> bool {
151        false
152    }
153}
154impl<T> Eq for State<T> {}
155
156impl<T> State<T> {
157    pub fn new(value: T) -> Self
158    where
159        T: OperationValue,
160    {
161        Self {
162            serialization_invalidator: mark_stateful(),
163            inner: Mutex::new(StateInner::new(value)),
164        }
165    }
166
167    /// Gets the current value of the state. The current task will be registered
168    /// as dependency of the state and will be invalidated when the state
169    /// changes.
170    pub fn get(&self) -> StateRef<'_, T> {
171        let invalidator = get_invalidator();
172        let mut inner = self.inner.lock();
173        inner.add_invalidator(invalidator);
174        StateRef {
175            serialization_invalidator: Some(&self.serialization_invalidator),
176            inner,
177            mutated: false,
178        }
179    }
180
181    /// Gets the current value of the state. Untracked.
182    pub fn get_untracked(&self) -> StateRef<'_, T> {
183        let inner = self.inner.lock();
184        StateRef {
185            serialization_invalidator: Some(&self.serialization_invalidator),
186            inner,
187            mutated: false,
188        }
189    }
190
191    /// Sets the current state without comparing it with the old value. This
192    /// should only be used if one is sure that the value has changed.
193    pub fn set_unconditionally(&self, value: T) {
194        {
195            let mut inner = self.inner.lock();
196            inner.set_unconditionally(value);
197        }
198        self.serialization_invalidator.invalidate();
199    }
200
201    /// Updates the current state with the `update` function. The `update`
202    /// function need to return `true` when the value was modified. Exposing
203    /// the current value from the `update` function is not allowed and will
204    /// result in incorrect cache invalidation.
205    pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
206        {
207            let mut inner = self.inner.lock();
208            if !inner.update_conditionally(update) {
209                return;
210            }
211        }
212        self.serialization_invalidator.invalidate();
213    }
214}
215
216impl<T: PartialEq> State<T> {
217    /// Update the current state when the `value` is different from the current
218    /// value. `T` must implement [PartialEq] for this to work.
219    pub fn set(&self, value: T) {
220        {
221            let mut inner = self.inner.lock();
222            if !inner.set(value) {
223                return;
224            }
225        }
226        self.serialization_invalidator.invalidate();
227    }
228}
229
230pub struct TransientState<T> {
231    inner: Mutex<StateInner<Option<T>>>,
232}
233
234impl<T> Serialize for TransientState<T> {
235    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236    where
237        S: serde::Serializer,
238    {
239        Serialize::serialize(&(), serializer)
240    }
241}
242
243impl<'de, T> Deserialize<'de> for TransientState<T> {
244    fn deserialize<D>(deserializer: D) -> Result<TransientState<T>, D::Error>
245    where
246        D: serde::Deserializer<'de>,
247    {
248        let () = Deserialize::deserialize(deserializer)?;
249        Ok(TransientState {
250            inner: Mutex::new(StateInner::new(Default::default())),
251        })
252    }
253}
254
255impl<T: Debug> Debug for TransientState<T> {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        f.debug_struct("TransientState")
258            .field("value", &self.inner.lock().value)
259            .finish()
260    }
261}
262
263impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
264    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
265        self.inner.lock().value.trace_raw_vcs(trace_context);
266    }
267}
268
269impl<T> Default for TransientState<T> {
270    fn default() -> Self {
271        // Need to be explicit to ensure marking as stateful.
272        Self::new()
273    }
274}
275
276impl<T> PartialEq for TransientState<T> {
277    fn eq(&self, _other: &Self) -> bool {
278        false
279    }
280}
281impl<T> Eq for TransientState<T> {}
282
283impl<T> TransientState<T> {
284    pub fn new() -> Self {
285        mark_stateful();
286        Self {
287            inner: Mutex::new(StateInner::new(None)),
288        }
289    }
290
291    /// Gets the current value of the state. The current task will be registered
292    /// as dependency of the state and will be invalidated when the state
293    /// changes.
294    pub fn get(&self) -> StateRef<'_, Option<T>> {
295        mark_session_dependent();
296        let invalidator = get_invalidator();
297        let mut inner = self.inner.lock();
298        inner.add_invalidator(invalidator);
299        StateRef {
300            serialization_invalidator: None,
301            inner,
302            mutated: false,
303        }
304    }
305
306    /// Gets the current value of the state. Untracked.
307    pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
308        let inner = self.inner.lock();
309        StateRef {
310            serialization_invalidator: None,
311            inner,
312            mutated: false,
313        }
314    }
315
316    /// Sets the current state without comparing it with the old value. This
317    /// should only be used if one is sure that the value has changed.
318    pub fn set_unconditionally(&self, value: T) {
319        let mut inner = self.inner.lock();
320        inner.set_unconditionally(Some(value));
321    }
322
323    /// Unset the current value without comparing it with the old value. This
324    /// should only be used if one is sure that the value has changed.
325    pub fn unset_unconditionally(&self) {
326        let mut inner = self.inner.lock();
327        inner.set_unconditionally(None);
328    }
329
330    /// Updates the current state with the `update` function. The `update`
331    /// function need to return `true` when the value was modified. Exposing
332    /// the current value from the `update` function is not allowed and will
333    /// result in incorrect cache invalidation.
334    pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
335        let mut inner = self.inner.lock();
336        inner.update_conditionally(update);
337    }
338}
339
340impl<T: PartialEq> TransientState<T> {
341    /// Update the current state when the `value` is different from the current
342    /// value. `T` must implement [PartialEq] for this to work.
343    pub fn set(&self, value: T) {
344        let mut inner = self.inner.lock();
345        inner.set(Some(value));
346    }
347
348    /// Unset the current value.
349    pub fn unset(&self) {
350        let mut inner = self.inner.lock();
351        inner.set(None);
352    }
353}