Skip to main content

turbo_tasks_backend/
kv_backing_storage.rs

1use std::{
2    borrow::Borrow,
3    env,
4    path::PathBuf,
5    sync::{Arc, LazyLock, Mutex, PoisonError, Weak},
6};
7
8use anyhow::{Context, Result};
9use smallvec::SmallVec;
10use turbo_bincode::{
11    TurboBincodeBuffer, new_turbo_bincode_decoder, turbo_bincode_decode, turbo_bincode_encode,
12};
13use turbo_tasks::{
14    TaskId,
15    backend::CachedTaskType,
16    panic_hooks::{PanicHookGuard, register_panic_hook},
17    parallel,
18};
19use turbo_tasks_hash::Xxh3Hash64Hasher;
20
21use crate::{
22    GitVersionInfo,
23    backend::{AnyOperation, SpecificTaskDataCategory, storage_schema::TaskStorage},
24    backing_storage::{BackingStorage, BackingStorageSealed},
25    database::{
26        db_invalidation::{StartupCacheState, check_db_invalidation_and_cleanup, invalidate_db},
27        db_versioning::handle_db_versioning,
28        key_value_database::{KeySpace, KeyValueDatabase},
29        write_batch::{
30            BaseWriteBatch, ConcurrentWriteBatch, SerialWriteBatch, WriteBatch, WriteBatchRef,
31            WriteBuffer,
32        },
33    },
34    db_invalidation::invalidation_reasons,
35    utils::chunked_vec::ChunkedVec,
36};
37
38const META_KEY_OPERATIONS: u32 = 0;
39const META_KEY_NEXT_FREE_TASK_ID: u32 = 1;
40
41struct IntKey([u8; 4]);
42
43impl IntKey {
44    fn new(value: u32) -> Self {
45        Self(value.to_le_bytes())
46    }
47}
48
49impl AsRef<[u8]> for IntKey {
50    fn as_ref(&self) -> &[u8] {
51        &self.0
52    }
53}
54
55fn as_u32(bytes: impl Borrow<[u8]>) -> Result<u32> {
56    let n = u32::from_le_bytes(bytes.borrow().try_into()?);
57    Ok(n)
58}
59
60// We want to invalidate the cache on panic for most users, but this is a band-aid to underlying
61// problems in turbo-tasks.
62//
63// If we invalidate the cache upon panic and it "fixes" the issue upon restart, users typically
64// won't report bugs to us, and we'll never find root-causes for these problems.
65//
66// These overrides let us avoid the cache invalidation / error suppression within Vercel so that we
67// feel these pain points and fix the root causes of bugs.
68fn should_invalidate_on_panic() -> bool {
69    fn env_is_falsy(key: &str) -> bool {
70        env::var_os(key)
71            .is_none_or(|value| ["".as_ref(), "0".as_ref(), "false".as_ref()].contains(&&*value))
72    }
73    static SHOULD_INVALIDATE: LazyLock<bool> = LazyLock::new(|| {
74        env_is_falsy("TURBO_ENGINE_SKIP_INVALIDATE_ON_PANIC") && env_is_falsy("__NEXT_TEST_MODE")
75    });
76    *SHOULD_INVALIDATE
77}
78
79pub struct KeyValueDatabaseBackingStorageInner<T: KeyValueDatabase> {
80    database: T,
81    /// Used when calling [`BackingStorage::invalidate`]. Can be `None` in the memory-only/no-op
82    /// storage case.
83    base_path: Option<PathBuf>,
84    /// Used to skip calling [`invalidate_db`] when the database has already been invalidated.
85    invalidated: Mutex<bool>,
86    /// We configure a panic hook to invalidate the cache. This guard cleans up our panic hook upon
87    /// drop.
88    _panic_hook_guard: Option<PanicHookGuard>,
89}
90
91pub struct KeyValueDatabaseBackingStorage<T: KeyValueDatabase> {
92    // wrapped so that `register_panic_hook` can hold a weak reference to `inner`.
93    inner: Arc<KeyValueDatabaseBackingStorageInner<T>>,
94}
95
96/// A wrapper type used by [`crate::turbo_backing_storage`] and [`crate::noop_backing_storage`].
97///
98/// Wraps a low-level key-value database into a higher-level [`BackingStorage`] type.
99impl<T: KeyValueDatabase> KeyValueDatabaseBackingStorage<T> {
100    pub(crate) fn new_in_memory(database: T) -> Self {
101        Self {
102            inner: Arc::new(KeyValueDatabaseBackingStorageInner {
103                database,
104                base_path: None,
105                invalidated: Mutex::new(false),
106                _panic_hook_guard: None,
107            }),
108        }
109    }
110
111    /// Handles boilerplate logic for an on-disk persisted database with versioning.
112    ///
113    /// - Creates a directory per version, with a maximum number of old versions and performs
114    ///   automatic cleanup of old versions.
115    /// - Checks for a database invalidation marker file, and cleans up the database as needed.
116    /// - [Registers a dynamic panic hook][turbo_tasks::panic_hooks] to invalidate the database upon
117    ///   a panic. This invalidates the database using [`invalidation_reasons::PANIC`].
118    ///
119    /// Along with returning a [`KeyValueDatabaseBackingStorage`], this returns a
120    /// [`StartupCacheState`], which can be used by the application for logging information to the
121    /// user or telemetry about the cache.
122    pub(crate) fn open_versioned_on_disk(
123        base_path: PathBuf,
124        version_info: &GitVersionInfo,
125        is_ci: bool,
126        database: impl FnOnce(PathBuf) -> Result<T>,
127    ) -> Result<(Self, StartupCacheState)>
128    where
129        T: Send + Sync + 'static,
130    {
131        let startup_cache_state = check_db_invalidation_and_cleanup(&base_path)
132            .context("Failed to check database invalidation and cleanup")?;
133        let versioned_path = handle_db_versioning(&base_path, version_info, is_ci)
134            .context("Failed to handle database versioning")?;
135        let database = (database)(versioned_path).context("Failed to open database")?;
136        let backing_storage = Self {
137            inner: Arc::new_cyclic(
138                move |weak_inner: &Weak<KeyValueDatabaseBackingStorageInner<T>>| {
139                    let panic_hook_guard = if should_invalidate_on_panic() {
140                        let weak_inner = weak_inner.clone();
141                        Some(register_panic_hook(Box::new(move |_| {
142                            let Some(inner) = weak_inner.upgrade() else {
143                                return;
144                            };
145                            // If a panic happened that must mean something deep inside of turbopack
146                            // or turbo-tasks failed, and it may be hard to recover. We don't want
147                            // the cache to stick around, as that may persist bugs. Make a
148                            // best-effort attempt to invalidate the database (ignoring failures).
149                            let _ = inner.invalidate(invalidation_reasons::PANIC);
150                        })))
151                    } else {
152                        None
153                    };
154                    KeyValueDatabaseBackingStorageInner {
155                        database,
156                        base_path: Some(base_path),
157                        invalidated: Mutex::new(false),
158                        _panic_hook_guard: panic_hook_guard,
159                    }
160                },
161            ),
162        };
163        Ok((backing_storage, startup_cache_state))
164    }
165}
166
167impl<T: KeyValueDatabase> KeyValueDatabaseBackingStorageInner<T> {
168    fn with_tx<R>(
169        &self,
170        tx: Option<&T::ReadTransaction<'_>>,
171        f: impl FnOnce(&T::ReadTransaction<'_>) -> Result<R>,
172    ) -> Result<R> {
173        if let Some(tx) = tx {
174            f(tx)
175        } else {
176            let tx = self.database.begin_read_transaction()?;
177            let r = f(&tx)?;
178            drop(tx);
179            Ok(r)
180        }
181    }
182
183    fn invalidate(&self, reason_code: &str) -> Result<()> {
184        // `base_path` can be `None` for a `NoopKvDb`
185        if let Some(base_path) = &self.base_path {
186            // Invalidation could happen frequently if there's a bunch of panics. We only need to
187            // invalidate once, so grab a lock.
188            let mut invalidated_guard = self
189                .invalidated
190                .lock()
191                .unwrap_or_else(PoisonError::into_inner);
192            if *invalidated_guard {
193                return Ok(());
194            }
195            // Invalidate first, as it's a very fast atomic operation. `prevent_writes` is allowed
196            // to be slower (e.g. wait for a lock) and is allowed to corrupt the database with
197            // partial writes.
198            invalidate_db(base_path, reason_code)?;
199            self.database.prevent_writes();
200            // Avoid redundant invalidations from future panics
201            *invalidated_guard = true;
202        }
203        Ok(())
204    }
205
206    /// Used to read the next free task ID from the database.
207    fn get_infra_u32(&self, key: u32) -> Result<Option<u32>> {
208        let tx = self.database.begin_read_transaction()?;
209        self.database
210            .get(&tx, KeySpace::Infra, IntKey::new(key).as_ref())?
211            .map(as_u32)
212            .transpose()
213    }
214}
215
216impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorage
217    for KeyValueDatabaseBackingStorage<T>
218{
219    fn invalidate(&self, reason_code: &str) -> Result<()> {
220        self.inner.invalidate(reason_code)
221    }
222}
223
224impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorageSealed
225    for KeyValueDatabaseBackingStorage<T>
226{
227    type ReadTransaction<'l> = T::ReadTransaction<'l>;
228
229    fn next_free_task_id(&self) -> Result<TaskId> {
230        Ok(self
231            .inner
232            .get_infra_u32(META_KEY_NEXT_FREE_TASK_ID)
233            .context("Unable to read next free task id from database")?
234            .map_or(Ok(TaskId::MIN), TaskId::try_from)?)
235    }
236
237    fn uncompleted_operations(&self) -> Result<Vec<AnyOperation>> {
238        fn get(database: &impl KeyValueDatabase) -> Result<Vec<AnyOperation>> {
239            let tx = database.begin_read_transaction()?;
240            let Some(operations) = database.get(
241                &tx,
242                KeySpace::Infra,
243                IntKey::new(META_KEY_OPERATIONS).as_ref(),
244            )?
245            else {
246                return Ok(Vec::new());
247            };
248            let operations = turbo_bincode_decode(operations.borrow())?;
249            Ok(operations)
250        }
251        get(&self.inner.database).context("Unable to read uncompleted operations from database")
252    }
253
254    fn save_snapshot<I>(
255        &self,
256        operations: Vec<Arc<AnyOperation>>,
257        task_cache_updates: Vec<ChunkedVec<(Arc<CachedTaskType>, TaskId)>>,
258        snapshots: Vec<I>,
259    ) -> Result<()>
260    where
261        I: Iterator<
262                Item = (
263                    TaskId,
264                    Option<TurboBincodeBuffer>,
265                    Option<TurboBincodeBuffer>,
266                ),
267            > + Send
268            + Sync,
269    {
270        let _span = tracing::info_span!("save snapshot", operations = operations.len()).entered();
271        let mut batch = self.inner.database.write_batch()?;
272        // Start organizing the updates in parallel
273        match &mut batch {
274            &mut WriteBatch::Concurrent(ref batch, _) => {
275                {
276                    let _span = tracing::trace_span!("update task data").entered();
277                    process_task_data(snapshots, Some(batch))?;
278                    let span = tracing::trace_span!("flush task data").entered();
279                    parallel::try_for_each(
280                        &[KeySpace::TaskMeta, KeySpace::TaskData],
281                        |&key_space| {
282                            let _span = span.clone().entered();
283                            // Safety: We already finished all processing of the task data and task
284                            // meta
285                            unsafe { batch.flush(key_space) }
286                        },
287                    )?;
288                }
289
290                let mut next_task_id = get_next_free_task_id::<
291                    T::SerialWriteBatch<'_>,
292                    T::ConcurrentWriteBatch<'_>,
293                >(&mut WriteBatchRef::concurrent(batch))?;
294
295                {
296                    let _span = tracing::trace_span!(
297                        "update task cache",
298                        items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
299                    )
300                    .entered();
301                    let max_task_id = parallel::map_collect_owned::<_, _, Result<Vec<_>>>(
302                        task_cache_updates,
303                        |updates| {
304                            let _span = _span.clone().entered();
305                            let mut max_task_id = 0;
306                            for (task_type, task_id) in updates {
307                                let hash = compute_task_type_hash(&task_type);
308                                let task_id: u32 = *task_id;
309
310                                batch
311                                    .put(
312                                        KeySpace::TaskCache,
313                                        WriteBuffer::Borrowed(&hash.to_le_bytes()),
314                                        WriteBuffer::Borrowed(&task_id.to_le_bytes()),
315                                    )
316                                    .with_context(|| {
317                                        format!(
318                                            "Unable to write task cache {task_type:?} => {task_id}"
319                                        )
320                                    })?;
321                                max_task_id = max_task_id.max(task_id);
322                            }
323
324                            Ok(max_task_id)
325                        },
326                    )?
327                    .into_iter()
328                    .max()
329                    .unwrap_or(0);
330                    next_task_id = next_task_id.max(max_task_id + 1);
331                }
332
333                save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
334                    &mut WriteBatchRef::concurrent(batch),
335                    next_task_id,
336                    operations,
337                )?;
338            }
339            WriteBatch::Serial(batch) => {
340                {
341                    let _span = tracing::trace_span!("update tasks").entered();
342                    let task_items =
343                        process_task_data(snapshots, None::<&T::ConcurrentWriteBatch<'_>>)?;
344                    for (task_id, meta, data) in task_items.into_iter().flatten() {
345                        let key = IntKey::new(*task_id);
346                        let key = key.as_ref();
347                        if let Some(meta) = meta {
348                            batch
349                                .put(KeySpace::TaskMeta, WriteBuffer::Borrowed(key), meta)
350                                .with_context(|| {
351                                    format!("Unable to write meta items for {task_id}")
352                                })?;
353                        }
354                        if let Some(data) = data {
355                            batch
356                                .put(KeySpace::TaskData, WriteBuffer::Borrowed(key), data)
357                                .with_context(|| {
358                                    format!("Unable to write data items for {task_id}")
359                                })?;
360                        }
361                    }
362                    batch.flush(KeySpace::TaskMeta)?;
363                    batch.flush(KeySpace::TaskData)?;
364                }
365
366                let mut next_task_id = get_next_free_task_id::<
367                    T::SerialWriteBatch<'_>,
368                    T::ConcurrentWriteBatch<'_>,
369                >(&mut WriteBatchRef::serial(batch))?;
370
371                {
372                    let _span = tracing::trace_span!(
373                        "update task cache",
374                        items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
375                    )
376                    .entered();
377                    for (task_type, task_id) in task_cache_updates.into_iter().flatten() {
378                        let hash = compute_task_type_hash(&task_type);
379                        let task_id = *task_id;
380
381                        batch
382                            .put(
383                                KeySpace::TaskCache,
384                                WriteBuffer::Borrowed(&hash.to_le_bytes()),
385                                WriteBuffer::Borrowed(&task_id.to_le_bytes()),
386                            )
387                            .with_context(|| {
388                                format!("Unable to write task cache {task_type:?} => {task_id}")
389                            })?;
390                        next_task_id = next_task_id.max(task_id + 1);
391                    }
392                }
393
394                save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
395                    &mut WriteBatchRef::serial(batch),
396                    next_task_id,
397                    operations,
398                )?;
399            }
400        }
401
402        {
403            let _span = tracing::trace_span!("commit").entered();
404            batch.commit().context("Unable to commit operations")?;
405        }
406        Ok(())
407    }
408
409    fn start_read_transaction(&self) -> Option<Self::ReadTransaction<'_>> {
410        self.inner.database.begin_read_transaction().ok()
411    }
412
413    unsafe fn lookup_task_candidates(
414        &self,
415        tx: Option<&T::ReadTransaction<'_>>,
416        task_type: &CachedTaskType,
417    ) -> Result<SmallVec<[TaskId; 1]>> {
418        let inner = &*self.inner;
419        fn lookup<D: KeyValueDatabase>(
420            database: &D,
421            tx: &D::ReadTransaction<'_>,
422            task_type: &CachedTaskType,
423        ) -> Result<SmallVec<[TaskId; 1]>> {
424            let hash = compute_task_type_hash(task_type);
425            let buffers = database.get_multiple(tx, KeySpace::TaskCache, &hash.to_le_bytes())?;
426
427            let mut task_ids = SmallVec::with_capacity(buffers.len());
428            for bytes in buffers {
429                let bytes = bytes.borrow().try_into()?;
430                let id = TaskId::try_from(u32::from_le_bytes(bytes)).unwrap();
431                task_ids.push(id);
432            }
433            Ok(task_ids)
434        }
435        if inner.database.is_empty() {
436            // Checking if the database is empty is a performance optimization
437            // to avoid computing the hash.
438            return Ok(SmallVec::new());
439        }
440        inner
441            .with_tx(tx, |tx| lookup(&self.inner.database, tx, task_type))
442            .with_context(|| format!("Looking up task id for {task_type:?} from database failed"))
443    }
444
445    unsafe fn lookup_data(
446        &self,
447        tx: Option<&T::ReadTransaction<'_>>,
448        task_id: TaskId,
449        category: SpecificTaskDataCategory,
450        storage: &mut TaskStorage,
451    ) -> Result<()> {
452        let inner = &*self.inner;
453        fn lookup<D: KeyValueDatabase>(
454            database: &D,
455            tx: &D::ReadTransaction<'_>,
456            task_id: TaskId,
457            category: SpecificTaskDataCategory,
458            storage: &mut TaskStorage,
459        ) -> Result<()> {
460            let Some(bytes) =
461                database.get(tx, category.key_space(), IntKey::new(*task_id).as_ref())?
462            else {
463                return Ok(());
464            };
465            let mut decoder = new_turbo_bincode_decoder(bytes.borrow());
466            storage
467                .decode(category, &mut decoder)
468                .map_err(|e| anyhow::anyhow!("Failed to decode {category:?}: {e:?}"))
469        }
470        inner
471            .with_tx(tx, |tx| {
472                lookup(&inner.database, tx, task_id, category, storage)
473            })
474            .with_context(|| format!("Looking up task storage for {task_id} from database failed"))
475    }
476
477    unsafe fn batch_lookup_data(
478        &self,
479        tx: Option<&Self::ReadTransaction<'_>>,
480        task_ids: &[TaskId],
481        category: SpecificTaskDataCategory,
482    ) -> Result<Vec<TaskStorage>> {
483        let inner = &*self.inner;
484        fn lookup<D: KeyValueDatabase>(
485            database: &D,
486            tx: &D::ReadTransaction<'_>,
487            task_ids: &[TaskId],
488            category: SpecificTaskDataCategory,
489        ) -> Result<Vec<TaskStorage>> {
490            let int_keys: Vec<_> = task_ids.iter().map(|&id| IntKey::new(*id)).collect();
491            let keys = int_keys.iter().map(|k| k.as_ref()).collect::<Vec<_>>();
492            let bytes = database.batch_get(tx, category.key_space(), &keys)?;
493            bytes
494                .into_iter()
495                .map(|opt_bytes| {
496                    let mut storage = TaskStorage::new();
497                    if let Some(bytes) = opt_bytes {
498                        let mut decoder = new_turbo_bincode_decoder(bytes.borrow());
499                        storage
500                            .decode(category, &mut decoder)
501                            .map_err(|e| anyhow::anyhow!("Failed to decode {category:?}: {e:?}"))?;
502                    }
503                    Ok(storage)
504                })
505                .collect::<Result<Vec<_>>>()
506        }
507        inner
508            .with_tx(tx, |tx| lookup(&inner.database, tx, task_ids, category))
509            .with_context(|| {
510                format!(
511                    "Looking up typed data for {} tasks from database failed",
512                    task_ids.len()
513                )
514            })
515    }
516
517    fn shutdown(&self) -> Result<()> {
518        self.inner.database.shutdown()
519    }
520}
521
522fn get_next_free_task_id<'a, S, C>(
523    batch: &mut WriteBatchRef<'_, 'a, S, C>,
524) -> Result<u32, anyhow::Error>
525where
526    S: SerialWriteBatch<'a>,
527    C: ConcurrentWriteBatch<'a>,
528{
529    Ok(
530        match batch.get(
531            KeySpace::Infra,
532            IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref(),
533        )? {
534            Some(bytes) => u32::from_le_bytes(Borrow::<[u8]>::borrow(&bytes).try_into()?),
535            None => 1,
536        },
537    )
538}
539
540fn save_infra<'a, S, C>(
541    batch: &mut WriteBatchRef<'_, 'a, S, C>,
542    next_task_id: u32,
543    operations: Vec<Arc<AnyOperation>>,
544) -> Result<(), anyhow::Error>
545where
546    S: SerialWriteBatch<'a>,
547    C: ConcurrentWriteBatch<'a>,
548{
549    {
550        batch
551            .put(
552                KeySpace::Infra,
553                WriteBuffer::Borrowed(IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref()),
554                WriteBuffer::Borrowed(&next_task_id.to_le_bytes()),
555            )
556            .context("Unable to write next free task id")?;
557    }
558    {
559        let _span =
560            tracing::trace_span!("update operations", operations = operations.len()).entered();
561        let operations =
562            turbo_bincode_encode(&operations).context("Unable to serialize operations")?;
563        batch
564            .put(
565                KeySpace::Infra,
566                WriteBuffer::Borrowed(IntKey::new(META_KEY_OPERATIONS).as_ref()),
567                WriteBuffer::SmallVec(operations),
568            )
569            .context("Unable to write operations")?;
570    }
571    batch.flush(KeySpace::Infra)?;
572    Ok(())
573}
574
575/// Computes a deterministic 64-bit hash of a CachedTaskType for use as a TaskCache key.
576///
577/// This encodes the task type directly to a hasher, avoiding intermediate buffer allocation.
578/// The encoding is deterministic (function IDs from registry, bincode argument encoding).
579fn compute_task_type_hash(task_type: &CachedTaskType) -> u64 {
580    let mut hasher = Xxh3Hash64Hasher::new();
581    task_type.hash_encode(&mut hasher);
582    let hash = hasher.finish();
583    if cfg!(feature = "verify_serialization") {
584        task_type.hash_encode(&mut hasher);
585        let hash2 = hasher.finish();
586        assert_eq!(
587            hash, hash2,
588            "Hashing TaskType twice was non-deterministic: \n{:?}\ngot hashes {} != {}",
589            task_type, hash, hash2
590        );
591    }
592    hash
593}
594
595type SerializedTasks = Vec<
596    Vec<(
597        TaskId,
598        Option<WriteBuffer<'static>>,
599        Option<WriteBuffer<'static>>,
600    )>,
601>;
602
603fn process_task_data<'a, B: ConcurrentWriteBatch<'a> + Send + Sync, I>(
604    tasks: Vec<I>,
605    batch: Option<&B>,
606) -> Result<SerializedTasks>
607where
608    I: Iterator<
609            Item = (
610                TaskId,
611                Option<TurboBincodeBuffer>,
612                Option<TurboBincodeBuffer>,
613            ),
614        > + Send
615        + Sync,
616{
617    parallel::map_collect_owned::<_, _, Result<Vec<_>>>(tasks, |tasks| {
618        let mut result = Vec::new();
619        for (task_id, meta, data) in tasks {
620            if let Some(batch) = batch {
621                let key = IntKey::new(*task_id);
622                let key = key.as_ref();
623                if let Some(meta) = meta {
624                    batch.put(
625                        KeySpace::TaskMeta,
626                        WriteBuffer::Borrowed(key),
627                        WriteBuffer::SmallVec(meta),
628                    )?;
629                }
630                if let Some(data) = data {
631                    batch.put(
632                        KeySpace::TaskData,
633                        WriteBuffer::Borrowed(key),
634                        WriteBuffer::SmallVec(data),
635                    )?;
636                }
637            } else {
638                // Store the new task data
639                result.push((
640                    task_id,
641                    meta.map(WriteBuffer::SmallVec),
642                    data.map(WriteBuffer::SmallVec),
643                ));
644            }
645        }
646
647        Ok(result)
648    })
649}
650#[cfg(test)]
651mod tests {
652    use std::borrow::Borrow;
653
654    use turbo_tasks::TaskId;
655
656    use super::*;
657    use crate::database::{
658        key_value_database::KeyValueDatabase,
659        turbo::TurboKeyValueDatabase,
660        write_batch::{BaseWriteBatch, ConcurrentWriteBatch, WriteBatch, WriteBuffer},
661    };
662
663    /// Helper to write to the database using the concurrent batch API.
664    fn write_task_cache_entry(
665        db: &TurboKeyValueDatabase,
666        hash: u64,
667        task_id: TaskId,
668    ) -> Result<()> {
669        let batch = db.write_batch()?;
670        match batch {
671            WriteBatch::Concurrent(concurrent, _) => {
672                concurrent.put(
673                    KeySpace::TaskCache,
674                    WriteBuffer::Borrowed(&hash.to_le_bytes()),
675                    WriteBuffer::Borrowed(&(*task_id).to_le_bytes()),
676                )?;
677                concurrent.commit()?;
678            }
679            WriteBatch::Serial(_) => {
680                panic!("Expected concurrent batch");
681            }
682        }
683        Ok(())
684    }
685
686    /// Tests that `get_multiple` correctly returns multiple TaskIds when the same hash key
687    /// is used (simulating a hash collision scenario).
688    ///
689    /// This is a lower-level test that verifies the database layer correctly handles
690    /// the case where multiple task IDs are stored under the same hash key.
691    #[tokio::test(flavor = "multi_thread")]
692    async fn test_hash_collision_returns_multiple_candidates() -> Result<()> {
693        let tempdir = tempfile::tempdir()?;
694        let path = tempdir.path();
695
696        // Use is_short_session=true to disable background compaction (which requires turbo-tasks
697        // context)
698        let db = TurboKeyValueDatabase::new(path.to_path_buf(), false, true)?;
699
700        // Simulate a hash collision by writing multiple TaskIds with the same hash key
701        let collision_hash: u64 = 0xDEADBEEF;
702        let task_id_1 = TaskId::try_from(100u32).unwrap();
703        let task_id_2 = TaskId::try_from(200u32).unwrap();
704        let task_id_3 = TaskId::try_from(300u32).unwrap();
705
706        // Write three task IDs under the same hash key (simulating collision)
707        // Each write creates a new SST file, so all three will be returned by get_multiple
708        write_task_cache_entry(&db, collision_hash, task_id_1)?;
709        write_task_cache_entry(&db, collision_hash, task_id_2)?;
710        write_task_cache_entry(&db, collision_hash, task_id_3)?;
711
712        // Now query using get_multiple - should return all three TaskIds
713        let results = db.get_multiple(&(), KeySpace::TaskCache, &collision_hash.to_le_bytes())?;
714
715        assert_eq!(
716            results.len(),
717            3,
718            "Should return all 3 task IDs for the colliding hash"
719        );
720
721        // Convert results to TaskIds and verify all three are present
722        let mut found_ids: Vec<TaskId> = results
723            .iter()
724            .map(|bytes| {
725                let bytes: [u8; 4] = Borrow::<[u8]>::borrow(bytes).try_into().unwrap();
726                TaskId::try_from(u32::from_le_bytes(bytes)).unwrap()
727            })
728            .collect();
729        found_ids.sort_by_key(|id| **id);
730
731        assert_eq!(found_ids, vec![task_id_1, task_id_2, task_id_3]);
732
733        db.shutdown()?;
734        Ok(())
735    }
736}