1use std::{
2 any::type_name,
3 fmt::Debug,
4 mem::take,
5 ops::{Deref, DerefMut},
6};
7
8use auto_hash_map::AutoSet;
9use parking_lot::{Mutex, MutexGuard};
10use serde::{Deserialize, Serialize};
11use tracing::trace_span;
12
13use crate::{
14 Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
15 mark_stateful, trace::TraceRawVcs,
16};
17
18#[derive(Serialize, Deserialize)]
19struct StateInner<T> {
20 value: T,
21 invalidators: AutoSet<Invalidator>,
22}
23
24impl<T> StateInner<T> {
25 pub fn new(value: T) -> Self {
26 Self {
27 value,
28 invalidators: AutoSet::new(),
29 }
30 }
31
32 pub fn add_invalidator(&mut self, invalidator: Invalidator) {
33 self.invalidators.insert(invalidator);
34 }
35
36 pub fn set_unconditionally(&mut self, value: T) {
37 self.value = value;
38 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
39 for invalidator in take(&mut self.invalidators) {
40 invalidator.invalidate();
41 }
42 }
43
44 pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
45 if !update(&mut self.value) {
46 return false;
47 }
48 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
49 for invalidator in take(&mut self.invalidators) {
50 invalidator.invalidate();
51 }
52 true
53 }
54}
55
56impl<T: PartialEq> StateInner<T> {
57 pub fn set(&mut self, value: T) -> bool {
58 if self.value == value {
59 return false;
60 }
61 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
62 self.value = value;
63 for invalidator in take(&mut self.invalidators) {
64 invalidator.invalidate();
65 }
66 true
67 }
68}
69
70pub struct StateRef<'a, T> {
71 serialization_invalidator: Option<&'a SerializationInvalidator>,
72 inner: MutexGuard<'a, StateInner<T>>,
73 mutated: bool,
74}
75
76impl<T> Deref for StateRef<'_, T> {
77 type Target = T;
78
79 fn deref(&self) -> &Self::Target {
80 &self.inner.value
81 }
82}
83
84impl<T> DerefMut for StateRef<'_, T> {
85 fn deref_mut(&mut self) -> &mut Self::Target {
86 self.mutated = true;
87 &mut self.inner.value
88 }
89}
90
91impl<T> Drop for StateRef<'_, T> {
92 fn drop(&mut self) {
93 if self.mutated {
94 let _span = trace_span!("state value changed", value_type = type_name::<T>()).entered();
95 for invalidator in take(&mut self.inner.invalidators) {
96 invalidator.invalidate();
97 }
98 if let Some(serialization_invalidator) = self.serialization_invalidator {
99 serialization_invalidator.invalidate();
100 }
101 }
102 }
103}
104
105#[derive(Serialize, Deserialize)]
129pub struct State<T> {
130 serialization_invalidator: SerializationInvalidator,
131 inner: Mutex<StateInner<T>>,
132}
133
134impl<T: Debug> Debug for State<T> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("State")
137 .field("value", &self.inner.lock().value)
138 .finish()
139 }
140}
141
142impl<T: TraceRawVcs> TraceRawVcs for State<T> {
143 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
144 self.inner.lock().value.trace_raw_vcs(trace_context);
145 }
146}
147
148impl<T: Default + OperationValue> Default for State<T> {
149 fn default() -> Self {
150 Self::new(Default::default())
152 }
153}
154
155impl<T> PartialEq for State<T> {
156 fn eq(&self, _other: &Self) -> bool {
157 false
158 }
159}
160impl<T> Eq for State<T> {}
161
162impl<T> State<T> {
163 pub fn new(value: T) -> Self
164 where
165 T: OperationValue,
166 {
167 Self {
168 serialization_invalidator: mark_stateful(),
169 inner: Mutex::new(StateInner::new(value)),
170 }
171 }
172
173 pub fn get(&self) -> StateRef<'_, T> {
177 let invalidator = get_invalidator();
178 let mut inner = self.inner.lock();
179 if let Some(invalidator) = invalidator {
180 inner.add_invalidator(invalidator);
181 }
182 StateRef {
183 serialization_invalidator: Some(&self.serialization_invalidator),
184 inner,
185 mutated: false,
186 }
187 }
188
189 pub fn get_untracked(&self) -> StateRef<'_, T> {
191 let inner = self.inner.lock();
192 StateRef {
193 serialization_invalidator: Some(&self.serialization_invalidator),
194 inner,
195 mutated: false,
196 }
197 }
198
199 pub fn set_unconditionally(&self, value: T) {
202 {
203 let mut inner = self.inner.lock();
204 inner.set_unconditionally(value);
205 }
206 self.serialization_invalidator.invalidate();
207 }
208
209 pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
214 {
215 let mut inner = self.inner.lock();
216 if !inner.update_conditionally(update) {
217 return;
218 }
219 }
220 self.serialization_invalidator.invalidate();
221 }
222}
223
224impl<T: PartialEq> State<T> {
225 pub fn set(&self, value: T) {
228 {
229 let mut inner = self.inner.lock();
230 if !inner.set(value) {
231 return;
232 }
233 }
234 self.serialization_invalidator.invalidate();
235 }
236}
237
238pub struct TransientState<T> {
239 inner: Mutex<StateInner<Option<T>>>,
240}
241
242impl<T> Serialize for TransientState<T> {
243 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244 where
245 S: serde::Serializer,
246 {
247 Serialize::serialize(&(), serializer)
248 }
249}
250
251impl<'de, T> Deserialize<'de> for TransientState<T> {
252 fn deserialize<D>(deserializer: D) -> Result<TransientState<T>, D::Error>
253 where
254 D: serde::Deserializer<'de>,
255 {
256 let () = Deserialize::deserialize(deserializer)?;
257 Ok(TransientState {
258 inner: Mutex::new(StateInner::new(Default::default())),
259 })
260 }
261}
262
263impl<T: Debug> Debug for TransientState<T> {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 f.debug_struct("TransientState")
266 .field("value", &self.inner.lock().value)
267 .finish()
268 }
269}
270
271impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
272 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
273 self.inner.lock().value.trace_raw_vcs(trace_context);
274 }
275}
276
277impl<T> Default for TransientState<T> {
278 fn default() -> Self {
279 Self::new()
281 }
282}
283
284impl<T> PartialEq for TransientState<T> {
285 fn eq(&self, _other: &Self) -> bool {
286 false
287 }
288}
289impl<T> Eq for TransientState<T> {}
290
291impl<T> TransientState<T> {
292 pub fn new() -> Self {
293 mark_stateful();
294 Self {
295 inner: Mutex::new(StateInner::new(None)),
296 }
297 }
298
299 pub fn get(&self) -> StateRef<'_, Option<T>> {
303 mark_session_dependent();
304 let invalidator = get_invalidator();
305 let mut inner = self.inner.lock();
306 if let Some(invalidator) = invalidator {
307 inner.add_invalidator(invalidator);
308 }
309 StateRef {
310 serialization_invalidator: None,
311 inner,
312 mutated: false,
313 }
314 }
315
316 pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
318 let inner = self.inner.lock();
319 StateRef {
320 serialization_invalidator: None,
321 inner,
322 mutated: false,
323 }
324 }
325
326 pub fn set_unconditionally(&self, value: T) {
329 let mut inner = self.inner.lock();
330 inner.set_unconditionally(Some(value));
331 }
332
333 pub fn unset_unconditionally(&self) {
336 let mut inner = self.inner.lock();
337 inner.set_unconditionally(None);
338 }
339
340 pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
345 let mut inner = self.inner.lock();
346 inner.update_conditionally(update);
347 }
348}
349
350impl<T: PartialEq> TransientState<T> {
351 pub fn set(&self, value: T) {
354 let mut inner = self.inner.lock();
355 inner.set(Some(value));
356 }
357
358 pub fn unset(&self) {
360 let mut inner = self.inner.lock();
361 inner.set(None);
362 }
363}