1use std::{
2 any::type_name,
3 fmt::Debug,
4 mem::take,
5 ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use bincode::{Decode, Encode};
10use parking_lot::{Mutex, MutexGuard};
11use tracing::trace_span;
12
13use crate::{
14 Invalidator, OperationValue, SerializationInvalidator, get_invalidator,
15 get_serialization_invalidator, manager::with_turbo_tasks, mark_session_dependent,
16 trace::TraceRawVcs,
17};
18
19#[derive(Encode, Decode)]
20struct StateInner<T> {
21 value: T,
22 invalidators: AutoSet<Invalidator>,
23}
24
25impl<T> StateInner<T> {
26 pub fn new(value: T) -> Self {
27 Self {
28 value,
29 invalidators: AutoSet::new(),
30 }
31 }
32
33 pub fn add_invalidator(&mut self, invalidator: Invalidator) {
34 self.invalidators.insert(invalidator);
35 }
36
37 pub fn set_unconditionally(&mut self, value: T) {
38 self.value = value;
39 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
40 let invalidators = take(&mut self.invalidators);
41 if !invalidators.is_empty() {
42 with_turbo_tasks(|tt| {
43 for invalidator in invalidators {
44 invalidator.invalidate(&**tt);
45 }
46 });
47 }
48 }
49
50 pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
51 if !update(&mut self.value) {
52 return false;
53 }
54 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
55 let invalidators = take(&mut self.invalidators);
56 if !invalidators.is_empty() {
57 with_turbo_tasks(|tt| {
58 for invalidator in invalidators {
59 invalidator.invalidate(&**tt);
60 }
61 });
62 }
63 true
64 }
65}
66
67impl<T: PartialEq> StateInner<T> {
68 pub fn set(&mut self, value: T) -> bool {
69 if self.value == value {
70 return false;
71 }
72 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
73 self.value = value;
74 let invalidators = take(&mut self.invalidators);
75 if !invalidators.is_empty() {
76 with_turbo_tasks(|tt| {
77 for invalidator in invalidators {
78 invalidator.invalidate(&**tt);
79 }
80 });
81 }
82 true
83 }
84}
85
86pub struct StateRef<'a, T> {
87 serialization_invalidator: Option<&'a SerializationInvalidator>,
88 inner: MutexGuard<'a, StateInner<T>>,
89 mutated: bool,
90}
91
92impl<T> Deref for StateRef<'_, T> {
93 type Target = T;
94
95 fn deref(&self) -> &Self::Target {
96 &self.inner.value
97 }
98}
99
100impl<T> DerefMut for StateRef<'_, T> {
101 fn deref_mut(&mut self) -> &mut Self::Target {
102 self.mutated = true;
103 &mut self.inner.value
104 }
105}
106
107impl<T> Drop for StateRef<'_, T> {
108 fn drop(&mut self) {
109 if self.mutated {
110 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
111 let invalidators = take(&mut self.inner.invalidators);
112 if !invalidators.is_empty() {
113 with_turbo_tasks(|tt| {
114 for invalidator in invalidators {
115 invalidator.invalidate(&**tt);
116 }
117 });
118 }
119 if let Some(serialization_invalidator) = self.serialization_invalidator {
120 serialization_invalidator.invalidate();
121 }
122 }
123 }
124}
125
126mod parking_lot_mutex_bincode {
127 use bincode::{
128 BorrowDecode,
129 de::{BorrowDecoder, Decoder},
130 enc::Encoder,
131 error::{DecodeError, EncodeError},
132 };
133
134 use super::*;
135
136 pub fn encode<T: Encode, E: Encoder>(
137 mutex: &Mutex<T>,
138 encoder: &mut E,
139 ) -> Result<(), EncodeError> {
140 mutex.lock().encode(encoder)
141 }
142
143 pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
144 decoder: &mut D,
145 ) -> Result<Mutex<T>, DecodeError> {
146 Ok(Mutex::new(T::decode(decoder)?))
147 }
148
149 pub fn borrow_decode<
150 'de,
151 Context,
152 T: BorrowDecode<'de, Context>,
153 D: BorrowDecoder<'de, Context = Context>,
154 >(
155 decoder: &mut D,
156 ) -> Result<Mutex<T>, DecodeError> {
157 Ok(Mutex::new(T::borrow_decode(decoder)?))
158 }
159}
160
161#[derive(Encode, Decode)]
185pub struct State<T> {
186 serialization_invalidator: SerializationInvalidator,
187 #[bincode(with = "parking_lot_mutex_bincode")]
188 inner: Mutex<StateInner<T>>,
189}
190
191impl<T: Debug> Debug for State<T> {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.debug_struct("State")
194 .field("value", &self.inner.lock().value)
195 .finish()
196 }
197}
198
199impl<T: TraceRawVcs> TraceRawVcs for State<T> {
200 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
201 self.inner.lock().value.trace_raw_vcs(trace_context);
202 }
203}
204
205impl<T: Default + OperationValue> Default for State<T> {
206 fn default() -> Self {
207 Self::new(Default::default())
209 }
210}
211
212impl<T> PartialEq for State<T> {
213 fn eq(&self, _other: &Self) -> bool {
214 false
215 }
216}
217impl<T> Eq for State<T> {}
218
219impl<T> State<T> {
220 pub fn new(value: T) -> Self
221 where
222 T: OperationValue,
223 {
224 Self {
225 serialization_invalidator: get_serialization_invalidator(),
226 inner: Mutex::new(StateInner::new(value)),
227 }
228 }
229
230 pub fn get(&self) -> StateRef<'_, T> {
234 let invalidator = get_invalidator();
235 let mut inner = self.inner.lock();
236 if let Some(invalidator) = invalidator {
237 inner.add_invalidator(invalidator);
238 }
239 StateRef {
240 serialization_invalidator: Some(&self.serialization_invalidator),
241 inner,
242 mutated: false,
243 }
244 }
245
246 pub fn get_untracked(&self) -> StateRef<'_, T> {
248 let inner = self.inner.lock();
249 StateRef {
250 serialization_invalidator: Some(&self.serialization_invalidator),
251 inner,
252 mutated: false,
253 }
254 }
255
256 pub fn set_unconditionally(&self, value: T) {
259 {
260 let mut inner = self.inner.lock();
261 inner.set_unconditionally(value);
262 }
263 self.serialization_invalidator.invalidate();
264 }
265
266 pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
271 {
272 let mut inner = self.inner.lock();
273 if !inner.update_conditionally(update) {
274 return;
275 }
276 }
277 self.serialization_invalidator.invalidate();
278 }
279}
280
281impl<T: PartialEq> State<T> {
282 pub fn set(&self, value: T) {
285 {
286 let mut inner = self.inner.lock();
287 if !inner.set(value) {
288 return;
289 }
290 }
291 self.serialization_invalidator.invalidate();
292 }
293}
294
295#[derive(Encode, Decode)]
296#[bincode(bounds = "")]
297pub struct TransientState<T> {
298 #[bincode(skip, default = "default_transient_state_inner")]
299 inner: Mutex<StateInner<Option<T>>>,
300}
301
302fn default_transient_state_inner<T>() -> Mutex<StateInner<Option<T>>> {
303 Mutex::new(StateInner::new(None))
304}
305
306impl<T: Debug> Debug for TransientState<T> {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 f.debug_struct("TransientState")
309 .field("value", &self.inner.lock().value)
310 .finish()
311 }
312}
313
314impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
315 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
316 self.inner.lock().value.trace_raw_vcs(trace_context);
317 }
318}
319
320impl<T> Default for TransientState<T> {
321 fn default() -> Self {
322 Self::new()
324 }
325}
326
327impl<T> PartialEq for TransientState<T> {
328 fn eq(&self, _other: &Self) -> bool {
329 false
330 }
331}
332impl<T> Eq for TransientState<T> {}
333
334impl<T> TransientState<T> {
335 pub fn new() -> Self {
336 Self {
337 inner: Mutex::new(StateInner::new(None)),
338 }
339 }
340
341 pub fn get(&self) -> StateRef<'_, Option<T>> {
345 mark_session_dependent();
346 let invalidator = get_invalidator();
347 let mut inner = self.inner.lock();
348 if let Some(invalidator) = invalidator {
349 inner.add_invalidator(invalidator);
350 }
351 StateRef {
352 serialization_invalidator: None,
353 inner,
354 mutated: false,
355 }
356 }
357
358 pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
360 let inner = self.inner.lock();
361 StateRef {
362 serialization_invalidator: None,
363 inner,
364 mutated: false,
365 }
366 }
367
368 pub fn set_unconditionally(&self, value: T) {
371 let mut inner = self.inner.lock();
372 inner.set_unconditionally(Some(value));
373 }
374
375 pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
380 let mut inner = self.inner.lock();
381 inner.update_conditionally(update);
382 }
383}
384
385impl<T: PartialEq> TransientState<T> {
386 pub fn set(&self, value: T) {
389 let mut inner = self.inner.lock();
390 inner.set(Some(value));
391 }
392
393 pub fn unset(&self) {
395 let mut inner = self.inner.lock();
396 inner.set(None);
397 }
398}