1use std::{
2 hash::Hash,
3 ops::{Deref, DerefMut},
4 sync::{Arc, atomic::AtomicBool},
5};
6
7use smallvec::SmallVec;
8use turbo_bincode::TurboBincodeBuffer;
9use turbo_tasks::{FxDashMap, TaskId, parallel};
10
11use crate::{
12 backend::storage_schema::TaskStorage,
13 database::key_value_database::KeySpace,
14 utils::{
15 dash_map_drop_contents::drop_contents,
16 dash_map_multi::{RefMut, get_multiple_mut},
17 },
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum TaskDataCategory {
22 Meta,
23 Data,
24 All,
25}
26
27impl TaskDataCategory {
28 pub fn into_specific(self) -> SpecificTaskDataCategory {
29 match self {
30 TaskDataCategory::Meta => SpecificTaskDataCategory::Meta,
31 TaskDataCategory::Data => SpecificTaskDataCategory::Data,
32 TaskDataCategory::All => unreachable!(),
33 }
34 }
35
36 pub fn includes_data(self) -> bool {
37 matches!(self, TaskDataCategory::Data | TaskDataCategory::All)
38 }
39
40 pub fn includes_meta(self) -> bool {
41 matches!(self, TaskDataCategory::Meta | TaskDataCategory::All)
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum SpecificTaskDataCategory {
47 Meta,
48 Data,
49}
50
51impl SpecificTaskDataCategory {
52 pub fn key_space(self) -> KeySpace {
54 match self {
55 SpecificTaskDataCategory::Meta => KeySpace::TaskMeta,
56 SpecificTaskDataCategory::Data => KeySpace::TaskData,
57 }
58 }
59}
60
61enum ModifiedState {
62 Modified,
65 Snapshot(Option<Box<TaskStorage>>),
74}
75
76pub struct Storage {
77 snapshot_mode: AtomicBool,
78 modified: FxDashMap<TaskId, ModifiedState>,
79 map: FxDashMap<TaskId, Box<TaskStorage>>,
80}
81
82impl Storage {
83 pub fn new(shard_amount: usize, small_preallocation: bool) -> Self {
84 let map_capacity: usize = if small_preallocation {
85 1024
86 } else {
87 1024 * 1024
88 };
89 let modified_capacity: usize = if small_preallocation { 0 } else { 1024 };
90
91 Self {
92 snapshot_mode: AtomicBool::new(false),
93 modified: FxDashMap::with_capacity_and_hasher_and_shard_amount(
94 modified_capacity,
95 Default::default(),
96 shard_amount,
97 ),
98 map: FxDashMap::with_capacity_and_hasher_and_shard_amount(
99 map_capacity,
100 Default::default(),
101 shard_amount,
102 ),
103 }
104 }
105
106 pub fn take_snapshot<
113 'l,
114 T,
115 R,
116 PP: for<'a> Fn(TaskId, &'a TaskStorage) -> T + Sync,
117 P: Fn(TaskId, T, &mut TurboBincodeBuffer) -> R + Sync,
118 PS: Fn(TaskId, Box<TaskStorage>, &mut TurboBincodeBuffer) -> R + Sync,
119 >(
120 &'l self,
121 preprocess: &'l PP,
122 process: &'l P,
123 process_snapshot: &'l PS,
124 ) -> Vec<SnapshotShard<'l, PP, P, PS>> {
125 if !self.snapshot_mode() {
126 self.start_snapshot();
127 }
128
129 let guard = Arc::new(SnapshotGuard { storage: self });
130
131 parallel::map_collect::<_, _, Vec<_>>(self.modified.shards(), |shard| {
134 let mut direct_snapshots: Vec<(TaskId, Box<TaskStorage>)> = Vec::new();
135 let mut modified: SmallVec<[TaskId; 4]> = SmallVec::new();
136 {
137 let guard = shard.write();
139 for bucket in unsafe { guard.iter() } {
141 let (key, shared_value) = unsafe { bucket.as_mut() };
144 let modified_state = shared_value.get_mut();
145 match modified_state {
146 ModifiedState::Modified => {
147 modified.push(*key);
148 }
149 ModifiedState::Snapshot(snapshot) => {
150 if let Some(snapshot) = snapshot.take() {
151 direct_snapshots.push((*key, snapshot));
152 }
153 }
154 }
155 }
156 drop(guard);
158 }
159 const SCRATCH_BUFFER_SIZE: usize = 4096;
162 SnapshotShard {
163 direct_snapshots,
164 modified,
165 storage: self,
166 guard: Some(guard.clone()),
167 process,
168 preprocess,
169 process_snapshot,
170 scratch_buffer: TurboBincodeBuffer::with_capacity(SCRATCH_BUFFER_SIZE),
171 }
172 })
173 }
174
175 pub fn start_snapshot(&self) {
177 self.snapshot_mode
178 .store(true, std::sync::atomic::Ordering::Release);
179 }
180
181 fn end_snapshot(&self) {
186 let mut removed_modified = Vec::new();
189 self.modified.retain(|key, inner| {
190 if matches!(inner, ModifiedState::Modified) {
191 removed_modified.push(*key);
192 false
193 } else {
194 true
195 }
196 });
197
198 for key in removed_modified {
200 if let Some(mut inner) = self.map.get_mut(&key) {
201 inner.flags.set_data_modified(false);
202 inner.flags.set_meta_modified(false);
203 }
204 }
205
206 self.snapshot_mode
209 .store(false, std::sync::atomic::Ordering::Release);
210
211 let mut removed_snapshots = Vec::new();
213 for mut item in self.modified.iter_mut() {
214 match item.value() {
215 ModifiedState::Snapshot(_) => {
216 removed_snapshots.push(*item.key());
217 *item.value_mut() = ModifiedState::Modified;
218 }
219 ModifiedState::Modified => {
220 }
223 }
224 }
225
226 for key in removed_snapshots {
228 if let Some(mut inner) = self.map.get_mut(&key) {
229 if inner.flags.meta_snapshot() {
230 inner.flags.set_meta_snapshot(false);
231 inner.flags.set_meta_modified(true);
232 }
233 if inner.flags.data_snapshot() {
234 inner.flags.set_data_snapshot(false);
235 inner.flags.set_data_modified(true);
236 }
237 }
238 }
239
240 self.modified.shrink_to_fit();
242 }
243
244 fn snapshot_mode(&self) -> bool {
245 self.snapshot_mode
246 .load(std::sync::atomic::Ordering::Acquire)
247 }
248
249 pub fn access_mut(&self, key: TaskId) -> StorageWriteGuard<'_> {
250 let inner = match self.map.entry(key) {
251 dashmap::mapref::entry::Entry::Occupied(e) => e.into_ref(),
252 dashmap::mapref::entry::Entry::Vacant(e) => e.insert(Box::new(TaskStorage::new())),
253 };
254 StorageWriteGuard {
255 storage: self,
256 inner: inner.into(),
257 }
258 }
259
260 pub fn access_pair_mut(
261 &self,
262 key1: TaskId,
263 key2: TaskId,
264 ) -> (StorageWriteGuard<'_>, StorageWriteGuard<'_>) {
265 let (a, b) = get_multiple_mut(&self.map, key1, key2, || Box::new(TaskStorage::new()));
266 (
267 StorageWriteGuard {
268 storage: self,
269 inner: a,
270 },
271 StorageWriteGuard {
272 storage: self,
273 inner: b,
274 },
275 )
276 }
277
278 pub fn drop_contents(&self) {
279 drop_contents(&self.map);
280 drop_contents(&self.modified);
281 }
282}
283
284pub struct StorageWriteGuard<'a> {
285 storage: &'a Storage,
286 inner: RefMut<'a, TaskId, Box<TaskStorage>>,
287}
288
289impl StorageWriteGuard<'_> {
290 #[inline(always)]
292 pub fn track_modification(
293 &mut self,
294 category: SpecificTaskDataCategory,
295 #[allow(unused_variables)] name: &str,
296 ) {
297 self.track_modification_internal(
298 category,
299 #[cfg(feature = "trace_task_modification")]
300 name,
301 );
302 }
303
304 fn track_modification_internal(
305 &mut self,
306 category: SpecificTaskDataCategory,
307 #[cfg(feature = "trace_task_modification")] name: &str,
308 ) {
309 let flags = &self.inner.flags;
310 if flags.is_snapshot(category) {
311 return;
312 }
313 let modified = flags.is_modified(category);
314 #[cfg(feature = "trace_task_modification")]
315 let _span = (!modified).then(|| tracing::trace_span!("mark_modified", name).entered());
316 match (self.storage.snapshot_mode(), modified) {
317 (false, false) => {
318 if !flags.any_snapshot() && !flags.any_modified() {
320 self.storage
321 .modified
322 .insert(*self.inner.key(), ModifiedState::Modified);
323 }
324 self.inner.flags.set_modified(category, true);
325 }
326 (false, true) => {
327 }
330 (true, false) => {
331 if !flags.any_snapshot() {
333 self.storage
334 .modified
335 .insert(*self.inner.key(), ModifiedState::Snapshot(None));
336 }
337 self.inner.flags.set_snapshot(category, true);
338 }
339 (true, true) => {
340 if !flags.any_snapshot() {
343 let mut snapshot = self.inner.clone_snapshot();
345 snapshot.flags.set_data_modified(flags.data_modified());
346 snapshot.flags.set_meta_modified(flags.meta_modified());
347 self.storage.modified.insert(
348 *self.inner.key(),
349 ModifiedState::Snapshot(Some(Box::new(snapshot))),
350 );
351 }
352 self.inner.flags.set_snapshot(category, true);
353 }
354 }
355 }
356}
357
358impl Deref for StorageWriteGuard<'_> {
359 type Target = TaskStorage;
360
361 fn deref(&self) -> &Self::Target {
362 &self.inner
363 }
364}
365
366impl DerefMut for StorageWriteGuard<'_> {
367 fn deref_mut(&mut self) -> &mut Self::Target {
368 &mut self.inner
369 }
370}
371
372pub struct SnapshotGuard<'l> {
373 storage: &'l Storage,
374}
375
376impl Drop for SnapshotGuard<'_> {
377 fn drop(&mut self) {
378 self.storage.end_snapshot();
379 }
380}
381
382pub struct SnapshotShard<'l, PP, P, PS> {
383 direct_snapshots: Vec<(TaskId, Box<TaskStorage>)>,
384 modified: SmallVec<[TaskId; 4]>,
385 storage: &'l Storage,
386 guard: Option<Arc<SnapshotGuard<'l>>>,
387 process: &'l P,
388 preprocess: &'l PP,
389 process_snapshot: &'l PS,
390 scratch_buffer: TurboBincodeBuffer,
392}
393
394impl<'l, T, R, PP, P, PS> Iterator for SnapshotShard<'l, PP, P, PS>
395where
396 PP: for<'a> Fn(TaskId, &'a TaskStorage) -> T + Sync,
397 P: Fn(TaskId, T, &mut TurboBincodeBuffer) -> R + Sync,
398 PS: Fn(TaskId, Box<TaskStorage>, &mut TurboBincodeBuffer) -> R + Sync,
399{
400 type Item = R;
401
402 fn next(&mut self) -> Option<Self::Item> {
403 if let Some((task_id, snapshot)) = self.direct_snapshots.pop() {
404 return Some((self.process_snapshot)(
405 task_id,
406 snapshot,
407 &mut self.scratch_buffer,
408 ));
409 }
410 while let Some(task_id) = self.modified.pop() {
411 let inner = self.storage.map.get(&task_id).unwrap();
412 if !inner.flags.any_snapshot() {
413 let preprocessed = (self.preprocess)(task_id, &inner);
414 drop(inner);
415 return Some((self.process)(
416 task_id,
417 preprocessed,
418 &mut self.scratch_buffer,
419 ));
420 } else {
421 drop(inner);
422 let maybe_snapshot = {
423 let mut modified_state = self.storage.modified.get_mut(&task_id).unwrap();
424 let ModifiedState::Snapshot(snapshot) = &mut *modified_state else {
425 unreachable!("The snapshot bit was set, so it must be in Snapshot state");
426 };
427 snapshot.take()
428 };
429 if let Some(snapshot) = maybe_snapshot {
430 return Some((self.process_snapshot)(
431 task_id,
432 snapshot,
433 &mut self.scratch_buffer,
434 ));
435 }
436 }
437 }
438 self.guard = None;
439 None
440 }
441}