turbo_tasks/
state.rs

1use std::{
2    any::type_name,
3    fmt::Debug,
4    mem::take,
5    ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use bincode::{Decode, Encode};
10use parking_lot::{Mutex, MutexGuard};
11use tracing::trace_span;
12
13use crate::{
14    Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
15    mark_stateful, trace::TraceRawVcs,
16};
17
18#[derive(Encode, Decode)]
19struct StateInner<T> {
20    value: T,
21    invalidators: AutoSet<Invalidator>,
22}
23
24impl<T> StateInner<T> {
25    pub fn new(value: T) -> Self {
26        Self {
27            value,
28            invalidators: AutoSet::new(),
29        }
30    }
31
32    pub fn add_invalidator(&mut self, invalidator: Invalidator) {
33        self.invalidators.insert(invalidator);
34    }
35
36    pub fn set_unconditionally(&mut self, value: T) {
37        self.value = value;
38        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
39        for invalidator in take(&mut self.invalidators) {
40            invalidator.invalidate();
41        }
42    }
43
44    pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
45        if !update(&mut self.value) {
46            return false;
47        }
48        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
49        for invalidator in take(&mut self.invalidators) {
50            invalidator.invalidate();
51        }
52        true
53    }
54}
55
56impl<T: PartialEq> StateInner<T> {
57    pub fn set(&mut self, value: T) -> bool {
58        if self.value == value {
59            return false;
60        }
61        let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
62        self.value = value;
63        for invalidator in take(&mut self.invalidators) {
64            invalidator.invalidate();
65        }
66        true
67    }
68}
69
70pub struct StateRef<'a, T> {
71    serialization_invalidator: Option<&'a SerializationInvalidator>,
72    inner: MutexGuard<'a, StateInner<T>>,
73    mutated: bool,
74}
75
76impl<T> Deref for StateRef<'_, T> {
77    type Target = T;
78
79    fn deref(&self) -> &Self::Target {
80        &self.inner.value
81    }
82}
83
84impl<T> DerefMut for StateRef<'_, T> {
85    fn deref_mut(&mut self) -> &mut Self::Target {
86        self.mutated = true;
87        &mut self.inner.value
88    }
89}
90
91impl<T> Drop for StateRef<'_, T> {
92    fn drop(&mut self) {
93        if self.mutated {
94            let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
95            for invalidator in take(&mut self.inner.invalidators) {
96                invalidator.invalidate();
97            }
98            if let Some(serialization_invalidator) = self.serialization_invalidator {
99                serialization_invalidator.invalidate();
100            }
101        }
102    }
103}
104
105mod parking_lot_mutex_bincode {
106    use bincode::{
107        BorrowDecode,
108        de::{BorrowDecoder, Decoder},
109        enc::Encoder,
110        error::{DecodeError, EncodeError},
111    };
112
113    use super::*;
114
115    pub fn encode<T: Encode, E: Encoder>(
116        mutex: &Mutex<T>,
117        encoder: &mut E,
118    ) -> Result<(), EncodeError> {
119        mutex.lock().encode(encoder)
120    }
121
122    pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
123        decoder: &mut D,
124    ) -> Result<Mutex<T>, DecodeError> {
125        Ok(Mutex::new(T::decode(decoder)?))
126    }
127
128    pub fn borrow_decode<
129        'de,
130        Context,
131        T: BorrowDecode<'de, Context>,
132        D: BorrowDecoder<'de, Context = Context>,
133    >(
134        decoder: &mut D,
135    ) -> Result<Mutex<T>, DecodeError> {
136        Ok(Mutex::new(T::borrow_decode(decoder)?))
137    }
138}
139
140/// **This API violates core assumption of turbo-tasks, is believed to be unsound, and there's no
141/// plan fix it.** You should prefer to use [collectibles][crate::CollectiblesSource] instead of
142/// state where at all possible. This API may be removed in the future.
143///
144/// An [internally-mutable] type, similar to [`RefCell`][std::cell::RefCell] or [`Mutex`] that can
145/// be stored inside a [`VcValueType`].
146///
147/// **[`State`] should only be used with [`OperationVc`] and types that implement
148/// [`OperationValue`]**.
149///
150/// Setting values inside a [`State`] bypasses the normal argument and return value tracking
151/// that's tracks child function calls and re-runs tasks until their values settled. That system is
152/// needed for [strong consistency]. [`OperationVc`] ensures that function calls are reconnected
153/// with the parent/child call graph.
154///
155/// When reading a `State` with [`State::get`], the state itself (though not any values inside of
156/// it) is marked as a dependency of the current task.
157///
158/// [internally-mutable]: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
159/// [`VcValueType`]: crate::VcValueType
160/// [strong consistency]: crate::OperationVc::read_strongly_consistent
161/// [`OperationVc`]: crate::OperationVc
162/// [`OperationValue`]: crate::OperationValue
163#[derive(Encode, Decode)]
164pub struct State<T> {
165    serialization_invalidator: SerializationInvalidator,
166    #[bincode(with = "parking_lot_mutex_bincode")]
167    inner: Mutex<StateInner<T>>,
168}
169
170impl<T: Debug> Debug for State<T> {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        f.debug_struct("State")
173            .field("value", &self.inner.lock().value)
174            .finish()
175    }
176}
177
178impl<T: TraceRawVcs> TraceRawVcs for State<T> {
179    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
180        self.inner.lock().value.trace_raw_vcs(trace_context);
181    }
182}
183
184impl<T: Default + OperationValue> Default for State<T> {
185    fn default() -> Self {
186        // Need to be explicit to ensure marking as stateful.
187        Self::new(Default::default())
188    }
189}
190
191impl<T> PartialEq for State<T> {
192    fn eq(&self, _other: &Self) -> bool {
193        false
194    }
195}
196impl<T> Eq for State<T> {}
197
198impl<T> State<T> {
199    pub fn new(value: T) -> Self
200    where
201        T: OperationValue,
202    {
203        Self {
204            serialization_invalidator: mark_stateful(),
205            inner: Mutex::new(StateInner::new(value)),
206        }
207    }
208
209    /// Gets the current value of the state. The current task will be registered
210    /// as dependency of the state and will be invalidated when the state
211    /// changes.
212    pub fn get(&self) -> StateRef<'_, T> {
213        let invalidator = get_invalidator();
214        let mut inner = self.inner.lock();
215        if let Some(invalidator) = invalidator {
216            inner.add_invalidator(invalidator);
217        }
218        StateRef {
219            serialization_invalidator: Some(&self.serialization_invalidator),
220            inner,
221            mutated: false,
222        }
223    }
224
225    /// Gets the current value of the state. Untracked.
226    pub fn get_untracked(&self) -> StateRef<'_, T> {
227        let inner = self.inner.lock();
228        StateRef {
229            serialization_invalidator: Some(&self.serialization_invalidator),
230            inner,
231            mutated: false,
232        }
233    }
234
235    /// Sets the current state without comparing it with the old value. This
236    /// should only be used if one is sure that the value has changed.
237    pub fn set_unconditionally(&self, value: T) {
238        {
239            let mut inner = self.inner.lock();
240            inner.set_unconditionally(value);
241        }
242        self.serialization_invalidator.invalidate();
243    }
244
245    /// Updates the current state with the `update` function. The `update`
246    /// function need to return `true` when the value was modified. Exposing
247    /// the current value from the `update` function is not allowed and will
248    /// result in incorrect cache invalidation.
249    pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
250        {
251            let mut inner = self.inner.lock();
252            if !inner.update_conditionally(update) {
253                return;
254            }
255        }
256        self.serialization_invalidator.invalidate();
257    }
258}
259
260impl<T: PartialEq> State<T> {
261    /// Update the current state when the `value` is different from the current
262    /// value. `T` must implement [PartialEq] for this to work.
263    pub fn set(&self, value: T) {
264        {
265            let mut inner = self.inner.lock();
266            if !inner.set(value) {
267                return;
268            }
269        }
270        self.serialization_invalidator.invalidate();
271    }
272}
273
274#[derive(Encode, Decode)]
275#[bincode(bounds = "")]
276pub struct TransientState<T> {
277    #[bincode(skip, default = "default_transient_state_inner")]
278    inner: Mutex<StateInner<Option<T>>>,
279}
280
281fn default_transient_state_inner<T>() -> Mutex<StateInner<Option<T>>> {
282    Mutex::new(StateInner::new(None))
283}
284
285impl<T: Debug> Debug for TransientState<T> {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        f.debug_struct("TransientState")
288            .field("value", &self.inner.lock().value)
289            .finish()
290    }
291}
292
293impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
294    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
295        self.inner.lock().value.trace_raw_vcs(trace_context);
296    }
297}
298
299impl<T> Default for TransientState<T> {
300    fn default() -> Self {
301        // Need to be explicit to ensure marking as stateful.
302        Self::new()
303    }
304}
305
306impl<T> PartialEq for TransientState<T> {
307    fn eq(&self, _other: &Self) -> bool {
308        false
309    }
310}
311impl<T> Eq for TransientState<T> {}
312
313impl<T> TransientState<T> {
314    pub fn new() -> Self {
315        mark_stateful();
316        Self {
317            inner: Mutex::new(StateInner::new(None)),
318        }
319    }
320
321    /// Gets the current value of the state. The current task will be registered
322    /// as dependency of the state and will be invalidated when the state
323    /// changes.
324    pub fn get(&self) -> StateRef<'_, Option<T>> {
325        mark_session_dependent();
326        let invalidator = get_invalidator();
327        let mut inner = self.inner.lock();
328        if let Some(invalidator) = invalidator {
329            inner.add_invalidator(invalidator);
330        }
331        StateRef {
332            serialization_invalidator: None,
333            inner,
334            mutated: false,
335        }
336    }
337
338    /// Gets the current value of the state. Untracked.
339    pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
340        let inner = self.inner.lock();
341        StateRef {
342            serialization_invalidator: None,
343            inner,
344            mutated: false,
345        }
346    }
347
348    /// Sets the current state without comparing it with the old value. This
349    /// should only be used if one is sure that the value has changed.
350    pub fn set_unconditionally(&self, value: T) {
351        let mut inner = self.inner.lock();
352        inner.set_unconditionally(Some(value));
353    }
354
355    /// Unset the current value without comparing it with the old value. This
356    /// should only be used if one is sure that the value has changed.
357    pub fn unset_unconditionally(&self) {
358        let mut inner = self.inner.lock();
359        inner.set_unconditionally(None);
360    }
361
362    /// Updates the current state with the `update` function. The `update`
363    /// function need to return `true` when the value was modified. Exposing
364    /// the current value from the `update` function is not allowed and will
365    /// result in incorrect cache invalidation.
366    pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
367        let mut inner = self.inner.lock();
368        inner.update_conditionally(update);
369    }
370}
371
372impl<T: PartialEq> TransientState<T> {
373    /// Update the current state when the `value` is different from the current
374    /// value. `T` must implement [PartialEq] for this to work.
375    pub fn set(&self, value: T) {
376        let mut inner = self.inner.lock();
377        inner.set(Some(value));
378    }
379
380    /// Unset the current value.
381    pub fn unset(&self) {
382        let mut inner = self.inner.lock();
383        inner.set(None);
384    }
385}