1use std::{
2 fmt::Debug,
3 mem::take,
4 ops::{Deref, DerefMut},
5};
6
7use auto_hash_map::AutoSet;
8use parking_lot::{Mutex, MutexGuard};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 Invalidator, OperationValue, SerializationInvalidator, get_invalidator, mark_session_dependent,
13 mark_stateful, trace::TraceRawVcs,
14};
15
16#[derive(Serialize, Deserialize)]
17struct StateInner<T> {
18 value: T,
19 invalidators: AutoSet<Invalidator>,
20}
21
22impl<T> StateInner<T> {
23 pub fn new(value: T) -> Self {
24 Self {
25 value,
26 invalidators: AutoSet::new(),
27 }
28 }
29
30 pub fn add_invalidator(&mut self, invalidator: Invalidator) {
31 self.invalidators.insert(invalidator);
32 }
33
34 pub fn set_unconditionally(&mut self, value: T) {
35 self.value = value;
36 for invalidator in take(&mut self.invalidators) {
37 invalidator.invalidate();
38 }
39 }
40
41 pub fn update_conditionally(&mut self, update: impl FnOnce(&mut T) -> bool) -> bool {
42 if !update(&mut self.value) {
43 return false;
44 }
45 for invalidator in take(&mut self.invalidators) {
46 invalidator.invalidate();
47 }
48 true
49 }
50}
51
52impl<T: PartialEq> StateInner<T> {
53 pub fn set(&mut self, value: T) -> bool {
54 if self.value == value {
55 return false;
56 }
57 self.value = value;
58 for invalidator in take(&mut self.invalidators) {
59 invalidator.invalidate();
60 }
61 true
62 }
63}
64
65pub struct StateRef<'a, T> {
66 serialization_invalidator: Option<&'a SerializationInvalidator>,
67 inner: MutexGuard<'a, StateInner<T>>,
68 mutated: bool,
69}
70
71impl<T> Deref for StateRef<'_, T> {
72 type Target = T;
73
74 fn deref(&self) -> &Self::Target {
75 &self.inner.value
76 }
77}
78
79impl<T> DerefMut for StateRef<'_, T> {
80 fn deref_mut(&mut self) -> &mut Self::Target {
81 self.mutated = true;
82 &mut self.inner.value
83 }
84}
85
86impl<T> Drop for StateRef<'_, T> {
87 fn drop(&mut self) {
88 if self.mutated {
89 for invalidator in take(&mut self.inner.invalidators) {
90 invalidator.invalidate();
91 }
92 if let Some(serialization_invalidator) = self.serialization_invalidator {
93 serialization_invalidator.invalidate();
94 }
95 }
96 }
97}
98
99#[derive(Serialize, Deserialize)]
123pub struct State<T> {
124 serialization_invalidator: SerializationInvalidator,
125 inner: Mutex<StateInner<T>>,
126}
127
128impl<T: Debug> Debug for State<T> {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("State")
131 .field("value", &self.inner.lock().value)
132 .finish()
133 }
134}
135
136impl<T: TraceRawVcs> TraceRawVcs for State<T> {
137 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
138 self.inner.lock().value.trace_raw_vcs(trace_context);
139 }
140}
141
142impl<T: Default + OperationValue> Default for State<T> {
143 fn default() -> Self {
144 Self::new(Default::default())
146 }
147}
148
149impl<T> PartialEq for State<T> {
150 fn eq(&self, _other: &Self) -> bool {
151 false
152 }
153}
154impl<T> Eq for State<T> {}
155
156impl<T> State<T> {
157 pub fn new(value: T) -> Self
158 where
159 T: OperationValue,
160 {
161 Self {
162 serialization_invalidator: mark_stateful(),
163 inner: Mutex::new(StateInner::new(value)),
164 }
165 }
166
167 pub fn get(&self) -> StateRef<'_, T> {
171 let invalidator = get_invalidator();
172 let mut inner = self.inner.lock();
173 inner.add_invalidator(invalidator);
174 StateRef {
175 serialization_invalidator: Some(&self.serialization_invalidator),
176 inner,
177 mutated: false,
178 }
179 }
180
181 pub fn get_untracked(&self) -> StateRef<'_, T> {
183 let inner = self.inner.lock();
184 StateRef {
185 serialization_invalidator: Some(&self.serialization_invalidator),
186 inner,
187 mutated: false,
188 }
189 }
190
191 pub fn set_unconditionally(&self, value: T) {
194 {
195 let mut inner = self.inner.lock();
196 inner.set_unconditionally(value);
197 }
198 self.serialization_invalidator.invalidate();
199 }
200
201 pub fn update_conditionally(&self, update: impl FnOnce(&mut T) -> bool) {
206 {
207 let mut inner = self.inner.lock();
208 if !inner.update_conditionally(update) {
209 return;
210 }
211 }
212 self.serialization_invalidator.invalidate();
213 }
214}
215
216impl<T: PartialEq> State<T> {
217 pub fn set(&self, value: T) {
220 {
221 let mut inner = self.inner.lock();
222 if !inner.set(value) {
223 return;
224 }
225 }
226 self.serialization_invalidator.invalidate();
227 }
228}
229
230pub struct TransientState<T> {
231 inner: Mutex<StateInner<Option<T>>>,
232}
233
234impl<T> Serialize for TransientState<T> {
235 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236 where
237 S: serde::Serializer,
238 {
239 Serialize::serialize(&(), serializer)
240 }
241}
242
243impl<'de, T> Deserialize<'de> for TransientState<T> {
244 fn deserialize<D>(deserializer: D) -> Result<TransientState<T>, D::Error>
245 where
246 D: serde::Deserializer<'de>,
247 {
248 let () = Deserialize::deserialize(deserializer)?;
249 Ok(TransientState {
250 inner: Mutex::new(StateInner::new(Default::default())),
251 })
252 }
253}
254
255impl<T: Debug> Debug for TransientState<T> {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.debug_struct("TransientState")
258 .field("value", &self.inner.lock().value)
259 .finish()
260 }
261}
262
263impl<T: TraceRawVcs> TraceRawVcs for TransientState<T> {
264 fn trace_raw_vcs(&self, trace_context: &mut crate::trace::TraceRawVcsContext) {
265 self.inner.lock().value.trace_raw_vcs(trace_context);
266 }
267}
268
269impl<T> Default for TransientState<T> {
270 fn default() -> Self {
271 Self::new()
273 }
274}
275
276impl<T> PartialEq for TransientState<T> {
277 fn eq(&self, _other: &Self) -> bool {
278 false
279 }
280}
281impl<T> Eq for TransientState<T> {}
282
283impl<T> TransientState<T> {
284 pub fn new() -> Self {
285 mark_stateful();
286 Self {
287 inner: Mutex::new(StateInner::new(None)),
288 }
289 }
290
291 pub fn get(&self) -> StateRef<'_, Option<T>> {
295 mark_session_dependent();
296 let invalidator = get_invalidator();
297 let mut inner = self.inner.lock();
298 inner.add_invalidator(invalidator);
299 StateRef {
300 serialization_invalidator: None,
301 inner,
302 mutated: false,
303 }
304 }
305
306 pub fn get_untracked(&self) -> StateRef<'_, Option<T>> {
308 let inner = self.inner.lock();
309 StateRef {
310 serialization_invalidator: None,
311 inner,
312 mutated: false,
313 }
314 }
315
316 pub fn set_unconditionally(&self, value: T) {
319 let mut inner = self.inner.lock();
320 inner.set_unconditionally(Some(value));
321 }
322
323 pub fn unset_unconditionally(&self) {
326 let mut inner = self.inner.lock();
327 inner.set_unconditionally(None);
328 }
329
330 pub fn update_conditionally(&self, update: impl FnOnce(&mut Option<T>) -> bool) {
335 let mut inner = self.inner.lock();
336 inner.update_conditionally(update);
337 }
338}
339
340impl<T: PartialEq> TransientState<T> {
341 pub fn set(&self, value: T) {
344 let mut inner = self.inner.lock();
345 inner.set(Some(value));
346 }
347
348 pub fn unset(&self) {
350 let mut inner = self.inner.lock();
351 inner.set(None);
352 }
353}