Skip to main content

turbo_tasks/
state.rs

1use std::{
2    fmt::Debug,
3    mem::take,
4    ops::{Deref, DerefMut},
5};
6
7use auto_hash_map::AutoSet;
8use bincode::{Decode, Encode};
9use parking_lot::{Mutex, MutexGuard};
10use tracing::trace_span;
11
12use crate::{
13    Invalidator, OperationValue, SerializationInvalidator, get_invalidator,
14    get_serialization_invalidator, manager::with_turbo_tasks, trace::TraceRawVcs,
15};
16
17#[derive(Encode, Decode)]
18struct StateInner<T> {
19    value: T,
20    invalidators: AutoSet<Invalidator>,
21}
22
23impl<T> StateInner<T> {
24    pub fn new(value: T) -> Self {
25        Self {
26            value,
27            invalidators: AutoSet::new(),
28        }
29    }
30
31    pub fn add_invalidator(&mut self, invalidator: Invalidator) {
32        self.invalidators.insert(invalidator);
33    }
34
35    /// Sets the value and returns the drained invalidators. The caller MUST
36    /// run them via [`run_invalidators`] *after* dropping the [`Mutex`] guard
37    /// — calling [`Invalidator::invalidate`] may grab locks in the backend which can lead to cycles
38    #[must_use]
39    fn set_unconditionally(&mut self, value: T) -> AutoSet<Invalidator> {
40        self.value = value;
41        take(&mut self.invalidators)
42    }
43
44    /// See [`Self::set_unconditionally`] for the locking contract on the
45    /// returned invalidators.
46    #[must_use]
47    fn update_conditionally(
48        &mut self,
49        update: impl FnOnce(&mut T) -> bool,
50    ) -> Option<AutoSet<Invalidator>> {
51        if !update(&mut self.value) {
52            return None;
53        }
54        Some(take(&mut self.invalidators))
55    }
56}
57
58impl<T: PartialEq> StateInner<T> {
59    /// See [`Self::set_unconditionally`] for the locking contract on the
60    /// returned invalidators.
61    #[must_use]
62    fn set(&mut self, value: T) -> Option<AutoSet<Invalidator>> {
63        if self.value == value {
64            return None;
65        }
66        self.value = value;
67        Some(take(&mut self.invalidators))
68    }
69}
70
71/// Notifies the backend that the [`State`] has been mutated: runs every
72/// dependent [`Invalidator`] and invalidates the serialized state. Must be
73/// called *outside* the [`StateInner`] mutex guard; see
74/// [`StateInner::set_unconditionally`] for why.
75///
76/// Both notifications resolve `TURBO_TASKS` from a task-local, so we do them
77/// inside a single [`with_turbo_tasks`] call to amortize that lookup.
78fn notify_mutated(
79    invalidators: AutoSet<Invalidator>,
80    serialization_invalidator: Option<&SerializationInvalidator>,
81) {
82    if invalidators.is_empty() && serialization_invalidator.is_none() {
83        return;
84    }
85    let _span = trace_span!("state value changed").entered();
86    with_turbo_tasks(|tt| {
87        for invalidator in invalidators {
88            invalidator.invalidate(&**tt);
89        }
90        if let Some(serialization_invalidator) = serialization_invalidator {
91            tt.invalidate_serialization(serialization_invalidator.task());
92        }
93    });
94}
95
96pub struct StateRef<'a, T> {
97    serialization_invalidator: Option<&'a SerializationInvalidator>,
98    // `Option` so `Drop` can `take()` the guard and release it before running
99    // invalidators. Always `Some` for the lifetime of the `StateRef` outside
100    // of `Drop`.
101    inner: Option<MutexGuard<'a, StateInner<T>>>,
102    mutated: bool,
103}
104
105impl<'a, T> StateRef<'a, T> {
106    fn new(
107        inner: MutexGuard<'a, StateInner<T>>,
108        serialization_invalidator: Option<&'a SerializationInvalidator>,
109    ) -> Self {
110        Self {
111            serialization_invalidator,
112            inner: Some(inner),
113            mutated: false,
114        }
115    }
116
117    fn inner(&self) -> &StateInner<T> {
118        self.inner.as_deref().expect("inner only None during Drop")
119    }
120
121    fn inner_mut(&mut self) -> &mut StateInner<T> {
122        self.inner
123            .as_deref_mut()
124            .expect("inner only None during Drop")
125    }
126}
127
128impl<T> Deref for StateRef<'_, T> {
129    type Target = T;
130
131    fn deref(&self) -> &Self::Target {
132        &self.inner().value
133    }
134}
135
136impl<T> DerefMut for StateRef<'_, T> {
137    fn deref_mut(&mut self) -> &mut Self::Target {
138        self.mutated = true;
139        &mut self.inner_mut().value
140    }
141}
142
143impl<T> Drop for StateRef<'_, T> {
144    fn drop(&mut self) {
145        if !self.mutated {
146            return;
147        }
148        // Drain invalidators while we still hold the guard, then drop the
149        // guard before running them. Running invalidators reaches into the
150        // backend and acquires task-storage shard locks, and the snapshot
151        // path takes the State mutex while holding such a shard lock — so
152        // running them under the guard is a lock-order inversion.
153        let mut guard = self.inner.take().expect("Drop only called once");
154        let invalidators = take(&mut guard.invalidators);
155        drop(guard);
156        notify_mutated(invalidators, self.serialization_invalidator);
157    }
158}
159
160pub mod parking_lot_mutex_bincode {
161    use bincode::{
162        BorrowDecode,
163        de::{BorrowDecoder, Decoder},
164        enc::Encoder,
165        error::{DecodeError, EncodeError},
166    };
167
168    use super::*;
169
170    pub fn encode<T: Encode, E: Encoder>(
171        mutex: &Mutex<T>,
172        encoder: &mut E,
173    ) -> Result<(), EncodeError> {
174        mutex.lock().encode(encoder)
175    }
176
177    pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
178        decoder: &mut D,
179    ) -> Result<Mutex<T>, DecodeError> {
180        Ok(Mutex::new(T::decode(decoder)?))
181    }
182
183    pub fn borrow_decode<
184        'de,
185        Context,
186        T: BorrowDecode<'de, Context>,
187        D: BorrowDecoder<'de, Context = Context>,
188    >(
189        decoder: &mut D,
190    ) -> Result<Mutex<T>, DecodeError> {
191        Ok(Mutex::new(T::borrow_decode(decoder)?))
192    }
193}
194
195/// **This API violates core assumption of turbo-tasks, is believed to be unsound, and there's no
196/// plan fix it.** You should prefer to use [collectibles][crate::CollectiblesSource] instead of
197/// state where at all possible. This API may be removed in the future.
198///
199/// An [internally-mutable] type, similar to [`RefCell`][std::cell::RefCell] or [`Mutex`] that can
200/// be stored inside a [`VcValueType`].
201///
202/// **[`State`] should only be used with [`OperationVc`] and types that implement
203/// [`OperationValue`]**.
204///
205/// Setting values inside a [`State`] bypasses the normal argument and return value tracking
206/// that's tracks child function calls and re-runs tasks until their values settled. That system is
207/// needed for [strong consistency]. [`OperationVc`] ensures that function calls are reconnected
208/// with the parent/child call graph.
209///
210/// When reading a `State` with [`State::get`], the state itself (though not any values inside of
211/// it) is marked as a dependency of the current task.
212///
213/// [internally-mutable]: https://doc.rust-lang.org/book/ch15-05-interior-mutability.html
214/// [`VcValueType`]: crate::VcValueType
215/// [strong consistency]: crate::OperationVc::read_strongly_consistent
216/// [`OperationVc`]: crate::OperationVc
217/// [`OperationValue`]: crate::OperationValue
218#[derive(Encode, Decode)]
219pub struct State<T> {
220    serialization_invalidator: SerializationInvalidator,
221    #[bincode(with = "parking_lot_mutex_bincode")]
222    inner: Mutex<StateInner<T>>,
223}
224
225impl<T: Debug> Debug for State<T> {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("State")
228            .field("value", &self.inner.lock().value)
229            .finish()
230    }
231}
232
233impl<T: TraceRawVcs> TraceRawVcs for State<T> {
234    fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
235        self.inner.lock().value.trace_raw_vcs(trace_context);
236    }
237}
238
239impl<T: Default + OperationValue> Default for State<T> {
240    fn default() -> Self {
241        // Need to be explicit to ensure marking as stateful.
242        Self::new(Default::default())
243    }
244}
245
246impl<T> PartialEq for State<T> {
247    fn eq(&self, _other: &Self) -> bool {
248        false
249    }
250}
251impl<T> Eq for State<T> {}
252
253impl<T> State<T> {
254    pub fn new(value: T) -> Self
255    where
256        T: OperationValue,
257    {
258        Self {
259            serialization_invalidator: get_serialization_invalidator(),
260            inner: Mutex::new(StateInner::new(value)),
261        }
262    }
263
264    /// Gets the current value of the state. The current task will be registered
265    /// as dependency of the state and will be invalidated when the state
266    /// changes.
267    pub fn get(&self) -> StateRef<'_, T> {
268        let invalidator = get_invalidator();
269        let mut inner = self.inner.lock();
270        if let Some(invalidator) = invalidator {
271            inner.add_invalidator(invalidator);
272        }
273        StateRef::new(inner, Some(&self.serialization_invalidator))
274    }
275
276    /// Gets the current value of the state. Untracked.
277    pub fn get_untracked(&self) -> StateRef<'_, T> {
278        let inner = self.inner.lock();
279        StateRef::new(inner, Some(&self.serialization_invalidator))
280    }
281
282    /// Sets the current state without comparing it with the old value. This
283    /// should only be used if one is sure that the value has changed.
284    pub fn set_unconditionally(&self, value: T) {
285        let invalidators = {
286            let mut inner = self.inner.lock();
287            inner.set_unconditionally(value)
288        };
289        notify_mutated(invalidators, Some(&self.serialization_invalidator));
290    }
291
292    /// Updates the current state with the `update` function. The `update`
293    /// function need to return `true` when the value was modified. Exposing
294    /// the current value from the `update` function is not allowed and will
295    /// result in incorrect cache invalidation.
296    pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
297        let Some(invalidators) = ({
298            let mut inner = self.inner.lock();
299            inner.update_conditionally(update)
300        }) else {
301            return;
302        };
303        notify_mutated(invalidators, Some(&self.serialization_invalidator));
304    }
305}
306
307impl<T: PartialEq> State<T> {
308    /// Update the current state when the `value` is different from the current
309    /// value. `T` must implement [PartialEq] for this to work.
310    pub fn set(&self, value: T) {
311        let Some(invalidators) = ({
312            let mut inner = self.inner.lock();
313            inner.set(value)
314        }) else {
315            return;
316        };
317        notify_mutated(invalidators, Some(&self.serialization_invalidator));
318    }
319}