Skip to main content

turbo_tasks_backend/backend/
storage.rs

1use std::{
2    hash::Hash,
3    ops::{Deref, DerefMut},
4    sync::{Arc, atomic::AtomicBool},
5};
6
7use smallvec::SmallVec;
8use turbo_bincode::TurboBincodeBuffer;
9use turbo_tasks::{FxDashMap, TaskId, parallel};
10
11use crate::{
12    backend::storage_schema::TaskStorage,
13    database::key_value_database::KeySpace,
14    utils::{
15        dash_map_drop_contents::drop_contents,
16        dash_map_multi::{RefMut, get_multiple_mut},
17    },
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum TaskDataCategory {
22    Meta,
23    Data,
24    All,
25}
26
27impl TaskDataCategory {
28    pub fn into_specific(self) -> SpecificTaskDataCategory {
29        match self {
30            TaskDataCategory::Meta => SpecificTaskDataCategory::Meta,
31            TaskDataCategory::Data => SpecificTaskDataCategory::Data,
32            TaskDataCategory::All => unreachable!(),
33        }
34    }
35
36    pub fn includes_data(self) -> bool {
37        matches!(self, TaskDataCategory::Data | TaskDataCategory::All)
38    }
39
40    pub fn includes_meta(self) -> bool {
41        matches!(self, TaskDataCategory::Meta | TaskDataCategory::All)
42    }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum SpecificTaskDataCategory {
47    Meta,
48    Data,
49}
50
51impl SpecificTaskDataCategory {
52    /// Returns the KeySpace for storing data of this category
53    pub fn key_space(self) -> KeySpace {
54        match self {
55            SpecificTaskDataCategory::Meta => KeySpace::TaskMeta,
56            SpecificTaskDataCategory::Data => KeySpace::TaskData,
57        }
58    }
59}
60
61enum ModifiedState {
62    /// It was modified before snapshot mode was entered, but it was not accessed during snapshot
63    /// mode.
64    Modified,
65    /// Snapshot(Some):
66    /// It was modified before snapshot mode was entered and it was accessed again during snapshot
67    /// mode. A copy of the version of the item when snapshot mode was entered is stored here.
68    /// The `TaskStorage` contains only persistent fields (via `clone_snapshot()`), and has
69    /// `meta_modified`/`data_modified` flags set to indicate which categories need serializing.
70    /// Snapshot(None):
71    /// It was not modified before snapshot mode was entered, but it was accessed during snapshot
72    /// mode. Or the snapshot was already taken out by the snapshot operation.
73    Snapshot(Option<Box<TaskStorage>>),
74}
75
76pub struct Storage {
77    snapshot_mode: AtomicBool,
78    modified: FxDashMap<TaskId, ModifiedState>,
79    map: FxDashMap<TaskId, Box<TaskStorage>>,
80}
81
82impl Storage {
83    pub fn new(shard_amount: usize, small_preallocation: bool) -> Self {
84        let map_capacity: usize = if small_preallocation {
85            1024
86        } else {
87            1024 * 1024
88        };
89        let modified_capacity: usize = if small_preallocation { 0 } else { 1024 };
90
91        Self {
92            snapshot_mode: AtomicBool::new(false),
93            modified: FxDashMap::with_capacity_and_hasher_and_shard_amount(
94                modified_capacity,
95                Default::default(),
96                shard_amount,
97            ),
98            map: FxDashMap::with_capacity_and_hasher_and_shard_amount(
99                map_capacity,
100                Default::default(),
101                shard_amount,
102            ),
103        }
104    }
105
106    /// Processes every modified item (resp. a snapshot of it) with the given functions and returns
107    /// the results. Ends snapshot mode afterwards.
108    /// preprocess is potentially called within a lock, so it should be fast.
109    /// process is called outside of locks, so it could do more expensive operations.
110    /// Both process and process_snapshot receive a mutable scratch buffer that can be reused
111    /// across iterations to avoid repeated allocations.
112    pub fn take_snapshot<
113        'l,
114        T,
115        R,
116        PP: for<'a> Fn(TaskId, &'a TaskStorage) -> T + Sync,
117        P: Fn(TaskId, T, &mut TurboBincodeBuffer) -> R + Sync,
118        PS: Fn(TaskId, Box<TaskStorage>, &mut TurboBincodeBuffer) -> R + Sync,
119    >(
120        &'l self,
121        preprocess: &'l PP,
122        process: &'l P,
123        process_snapshot: &'l PS,
124    ) -> Vec<SnapshotShard<'l, PP, P, PS>> {
125        if !self.snapshot_mode() {
126            self.start_snapshot();
127        }
128
129        let guard = Arc::new(SnapshotGuard { storage: self });
130
131        // The number of shards is much larger than the number of threads, so the effect of the
132        // locks held is negligible.
133        parallel::map_collect::<_, _, Vec<_>>(self.modified.shards(), |shard| {
134            let mut direct_snapshots: Vec<(TaskId, Box<TaskStorage>)> = Vec::new();
135            let mut modified: SmallVec<[TaskId; 4]> = SmallVec::new();
136            {
137                // Take the snapshots from the modified map
138                let guard = shard.write();
139                // Safety: guard must outlive the iterator.
140                for bucket in unsafe { guard.iter() } {
141                    // Safety: the guard guarantees that the bucket is not removed and the ptr
142                    // is valid.
143                    let (key, shared_value) = unsafe { bucket.as_mut() };
144                    let modified_state = shared_value.get_mut();
145                    match modified_state {
146                        ModifiedState::Modified => {
147                            modified.push(*key);
148                        }
149                        ModifiedState::Snapshot(snapshot) => {
150                            if let Some(snapshot) = snapshot.take() {
151                                direct_snapshots.push((*key, snapshot));
152                            }
153                        }
154                    }
155                }
156                // Safety: guard must outlive the iterator.
157                drop(guard);
158            }
159            /// How big of a buffer to allocate initially.  Based on metrics from a large
160            /// application this should cover about 98% of values with no resizes
161            const SCRATCH_BUFFER_SIZE: usize = 4096;
162            SnapshotShard {
163                direct_snapshots,
164                modified,
165                storage: self,
166                guard: Some(guard.clone()),
167                process,
168                preprocess,
169                process_snapshot,
170                scratch_buffer: TurboBincodeBuffer::with_capacity(SCRATCH_BUFFER_SIZE),
171            }
172        })
173    }
174
175    /// Start snapshot mode.
176    pub fn start_snapshot(&self) {
177        self.snapshot_mode
178            .store(true, std::sync::atomic::Ordering::Release);
179    }
180
181    /// End snapshot mode.
182    /// Items that have snapshots will be kept as modified since they have been accessed during the
183    /// snapshot mode. Items that are modified will be removed and considered as unmodified.
184    /// When items are accessed in future they will be marked as modified.
185    fn end_snapshot(&self) {
186        // We are still in snapshot mode, so all accessed items would be stored as snapshot.
187        // This means we can start by removing all modified items.
188        let mut removed_modified = Vec::new();
189        self.modified.retain(|key, inner| {
190            if matches!(inner, ModifiedState::Modified) {
191                removed_modified.push(*key);
192                false
193            } else {
194                true
195            }
196        });
197
198        // We also need to unset all the modified flags.
199        for key in removed_modified {
200            if let Some(mut inner) = self.map.get_mut(&key) {
201                inner.flags.set_data_modified(false);
202                inner.flags.set_meta_modified(false);
203            }
204        }
205
206        // Now modified only contains snapshots.
207        // We leave snapshot mode. Any access would be stored as modified and not as snapshot.
208        self.snapshot_mode
209            .store(false, std::sync::atomic::Ordering::Release);
210
211        // We can change all the snapshots to modified now.
212        let mut removed_snapshots = Vec::new();
213        for mut item in self.modified.iter_mut() {
214            match item.value() {
215                ModifiedState::Snapshot(_) => {
216                    removed_snapshots.push(*item.key());
217                    *item.value_mut() = ModifiedState::Modified;
218                }
219                ModifiedState::Modified => {
220                    // This means it was concurrently modified.
221                    // It's already in the correct state.
222                }
223            }
224        }
225
226        // And update the flags
227        for key in removed_snapshots {
228            if let Some(mut inner) = self.map.get_mut(&key) {
229                if inner.flags.meta_snapshot() {
230                    inner.flags.set_meta_snapshot(false);
231                    inner.flags.set_meta_modified(true);
232                }
233                if inner.flags.data_snapshot() {
234                    inner.flags.set_data_snapshot(false);
235                    inner.flags.set_data_modified(true);
236                }
237            }
238        }
239
240        // Remove excessive capacity in modified
241        self.modified.shrink_to_fit();
242    }
243
244    fn snapshot_mode(&self) -> bool {
245        self.snapshot_mode
246            .load(std::sync::atomic::Ordering::Acquire)
247    }
248
249    pub fn access_mut(&self, key: TaskId) -> StorageWriteGuard<'_> {
250        let inner = match self.map.entry(key) {
251            dashmap::mapref::entry::Entry::Occupied(e) => e.into_ref(),
252            dashmap::mapref::entry::Entry::Vacant(e) => e.insert(Box::new(TaskStorage::new())),
253        };
254        StorageWriteGuard {
255            storage: self,
256            inner: inner.into(),
257        }
258    }
259
260    pub fn access_pair_mut(
261        &self,
262        key1: TaskId,
263        key2: TaskId,
264    ) -> (StorageWriteGuard<'_>, StorageWriteGuard<'_>) {
265        let (a, b) = get_multiple_mut(&self.map, key1, key2, || Box::new(TaskStorage::new()));
266        (
267            StorageWriteGuard {
268                storage: self,
269                inner: a,
270            },
271            StorageWriteGuard {
272                storage: self,
273                inner: b,
274            },
275        )
276    }
277
278    pub fn drop_contents(&self) {
279        drop_contents(&self.map);
280        drop_contents(&self.modified);
281    }
282}
283
284pub struct StorageWriteGuard<'a> {
285    storage: &'a Storage,
286    inner: RefMut<'a, TaskId, Box<TaskStorage>>,
287}
288
289impl StorageWriteGuard<'_> {
290    /// Tracks mutation of this task
291    #[inline(always)]
292    pub fn track_modification(
293        &mut self,
294        category: SpecificTaskDataCategory,
295        #[allow(unused_variables)] name: &str,
296    ) {
297        self.track_modification_internal(
298            category,
299            #[cfg(feature = "trace_task_modification")]
300            name,
301        );
302    }
303
304    fn track_modification_internal(
305        &mut self,
306        category: SpecificTaskDataCategory,
307        #[cfg(feature = "trace_task_modification")] name: &str,
308    ) {
309        let flags = &self.inner.flags;
310        if flags.is_snapshot(category) {
311            return;
312        }
313        let modified = flags.is_modified(category);
314        #[cfg(feature = "trace_task_modification")]
315        let _span = (!modified).then(|| tracing::trace_span!("mark_modified", name).entered());
316        match (self.storage.snapshot_mode(), modified) {
317            (false, false) => {
318                // Not in snapshot mode and item is unmodified
319                if !flags.any_snapshot() && !flags.any_modified() {
320                    self.storage
321                        .modified
322                        .insert(*self.inner.key(), ModifiedState::Modified);
323                }
324                self.inner.flags.set_modified(category, true);
325            }
326            (false, true) => {
327                // Not in snapshot mode and item is already modified
328                // Do nothing
329            }
330            (true, false) => {
331                // In snapshot mode and item is unmodified (so it's not part of the snapshot)
332                if !flags.any_snapshot() {
333                    self.storage
334                        .modified
335                        .insert(*self.inner.key(), ModifiedState::Snapshot(None));
336                }
337                self.inner.flags.set_snapshot(category, true);
338            }
339            (true, true) => {
340                // In snapshot mode and item is modified (so it's part of the snapshot)
341                // We need to store the original version that is part of the snapshot
342                if !flags.any_snapshot() {
343                    // Snapshot all non-transient fields but keep the modified bits.
344                    let mut snapshot = self.inner.clone_snapshot();
345                    snapshot.flags.set_data_modified(flags.data_modified());
346                    snapshot.flags.set_meta_modified(flags.meta_modified());
347                    self.storage.modified.insert(
348                        *self.inner.key(),
349                        ModifiedState::Snapshot(Some(Box::new(snapshot))),
350                    );
351                }
352                self.inner.flags.set_snapshot(category, true);
353            }
354        }
355    }
356}
357
358impl Deref for StorageWriteGuard<'_> {
359    type Target = TaskStorage;
360
361    fn deref(&self) -> &Self::Target {
362        &self.inner
363    }
364}
365
366impl DerefMut for StorageWriteGuard<'_> {
367    fn deref_mut(&mut self) -> &mut Self::Target {
368        &mut self.inner
369    }
370}
371
372pub struct SnapshotGuard<'l> {
373    storage: &'l Storage,
374}
375
376impl Drop for SnapshotGuard<'_> {
377    fn drop(&mut self) {
378        self.storage.end_snapshot();
379    }
380}
381
382pub struct SnapshotShard<'l, PP, P, PS> {
383    direct_snapshots: Vec<(TaskId, Box<TaskStorage>)>,
384    modified: SmallVec<[TaskId; 4]>,
385    storage: &'l Storage,
386    guard: Option<Arc<SnapshotGuard<'l>>>,
387    process: &'l P,
388    preprocess: &'l PP,
389    process_snapshot: &'l PS,
390    /// Scratch buffer for encoding task data, reused across iterations to avoid allocations
391    scratch_buffer: TurboBincodeBuffer,
392}
393
394impl<'l, T, R, PP, P, PS> Iterator for SnapshotShard<'l, PP, P, PS>
395where
396    PP: for<'a> Fn(TaskId, &'a TaskStorage) -> T + Sync,
397    P: Fn(TaskId, T, &mut TurboBincodeBuffer) -> R + Sync,
398    PS: Fn(TaskId, Box<TaskStorage>, &mut TurboBincodeBuffer) -> R + Sync,
399{
400    type Item = R;
401
402    fn next(&mut self) -> Option<Self::Item> {
403        if let Some((task_id, snapshot)) = self.direct_snapshots.pop() {
404            return Some((self.process_snapshot)(
405                task_id,
406                snapshot,
407                &mut self.scratch_buffer,
408            ));
409        }
410        while let Some(task_id) = self.modified.pop() {
411            let inner = self.storage.map.get(&task_id).unwrap();
412            if !inner.flags.any_snapshot() {
413                let preprocessed = (self.preprocess)(task_id, &inner);
414                drop(inner);
415                return Some((self.process)(
416                    task_id,
417                    preprocessed,
418                    &mut self.scratch_buffer,
419                ));
420            } else {
421                drop(inner);
422                let maybe_snapshot = {
423                    let mut modified_state = self.storage.modified.get_mut(&task_id).unwrap();
424                    let ModifiedState::Snapshot(snapshot) = &mut *modified_state else {
425                        unreachable!("The snapshot bit was set, so it must be in Snapshot state");
426                    };
427                    snapshot.take()
428                };
429                if let Some(snapshot) = maybe_snapshot {
430                    return Some((self.process_snapshot)(
431                        task_id,
432                        snapshot,
433                        &mut self.scratch_buffer,
434                    ));
435                }
436            }
437        }
438        self.guard = None;
439        None
440    }
441}