1use std::{
2 any::type_name,
3 fmt::Debug,
4 mem::take,
5 ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use bincode::{Decode, Encode};
10use parking_lot::{Mutex, MutexGuard};
11use tracing::trace_span;
12
13use crate::{
14 Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
15 mark_stateful, trace::TraceRawVcs,
16};
17
18#[derive(Encode, Decode)]
19struct StateInner<T> {
20 value: T,
21 invalidators: AutoSet<Invalidator>,
22}
23
24impl<T> StateInner<T> {
25 pub fn new(value: T) -> Self {
26 Self {
27 value,
28 invalidators: AutoSet::new(),
29 }
30 }
31
32 pub fn add_invalidator(&mut self, invalidator: Invalidator) {
33 self.invalidators.insert(invalidator);
34 }
35
36 pub fn set_unconditionally(&mut self, value: T) {
37 self.value = value;
38 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
39 for invalidator in take(&mut self.invalidators) {
40 invalidator.invalidate();
41 }
42 }
43
44 pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
45 if !update(&mut self.value) {
46 return false;
47 }
48 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
49 for invalidator in take(&mut self.invalidators) {
50 invalidator.invalidate();
51 }
52 true
53 }
54}
55
56impl<T: PartialEq> StateInner<T> {
57 pub fn set(&mut self, value: T) -> bool {
58 if self.value == value {
59 return false;
60 }
61 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
62 self.value = value;
63 for invalidator in take(&mut self.invalidators) {
64 invalidator.invalidate();
65 }
66 true
67 }
68}
69
70pub struct StateRef<'a, T> {
71 serialization_invalidator: Option<&'a SerializationInvalidator>,
72 inner: MutexGuard<'a, StateInner<T>>,
73 mutated: bool,
74}
75
76impl<T> Deref for StateRef<'_, T> {
77 type Target = T;
78
79 fn deref(&self) -> &Self::Target {
80 &self.inner.value
81 }
82}
83
84impl<T> DerefMut for StateRef<'_, T> {
85 fn deref_mut(&mut self) -> &mut Self::Target {
86 self.mutated = true;
87 &mut self.inner.value
88 }
89}
90
91impl<T> Drop for StateRef<'_, T> {
92 fn drop(&mut self) {
93 if self.mutated {
94 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
95 for invalidator in take(&mut self.inner.invalidators) {
96 invalidator.invalidate();
97 }
98 if let Some(serialization_invalidator) = self.serialization_invalidator {
99 serialization_invalidator.invalidate();
100 }
101 }
102 }
103}
104
105mod parking_lot_mutex_bincode {
106 use bincode::{
107 BorrowDecode,
108 de::{BorrowDecoder, Decoder},
109 enc::Encoder,
110 error::{DecodeError, EncodeError},
111 };
112
113 use super::*;
114
115 pub fn encode<T: Encode, E: Encoder>(
116 mutex: &Mutex<T>,
117 encoder: &mut E,
118 ) -> Result<(), EncodeError> {
119 mutex.lock().encode(encoder)
120 }
121
122 pub fn decode<Context, T: Decode<Context>, D: Decoder<Context = Context>>(
123 decoder: &mut D,
124 ) -> Result<Mutex<T>, DecodeError> {
125 Ok(Mutex::new(T::decode(decoder)?))
126 }
127
128 pub fn borrow_decode<
129 'de,
130 Context,
131 T: BorrowDecode<'de, Context>,
132 D: BorrowDecoder<'de, Context = Context>,
133 >(
134 decoder: &mut D,
135 ) -> Result<Mutex<T>, DecodeError> {
136 Ok(Mutex::new(T::borrow_decode(decoder)?))
137 }
138}
139
140#[derive(Encode, Decode)]
164pub struct State<T> {
165 serialization_invalidator: SerializationInvalidator,
166 #[bincode(with = "parking_lot_mutex_bincode")]
167 inner: Mutex<StateInner<T>>,
168}
169
170impl<T: Debug> Debug for State<T> {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 f.debug_struct("State")
173 .field("value", &self.inner.lock().value)
174 .finish()
175 }
176}
177
178impl<T: TraceRawVcs> TraceRawVcs for State<T> {
179 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
180 self.inner.lock().value.trace_raw_vcs(trace_context);
181 }
182}
183
184impl<T: Default + OperationValue> Default for State<T> {
185 fn default() -> Self {
186 Self::new(Default::default())
188 }
189}
190
191impl<T> PartialEq for State<T> {
192 fn eq(&self, _other: &Self) -> bool {
193 false
194 }
195}
196impl<T> Eq for State<T> {}
197
198impl<T> State<T> {
199 pub fn new(value: T) -> Self
200 where
201 T: OperationValue,
202 {
203 Self {
204 serialization_invalidator: mark_stateful(),
205 inner: Mutex::new(StateInner::new(value)),
206 }
207 }
208
209 pub fn get(&self) -> StateRef<'_, T> {
213 let invalidator = get_invalidator();
214 let mut inner = self.inner.lock();
215 if let Some(invalidator) = invalidator {
216 inner.add_invalidator(invalidator);
217 }
218 StateRef {
219 serialization_invalidator: Some(&self.serialization_invalidator),
220 inner,
221 mutated: false,
222 }
223 }
224
225 pub fn get_untracked(&self) -> StateRef<'_, T> {
227 let inner = self.inner.lock();
228 StateRef {
229 serialization_invalidator: Some(&self.serialization_invalidator),
230 inner,
231 mutated: false,
232 }
233 }
234
235 pub fn set_unconditionally(&self, value: T) {
238 {
239 let mut inner = self.inner.lock();
240 inner.set_unconditionally(value);
241 }
242 self.serialization_invalidator.invalidate();
243 }
244
245 pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
250 {
251 let mut inner = self.inner.lock();
252 if !inner.update_conditionally(update) {
253 return;
254 }
255 }
256 self.serialization_invalidator.invalidate();
257 }
258}
259
260impl<T: PartialEq> State<T> {
261 pub fn set(&self, value: T) {
264 {
265 let mut inner = self.inner.lock();
266 if !inner.set(value) {
267 return;
268 }
269 }
270 self.serialization_invalidator.invalidate();
271 }
272}
273
274#[derive(Encode, Decode)]
275#[bincode(bounds = "")]
276pub struct TransientState<T> {
277 #[bincode(skip, default = "default_transient_state_inner")]
278 inner: Mutex<StateInner<Option<T>>>,
279}
280
281fn default_transient_state_inner<T>() -> Mutex<StateInner<Option<T>>> {
282 Mutex::new(StateInner::new(None))
283}
284
285impl<T: Debug> Debug for TransientState<T> {
286 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287 f.debug_struct("TransientState")
288 .field("value", &self.inner.lock().value)
289 .finish()
290 }
291}
292
293impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
294 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
295 self.inner.lock().value.trace_raw_vcs(trace_context);
296 }
297}
298
299impl<T> Default for TransientState<T> {
300 fn default() -> Self {
301 Self::new()
303 }
304}
305
306impl<T> PartialEq for TransientState<T> {
307 fn eq(&self, _other: &Self) -> bool {
308 false
309 }
310}
311impl<T> Eq for TransientState<T> {}
312
313impl<T> TransientState<T> {
314 pub fn new() -> Self {
315 mark_stateful();
316 Self {
317 inner: Mutex::new(StateInner::new(None)),
318 }
319 }
320
321 pub fn get(&self) -> StateRef<'_, Option<T>> {
325 mark_session_dependent();
326 let invalidator = get_invalidator();
327 let mut inner = self.inner.lock();
328 if let Some(invalidator) = invalidator {
329 inner.add_invalidator(invalidator);
330 }
331 StateRef {
332 serialization_invalidator: None,
333 inner,
334 mutated: false,
335 }
336 }
337
338 pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
340 let inner = self.inner.lock();
341 StateRef {
342 serialization_invalidator: None,
343 inner,
344 mutated: false,
345 }
346 }
347
348 pub fn set_unconditionally(&self, value: T) {
351 let mut inner = self.inner.lock();
352 inner.set_unconditionally(Some(value));
353 }
354
355 pub fn unset_unconditionally(&self) {
358 let mut inner = self.inner.lock();
359 inner.set_unconditionally(None);
360 }
361
362 pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
367 let mut inner = self.inner.lock();
368 inner.update_conditionally(update);
369 }
370}
371
372impl<T: PartialEq> TransientState<T> {
373 pub fn set(&self, value: T) {
376 let mut inner = self.inner.lock();
377 inner.set(Some(value));
378 }
379
380 pub fn unset(&self) {
382 let mut inner = self.inner.lock();
383 inner.set(None);
384 }
385}