Skip to main content

turbo_tasks_backend/backend/
storage.rs

1use std::{
2    cell::Cell,
3    fmt::{Display, Formatter},
4    hash::{BuildHasher, Hash},
5    ops::{Deref, DerefMut},
6    sync::{
7        Arc,
8        atomic::{AtomicBool, AtomicU64, Ordering},
9    },
10};
11
12use thread_local::ThreadLocal;
13use tracing::span::Id;
14use turbo_bincode::TurboBincodeBuffer;
15use turbo_tasks::{FxDashMap, TaskId, backend::CachedTaskTypeArc, event::Event, parallel};
16
17use crate::{
18    backend::storage_schema::{
19        DropPartialOutcome, KeyEvictability, TaskStorage, UnevictableReason, ValueEvictability,
20    },
21    backing_storage::SnapshotItem,
22    database::key_value_database::KeySpace,
23    utils::{
24        dash_map_drop_contents::drop_contents,
25        dash_map_multi::{RefMut, get_multiple_mut},
26        dash_map_raw_entry::{TryLockAndRemove, try_lock_and_remove},
27    },
28};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum TaskDataCategory {
32    Meta,
33    Data,
34    All,
35}
36
37/// Counts of tasks evicted at each level.
38#[derive(Debug, Default)]
39pub struct EvictionCounts {
40    pub key_evictions: usize,
41    pub full: usize,
42    pub data_and_meta: usize,
43    pub data_only: usize,
44    pub meta_only: usize,
45    /// Per-reason counts of tasks we considered but could not evict, indexed by
46    /// `UnevictableReason::index()`.
47    pub unevictable_reasons: [usize; UnevictableReason::COUNT],
48}
49
50impl std::ops::AddAssign for EvictionCounts {
51    fn add_assign(&mut self, rhs: Self) {
52        self.key_evictions += rhs.key_evictions;
53        self.full += rhs.full;
54        self.data_and_meta += rhs.data_and_meta;
55        self.data_only += rhs.data_only;
56        self.meta_only += rhs.meta_only;
57        for i in 0..UnevictableReason::COUNT {
58            self.unevictable_reasons[i] += rhs.unevictable_reasons[i];
59        }
60    }
61}
62
63impl Display for EvictionCounts {
64    /// Compact `field=value,...` form used as a single tracing span field so that
65    /// adding a new counter or `UnevictableReason` variant doesn't require updating
66    /// the span field list.
67    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68        let skipped: usize = self.unevictable_reasons.iter().sum();
69        write!(
70            f,
71            "task_cache_evictions={},full={},data_and_meta={},data_only={},meta_only={},skipped={}",
72            self.key_evictions,
73            self.full,
74            self.data_and_meta,
75            self.data_only,
76            self.meta_only,
77            skipped,
78        )?;
79        for reason in UnevictableReason::ALL {
80            write!(
81                f,
82                ",{}={}",
83                reason.span_name(),
84                self.unevictable_reasons[reason.index()],
85            )?;
86        }
87        Ok(())
88    }
89}
90
91impl TaskDataCategory {
92    pub fn includes_data(self) -> bool {
93        matches!(self, TaskDataCategory::Data | TaskDataCategory::All)
94    }
95
96    pub fn includes_meta(self) -> bool {
97        matches!(self, TaskDataCategory::Meta | TaskDataCategory::All)
98    }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum SpecificTaskDataCategory {
103    Meta,
104    Data,
105}
106
107impl From<SpecificTaskDataCategory> for TaskDataCategory {
108    fn from(category: SpecificTaskDataCategory) -> Self {
109        match category {
110            SpecificTaskDataCategory::Meta => TaskDataCategory::Meta,
111            SpecificTaskDataCategory::Data => TaskDataCategory::Data,
112        }
113    }
114}
115
116impl SpecificTaskDataCategory {
117    /// Returns the KeySpace for storing data of this category
118    pub fn key_space(self) -> KeySpace {
119        match self {
120            SpecificTaskDataCategory::Meta => KeySpace::TaskMeta,
121            SpecificTaskDataCategory::Data => KeySpace::TaskData,
122        }
123    }
124}
125
126pub struct Storage {
127    snapshot_mode: AtomicBool,
128    /// Per-shard counts of tasks with modified flags set. Incremented when a task
129    /// transitions from unmodified to modified (outside snapshot mode). Reset to zero when
130    /// snapshot mode begins, and re-incremented in `end_snapshot` for tasks that still have
131    /// modifications (promoted from `modified_during_snapshot`). Used to skip unmodified shards
132    /// in `take_snapshot`, avoiding unnecessary iteration and enabling early returns
133    ///
134    /// Indexed by `map.determine_shard(map.hash_usize(&key))` and guaranteed by construction so
135    /// that  `shard_modified_counts.len()==map.shards().len()`
136    ///
137    /// Should only be modified while holding the corresponding dashmap shard lock.
138    shard_modified_counts: Box<[AtomicU64]>,
139    /// Stores snapshots of task state for tasks accessed during snapshot mode.
140    /// - `Some(snapshot)`: Task was modified before snapshot mode and accessed again during it.
141    ///   Contains a copy of the pre-snapshot state that needs to be persisted.
142    /// - `None`: Task was first modified during snapshot mode (not part of current snapshot). Will
143    ///   be marked as modified at the beginning of the next snapshot cycle.
144    ///
145    /// Lock Ordering: `snapshots` locks are acquired **after** `map` locks (see the comment on
146    /// `map` below). Holding a `snapshots` shard write lock and then trying to take a `map` shard
147    /// write lock is forbidden — it would deadlock against `track_modification_internal` /
148    /// `SnapshotShardIter::next`, which take map first.
149    ///
150    /// Shard Invariant: `snapshots` is constructed with the same `shard_amount`, the same key
151    /// type (`TaskId`), and the same stateless hasher (`FxBuildHasher`) as `map`. Therefore shard
152    /// index `N` in `snapshots` corresponds exactly to shard index `N` in `map`: any `TaskId`
153    /// present in `snapshots.shards()[N]` (if present in `map` at all) is in `map.shards()[N]`.
154    /// Code that walks both maps in parallel (e.g. `end_snapshot`) relies on this to lock pairs
155    /// of shards by index instead of going through the top-level `DashMap` accessors.
156    snapshots: FxDashMap<TaskId, Option<Box<TaskStorage>>>,
157    /// The main storage map
158    ///
159    /// Lock Ordering: Task creation acquires a `task_cache` lock and then inserts into this map.
160    /// Because both datastructures are sharded on different keys, the locks are not 'strictly'
161    /// ordered but we should treat them as such
162    /// Acquiring locks in the opposite order should be defensive
163    ///
164    /// Lock Ordering vs. `snapshots`: `map` locks are acquired **before** `snapshots` locks.
165    /// `track_modification_internal` and `SnapshotShardIter::next` both hold a `map` shard write
166    /// lock (via `StorageWriteGuard` / `map.get_mut`) and then take a `snapshots` shard lock.
167    /// `end_snapshot` must lock in the same order — see the shard-zipping pattern there.
168    map: FxDashMap<TaskId, Box<TaskStorage>>,
169    /// A shared event notified whenever any task finishes restoring (successfully or not).
170    ///
171    /// Threads waiting for another thread's in-progress restore subscribe to this event,
172    /// then re-check the specific task's `restoring`/`restored` bits after waking.
173    pub(crate) restored: Event,
174    /// Maps `CachedTaskType` → `TaskId` for deduplication of persistent task creation.
175    /// This is backed by the TaskCache table in the database.
176    ///
177    /// LockOrdering: See the comments on [map].
178    pub task_cache: FxDashMap<CachedTaskTypeArc, TaskId>,
179}
180
181impl Storage {
182    pub fn new(shard_amount: usize, small_preallocation: bool) -> Self {
183        let map_capacity: usize = if small_preallocation {
184            1024
185        } else {
186            1024 * 1024
187        };
188
189        let map = FxDashMap::with_capacity_and_hasher_and_shard_amount(
190            map_capacity,
191            Default::default(),
192            shard_amount,
193        );
194        let shard_modified_counts = (0..shard_amount)
195            .map(|_| AtomicU64::new(0))
196            .collect::<Vec<_>>()
197            .into_boxed_slice();
198        Self {
199            snapshot_mode: AtomicBool::new(false),
200            shard_modified_counts,
201            snapshots: FxDashMap::with_capacity_and_hasher_and_shard_amount(
202                // We expect very few updates to this map since it will only happen when updates
203                // race with snapshots.  This never happens in a build and only rarely happens in
204                // dev sessions
205                0,
206                Default::default(),
207                shard_amount,
208            ),
209            map,
210            restored: Event::new(|| || "Storage::restored".to_string()),
211            task_cache: FxDashMap::default(),
212        }
213    }
214
215    /// Returns the shard index for the given key in the `map` DashMap.
216    fn shard_index(&self, key: &TaskId) -> usize {
217        let hash = self.map.hash_usize(key);
218        self.map.determine_shard(hash)
219    }
220
221    /// Promote `modified_during_snapshot` → `modified` flags on a task, and increment the
222    /// per-shard modified count if the task was not already marked as modified.
223    ///
224    /// This is used after persisting a snapshot: _during_snapshot flags represent changes
225    /// that occurred concurrently and were not included in the persisted snapshot, so they
226    /// must be carried forward as `modified` for the next snapshot cycle.
227    fn promote_during_snapshot_flags(&self, task: &mut TaskStorage, shard_idx: usize) {
228        let already_modified = task.flags.any_modified();
229        let mut promoted = false;
230        if task.flags.meta_modified_during_snapshot() {
231            task.flags.set_meta_modified_during_snapshot(false);
232            task.flags.set_meta_modified(true);
233            promoted = true;
234        }
235        if task.flags.data_modified_during_snapshot() {
236            task.flags.set_data_modified_during_snapshot(false);
237            task.flags.set_data_modified(true);
238            promoted = true;
239        }
240        if !already_modified && promoted {
241            self.shard_modified_counts[shard_idx].fetch_add(1, Ordering::Relaxed);
242        }
243    }
244
245    /// Mark a newly allocated task as restored (skip DB queries) and new (include in persistence
246    /// snapshots). Optionally sets the `persistent_task_type` eagerly so it's available for
247    /// persistence snapshots without needing to propagate it through `connect_child`.
248    pub fn initialize_new_task(&self, task_id: TaskId, task_type: Option<CachedTaskTypeArc>) {
249        let mut task = self.access_mut(task_id);
250        task.flags.set_restored(TaskDataCategory::All);
251        task.flags.set_new_task(true);
252        if let Some(task_type) = task_type {
253            task.set_persistent_task_type(task_type);
254            if !task_id.is_transient() {
255                task.track_modification(SpecificTaskDataCategory::Data, "persistent_task_type");
256            }
257        }
258    }
259
260    /// Processes every modified item (resp. a snapshot of it) with the given function and returns
261    /// the results. Ends snapshot mode when the returned `SnapshotGuard` (held by each shard) is
262    /// dropped.
263    ///
264    /// `process` is called while holding a read lock on the task storage, so it can access
265    /// the TaskStorage directly without cloning.
266    ///
267    /// Both callbacks receive a mutable scratch buffer that can be reused across iterations
268    /// to avoid repeated allocations.
269    ///
270    /// The returned shards implement `IntoIterator`. Empty shards (no modified or snapshot
271    /// entries) are filtered out, but shards may still yield no items if all entries produce
272    /// empty `SnapshotItem`s (this is rare and only happens under error conditions).
273    pub fn take_snapshot<
274        'l,
275        P: for<'a> Fn(TaskId, &'a TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
276    >(
277        &'l self,
278        guard: SnapshotGuard<'l>,
279        process: &'l P,
280    ) -> Vec<SnapshotShard<'l, P>> {
281        let guard = Arc::new(guard);
282
283        let shards: Vec<_> = self.map.shards().iter().enumerate().collect();
284
285        // The number of shards is much larger than the number of threads, so the effect of the
286        // locks held is negligible.
287        parallel::map_collect::<_, _, Vec<_>>(&shards, |&(shard_idx, shard)| {
288            // Check how many modifications there are in this shard, because we have entered
289            // snapshot_mode, there are no racing writes
290            // So we can safely clear it out now that we are processing the modifications
291            let modified_count = self.shard_modified_counts[shard_idx].swap(0, Ordering::Relaxed);
292            if modified_count == 0 {
293                return None;
294            }
295            let mut modified = Vec::with_capacity(modified_count as usize);
296            {
297                let shard_guard = shard.read();
298                // Safety: shard_guard must outlive the iterator.
299                for bucket in unsafe { shard_guard.iter() } {
300                    // Safety: the guard guarantees that the bucket is not removed and the ptr
301                    // is valid.
302                    let (key, shared_value) = unsafe { bucket.as_ref() };
303                    let flags = &shared_value.get().flags;
304                    // Only check modified flags here — transient tasks never have
305                    // modified flags set (track_modification guards against it), so
306                    // this naturally excludes them. new_task is always
307                    // accompanied by modified flags (set_persistent_task_type calls
308                    // track_modification), so any_modified() is sufficient.
309                    if flags.any_modified() {
310                        if key.is_transient() {
311                            debug_assert!(
312                                false,
313                                "found a modified transient task: {:?}",
314                                shared_value.get().get_persistent_task_type()
315                            );
316                            continue;
317                        }
318
319                        modified.push(*key);
320                    }
321                }
322                // Safety: shard_guard must outlive the iterator.
323                drop(shard_guard);
324            }
325
326            debug_assert!(!modified.is_empty());
327
328            Some(SnapshotShard {
329                shard_idx,
330                modified,
331                storage: self,
332                process,
333                _guard: guard.clone(),
334            })
335        })
336        .into_iter()
337        .flatten()
338        .collect()
339    }
340
341    /// Enter snapshot mode and return a guard that will call `end_snapshot` on drop.
342    ///
343    /// Returns whether any shard has modifications. Per-shard counts are reset
344    /// in `take_snapshot` as each shard is processed, not here — resetting eagerly
345    /// would lose track of modifications for shards that haven't been persisted yet.
346    ///
347    /// Safety invariant: `start_snapshot` and `end_snapshot` are always called
348    /// sequentially within a single `snapshot_and_persist` invocation (the sole
349    /// caller). There is no concurrent snapshot lifecycle, so they cannot race.
350    pub fn start_snapshot(&self) -> (SnapshotGuard<'_>, bool) {
351        // Enter snapshot mode first so concurrent track_modification calls switch
352        // to the _during_snapshot path and stop incrementing shard_modified_counts.
353        self.snapshot_mode.store(true, Ordering::Release);
354        // Check if any shard has modifications. Don't reset counts here —
355        // take_snapshot resets per-shard counts as it processes each shard,
356        // which avoids losing track of modifications for shards that haven't
357        // been persisted yet.
358        let has_modifications = self
359            .shard_modified_counts
360            .iter()
361            .any(|c| c.load(Ordering::Relaxed) > 0);
362        (SnapshotGuard::new(self), has_modifications)
363    }
364
365    /// End snapshot mode.
366    ///
367    /// Modified/new flags on tasks are cleared incrementally during snapshot iteration
368    /// (in `take_snapshot` for direct_snapshots, and in `SnapshotShardIter::next` for
369    /// modified tasks), so no full-map scan is needed here.
370    ///
371    /// This method only needs to:
372    /// 1. Leave snapshot mode so new modifications go to the modified flags directly.
373    /// 2. Promote `modified_during_snapshot` → `modified` for tasks that were accessed during
374    ///    snapshot mode (tracked in the small `snapshots` map).
375    fn end_snapshot(&self) {
376        // Leave snapshot mode first. After this, concurrent track_modification calls
377        // will set modified flags directly instead of going through the snapshots map.
378        self.snapshot_mode.store(false, Ordering::Release);
379
380        // Promote modified_during_snapshot → modified for tasks that had snapshots.
381        // The snapshots map should be small (only tasks concurrently accessed during snapshot
382        // mode). Increment the per-shard modified counts for promoted tasks.
383
384        // Lock Ordering: we must acquire `map` shards BEFORE `snapshots` shards, matching the
385        // order used by `track_modification_internal` and `SnapshotShardIter::next`. The
386        // previous implementation drained `snapshots` first and then called `self.map.get_mut`,
387        // which is the opposite order — a concurrent `track_modification` (holding map[N], about
388        // to insert into snapshots[N]) could deadlock against it through the
389        // `snapshot_mode = false` race window.
390        //
391        // Shard pairing: `map` and `snapshots` are constructed with the same `shard_amount`,
392        // same `TaskId` keys, and the same stateless `FxBuildHasher`. Therefore shard `N` in
393        // `snapshots` pairs with shard `N` in `map`: every key drained from `snapshots[N]` (if
394        // it still exists in `map`) lives in `map[N]`. We zip them and lock each pair in order.
395        let map_shards = self.map.shards();
396        let snapshot_shards = self.snapshots.shards();
397        debug_assert_eq!(
398            map_shards.len(),
399            snapshot_shards.len(),
400            "map and snapshots must share shard count for zipped locking; see Shard Invariant on \
401             `snapshots` field"
402        );
403
404        let shard_indices: Vec<usize> = (0..map_shards.len()).collect();
405        parallel::for_each(&shard_indices, |&shard_idx| {
406            let map_shard = &map_shards[shard_idx];
407            let snap_shard = &snapshot_shards[shard_idx];
408
409            // Acquire in documented order: map first, snapshots second.
410            let map_guard = map_shard.write();
411            let mut snap_guard = snap_shard.write();
412
413            for (key, _) in snap_guard.drain() {
414                // The key is in this shard's `map` (or absent entirely), by the shard
415                // invariant above. Resolve directly in the held map guard rather than going
416                // through `self.map.get_mut`, which would attempt to re-acquire this shard's
417                // write lock and would also obscure the pairing.
418                let hash = self.map.hasher().hash_one(key);
419                if let Some(bucket) = map_guard.find(hash, |(k, _)| *k == key) {
420                    // SAFETY: We hold `map_shard`'s write lock for the duration of this
421                    // access, so the bucket pointer is valid and no other thread can alias it.
422                    let (_, shared_value) = unsafe { bucket.as_mut() };
423                    self.promote_during_snapshot_flags(shared_value.get_mut(), shard_idx);
424                }
425            }
426            // If we are saving a non-trivial amount of memory just clear it out.
427            if snap_guard.capacity() > 1024 {
428                snap_guard.shrink_to(0, |_entry| {
429                    unreachable!("nothing is hashed when resizing an empty shard to zero");
430                });
431            }
432
433            drop(snap_guard);
434            drop(map_guard);
435        });
436    }
437
438    /// Returns true if actively snapshotting (modifications should go to snapshots map).
439    /// Returns false if inactive (modifications go to modified list).
440    fn snapshot_mode(&self) -> bool {
441        self.snapshot_mode.load(Ordering::Acquire)
442    }
443
444    pub fn access_mut(&self, key: TaskId) -> StorageWriteGuard<'_> {
445        let inner = match self.map.entry(key) {
446            dashmap::mapref::entry::Entry::Occupied(e) => e.into_ref(),
447            dashmap::mapref::entry::Entry::Vacant(e) => e.insert(Box::new(TaskStorage::new())),
448        };
449        StorageWriteGuard {
450            storage: self,
451            inner: inner.into(),
452        }
453    }
454
455    pub fn access_pair_mut(
456        &self,
457        key1: TaskId,
458        key2: TaskId,
459    ) -> (StorageWriteGuard<'_>, StorageWriteGuard<'_>) {
460        let (a, b) = get_multiple_mut(&self.map, key1, key2, || Box::new(TaskStorage::new()));
461        (
462            StorageWriteGuard {
463                storage: self,
464                inner: a,
465            },
466            StorageWriteGuard {
467                storage: self,
468                inner: b,
469            },
470        )
471    }
472
473    pub fn drop_contents(&self) {
474        drop_contents(&self.map);
475        drop_contents(&self.snapshots);
476        drop_contents(&self.task_cache);
477    }
478
479    /// Evict tasks from in-memory storage after a successful snapshot.
480    ///
481    /// Iterates all tasks and applies the eviction level returned by
482    /// `TaskStorage::evictability()`:
483    /// - `Full`: remove from map entirely
484    /// - `DataAndMeta`: drop both data and meta fields, keep task in map
485    /// - `DataOnly`: drop data fields only
486    /// - `MetaOnly`: drop meta fields only
487    /// - `No`: skip
488    ///
489    /// Must be called when NOT in snapshot mode (i.e., after `end_snapshot()`).
490    pub fn evict_after_snapshot(&self, parent_span: Option<Id>) -> EvictionCounts {
491        let span = tracing::trace_span!(
492            parent: parent_span,
493            "evict_after_snapshot",
494            total_task_cache_keys = self.task_cache.len(),
495            total_map_keys = self.map.len(),
496            counts = tracing::field::Empty,
497        )
498        .entered();
499        debug_assert!(
500            !self.snapshot_mode(),
501            "evict_after_snapshot must not be called during snapshot mode"
502        );
503
504        let counts: Vec<EvictionCounts> = parallel::map_collect(self.map.shards(), |shard| {
505            let mut shard = shard.write();
506            let mut evicted = EvictionCounts::default();
507            // task_cache removals that we couldn't perform inline because the target shard
508            // was contended. We defer them until after the map shard lock is released to
509            // avoid a lock cycle with get_or_create_persistent_task, which takes task_cache
510            // before map. Allocated lazily on first conflict.
511            let mut deferred_task_cache_removals: Vec<CachedTaskTypeArc> = Vec::new();
512            // SAFETY: We hold the write lock for the duration of iteration.
513            for bucket in unsafe { shard.iter() } {
514                // SAFETY: The write lock guard outlives the bucket reference.
515                let (task_id, task) = unsafe { bucket.as_mut() };
516                if task_id.is_transient() {
517                    evicted.unevictable_reasons[UnevictableReason::Transient.index()] += 1;
518                    continue;
519                }
520                let (key_evictability, value_evictability) = task.get().evictability();
521                match key_evictability {
522                    KeyEvictability::Evictable => {
523                        // The task type is persisted to backing storage (new_task = false),
524                        // so task_cache is a pure perf cache. Remove it now; it will be
525                        // re-populated by task_by_type() on the next cache miss.
526                        let task_type = task.get().get_persistent_task_type().unwrap();
527                        // Only try to acquire the lock, if we cannot just remove at the end
528                        // Because `get_or_create_task` acquires 'task_cache' then `storage.map` and
529                        // we do the opposite we need to be defensive here.  Attempting here is just
530                        // an optimization to avoid pushing into `deferred_task_cache_removals`
531                        match try_lock_and_remove(&self.task_cache, task_type.as_ref()) {
532                            TryLockAndRemove::Removed => {
533                                evicted.key_evictions += 1;
534                            }
535                            TryLockAndRemove::NotFound => {
536                                // Generally this should be rare, it more or less implies something
537                                // else is concurrently holding the Arc
538                            }
539                            TryLockAndRemove::WouldBlock => {
540                                // Contention, to avoid a deadlock just defer
541                                deferred_task_cache_removals.push(task_type.clone());
542                            }
543                        }
544                    }
545                    KeyEvictability::AlreadyEvicted | KeyEvictability::Unevictable => {}
546                }
547                match value_evictability {
548                    ValueEvictability::Evictable { meta, data } => {
549                        match task.get_mut().drop_partial(data, meta) {
550                            DropPartialOutcome::Empty => {
551                                unsafe {
552                                    shard.erase(bucket);
553                                }
554                                evicted.full += 1;
555                            }
556                            DropPartialOutcome::HasResidue => {
557                                if data && meta {
558                                    evicted.data_and_meta += 1;
559                                } else if data {
560                                    evicted.data_only += 1;
561                                } else {
562                                    debug_assert!(meta);
563                                    evicted.meta_only += 1;
564                                }
565                            }
566                        }
567                    }
568                    ValueEvictability::Unevictable(reason) => {
569                        evicted.unevictable_reasons[reason.index()] += 1;
570                    }
571                }
572            }
573            // Shrink the shard if it's less than half full, to reclaim slack capacity
574            // after bulk evictions. We already hold the write lock, so this is free
575            // from a locking perspective. TaskId hashing is cheap (it's just an integer).
576            let len = shard.len();
577            if shard.capacity() > len * 2 {
578                shard.shrink_to(len, |(k, _v)| self.map.hasher().hash_one(k));
579            }
580            // Release the map shard lock before draining deferred removals so that a thread
581            // holding a task_cache shard lock and waiting on this map shard can make progress.
582            drop(shard);
583            for task_type in deferred_task_cache_removals {
584                if self.task_cache.remove(task_type.as_ref()).is_some() {
585                    evicted.key_evictions += 1;
586                }
587            }
588            evicted
589        });
590
591        let mut totals = EvictionCounts::default();
592        for evicted in counts {
593            totals += evicted;
594        }
595        // Shrink task_cache only when we evicted more entries than remain — i.e. the map
596        // is less than half full. Rehashing each surviving CachedTaskType isn't free, so
597        // we gate it on meaningful slack. Within that, walk shards in parallel and shrink
598        // each one independently if it is itself less than half full.
599        if totals.key_evictions > self.task_cache.len() {
600            parallel::for_each(self.task_cache.shards(), |shard| {
601                let mut shard = shard.write();
602                let len = shard.len();
603                if shard.capacity() > len * 2 {
604                    shard.shrink_to(len, |(k, _v)| self.task_cache.hasher().hash_one(k));
605                }
606            });
607        }
608        span.record("counts", tracing::field::display(&totals));
609
610        totals
611    }
612}
613
614pub struct StorageWriteGuard<'a> {
615    storage: &'a Storage,
616    inner: RefMut<'a, TaskId, Box<TaskStorage>>,
617}
618
619impl StorageWriteGuard<'_> {
620    /// Tracks mutation of this task
621    #[inline(always)]
622    pub fn track_modification(
623        &mut self,
624        category: SpecificTaskDataCategory,
625        #[allow(unused_variables)] name: &str,
626    ) {
627        debug_assert!(
628            !self.inner.key().is_transient(),
629            "transient task_ids should never be enqueued to be persisted"
630        );
631        self.track_modification_internal(
632            category,
633            #[cfg(feature = "trace_task_modification")]
634            name,
635        );
636    }
637
638    fn track_modification_internal(
639        &mut self,
640        category: SpecificTaskDataCategory,
641        #[cfg(feature = "trace_task_modification")] name: &str,
642    ) {
643        // Transient tasks are never persisted, so tracking modifications is meaningless.
644        // All callers (TaskGuard, invalidate_serialization) already
645        // guard against this, but we enforce it here as defense-in-depth.
646        debug_assert!(
647            !self.inner.key().is_transient(),
648            "track_modification called on transient task {:?}",
649            self.inner.key()
650        );
651        let flags = &self.inner.flags;
652        if flags.is_modified_during_snapshot(category) {
653            // We can early return since `end_snapshot` is responsible for reconciling.
654            return;
655        }
656        #[cfg(feature = "trace_task_modification")]
657        let _span = (!modified).then(|| tracing::trace_span!("mark_modified", name).entered());
658        match (self.storage.snapshot_mode(), flags.is_modified(category)) {
659            (false, false) => {
660                // Not in snapshot mode and item is unmodified
661                if !flags.any_modified() {
662                    let shard_idx = self.storage.shard_index(self.inner.key());
663                    self.storage.shard_modified_counts[shard_idx].fetch_add(1, Ordering::Relaxed);
664                }
665                self.inner.flags.set_modified(category, true);
666            }
667            (false, true) => {
668                // Not in snapshot mode and item is already modified
669                // Do nothing
670            }
671            (true, false) => {
672                // In snapshot mode and item is unmodified (so it's not part of the snapshot)
673                // Mark it so it gets re-added as Modified after this snapshot completes.
674                // Insert a None entry into snapshots so end_snapshot discovers this task
675                // and promotes its _during_snapshot flags.
676                if !flags.any_modified_during_snapshot() {
677                    self.storage.snapshots.insert(*self.inner.key(), None);
678                }
679                self.inner
680                    .flags
681                    .set_modified_during_snapshot(category, true);
682            }
683            (true, true) => {
684                // In snapshot mode and item is modified (so it's part of the snapshot)
685                // We need to store the original version that is part of the snapshot
686                if !flags.any_modified_during_snapshot() {
687                    // Snapshot all non-transient fields, carrying the modified bits into
688                    // the copy so the iterator knows which categories to persist.
689                    let mut snapshot = self.inner.clone_snapshot();
690                    snapshot.flags.set_data_modified(flags.data_modified());
691                    snapshot.flags.set_meta_modified(flags.meta_modified());
692                    snapshot.flags.set_new_task(flags.new_task());
693                    self.storage
694                        .snapshots
695                        .insert(*self.inner.key(), Some(Box::new(snapshot)));
696                }
697                self.inner
698                    .flags
699                    .set_modified_during_snapshot(category, true);
700            }
701        }
702    }
703}
704
705impl Deref for StorageWriteGuard<'_> {
706    type Target = TaskStorage;
707
708    fn deref(&self) -> &Self::Target {
709        &self.inner
710    }
711}
712
713impl DerefMut for StorageWriteGuard<'_> {
714    fn deref_mut(&mut self) -> &mut Self::Target {
715        &mut self.inner
716    }
717}
718
719/// How big of a buffer to allocate initially. Based on metrics from a large
720/// application this should cover about 98% of values with no resizes.
721const SCRATCH_BUFFER_INITIAL_SIZE: usize = 4096;
722
723/// State machine for a per-thread scratch buffer slot.
724///
725/// Transitions:
726/// - `Uninit` → `Taken` (first take)
727/// - `Available` → `Taken` (subsequent takes)
728/// - `Taken` → `Available` (return)
729///
730/// Any other transition is a bug (e.g. double-take or double-return).
731#[derive(Default)]
732enum ScratchBufferSlot {
733    /// No buffer has been allocated on this thread yet.
734    #[default]
735    Uninit,
736    /// The buffer is currently checked out.
737    Taken,
738    /// The buffer is available for reuse.
739    Available(TurboBincodeBuffer),
740}
741
742pub struct SnapshotGuard<'l> {
743    storage: &'l Storage,
744    /// Per-thread scratch buffers for encoding task data. Buffers are taken
745    /// by `SnapshotShardIter` on creation and returned on drop, allowing reuse
746    /// across multiple shards processed by the same thread. When the guard is
747    /// dropped (after all iterators are done), the `ThreadLocal` drops too,
748    /// freeing all buffers.
749    scratch_buffers: ThreadLocal<Cell<ScratchBufferSlot>>,
750}
751
752impl<'l> SnapshotGuard<'l> {
753    fn new(storage: &'l Storage) -> Self {
754        Self {
755            storage,
756            scratch_buffers: ThreadLocal::new(),
757        }
758    }
759
760    fn take_scratch_buffer(&self) -> TurboBincodeBuffer {
761        let cell = self.scratch_buffers.get_or_default();
762        match cell.take() {
763            ScratchBufferSlot::Available(buf) => {
764                cell.set(ScratchBufferSlot::Taken);
765                buf
766            }
767            ScratchBufferSlot::Uninit => {
768                cell.set(ScratchBufferSlot::Taken);
769                TurboBincodeBuffer::with_capacity(SCRATCH_BUFFER_INITIAL_SIZE)
770            }
771            ScratchBufferSlot::Taken => {
772                panic!("scratch buffer taken twice without being returned");
773            }
774        }
775    }
776
777    fn return_scratch_buffer(&self, buffer: TurboBincodeBuffer) {
778        let cell = self.scratch_buffers.get_or_default();
779        match cell.take() {
780            ScratchBufferSlot::Taken => cell.set(ScratchBufferSlot::Available(buffer)),
781            ScratchBufferSlot::Available(_) => {
782                panic!("scratch buffer returned without being taken (already available)");
783            }
784            ScratchBufferSlot::Uninit => {
785                panic!("scratch buffer returned without being taken (uninit)");
786            }
787        }
788    }
789}
790
791impl Drop for SnapshotGuard<'_> {
792    fn drop(&mut self) {
793        self.storage.end_snapshot();
794    }
795}
796
797pub struct SnapshotShard<'l, P> {
798    shard_idx: usize,
799    modified: Vec<TaskId>,
800    storage: &'l Storage,
801    process: &'l P,
802    /// Held for its `Drop` impl — ensures snapshot mode ends when all shards are done.
803    _guard: Arc<SnapshotGuard<'l>>,
804}
805
806impl<'l, P> IntoIterator for SnapshotShard<'l, P>
807where
808    P: Fn(TaskId, &TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
809{
810    type Item = SnapshotItem;
811    type IntoIter = SnapshotShardIter<'l, P>;
812
813    fn into_iter(self) -> Self::IntoIter {
814        let buffer = self._guard.take_scratch_buffer();
815        SnapshotShardIter {
816            shard: self,
817            buffer,
818        }
819    }
820}
821
822/// Iterator over a single shard's snapshot items. Holds a thread-local scratch
823/// buffer for the duration of iteration and returns it on drop.
824pub struct SnapshotShardIter<'l, P> {
825    shard: SnapshotShard<'l, P>,
826    buffer: TurboBincodeBuffer,
827}
828
829impl<'l, P> Iterator for SnapshotShardIter<'l, P>
830where
831    P: Fn(TaskId, &TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
832{
833    type Item = SnapshotItem;
834
835    fn next(&mut self) -> Option<Self::Item> {
836        if let Some(task_id) = self.shard.modified.pop() {
837            let mut inner = self.shard.storage.map.get_mut(&task_id).unwrap();
838            // If the task was re-modified during snapshot, the snapshots map may
839            // hold a pre-modification copy we must serialize instead of the live
840            // data. Remove the entry so end_snapshot doesn't double-promote it;
841            // we promote manually below.
842            let item = if inner.flags.any_modified_during_snapshot() {
843                match self.shard.storage.snapshots.remove(&task_id) {
844                    Some((_, Some(snapshot))) => {
845                        (self.shard.process)(task_id, &snapshot, &mut self.buffer)
846                    }
847                    Some((_, None)) | None => {
848                        (self.shard.process)(task_id, &inner, &mut self.buffer)
849                    }
850                }
851            } else {
852                (self.shard.process)(task_id, &inner, &mut self.buffer)
853            };
854            // Clear the modified flags that were captured into the snapshot copy,
855            // then promote modified_during_snapshot → modified so the task stays
856            // dirty for the next snapshot cycle.
857            inner.flags.set_data_modified(false);
858            inner.flags.set_meta_modified(false);
859            inner.flags.set_new_task(false);
860            self.shard
861                .storage
862                .promote_during_snapshot_flags(&mut inner, self.shard.shard_idx);
863            return Some(item);
864        }
865        None
866    }
867}
868
869impl<P> Drop for SnapshotShardIter<'_, P> {
870    fn drop(&mut self) {
871        self.shard
872            ._guard
873            .return_scratch_buffer(std::mem::take(&mut self.buffer));
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use turbo_bincode::TurboBincodeBuffer;
880    use turbo_tasks::TaskId;
881
882    use super::{SpecificTaskDataCategory, Storage};
883    use crate::backing_storage::SnapshotItem;
884
885    fn non_transient_task(id: u32) -> TaskId {
886        // TRANSIENT_TASK_BIT is 0x8000_0000; any id without that bit is non-transient.
887        TaskId::new(id).expect("id must be non-zero")
888    }
889
890    /// A process fn that returns a non-empty SnapshotItem so the iterator doesn't
891    /// silently skip items via the "encoding failed" error path.
892    fn dummy_process(
893        task_id: TaskId,
894        _: &super::TaskStorage,
895        _: &mut TurboBincodeBuffer,
896    ) -> SnapshotItem {
897        SnapshotItem {
898            task_id,
899            meta: Some(TurboBincodeBuffer::default()),
900            data: None,
901            task_type_hash: None,
902        }
903    }
904
905    /// Regression test: a task modified before a snapshot and then modified *again* during
906    /// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
907    /// modification forward to the next cycle.
908    ///
909    /// Sequence of events:
910    /// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
911    /// 2. `start_snapshot` puts us in snapshot mode.
912    /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
913    ///    `modified` list.
914    /// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
915    ///    is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
916    ///    the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
917    ///    `data_modified_during_snapshot` is set.
918    /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
919    ///    `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
920    ///    pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
921    ///    promotes `data_modified_during_snapshot → data_modified` for the next cycle.
922    // `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
923    // requiring a multi-threaded Tokio runtime.
924    #[tokio::test(flavor = "multi_thread")]
925    async fn modify_during_snapshot_clears_live_modified_flags() {
926        let storage = Storage::new(2, true);
927        let task_id = non_transient_task(1);
928
929        // Step 1: modify the task outside snapshot mode (data_modified = true).
930        {
931            let mut guard = storage.access_mut(task_id);
932            guard.track_modification(SpecificTaskDataCategory::Data, "test");
933        }
934
935        // Step 2: enter snapshot mode.
936        let (snapshot_guard, has_modifications) = storage.start_snapshot();
937        assert!(has_modifications);
938
939        // Step 3: `take_snapshot` scans the shard. At this point the task has
940        // `any_modified()=true` and `any_modified_during_snapshot()=false`, so it
941        // goes into the `modified` list inside the returned `SnapshotShard`.
942        let shards = storage.take_snapshot(snapshot_guard, &dummy_process);
943
944        // Step 4: now that the scan is done but before we consume the iterator,
945        // modify the task again. We're still in snapshot mode, the task is already
946        // modified → `(true, true)` branch: creates a snapshot copy (carrying the
947        // modified bits) and sets `data_modified_during_snapshot=true`.
948        {
949            let mut guard = storage.access_mut(task_id);
950            guard.track_modification(SpecificTaskDataCategory::Data, "test");
951            // We should have set a snapshot bit
952            assert!(guard.flags.data_modified_during_snapshot())
953        }
954
955        // Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
956        // clears the live modified flags, removes the snapshots entry, and promotes
957        // `data_modified_during_snapshot → data_modified` for the next cycle.
958        let items: Vec<_> = shards
959            .into_iter()
960            .flat_map(|shard| shard.into_iter())
961            .collect();
962
963        // The pre-snapshot snapshot copy should have been encoded and returned.
964        assert_eq!(items.len(), 1);
965        assert_eq!(items[0].task_id, task_id);
966
967        {
968            let guard = storage.access_mut(task_id);
969            // The iterator should have promoted modified_during_snapshot → modified.
970            assert!(guard.flags.data_modified());
971        }
972
973        // The during-snapshot modification must be reflected in shard_modified_counts so
974        // the next snapshot cycle picks it up. Verify by starting another snapshot.
975        let (_guard2, has_modifications) = storage.start_snapshot();
976        assert!(
977            has_modifications,
978            "shard_modified_counts must be non-zero after promoting modified_during_snapshot"
979        );
980    }
981
982    /// Regression test for the `(true, false)` during-snapshot case: a task modified in one
983    /// category before a snapshot, then modified in a *different* category during snapshot
984    /// iteration, must not panic and must carry both modifications forward correctly.
985    ///
986    /// Sequence of events:
987    /// 1. Task meta is modified (meta_modified = true).
988    /// 2. `start_snapshot` puts us in snapshot mode.
989    /// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
990    /// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
991    ///    modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
992    /// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
993    ///    sees `None` in snapshots, encodes from live data (correct — live data for the
994    ///    unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
995    ///    flags, and promotes `data_modified_during_snapshot → data_modified`.
996    #[tokio::test(flavor = "multi_thread")]
997    async fn modify_different_category_during_snapshot() {
998        let storage = Storage::new(2, true);
999        let task_id = non_transient_task(1);
1000
1001        // Step 1: modify meta only, outside snapshot mode.
1002        {
1003            let mut guard = storage.access_mut(task_id);
1004            guard.track_modification(SpecificTaskDataCategory::Meta, "test");
1005            assert!(guard.flags.meta_modified());
1006            assert!(!guard.flags.data_modified());
1007        }
1008
1009        // Step 2: enter snapshot mode.
1010        let (snapshot_guard, has_modifications) = storage.start_snapshot();
1011        assert!(has_modifications);
1012
1013        // Step 3: take_snapshot — task goes into modified list (meta_modified = true).
1014        let shards = storage.take_snapshot(snapshot_guard, &dummy_process);
1015
1016        // Step 4: modify data during snapshot. The `(true, false)` branch fires:
1017        // data was not previously modified, so snapshots gets a None entry.
1018        {
1019            let mut guard = storage.access_mut(task_id);
1020            guard.track_modification(SpecificTaskDataCategory::Data, "test");
1021            assert!(guard.flags.data_modified_during_snapshot());
1022            assert!(!guard.flags.meta_modified_during_snapshot());
1023        }
1024
1025        // Step 5: consume the iterator — must not panic.
1026        let items: Vec<_> = shards
1027            .into_iter()
1028            .flat_map(|shard| shard.into_iter())
1029            .collect();
1030
1031        assert_eq!(items.len(), 1);
1032        assert_eq!(items[0].task_id, task_id);
1033
1034        {
1035            let guard = storage.access_mut(task_id);
1036            // meta_modified was cleared by the iterator (it was the pre-snapshot flag).
1037            assert!(!guard.flags.meta_modified());
1038            // data_modified_during_snapshot was promoted to data_modified.
1039            assert!(guard.flags.data_modified());
1040            assert!(!guard.flags.data_modified_during_snapshot());
1041        }
1042
1043        // Next snapshot cycle must pick up the promoted data_modified.
1044        let (_guard2, has_modifications) = storage.start_snapshot();
1045        assert!(
1046            has_modifications,
1047            "shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
1048        );
1049    }
1050}