1use std::{
2 fmt::Debug,
3 mem::take,
4 ops::{Deref, DerefMut},
5};
6
7use auto_hash_map::AutoSet;
8use bincode::{Decode, Encode};
9use parking_lot::{Mutex, MutexGuard};
10use tracing::trace_span;
11
12use crate::{
13 Invalidator, OperationValue, SerializationInvalidator, get_invalidator,
14 get_serialization_invalidator, manager::with_turbo_tasks, trace::TraceRawVcs,
15};
16
17#[derive(Encode, Decode)]
18struct StateInner<T> {
19 value: T,
20 invalidators: AutoSet<Invalidator>,
21}
22
23impl<T> StateInner<T> {
24 pub fn new(value: T) -> Self {
25 Self {
26 value,
27 invalidators: AutoSet::new(),
28 }
29 }
30
31 pub fn add_invalidator(&mut self, invalidator: Invalidator) {
32 self.invalidators.insert(invalidator);
33 }
34
35 #[must_use]
39 fn set_unconditionally(&mut self, value: T) -> AutoSet<Invalidator> {
40 self.value = value;
41 take(&mut self.invalidators)
42 }
43
44 #[must_use]
47 fn update_conditionally(
48 &mut self,
49 update: impl FnOnce(&mut T) -> bool,
50 ) -> Option<AutoSet<Invalidator>> {
51 if !update(&mut self.value) {
52 return None;
53 }
54 Some(take(&mut self.invalidators))
55 }
56}
57
58impl<T: PartialEq> StateInner<T> {
59 #[must_use]
62 fn set(&mut self, value: T) -> Option<AutoSet<Invalidator>> {
63 if self.value == value {
64 return None;
65 }
66 self.value = value;
67 Some(take(&mut self.invalidators))
68 }
69}
70
71fn notify_mutated(
79 invalidators: AutoSet<Invalidator>,
80 serialization_invalidator: Option<&SerializationInvalidator>,
81) {
82 if invalidators.is_empty() && serialization_invalidator.is_none() {
83 return;
84 }
85 let _span = trace_span!("state value changed").entered();
86 with_turbo_tasks(|tt| {
87 for invalidator in invalidators {
88 invalidator.invalidate(&**tt);
89 }
90 if let Some(serialization_invalidator) = serialization_invalidator {
91 tt.invalidate_serialization(serialization_invalidator.task());
92 }
93 });
94}
95
96pub struct StateRef<'a, T> {
97 serialization_invalidator: Option<&'a SerializationInvalidator>,
98 inner: Option<MutexGuard<'a, StateInner<T>>>,
102 mutated: bool,
103}
104
105impl<'a, T> StateRef<'a, T> {
106 fn new(
107 inner: MutexGuard<'a, StateInner<T>>,
108 serialization_invalidator: Option<&'a SerializationInvalidator>,
109 ) -> Self {
110 Self {
111 serialization_invalidator,
112 inner: Some(inner),
113 mutated: false,
114 }
115 }
116
117 fn inner(&self) -> &StateInner<T> {
118 self.inner.as_deref().expect("inner only None during Drop")
119 }
120
121 fn inner_mut(&mut self) -> &mut StateInner<T> {
122 self.inner
123 .as_deref_mut()
124 .expect("inner only None during Drop")
125 }
126}
127
128impl<T> Deref for StateRef<'_, T> {
129 type Target = T;
130
131 fn deref(&self) -> &Self::Target {
132 &self.inner().value
133 }
134}
135
136impl<T> DerefMut for StateRef<'_, T> {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 self.mutated = true;
139 &mut self.inner_mut().value
140 }
141}
142
143impl<T> Drop for StateRef<'_, T> {
144 fn drop(&mut self) {
145 if !self.mutated {
146 return;
147 }
148 let mut guard = self.inner.take().expect("Drop only called once");
154 let invalidators = take(&mut guard.invalidators);
155 drop(guard);
156 notify_mutated(invalidators, self.serialization_invalidator);
157 }
158}
159
160pub mod parking_lot_mutex_bincode {
161 use bincode::{
162 BorrowDecode,
163 de::{BorrowDecoder, Decoder},
164 enc::Encoder,
165 error::{DecodeError, EncodeError},
166 };
167
168 use super::*;
169
170 pub fn encode<T: Encode, E: Encoder>(
171 mutex: &Mutex<T>,
172 encoder: &mut E,
173 ) -> Result<(), EncodeError> {
174 mutex.lock().encode(encoder)
175 }
176
177 pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
178 decoder: &mut D,
179 ) -> Result<Mutex<T>, DecodeError> {
180 Ok(Mutex::new(T::decode(decoder)?))
181 }
182
183 pub fn borrow_decode<
184 'de,
185 Context,
186 T: BorrowDecode<'de, Context>,
187 D: BorrowDecoder<'de, Context = Context>,
188 >(
189 decoder: &mut D,
190 ) -> Result<Mutex<T>, DecodeError> {
191 Ok(Mutex::new(T::borrow_decode(decoder)?))
192 }
193}
194
195#[derive(Encode, Decode)]
219pub struct State<T> {
220 serialization_invalidator: SerializationInvalidator,
221 #[bincode(with = "parking_lot_mutex_bincode")]
222 inner: Mutex<StateInner<T>>,
223}
224
225impl<T: Debug> Debug for State<T> {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("State")
228 .field("value", &self.inner.lock().value)
229 .finish()
230 }
231}
232
233impl<T: TraceRawVcs> TraceRawVcs for State<T> {
234 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
235 self.inner.lock().value.trace_raw_vcs(trace_context);
236 }
237}
238
239impl<T: Default + OperationValue> Default for State<T> {
240 fn default() -> Self {
241 Self::new(Default::default())
243 }
244}
245
246impl<T> PartialEq for State<T> {
247 fn eq(&self, _other: &Self) -> bool {
248 false
249 }
250}
251impl<T> Eq for State<T> {}
252
253impl<T> State<T> {
254 pub fn new(value: T) -> Self
255 where
256 T: OperationValue,
257 {
258 Self {
259 serialization_invalidator: get_serialization_invalidator(),
260 inner: Mutex::new(StateInner::new(value)),
261 }
262 }
263
264 pub fn get(&self) -> StateRef<'_, T> {
268 let invalidator = get_invalidator();
269 let mut inner = self.inner.lock();
270 if let Some(invalidator) = invalidator {
271 inner.add_invalidator(invalidator);
272 }
273 StateRef::new(inner, Some(&self.serialization_invalidator))
274 }
275
276 pub fn get_untracked(&self) -> StateRef<'_, T> {
278 let inner = self.inner.lock();
279 StateRef::new(inner, Some(&self.serialization_invalidator))
280 }
281
282 pub fn set_unconditionally(&self, value: T) {
285 let invalidators = {
286 let mut inner = self.inner.lock();
287 inner.set_unconditionally(value)
288 };
289 notify_mutated(invalidators, Some(&self.serialization_invalidator));
290 }
291
292 pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
297 let Some(invalidators) = ({
298 let mut inner = self.inner.lock();
299 inner.update_conditionally(update)
300 }) else {
301 return;
302 };
303 notify_mutated(invalidators, Some(&self.serialization_invalidator));
304 }
305}
306
307impl<T: PartialEq> State<T> {
308 pub fn set(&self, value: T) {
311 let Some(invalidators) = ({
312 let mut inner = self.inner.lock();
313 inner.set(value)
314 }) else {
315 return;
316 };
317 notify_mutated(invalidators, Some(&self.serialization_invalidator));
318 }
319}