turbo_tasks/
state.rs

1use std::{
2    any::type_name,
3    fmt::Debug,
4    mem::take,
5    ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use bincode::{Decode, Encode};
10use parking_lot::{Mutex, MutexGuard};
11use tracing::trace_span;
12
13use crate::{
14    Invalidator, OperationValue, SerializationInvalidator, get_invalidator,
15    get_serialization_invalidator, manager::with_turbo_tasks, mark_session_dependent,
16    trace::TraceRawVcs,
17};
18
19#[derive(Encode, Decode)]
20struct StateInner<T> {
21    value: T,
22    invalidators: AutoSet<Invalidator>,
23}
24
25impl<T> StateInner<T> {
26    pub fn new(value: T) -> Self {
27        Self {
28            value,
29            invalidators: AutoSet::new(),
30        }
31    }
32
33    pub fn add_invalidator(&mut self, invalidator: Invalidator) {
34        self.invalidators.insert(invalidator);
35    }
36
37    pub fn set_unconditionally(&mut self, value: T) {
38        self.value = value;
39        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
40        let invalidators = take(&mut self.invalidators);
41        if !invalidators.is_empty() {
42            with_turbo_tasks(|tt| {
43                for invalidator in invalidators {
44                    invalidator.invalidate(&**tt);
45                }
46            });
47        }
48    }
49
50    pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
51        if !update(&mut self.value) {
52            return false;
53        }
54        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
55        let invalidators = take(&mut self.invalidators);
56        if !invalidators.is_empty() {
57            with_turbo_tasks(|tt| {
58                for invalidator in invalidators {
59                    invalidator.invalidate(&**tt);
60                }
61            });
62        }
63        true
64    }
65}
66
67impl<T: PartialEq> StateInner<T> {
68    pub fn set(&mut self, value: T) -> bool {
69        if self.value == value {
70            return false;
71        }
72        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
73        self.value = value;
74        let invalidators = take(&mut self.invalidators);
75        if !invalidators.is_empty() {
76            with_turbo_tasks(|tt| {
77                for invalidator in invalidators {
78                    invalidator.invalidate(&**tt);
79                }
80            });
81        }
82        true
83    }
84}
85
86pub struct StateRef<'a, T> {
87    serialization_invalidator: Option<&'a SerializationInvalidator>,
88    inner: MutexGuard<'a, StateInner<T>>,
89    mutated: bool,
90}
91
92impl<T> Deref for StateRef<'_, T> {
93    type Target = T;
94
95    fn deref(&self) -> &Self::Target {
96        &self.inner.value
97    }
98}
99
100impl<T> DerefMut for StateRef<'_, T> {
101    fn deref_mut(&mut self) -> &mut Self::Target {
102        self.mutated = true;
103        &mut self.inner.value
104    }
105}
106
107impl<T> Drop for StateRef<'_, T> {
108    fn drop(&mut self) {
109        if self.mutated {
110            let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
111            let invalidators = take(&mut self.inner.invalidators);
112            if !invalidators.is_empty() {
113                with_turbo_tasks(|tt| {
114                    for invalidator in invalidators {
115                        invalidator.invalidate(&**tt);
116                    }
117                });
118            }
119            if let Some(serialization_invalidator) = self.serialization_invalidator {
120                serialization_invalidator.invalidate();
121            }
122        }
123    }
124}
125
126mod parking_lot_mutex_bincode {
127    use bincode::{
128        BorrowDecode,
129        de::{BorrowDecoder, Decoder},
130        enc::Encoder,
131        error::{DecodeError, EncodeError},
132    };
133
134    use super::*;
135
136    pub fn encode<T: Encode, E: Encoder>(
137        mutex: &Mutex<T>,
138        encoder: &mut E,
139    ) -> Result<(), EncodeError> {
140        mutex.lock().encode(encoder)
141    }
142
143    pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
144        decoder: &mut D,
145    ) -> Result<Mutex<T>, DecodeError> {
146        Ok(Mutex::new(T::decode(decoder)?))
147    }
148
149    pub fn borrow_decode<
150        'de,
151        Context,
152        T: BorrowDecode<'de, Context>,
153        D: BorrowDecoder<'de, Context = Context>,
154    >(
155        decoder: &mut D,
156    ) -> Result<Mutex<T>, DecodeError> {
157        Ok(Mutex::new(T::borrow_decode(decoder)?))
158    }
159}
160
161/// **This API violates core assumption of turbo-tasks, is believed to be unsound, and there's no
162/// plan fix it.** You should prefer to use [collectibles][crate::CollectiblesSource] instead of
163/// state where at all possible. This API may be removed in the future.
164///
165/// An [internally-mutable] type, similar to [`RefCell`][std::cell::RefCell] or [`Mutex`] that can
166/// be stored inside a [`VcValueType`].
167///
168/// **[`State`] should only be used with [`OperationVc`] and types that implement
169/// [`OperationValue`]**.
170///
171/// Setting values inside a [`State`] bypasses the normal argument and return value tracking
172/// that's tracks child function calls and re-runs tasks until their values settled. That system is
173/// needed for [strong consistency]. [`OperationVc`] ensures that function calls are reconnected
174/// with the parent/child call graph.
175///
176/// When reading a `State` with [`State::get`], the state itself (though not any values inside of
177/// it) is marked as a dependency of the current task.
178///
179/// [internally-mutable]: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
180/// [`VcValueType`]: crate::VcValueType
181/// [strong consistency]: crate::OperationVc::read_strongly_consistent
182/// [`OperationVc`]: crate::OperationVc
183/// [`OperationValue`]: crate::OperationValue
184#[derive(Encode, Decode)]
185pub struct State<T> {
186    serialization_invalidator: SerializationInvalidator,
187    #[bincode(with = "parking_lot_mutex_bincode")]
188    inner: Mutex<StateInner<T>>,
189}
190
191impl<T: Debug> Debug for State<T> {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.debug_struct("State")
194            .field("value", &self.inner.lock().value)
195            .finish()
196    }
197}
198
199impl<T: TraceRawVcs> TraceRawVcs for State<T> {
200    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
201        self.inner.lock().value.trace_raw_vcs(trace_context);
202    }
203}
204
205impl<T: Default + OperationValue> Default for State<T> {
206    fn default() -> Self {
207        // Need to be explicit to ensure marking as stateful.
208        Self::new(Default::default())
209    }
210}
211
212impl<T> PartialEq for State<T> {
213    fn eq(&self, _other: &Self) -> bool {
214        false
215    }
216}
217impl<T> Eq for State<T> {}
218
219impl<T> State<T> {
220    pub fn new(value: T) -> Self
221    where
222        T: OperationValue,
223    {
224        Self {
225            serialization_invalidator: get_serialization_invalidator(),
226            inner: Mutex::new(StateInner::new(value)),
227        }
228    }
229
230    /// Gets the current value of the state. The current task will be registered
231    /// as dependency of the state and will be invalidated when the state
232    /// changes.
233    pub fn get(&self) -> StateRef<'_, T> {
234        let invalidator = get_invalidator();
235        let mut inner = self.inner.lock();
236        if let Some(invalidator) = invalidator {
237            inner.add_invalidator(invalidator);
238        }
239        StateRef {
240            serialization_invalidator: Some(&self.serialization_invalidator),
241            inner,
242            mutated: false,
243        }
244    }
245
246    /// Gets the current value of the state. Untracked.
247    pub fn get_untracked(&self) -> StateRef<'_, T> {
248        let inner = self.inner.lock();
249        StateRef {
250            serialization_invalidator: Some(&self.serialization_invalidator),
251            inner,
252            mutated: false,
253        }
254    }
255
256    /// Sets the current state without comparing it with the old value. This
257    /// should only be used if one is sure that the value has changed.
258    pub fn set_unconditionally(&self, value: T) {
259        {
260            let mut inner = self.inner.lock();
261            inner.set_unconditionally(value);
262        }
263        self.serialization_invalidator.invalidate();
264    }
265
266    /// Updates the current state with the `update` function. The `update`
267    /// function need to return `true` when the value was modified. Exposing
268    /// the current value from the `update` function is not allowed and will
269    /// result in incorrect cache invalidation.
270    pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
271        {
272            let mut inner = self.inner.lock();
273            if !inner.update_conditionally(update) {
274                return;
275            }
276        }
277        self.serialization_invalidator.invalidate();
278    }
279}
280
281impl<T: PartialEq> State<T> {
282    /// Update the current state when the `value` is different from the current
283    /// value. `T` must implement [PartialEq] for this to work.
284    pub fn set(&self, value: T) {
285        {
286            let mut inner = self.inner.lock();
287            if !inner.set(value) {
288                return;
289            }
290        }
291        self.serialization_invalidator.invalidate();
292    }
293}
294
295#[derive(Encode, Decode)]
296#[bincode(bounds = "")]
297pub struct TransientState<T> {
298    #[bincode(skip, default = "default_transient_state_inner")]
299    inner: Mutex<StateInner<Option<T>>>,
300}
301
302fn default_transient_state_inner<T>() -> Mutex<StateInner<Option<T>>> {
303    Mutex::new(StateInner::new(None))
304}
305
306impl<T: Debug> Debug for TransientState<T> {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        f.debug_struct("TransientState")
309            .field("value", &self.inner.lock().value)
310            .finish()
311    }
312}
313
314impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
315    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
316        self.inner.lock().value.trace_raw_vcs(trace_context);
317    }
318}
319
320impl<T> Default for TransientState<T> {
321    fn default() -> Self {
322        // Need to be explicit to ensure marking as stateful.
323        Self::new()
324    }
325}
326
327impl<T> PartialEq for TransientState<T> {
328    fn eq(&self, _other: &Self) -> bool {
329        false
330    }
331}
332impl<T> Eq for TransientState<T> {}
333
334impl<T> TransientState<T> {
335    pub fn new() -> Self {
336        Self {
337            inner: Mutex::new(StateInner::new(None)),
338        }
339    }
340
341    /// Gets the current value of the state. The current task will be registered
342    /// as dependency of the state and will be invalidated when the state
343    /// changes.
344    pub fn get(&self) -> StateRef<'_, Option<T>> {
345        mark_session_dependent();
346        let invalidator = get_invalidator();
347        let mut inner = self.inner.lock();
348        if let Some(invalidator) = invalidator {
349            inner.add_invalidator(invalidator);
350        }
351        StateRef {
352            serialization_invalidator: None,
353            inner,
354            mutated: false,
355        }
356    }
357
358    /// Gets the current value of the state. Untracked.
359    pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
360        let inner = self.inner.lock();
361        StateRef {
362            serialization_invalidator: None,
363            inner,
364            mutated: false,
365        }
366    }
367
368    /// Sets the current state without comparing it with the old value. This
369    /// should only be used if one is sure that the value has changed.
370    pub fn set_unconditionally(&self, value: T) {
371        let mut inner = self.inner.lock();
372        inner.set_unconditionally(Some(value));
373    }
374
375    /// Updates the current state with the `update` function. The `update`
376    /// function need to return `true` when the value was modified. Exposing
377    /// the current value from the `update` function is not allowed and will
378    /// result in incorrect cache invalidation.
379    pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
380        let mut inner = self.inner.lock();
381        inner.update_conditionally(update);
382    }
383}
384
385impl<T: PartialEq> TransientState<T> {
386    /// Update the current state when the `value` is different from the current
387    /// value. `T` must implement [PartialEq] for this to work.
388    pub fn set(&self, value: T) {
389        let mut inner = self.inner.lock();
390        inner.set(Some(value));
391    }
392
393    /// Unset the current value.
394    pub fn unset(&self) {
395        let mut inner = self.inner.lock();
396        inner.set(None);
397    }
398}