turbo_tasks_backend/backend/storage.rs
1use std::{
2 cell::Cell,
3 fmt::{Display, Formatter},
4 hash::{BuildHasher, Hash},
5 ops::{Deref, DerefMut},
6 sync::{
7 Arc,
8 atomic::{AtomicBool, AtomicU64, Ordering},
9 },
10};
11
12use thread_local::ThreadLocal;
13use tracing::span::Id;
14use turbo_bincode::TurboBincodeBuffer;
15use turbo_tasks::{FxDashMap, TaskId, backend::CachedTaskTypeArc, event::Event, parallel};
16
17use crate::{
18 backend::storage_schema::{
19 DropPartialOutcome, KeyEvictability, TaskStorage, UnevictableReason, ValueEvictability,
20 },
21 backing_storage::SnapshotItem,
22 database::key_value_database::KeySpace,
23 utils::{
24 dash_map_drop_contents::drop_contents,
25 dash_map_multi::{RefMut, get_multiple_mut},
26 dash_map_raw_entry::{TryLockAndRemove, try_lock_and_remove},
27 },
28};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub enum TaskDataCategory {
32 Meta,
33 Data,
34 All,
35}
36
37/// Counts of tasks evicted at each level.
38#[derive(Debug, Default)]
39pub struct EvictionCounts {
40 pub key_evictions: usize,
41 pub full: usize,
42 pub data_and_meta: usize,
43 pub data_only: usize,
44 pub meta_only: usize,
45 /// Per-reason counts of tasks we considered but could not evict, indexed by
46 /// `UnevictableReason::index()`.
47 pub unevictable_reasons: [usize; UnevictableReason::COUNT],
48}
49
50impl std::ops::AddAssign for EvictionCounts {
51 fn add_assign(&mut self, rhs: Self) {
52 self.key_evictions += rhs.key_evictions;
53 self.full += rhs.full;
54 self.data_and_meta += rhs.data_and_meta;
55 self.data_only += rhs.data_only;
56 self.meta_only += rhs.meta_only;
57 for i in 0..UnevictableReason::COUNT {
58 self.unevictable_reasons[i] += rhs.unevictable_reasons[i];
59 }
60 }
61}
62
63impl Display for EvictionCounts {
64 /// Compact `field=value,...` form used as a single tracing span field so that
65 /// adding a new counter or `UnevictableReason` variant doesn't require updating
66 /// the span field list.
67 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68 let skipped: usize = self.unevictable_reasons.iter().sum();
69 write!(
70 f,
71 "task_cache_evictions={},full={},data_and_meta={},data_only={},meta_only={},skipped={}",
72 self.key_evictions,
73 self.full,
74 self.data_and_meta,
75 self.data_only,
76 self.meta_only,
77 skipped,
78 )?;
79 for reason in UnevictableReason::ALL {
80 write!(
81 f,
82 ",{}={}",
83 reason.span_name(),
84 self.unevictable_reasons[reason.index()],
85 )?;
86 }
87 Ok(())
88 }
89}
90
91impl TaskDataCategory {
92 pub fn includes_data(self) -> bool {
93 matches!(self, TaskDataCategory::Data | TaskDataCategory::All)
94 }
95
96 pub fn includes_meta(self) -> bool {
97 matches!(self, TaskDataCategory::Meta | TaskDataCategory::All)
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
102pub enum SpecificTaskDataCategory {
103 Meta,
104 Data,
105}
106
107impl From<SpecificTaskDataCategory> for TaskDataCategory {
108 fn from(category: SpecificTaskDataCategory) -> Self {
109 match category {
110 SpecificTaskDataCategory::Meta => TaskDataCategory::Meta,
111 SpecificTaskDataCategory::Data => TaskDataCategory::Data,
112 }
113 }
114}
115
116impl SpecificTaskDataCategory {
117 /// Returns the KeySpace for storing data of this category
118 pub fn key_space(self) -> KeySpace {
119 match self {
120 SpecificTaskDataCategory::Meta => KeySpace::TaskMeta,
121 SpecificTaskDataCategory::Data => KeySpace::TaskData,
122 }
123 }
124}
125
126pub struct Storage {
127 snapshot_mode: AtomicBool,
128 /// Per-shard counts of tasks with modified flags set. Incremented when a task
129 /// transitions from unmodified to modified (outside snapshot mode). Reset to zero when
130 /// snapshot mode begins, and re-incremented in `end_snapshot` for tasks that still have
131 /// modifications (promoted from `modified_during_snapshot`). Used to skip unmodified shards
132 /// in `take_snapshot`, avoiding unnecessary iteration and enabling early returns
133 ///
134 /// Indexed by `map.determine_shard(map.hash_usize(&key))` and guaranteed by construction so
135 /// that `shard_modified_counts.len()==map.shards().len()`
136 ///
137 /// Should only be modified while holding the corresponding dashmap shard lock.
138 shard_modified_counts: Box<[AtomicU64]>,
139 /// Stores snapshots of task state for tasks accessed during snapshot mode.
140 /// - `Some(snapshot)`: Task was modified before snapshot mode and accessed again during it.
141 /// Contains a copy of the pre-snapshot state that needs to be persisted.
142 /// - `None`: Task was first modified during snapshot mode (not part of current snapshot). Will
143 /// be marked as modified at the beginning of the next snapshot cycle.
144 ///
145 /// Lock Ordering: `snapshots` locks are acquired **after** `map` locks (see the comment on
146 /// `map` below). Holding a `snapshots` shard write lock and then trying to take a `map` shard
147 /// write lock is forbidden — it would deadlock against `track_modification_internal` /
148 /// `SnapshotShardIter::next`, which take map first.
149 ///
150 /// Shard Invariant: `snapshots` is constructed with the same `shard_amount`, the same key
151 /// type (`TaskId`), and the same stateless hasher (`FxBuildHasher`) as `map`. Therefore shard
152 /// index `N` in `snapshots` corresponds exactly to shard index `N` in `map`: any `TaskId`
153 /// present in `snapshots.shards()[N]` (if present in `map` at all) is in `map.shards()[N]`.
154 /// Code that walks both maps in parallel (e.g. `end_snapshot`) relies on this to lock pairs
155 /// of shards by index instead of going through the top-level `DashMap` accessors.
156 snapshots: FxDashMap<TaskId, Option<Box<TaskStorage>>>,
157 /// The main storage map
158 ///
159 /// Lock Ordering: Task creation acquires a `task_cache` lock and then inserts into this map.
160 /// Because both datastructures are sharded on different keys, the locks are not 'strictly'
161 /// ordered but we should treat them as such
162 /// Acquiring locks in the opposite order should be defensive
163 ///
164 /// Lock Ordering vs. `snapshots`: `map` locks are acquired **before** `snapshots` locks.
165 /// `track_modification_internal` and `SnapshotShardIter::next` both hold a `map` shard write
166 /// lock (via `StorageWriteGuard` / `map.get_mut`) and then take a `snapshots` shard lock.
167 /// `end_snapshot` must lock in the same order — see the shard-zipping pattern there.
168 map: FxDashMap<TaskId, Box<TaskStorage>>,
169 /// A shared event notified whenever any task finishes restoring (successfully or not).
170 ///
171 /// Threads waiting for another thread's in-progress restore subscribe to this event,
172 /// then re-check the specific task's `restoring`/`restored` bits after waking.
173 pub(crate) restored: Event,
174 /// Maps `CachedTaskType` → `TaskId` for deduplication of persistent task creation.
175 /// This is backed by the TaskCache table in the database.
176 ///
177 /// LockOrdering: See the comments on [map].
178 pub task_cache: FxDashMap<CachedTaskTypeArc, TaskId>,
179}
180
181impl Storage {
182 pub fn new(shard_amount: usize, small_preallocation: bool) -> Self {
183 let map_capacity: usize = if small_preallocation {
184 1024
185 } else {
186 1024 * 1024
187 };
188
189 let map = FxDashMap::with_capacity_and_hasher_and_shard_amount(
190 map_capacity,
191 Default::default(),
192 shard_amount,
193 );
194 let shard_modified_counts = (0..shard_amount)
195 .map(|_| AtomicU64::new(0))
196 .collect::<Vec<_>>()
197 .into_boxed_slice();
198 Self {
199 snapshot_mode: AtomicBool::new(false),
200 shard_modified_counts,
201 snapshots: FxDashMap::with_capacity_and_hasher_and_shard_amount(
202 // We expect very few updates to this map since it will only happen when updates
203 // race with snapshots. This never happens in a build and only rarely happens in
204 // dev sessions
205 0,
206 Default::default(),
207 shard_amount,
208 ),
209 map,
210 restored: Event::new(|| || "Storage::restored".to_string()),
211 task_cache: FxDashMap::default(),
212 }
213 }
214
215 /// Returns the shard index for the given key in the `map` DashMap.
216 fn shard_index(&self, key: &TaskId) -> usize {
217 let hash = self.map.hash_usize(key);
218 self.map.determine_shard(hash)
219 }
220
221 /// Promote `modified_during_snapshot` → `modified` flags on a task, and increment the
222 /// per-shard modified count if the task was not already marked as modified.
223 ///
224 /// This is used after persisting a snapshot: _during_snapshot flags represent changes
225 /// that occurred concurrently and were not included in the persisted snapshot, so they
226 /// must be carried forward as `modified` for the next snapshot cycle.
227 fn promote_during_snapshot_flags(&self, task: &mut TaskStorage, shard_idx: usize) {
228 let already_modified = task.flags.any_modified();
229 let mut promoted = false;
230 if task.flags.meta_modified_during_snapshot() {
231 task.flags.set_meta_modified_during_snapshot(false);
232 task.flags.set_meta_modified(true);
233 promoted = true;
234 }
235 if task.flags.data_modified_during_snapshot() {
236 task.flags.set_data_modified_during_snapshot(false);
237 task.flags.set_data_modified(true);
238 promoted = true;
239 }
240 if !already_modified && promoted {
241 self.shard_modified_counts[shard_idx].fetch_add(1, Ordering::Relaxed);
242 }
243 }
244
245 /// Mark a newly allocated task as restored (skip DB queries) and new (include in persistence
246 /// snapshots). Optionally sets the `persistent_task_type` eagerly so it's available for
247 /// persistence snapshots without needing to propagate it through `connect_child`.
248 pub fn initialize_new_task(&self, task_id: TaskId, task_type: Option<CachedTaskTypeArc>) {
249 let mut task = self.access_mut(task_id);
250 task.flags.set_restored(TaskDataCategory::All);
251 task.flags.set_new_task(true);
252 if let Some(task_type) = task_type {
253 task.set_persistent_task_type(task_type);
254 if !task_id.is_transient() {
255 task.track_modification(SpecificTaskDataCategory::Data, "persistent_task_type");
256 }
257 }
258 }
259
260 /// Processes every modified item (resp. a snapshot of it) with the given function and returns
261 /// the results. Ends snapshot mode when the returned `SnapshotGuard` (held by each shard) is
262 /// dropped.
263 ///
264 /// `process` is called while holding a read lock on the task storage, so it can access
265 /// the TaskStorage directly without cloning.
266 ///
267 /// Both callbacks receive a mutable scratch buffer that can be reused across iterations
268 /// to avoid repeated allocations.
269 ///
270 /// The returned shards implement `IntoIterator`. Empty shards (no modified or snapshot
271 /// entries) are filtered out, but shards may still yield no items if all entries produce
272 /// empty `SnapshotItem`s (this is rare and only happens under error conditions).
273 pub fn take_snapshot<
274 'l,
275 P: for<'a> Fn(TaskId, &'a TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
276 >(
277 &'l self,
278 guard: SnapshotGuard<'l>,
279 process: &'l P,
280 ) -> Vec<SnapshotShard<'l, P>> {
281 let guard = Arc::new(guard);
282
283 let shards: Vec<_> = self.map.shards().iter().enumerate().collect();
284
285 // The number of shards is much larger than the number of threads, so the effect of the
286 // locks held is negligible.
287 parallel::map_collect::<_, _, Vec<_>>(&shards, |&(shard_idx, shard)| {
288 // Check how many modifications there are in this shard, because we have entered
289 // snapshot_mode, there are no racing writes
290 // So we can safely clear it out now that we are processing the modifications
291 let modified_count = self.shard_modified_counts[shard_idx].swap(0, Ordering::Relaxed);
292 if modified_count == 0 {
293 return None;
294 }
295 let mut modified = Vec::with_capacity(modified_count as usize);
296 {
297 let shard_guard = shard.read();
298 // Safety: shard_guard must outlive the iterator.
299 for bucket in unsafe { shard_guard.iter() } {
300 // Safety: the guard guarantees that the bucket is not removed and the ptr
301 // is valid.
302 let (key, shared_value) = unsafe { bucket.as_ref() };
303 let flags = &shared_value.get().flags;
304 // Only check modified flags here — transient tasks never have
305 // modified flags set (track_modification guards against it), so
306 // this naturally excludes them. new_task is always
307 // accompanied by modified flags (set_persistent_task_type calls
308 // track_modification), so any_modified() is sufficient.
309 if flags.any_modified() {
310 if key.is_transient() {
311 debug_assert!(
312 false,
313 "found a modified transient task: {:?}",
314 shared_value.get().get_persistent_task_type()
315 );
316 continue;
317 }
318
319 modified.push(*key);
320 }
321 }
322 // Safety: shard_guard must outlive the iterator.
323 drop(shard_guard);
324 }
325
326 debug_assert!(!modified.is_empty());
327
328 Some(SnapshotShard {
329 shard_idx,
330 modified,
331 storage: self,
332 process,
333 _guard: guard.clone(),
334 })
335 })
336 .into_iter()
337 .flatten()
338 .collect()
339 }
340
341 /// Enter snapshot mode and return a guard that will call `end_snapshot` on drop.
342 ///
343 /// Returns whether any shard has modifications. Per-shard counts are reset
344 /// in `take_snapshot` as each shard is processed, not here — resetting eagerly
345 /// would lose track of modifications for shards that haven't been persisted yet.
346 ///
347 /// Safety invariant: `start_snapshot` and `end_snapshot` are always called
348 /// sequentially within a single `snapshot_and_persist` invocation (the sole
349 /// caller). There is no concurrent snapshot lifecycle, so they cannot race.
350 pub fn start_snapshot(&self) -> (SnapshotGuard<'_>, bool) {
351 // Enter snapshot mode first so concurrent track_modification calls switch
352 // to the _during_snapshot path and stop incrementing shard_modified_counts.
353 self.snapshot_mode.store(true, Ordering::Release);
354 // Check if any shard has modifications. Don't reset counts here —
355 // take_snapshot resets per-shard counts as it processes each shard,
356 // which avoids losing track of modifications for shards that haven't
357 // been persisted yet.
358 let has_modifications = self
359 .shard_modified_counts
360 .iter()
361 .any(|c| c.load(Ordering::Relaxed) > 0);
362 (SnapshotGuard::new(self), has_modifications)
363 }
364
365 /// End snapshot mode.
366 ///
367 /// Modified/new flags on tasks are cleared incrementally during snapshot iteration
368 /// (in `take_snapshot` for direct_snapshots, and in `SnapshotShardIter::next` for
369 /// modified tasks), so no full-map scan is needed here.
370 ///
371 /// This method only needs to:
372 /// 1. Leave snapshot mode so new modifications go to the modified flags directly.
373 /// 2. Promote `modified_during_snapshot` → `modified` for tasks that were accessed during
374 /// snapshot mode (tracked in the small `snapshots` map).
375 fn end_snapshot(&self) {
376 // Leave snapshot mode first. After this, concurrent track_modification calls
377 // will set modified flags directly instead of going through the snapshots map.
378 self.snapshot_mode.store(false, Ordering::Release);
379
380 // Promote modified_during_snapshot → modified for tasks that had snapshots.
381 // The snapshots map should be small (only tasks concurrently accessed during snapshot
382 // mode). Increment the per-shard modified counts for promoted tasks.
383
384 // Lock Ordering: we must acquire `map` shards BEFORE `snapshots` shards, matching the
385 // order used by `track_modification_internal` and `SnapshotShardIter::next`. The
386 // previous implementation drained `snapshots` first and then called `self.map.get_mut`,
387 // which is the opposite order — a concurrent `track_modification` (holding map[N], about
388 // to insert into snapshots[N]) could deadlock against it through the
389 // `snapshot_mode = false` race window.
390 //
391 // Shard pairing: `map` and `snapshots` are constructed with the same `shard_amount`,
392 // same `TaskId` keys, and the same stateless `FxBuildHasher`. Therefore shard `N` in
393 // `snapshots` pairs with shard `N` in `map`: every key drained from `snapshots[N]` (if
394 // it still exists in `map`) lives in `map[N]`. We zip them and lock each pair in order.
395 let map_shards = self.map.shards();
396 let snapshot_shards = self.snapshots.shards();
397 debug_assert_eq!(
398 map_shards.len(),
399 snapshot_shards.len(),
400 "map and snapshots must share shard count for zipped locking; see Shard Invariant on \
401 `snapshots` field"
402 );
403
404 let shard_indices: Vec<usize> = (0..map_shards.len()).collect();
405 parallel::for_each(&shard_indices, |&shard_idx| {
406 let map_shard = &map_shards[shard_idx];
407 let snap_shard = &snapshot_shards[shard_idx];
408
409 // Acquire in documented order: map first, snapshots second.
410 let map_guard = map_shard.write();
411 let mut snap_guard = snap_shard.write();
412
413 for (key, _) in snap_guard.drain() {
414 // The key is in this shard's `map` (or absent entirely), by the shard
415 // invariant above. Resolve directly in the held map guard rather than going
416 // through `self.map.get_mut`, which would attempt to re-acquire this shard's
417 // write lock and would also obscure the pairing.
418 let hash = self.map.hasher().hash_one(key);
419 if let Some(bucket) = map_guard.find(hash, |(k, _)| *k == key) {
420 // SAFETY: We hold `map_shard`'s write lock for the duration of this
421 // access, so the bucket pointer is valid and no other thread can alias it.
422 let (_, shared_value) = unsafe { bucket.as_mut() };
423 self.promote_during_snapshot_flags(shared_value.get_mut(), shard_idx);
424 }
425 }
426 // If we are saving a non-trivial amount of memory just clear it out.
427 if snap_guard.capacity() > 1024 {
428 snap_guard.shrink_to(0, |_entry| {
429 unreachable!("nothing is hashed when resizing an empty shard to zero");
430 });
431 }
432
433 drop(snap_guard);
434 drop(map_guard);
435 });
436 }
437
438 /// Returns true if actively snapshotting (modifications should go to snapshots map).
439 /// Returns false if inactive (modifications go to modified list).
440 fn snapshot_mode(&self) -> bool {
441 self.snapshot_mode.load(Ordering::Acquire)
442 }
443
444 pub fn access_mut(&self, key: TaskId) -> StorageWriteGuard<'_> {
445 let inner = match self.map.entry(key) {
446 dashmap::mapref::entry::Entry::Occupied(e) => e.into_ref(),
447 dashmap::mapref::entry::Entry::Vacant(e) => e.insert(Box::new(TaskStorage::new())),
448 };
449 StorageWriteGuard {
450 storage: self,
451 inner: inner.into(),
452 }
453 }
454
455 pub fn access_pair_mut(
456 &self,
457 key1: TaskId,
458 key2: TaskId,
459 ) -> (StorageWriteGuard<'_>, StorageWriteGuard<'_>) {
460 let (a, b) = get_multiple_mut(&self.map, key1, key2, || Box::new(TaskStorage::new()));
461 (
462 StorageWriteGuard {
463 storage: self,
464 inner: a,
465 },
466 StorageWriteGuard {
467 storage: self,
468 inner: b,
469 },
470 )
471 }
472
473 pub fn drop_contents(&self) {
474 drop_contents(&self.map);
475 drop_contents(&self.snapshots);
476 drop_contents(&self.task_cache);
477 }
478
479 /// Evict tasks from in-memory storage after a successful snapshot.
480 ///
481 /// Iterates all tasks and applies the eviction level returned by
482 /// `TaskStorage::evictability()`:
483 /// - `Full`: remove from map entirely
484 /// - `DataAndMeta`: drop both data and meta fields, keep task in map
485 /// - `DataOnly`: drop data fields only
486 /// - `MetaOnly`: drop meta fields only
487 /// - `No`: skip
488 ///
489 /// Must be called when NOT in snapshot mode (i.e., after `end_snapshot()`).
490 pub fn evict_after_snapshot(&self, parent_span: Option<Id>) -> EvictionCounts {
491 let span = tracing::trace_span!(
492 parent: parent_span,
493 "evict_after_snapshot",
494 total_task_cache_keys = self.task_cache.len(),
495 total_map_keys = self.map.len(),
496 counts = tracing::field::Empty,
497 )
498 .entered();
499 debug_assert!(
500 !self.snapshot_mode(),
501 "evict_after_snapshot must not be called during snapshot mode"
502 );
503
504 let counts: Vec<EvictionCounts> = parallel::map_collect(self.map.shards(), |shard| {
505 let mut shard = shard.write();
506 let mut evicted = EvictionCounts::default();
507 // task_cache removals that we couldn't perform inline because the target shard
508 // was contended. We defer them until after the map shard lock is released to
509 // avoid a lock cycle with get_or_create_persistent_task, which takes task_cache
510 // before map. Allocated lazily on first conflict.
511 let mut deferred_task_cache_removals: Vec<CachedTaskTypeArc> = Vec::new();
512 // SAFETY: We hold the write lock for the duration of iteration.
513 for bucket in unsafe { shard.iter() } {
514 // SAFETY: The write lock guard outlives the bucket reference.
515 let (task_id, task) = unsafe { bucket.as_mut() };
516 if task_id.is_transient() {
517 evicted.unevictable_reasons[UnevictableReason::Transient.index()] += 1;
518 continue;
519 }
520 let (key_evictability, value_evictability) = task.get().evictability();
521 match key_evictability {
522 KeyEvictability::Evictable => {
523 // The task type is persisted to backing storage (new_task = false),
524 // so task_cache is a pure perf cache. Remove it now; it will be
525 // re-populated by task_by_type() on the next cache miss.
526 let task_type = task.get().get_persistent_task_type().unwrap();
527 // Only try to acquire the lock, if we cannot just remove at the end
528 // Because `get_or_create_task` acquires 'task_cache' then `storage.map` and
529 // we do the opposite we need to be defensive here. Attempting here is just
530 // an optimization to avoid pushing into `deferred_task_cache_removals`
531 match try_lock_and_remove(&self.task_cache, task_type.as_ref()) {
532 TryLockAndRemove::Removed => {
533 evicted.key_evictions += 1;
534 }
535 TryLockAndRemove::NotFound => {
536 // Generally this should be rare, it more or less implies something
537 // else is concurrently holding the Arc
538 }
539 TryLockAndRemove::WouldBlock => {
540 // Contention, to avoid a deadlock just defer
541 deferred_task_cache_removals.push(task_type.clone());
542 }
543 }
544 }
545 KeyEvictability::AlreadyEvicted | KeyEvictability::Unevictable => {}
546 }
547 match value_evictability {
548 ValueEvictability::Evictable { meta, data } => {
549 match task.get_mut().drop_partial(data, meta) {
550 DropPartialOutcome::Empty => {
551 unsafe {
552 shard.erase(bucket);
553 }
554 evicted.full += 1;
555 }
556 DropPartialOutcome::HasResidue => {
557 if data && meta {
558 evicted.data_and_meta += 1;
559 } else if data {
560 evicted.data_only += 1;
561 } else {
562 debug_assert!(meta);
563 evicted.meta_only += 1;
564 }
565 }
566 }
567 }
568 ValueEvictability::Unevictable(reason) => {
569 evicted.unevictable_reasons[reason.index()] += 1;
570 }
571 }
572 }
573 // Shrink the shard if it's less than half full, to reclaim slack capacity
574 // after bulk evictions. We already hold the write lock, so this is free
575 // from a locking perspective. TaskId hashing is cheap (it's just an integer).
576 let len = shard.len();
577 if shard.capacity() > len * 2 {
578 shard.shrink_to(len, |(k, _v)| self.map.hasher().hash_one(k));
579 }
580 // Release the map shard lock before draining deferred removals so that a thread
581 // holding a task_cache shard lock and waiting on this map shard can make progress.
582 drop(shard);
583 for task_type in deferred_task_cache_removals {
584 if self.task_cache.remove(task_type.as_ref()).is_some() {
585 evicted.key_evictions += 1;
586 }
587 }
588 evicted
589 });
590
591 let mut totals = EvictionCounts::default();
592 for evicted in counts {
593 totals += evicted;
594 }
595 // Shrink task_cache only when we evicted more entries than remain — i.e. the map
596 // is less than half full. Rehashing each surviving CachedTaskType isn't free, so
597 // we gate it on meaningful slack. Within that, walk shards in parallel and shrink
598 // each one independently if it is itself less than half full.
599 if totals.key_evictions > self.task_cache.len() {
600 parallel::for_each(self.task_cache.shards(), |shard| {
601 let mut shard = shard.write();
602 let len = shard.len();
603 if shard.capacity() > len * 2 {
604 shard.shrink_to(len, |(k, _v)| self.task_cache.hasher().hash_one(k));
605 }
606 });
607 }
608 span.record("counts", tracing::field::display(&totals));
609
610 totals
611 }
612}
613
614pub struct StorageWriteGuard<'a> {
615 storage: &'a Storage,
616 inner: RefMut<'a, TaskId, Box<TaskStorage>>,
617}
618
619impl StorageWriteGuard<'_> {
620 /// Tracks mutation of this task
621 #[inline(always)]
622 pub fn track_modification(
623 &mut self,
624 category: SpecificTaskDataCategory,
625 #[allow(unused_variables)] name: &str,
626 ) {
627 debug_assert!(
628 !self.inner.key().is_transient(),
629 "transient task_ids should never be enqueued to be persisted"
630 );
631 self.track_modification_internal(
632 category,
633 #[cfg(feature = "trace_task_modification")]
634 name,
635 );
636 }
637
638 fn track_modification_internal(
639 &mut self,
640 category: SpecificTaskDataCategory,
641 #[cfg(feature = "trace_task_modification")] name: &str,
642 ) {
643 // Transient tasks are never persisted, so tracking modifications is meaningless.
644 // All callers (TaskGuard, invalidate_serialization) already
645 // guard against this, but we enforce it here as defense-in-depth.
646 debug_assert!(
647 !self.inner.key().is_transient(),
648 "track_modification called on transient task {:?}",
649 self.inner.key()
650 );
651 let flags = &self.inner.flags;
652 if flags.is_modified_during_snapshot(category) {
653 // We can early return since `end_snapshot` is responsible for reconciling.
654 return;
655 }
656 #[cfg(feature = "trace_task_modification")]
657 let _span = (!modified).then(|| tracing::trace_span!("mark_modified", name).entered());
658 match (self.storage.snapshot_mode(), flags.is_modified(category)) {
659 (false, false) => {
660 // Not in snapshot mode and item is unmodified
661 if !flags.any_modified() {
662 let shard_idx = self.storage.shard_index(self.inner.key());
663 self.storage.shard_modified_counts[shard_idx].fetch_add(1, Ordering::Relaxed);
664 }
665 self.inner.flags.set_modified(category, true);
666 }
667 (false, true) => {
668 // Not in snapshot mode and item is already modified
669 // Do nothing
670 }
671 (true, false) => {
672 // In snapshot mode and item is unmodified (so it's not part of the snapshot)
673 // Mark it so it gets re-added as Modified after this snapshot completes.
674 // Insert a None entry into snapshots so end_snapshot discovers this task
675 // and promotes its _during_snapshot flags.
676 if !flags.any_modified_during_snapshot() {
677 self.storage.snapshots.insert(*self.inner.key(), None);
678 }
679 self.inner
680 .flags
681 .set_modified_during_snapshot(category, true);
682 }
683 (true, true) => {
684 // In snapshot mode and item is modified (so it's part of the snapshot)
685 // We need to store the original version that is part of the snapshot
686 if !flags.any_modified_during_snapshot() {
687 // Snapshot all non-transient fields, carrying the modified bits into
688 // the copy so the iterator knows which categories to persist.
689 let mut snapshot = self.inner.clone_snapshot();
690 snapshot.flags.set_data_modified(flags.data_modified());
691 snapshot.flags.set_meta_modified(flags.meta_modified());
692 snapshot.flags.set_new_task(flags.new_task());
693 self.storage
694 .snapshots
695 .insert(*self.inner.key(), Some(Box::new(snapshot)));
696 }
697 self.inner
698 .flags
699 .set_modified_during_snapshot(category, true);
700 }
701 }
702 }
703}
704
705impl Deref for StorageWriteGuard<'_> {
706 type Target = TaskStorage;
707
708 fn deref(&self) -> &Self::Target {
709 &self.inner
710 }
711}
712
713impl DerefMut for StorageWriteGuard<'_> {
714 fn deref_mut(&mut self) -> &mut Self::Target {
715 &mut self.inner
716 }
717}
718
719/// How big of a buffer to allocate initially. Based on metrics from a large
720/// application this should cover about 98% of values with no resizes.
721const SCRATCH_BUFFER_INITIAL_SIZE: usize = 4096;
722
723/// State machine for a per-thread scratch buffer slot.
724///
725/// Transitions:
726/// - `Uninit` → `Taken` (first take)
727/// - `Available` → `Taken` (subsequent takes)
728/// - `Taken` → `Available` (return)
729///
730/// Any other transition is a bug (e.g. double-take or double-return).
731#[derive(Default)]
732enum ScratchBufferSlot {
733 /// No buffer has been allocated on this thread yet.
734 #[default]
735 Uninit,
736 /// The buffer is currently checked out.
737 Taken,
738 /// The buffer is available for reuse.
739 Available(TurboBincodeBuffer),
740}
741
742pub struct SnapshotGuard<'l> {
743 storage: &'l Storage,
744 /// Per-thread scratch buffers for encoding task data. Buffers are taken
745 /// by `SnapshotShardIter` on creation and returned on drop, allowing reuse
746 /// across multiple shards processed by the same thread. When the guard is
747 /// dropped (after all iterators are done), the `ThreadLocal` drops too,
748 /// freeing all buffers.
749 scratch_buffers: ThreadLocal<Cell<ScratchBufferSlot>>,
750}
751
752impl<'l> SnapshotGuard<'l> {
753 fn new(storage: &'l Storage) -> Self {
754 Self {
755 storage,
756 scratch_buffers: ThreadLocal::new(),
757 }
758 }
759
760 fn take_scratch_buffer(&self) -> TurboBincodeBuffer {
761 let cell = self.scratch_buffers.get_or_default();
762 match cell.take() {
763 ScratchBufferSlot::Available(buf) => {
764 cell.set(ScratchBufferSlot::Taken);
765 buf
766 }
767 ScratchBufferSlot::Uninit => {
768 cell.set(ScratchBufferSlot::Taken);
769 TurboBincodeBuffer::with_capacity(SCRATCH_BUFFER_INITIAL_SIZE)
770 }
771 ScratchBufferSlot::Taken => {
772 panic!("scratch buffer taken twice without being returned");
773 }
774 }
775 }
776
777 fn return_scratch_buffer(&self, buffer: TurboBincodeBuffer) {
778 let cell = self.scratch_buffers.get_or_default();
779 match cell.take() {
780 ScratchBufferSlot::Taken => cell.set(ScratchBufferSlot::Available(buffer)),
781 ScratchBufferSlot::Available(_) => {
782 panic!("scratch buffer returned without being taken (already available)");
783 }
784 ScratchBufferSlot::Uninit => {
785 panic!("scratch buffer returned without being taken (uninit)");
786 }
787 }
788 }
789}
790
791impl Drop for SnapshotGuard<'_> {
792 fn drop(&mut self) {
793 self.storage.end_snapshot();
794 }
795}
796
797pub struct SnapshotShard<'l, P> {
798 shard_idx: usize,
799 modified: Vec<TaskId>,
800 storage: &'l Storage,
801 process: &'l P,
802 /// Held for its `Drop` impl — ensures snapshot mode ends when all shards are done.
803 _guard: Arc<SnapshotGuard<'l>>,
804}
805
806impl<'l, P> IntoIterator for SnapshotShard<'l, P>
807where
808 P: Fn(TaskId, &TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
809{
810 type Item = SnapshotItem;
811 type IntoIter = SnapshotShardIter<'l, P>;
812
813 fn into_iter(self) -> Self::IntoIter {
814 let buffer = self._guard.take_scratch_buffer();
815 SnapshotShardIter {
816 shard: self,
817 buffer,
818 }
819 }
820}
821
822/// Iterator over a single shard's snapshot items. Holds a thread-local scratch
823/// buffer for the duration of iteration and returns it on drop.
824pub struct SnapshotShardIter<'l, P> {
825 shard: SnapshotShard<'l, P>,
826 buffer: TurboBincodeBuffer,
827}
828
829impl<'l, P> Iterator for SnapshotShardIter<'l, P>
830where
831 P: Fn(TaskId, &TaskStorage, &mut TurboBincodeBuffer) -> SnapshotItem + Sync,
832{
833 type Item = SnapshotItem;
834
835 fn next(&mut self) -> Option<Self::Item> {
836 if let Some(task_id) = self.shard.modified.pop() {
837 let mut inner = self.shard.storage.map.get_mut(&task_id).unwrap();
838 // If the task was re-modified during snapshot, the snapshots map may
839 // hold a pre-modification copy we must serialize instead of the live
840 // data. Remove the entry so end_snapshot doesn't double-promote it;
841 // we promote manually below.
842 let item = if inner.flags.any_modified_during_snapshot() {
843 match self.shard.storage.snapshots.remove(&task_id) {
844 Some((_, Some(snapshot))) => {
845 (self.shard.process)(task_id, &snapshot, &mut self.buffer)
846 }
847 Some((_, None)) | None => {
848 (self.shard.process)(task_id, &inner, &mut self.buffer)
849 }
850 }
851 } else {
852 (self.shard.process)(task_id, &inner, &mut self.buffer)
853 };
854 // Clear the modified flags that were captured into the snapshot copy,
855 // then promote modified_during_snapshot → modified so the task stays
856 // dirty for the next snapshot cycle.
857 inner.flags.set_data_modified(false);
858 inner.flags.set_meta_modified(false);
859 inner.flags.set_new_task(false);
860 self.shard
861 .storage
862 .promote_during_snapshot_flags(&mut inner, self.shard.shard_idx);
863 return Some(item);
864 }
865 None
866 }
867}
868
869impl<P> Drop for SnapshotShardIter<'_, P> {
870 fn drop(&mut self) {
871 self.shard
872 ._guard
873 .return_scratch_buffer(std::mem::take(&mut self.buffer));
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use turbo_bincode::TurboBincodeBuffer;
880 use turbo_tasks::TaskId;
881
882 use super::{SpecificTaskDataCategory, Storage};
883 use crate::backing_storage::SnapshotItem;
884
885 fn non_transient_task(id: u32) -> TaskId {
886 // TRANSIENT_TASK_BIT is 0x8000_0000; any id without that bit is non-transient.
887 TaskId::new(id).expect("id must be non-zero")
888 }
889
890 /// A process fn that returns a non-empty SnapshotItem so the iterator doesn't
891 /// silently skip items via the "encoding failed" error path.
892 fn dummy_process(
893 task_id: TaskId,
894 _: &super::TaskStorage,
895 _: &mut TurboBincodeBuffer,
896 ) -> SnapshotItem {
897 SnapshotItem {
898 task_id,
899 meta: Some(TurboBincodeBuffer::default()),
900 data: None,
901 task_type_hash: None,
902 }
903 }
904
905 /// Regression test: a task modified before a snapshot and then modified *again* during
906 /// snapshot iteration must serialize the pre-snapshot state and carry the during-snapshot
907 /// modification forward to the next cycle.
908 ///
909 /// Sequence of events:
910 /// 1. Task is modified (data_modified = true) → added to shard_modified_counts.
911 /// 2. `start_snapshot` puts us in snapshot mode.
912 /// 3. `take_snapshot` scans the shard: task has `any_modified()=true` → goes into the
913 /// `modified` list.
914 /// 4. **Between scan and iteration**: `track_modification` is called on the same category. This
915 /// is the `(true, true)` branch: already modified AND in snapshot mode. A snapshot copy of
916 /// the pre-second-modification state is stored in `snapshots` as `Some(copy)`, and
917 /// `data_modified_during_snapshot` is set.
918 /// 5. `SnapshotShardIter::next` processes the task from the `modified` list, detects
919 /// `any_modified_during_snapshot()=true`, finds the `Some(copy)` in `snapshots`, encodes the
920 /// pre-snapshot copy, clears the live modified flags, removes the snapshots entry, and
921 /// promotes `data_modified_during_snapshot → data_modified` for the next cycle.
922 // `end_snapshot` uses `parallel::for_each` which calls `block_in_place` internally,
923 // requiring a multi-threaded Tokio runtime.
924 #[tokio::test(flavor = "multi_thread")]
925 async fn modify_during_snapshot_clears_live_modified_flags() {
926 let storage = Storage::new(2, true);
927 let task_id = non_transient_task(1);
928
929 // Step 1: modify the task outside snapshot mode (data_modified = true).
930 {
931 let mut guard = storage.access_mut(task_id);
932 guard.track_modification(SpecificTaskDataCategory::Data, "test");
933 }
934
935 // Step 2: enter snapshot mode.
936 let (snapshot_guard, has_modifications) = storage.start_snapshot();
937 assert!(has_modifications);
938
939 // Step 3: `take_snapshot` scans the shard. At this point the task has
940 // `any_modified()=true` and `any_modified_during_snapshot()=false`, so it
941 // goes into the `modified` list inside the returned `SnapshotShard`.
942 let shards = storage.take_snapshot(snapshot_guard, &dummy_process);
943
944 // Step 4: now that the scan is done but before we consume the iterator,
945 // modify the task again. We're still in snapshot mode, the task is already
946 // modified → `(true, true)` branch: creates a snapshot copy (carrying the
947 // modified bits) and sets `data_modified_during_snapshot=true`.
948 {
949 let mut guard = storage.access_mut(task_id);
950 guard.track_modification(SpecificTaskDataCategory::Data, "test");
951 // We should have set a snapshot bit
952 assert!(guard.flags.data_modified_during_snapshot())
953 }
954
955 // Step 5: consume the iterator. The iterator encodes from the pre-snapshot copy,
956 // clears the live modified flags, removes the snapshots entry, and promotes
957 // `data_modified_during_snapshot → data_modified` for the next cycle.
958 let items: Vec<_> = shards
959 .into_iter()
960 .flat_map(|shard| shard.into_iter())
961 .collect();
962
963 // The pre-snapshot snapshot copy should have been encoded and returned.
964 assert_eq!(items.len(), 1);
965 assert_eq!(items[0].task_id, task_id);
966
967 {
968 let guard = storage.access_mut(task_id);
969 // The iterator should have promoted modified_during_snapshot → modified.
970 assert!(guard.flags.data_modified());
971 }
972
973 // The during-snapshot modification must be reflected in shard_modified_counts so
974 // the next snapshot cycle picks it up. Verify by starting another snapshot.
975 let (_guard2, has_modifications) = storage.start_snapshot();
976 assert!(
977 has_modifications,
978 "shard_modified_counts must be non-zero after promoting modified_during_snapshot"
979 );
980 }
981
982 /// Regression test for the `(true, false)` during-snapshot case: a task modified in one
983 /// category before a snapshot, then modified in a *different* category during snapshot
984 /// iteration, must not panic and must carry both modifications forward correctly.
985 ///
986 /// Sequence of events:
987 /// 1. Task meta is modified (meta_modified = true).
988 /// 2. `start_snapshot` puts us in snapshot mode.
989 /// 3. `take_snapshot` scans the shard: task goes into the `modified` list.
990 /// 4. Task data is modified during snapshot → `(true, false)` branch: data was not previously
991 /// modified, so `snapshots` gets a `None` entry and `data_modified_during_snapshot` is set.
992 /// 5. `SnapshotShardIter::next` processes the task: finds `any_modified_during_snapshot()`,
993 /// sees `None` in snapshots, encodes from live data (correct — live data for the
994 /// unmodified-before-snapshot category is still the pre-snapshot state), clears pre-snapshot
995 /// flags, and promotes `data_modified_during_snapshot → data_modified`.
996 #[tokio::test(flavor = "multi_thread")]
997 async fn modify_different_category_during_snapshot() {
998 let storage = Storage::new(2, true);
999 let task_id = non_transient_task(1);
1000
1001 // Step 1: modify meta only, outside snapshot mode.
1002 {
1003 let mut guard = storage.access_mut(task_id);
1004 guard.track_modification(SpecificTaskDataCategory::Meta, "test");
1005 assert!(guard.flags.meta_modified());
1006 assert!(!guard.flags.data_modified());
1007 }
1008
1009 // Step 2: enter snapshot mode.
1010 let (snapshot_guard, has_modifications) = storage.start_snapshot();
1011 assert!(has_modifications);
1012
1013 // Step 3: take_snapshot — task goes into modified list (meta_modified = true).
1014 let shards = storage.take_snapshot(snapshot_guard, &dummy_process);
1015
1016 // Step 4: modify data during snapshot. The `(true, false)` branch fires:
1017 // data was not previously modified, so snapshots gets a None entry.
1018 {
1019 let mut guard = storage.access_mut(task_id);
1020 guard.track_modification(SpecificTaskDataCategory::Data, "test");
1021 assert!(guard.flags.data_modified_during_snapshot());
1022 assert!(!guard.flags.meta_modified_during_snapshot());
1023 }
1024
1025 // Step 5: consume the iterator — must not panic.
1026 let items: Vec<_> = shards
1027 .into_iter()
1028 .flat_map(|shard| shard.into_iter())
1029 .collect();
1030
1031 assert_eq!(items.len(), 1);
1032 assert_eq!(items[0].task_id, task_id);
1033
1034 {
1035 let guard = storage.access_mut(task_id);
1036 // meta_modified was cleared by the iterator (it was the pre-snapshot flag).
1037 assert!(!guard.flags.meta_modified());
1038 // data_modified_during_snapshot was promoted to data_modified.
1039 assert!(guard.flags.data_modified());
1040 assert!(!guard.flags.data_modified_during_snapshot());
1041 }
1042
1043 // Next snapshot cycle must pick up the promoted data_modified.
1044 let (_guard2, has_modifications) = storage.start_snapshot();
1045 assert!(
1046 has_modifications,
1047 "shard_modified_counts must be non-zero after promoting data_modified_during_snapshot"
1048 );
1049 }
1050}