turbo_tasks_backend/
kv_backing_storage.rs

1use std::{borrow::Borrow, cmp::max, sync::Arc};
2
3use anyhow::{Context, Result, anyhow};
4use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
5use serde::Serialize;
6use smallvec::SmallVec;
7use tracing::Span;
8use turbo_tasks::{SessionId, TaskId, backend::CachedTaskType, turbo_tasks_scope};
9
10use crate::{
11    backend::{AnyOperation, TaskDataCategory},
12    backing_storage::BackingStorage,
13    data::CachedDataItem,
14    database::{
15        key_value_database::{KeySpace, KeyValueDatabase},
16        write_batch::{
17            BaseWriteBatch, ConcurrentWriteBatch, SerialWriteBatch, WriteBatch, WriteBatchRef,
18            WriteBuffer,
19        },
20    },
21    utils::chunked_vec::ChunkedVec,
22};
23
24const POT_CONFIG: pot::Config = pot::Config::new().compatibility(pot::Compatibility::V4);
25
26fn pot_serialize_small_vec<T: Serialize>(value: &T) -> pot::Result<SmallVec<[u8; 16]>> {
27    struct SmallVecWrite<'l>(&'l mut SmallVec<[u8; 16]>);
28    impl std::io::Write for SmallVecWrite<'_> {
29        #[inline]
30        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
31            self.0.extend_from_slice(buf);
32            Ok(buf.len())
33        }
34
35        #[inline]
36        fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
37            self.0.extend_from_slice(buf);
38            Ok(())
39        }
40
41        #[inline]
42        fn flush(&mut self) -> std::io::Result<()> {
43            Ok(())
44        }
45    }
46
47    let mut output = SmallVec::new();
48    POT_CONFIG.serialize_into(value, SmallVecWrite(&mut output))?;
49    Ok(output)
50}
51
52fn pot_ser_symbol_map() -> pot::ser::SymbolMap {
53    pot::ser::SymbolMap::new().with_compatibility(pot::Compatibility::V4)
54}
55
56#[cfg(feature = "verify_serialization")]
57fn pot_de_symbol_list<'l>() -> pot::de::SymbolList<'l> {
58    pot::de::SymbolList::new()
59}
60
61const META_KEY_OPERATIONS: u32 = 0;
62const META_KEY_NEXT_FREE_TASK_ID: u32 = 1;
63const META_KEY_SESSION_ID: u32 = 2;
64
65struct IntKey([u8; 4]);
66
67impl IntKey {
68    fn new(value: u32) -> Self {
69        Self(value.to_le_bytes())
70    }
71}
72
73impl AsRef<[u8]> for IntKey {
74    fn as_ref(&self) -> &[u8] {
75        &self.0
76    }
77}
78
79fn as_u32(bytes: impl Borrow<[u8]>) -> Result<u32> {
80    let n = u32::from_le_bytes(bytes.borrow().try_into()?);
81    Ok(n)
82}
83
84pub struct KeyValueDatabaseBackingStorage<T: KeyValueDatabase> {
85    database: T,
86}
87
88impl<T: KeyValueDatabase> KeyValueDatabaseBackingStorage<T> {
89    pub fn new(database: T) -> Self {
90        Self { database }
91    }
92
93    fn with_tx<R>(
94        &self,
95        tx: Option<&T::ReadTransaction<'_>>,
96        f: impl FnOnce(&T::ReadTransaction<'_>) -> Result<R>,
97    ) -> Result<R> {
98        if let Some(tx) = tx {
99            f(tx)
100        } else {
101            let tx = self.database.begin_read_transaction()?;
102            let r = f(&tx)?;
103            drop(tx);
104            Ok(r)
105        }
106    }
107}
108
109fn get_infra_u32(database: &impl KeyValueDatabase, key: u32) -> Option<u32> {
110    let tx = database.begin_read_transaction().ok()?;
111    let value = database
112        .get(&tx, KeySpace::Infra, IntKey::new(key).as_ref())
113        .ok()?
114        .map(as_u32)?
115        .ok()?;
116    Some(value)
117}
118
119impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorage
120    for KeyValueDatabaseBackingStorage<T>
121{
122    type ReadTransaction<'l> = T::ReadTransaction<'l>;
123
124    fn lower_read_transaction<'l: 'i + 'r, 'i: 'r, 'r>(
125        tx: &'r Self::ReadTransaction<'l>,
126    ) -> &'r Self::ReadTransaction<'i> {
127        T::lower_read_transaction(tx)
128    }
129
130    fn next_free_task_id(&self) -> TaskId {
131        TaskId::try_from(get_infra_u32(&self.database, META_KEY_NEXT_FREE_TASK_ID).unwrap_or(1))
132            .unwrap()
133    }
134
135    fn next_session_id(&self) -> SessionId {
136        SessionId::try_from(get_infra_u32(&self.database, META_KEY_SESSION_ID).unwrap_or(0) + 1)
137            .unwrap()
138    }
139
140    fn uncompleted_operations(&self) -> Vec<AnyOperation> {
141        fn get(database: &impl KeyValueDatabase) -> Result<Vec<AnyOperation>> {
142            let tx = database.begin_read_transaction()?;
143            let Some(operations) = database.get(
144                &tx,
145                KeySpace::Infra,
146                IntKey::new(META_KEY_OPERATIONS).as_ref(),
147            )?
148            else {
149                return Ok(Vec::new());
150            };
151            let operations = POT_CONFIG.deserialize(operations.borrow())?;
152            Ok(operations)
153        }
154        get(&self.database).unwrap_or_default()
155    }
156
157    fn serialize(task: TaskId, data: &Vec<CachedDataItem>) -> Result<SmallVec<[u8; 16]>> {
158        serialize(task, data)
159    }
160
161    fn save_snapshot<I>(
162        &self,
163        session_id: SessionId,
164        operations: Vec<Arc<AnyOperation>>,
165        task_cache_updates: Vec<ChunkedVec<(Arc<CachedTaskType>, TaskId)>>,
166        snapshots: Vec<I>,
167    ) -> Result<()>
168    where
169        I: Iterator<
170                Item = (
171                    TaskId,
172                    Option<SmallVec<[u8; 16]>>,
173                    Option<SmallVec<[u8; 16]>>,
174                ),
175            > + Send
176            + Sync,
177    {
178        let _span = tracing::trace_span!("save snapshot", session_id = ?session_id, operations = operations.len());
179        let mut batch = self.database.write_batch()?;
180
181        // Start organizing the updates in parallel
182        match &mut batch {
183            &mut WriteBatch::Concurrent(ref batch, _) => {
184                {
185                    let _span = tracing::trace_span!("update task data").entered();
186                    process_task_data(snapshots, Some(batch))?;
187                    let span = tracing::trace_span!("flush task data").entered();
188                    [KeySpace::TaskMeta, KeySpace::TaskData]
189                        .into_par_iter()
190                        .try_for_each(|key_space| {
191                            let _span = span.clone().entered();
192                            // Safety: We already finished all processing of the task data and task
193                            // meta
194                            unsafe { batch.flush(key_space) }
195                        })?;
196                }
197
198                let mut next_task_id = get_next_free_task_id::<
199                    T::SerialWriteBatch<'_>,
200                    T::ConcurrentWriteBatch<'_>,
201                >(&mut WriteBatchRef::concurrent(batch))?;
202
203                {
204                    let _span = tracing::trace_span!(
205                        "update task cache",
206                        items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
207                    )
208                    .entered();
209                    let result = task_cache_updates
210                        .into_par_iter()
211                        .with_max_len(1)
212                        .map(|updates| {
213                            let _span = _span.clone().entered();
214                            let mut max_task_id = 0;
215
216                            let mut task_type_bytes = Vec::new();
217                            for (task_type, task_id) in updates {
218                                let task_id: u32 = *task_id;
219                                serialize_task_type(&task_type, &mut task_type_bytes, task_id)?;
220
221                                batch
222                                    .put(
223                                        KeySpace::ForwardTaskCache,
224                                        WriteBuffer::Borrowed(&task_type_bytes),
225                                        WriteBuffer::Borrowed(&task_id.to_le_bytes()),
226                                    )
227                                    .with_context(|| {
228                                        anyhow!(
229                                            "Unable to write task cache {task_type:?} => {task_id}"
230                                        )
231                                    })?;
232                                batch
233                                    .put(
234                                        KeySpace::ReverseTaskCache,
235                                        WriteBuffer::Borrowed(IntKey::new(task_id).as_ref()),
236                                        WriteBuffer::Borrowed(&task_type_bytes),
237                                    )
238                                    .with_context(|| {
239                                        anyhow!(
240                                            "Unable to write task cache {task_id} => {task_type:?}"
241                                        )
242                                    })?;
243                                max_task_id = max_task_id.max(task_id + 1);
244                            }
245
246                            Ok(max_task_id)
247                        })
248                        .reduce(
249                            || Ok(0),
250                            |a, b| -> anyhow::Result<_> {
251                                let a_max = a?;
252                                let b_max = b?;
253                                Ok(max(a_max, b_max))
254                            },
255                        )?;
256                    next_task_id = next_task_id.max(result);
257                }
258
259                save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
260                    &mut WriteBatchRef::concurrent(batch),
261                    next_task_id,
262                    session_id,
263                    operations,
264                )?;
265            }
266            WriteBatch::Serial(batch) => {
267                let mut task_items_result = Ok(Vec::new());
268                turbo_tasks::scope(|s| {
269                    s.spawn(|_| {
270                        task_items_result =
271                            process_task_data(snapshots, None::<&T::ConcurrentWriteBatch<'_>>);
272                    });
273
274                    let mut next_task_id =
275                        get_next_free_task_id::<
276                            T::SerialWriteBatch<'_>,
277                            T::ConcurrentWriteBatch<'_>,
278                        >(&mut WriteBatchRef::serial(batch))?;
279
280                    {
281                        let _span = tracing::trace_span!(
282                            "update task cache",
283                            items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
284                        )
285                        .entered();
286                        let mut task_type_bytes = Vec::new();
287                        for (task_type, task_id) in task_cache_updates.into_iter().flatten() {
288                            let task_id = *task_id;
289                            serialize_task_type(&task_type, &mut task_type_bytes, task_id)?;
290
291                            batch
292                                .put(
293                                    KeySpace::ForwardTaskCache,
294                                    WriteBuffer::Borrowed(&task_type_bytes),
295                                    WriteBuffer::Borrowed(&task_id.to_le_bytes()),
296                                )
297                                .with_context(|| {
298                                    anyhow!("Unable to write task cache {task_type:?} => {task_id}")
299                                })?;
300                            batch
301                                .put(
302                                    KeySpace::ReverseTaskCache,
303                                    WriteBuffer::Borrowed(IntKey::new(task_id).as_ref()),
304                                    WriteBuffer::Borrowed(&task_type_bytes),
305                                )
306                                .with_context(|| {
307                                    anyhow!("Unable to write task cache {task_id} => {task_type:?}")
308                                })?;
309                            next_task_id = next_task_id.max(task_id + 1);
310                        }
311                    }
312
313                    save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
314                        &mut WriteBatchRef::serial(batch),
315                        next_task_id,
316                        session_id,
317                        operations,
318                    )?;
319                    anyhow::Ok(())
320                })?;
321
322                {
323                    let _span = tracing::trace_span!("update tasks").entered();
324                    for (task_id, meta, data) in task_items_result?.into_iter().flatten() {
325                        let key = IntKey::new(*task_id);
326                        let key = key.as_ref();
327                        if let Some(meta) = meta {
328                            batch
329                                .put(KeySpace::TaskMeta, WriteBuffer::Borrowed(key), meta)
330                                .with_context(|| {
331                                    anyhow!("Unable to write meta items for {task_id}")
332                                })?;
333                        }
334                        if let Some(data) = data {
335                            batch
336                                .put(KeySpace::TaskData, WriteBuffer::Borrowed(key), data)
337                                .with_context(|| {
338                                    anyhow!("Unable to write data items for {task_id}")
339                                })?;
340                        }
341                    }
342                }
343            }
344        }
345
346        {
347            let _span = tracing::trace_span!("commit").entered();
348            batch
349                .commit()
350                .with_context(|| anyhow!("Unable to commit operations"))?;
351        }
352        Ok(())
353    }
354
355    fn start_read_transaction(&self) -> Option<Self::ReadTransaction<'_>> {
356        self.database.begin_read_transaction().ok()
357    }
358
359    unsafe fn forward_lookup_task_cache(
360        &self,
361        tx: Option<&T::ReadTransaction<'_>>,
362        task_type: &CachedTaskType,
363    ) -> Option<TaskId> {
364        fn lookup<D: KeyValueDatabase>(
365            database: &D,
366            tx: &D::ReadTransaction<'_>,
367            task_type: &CachedTaskType,
368        ) -> Result<Option<TaskId>> {
369            let task_type = POT_CONFIG.serialize(task_type)?;
370            let Some(bytes) = database.get(tx, KeySpace::ForwardTaskCache, &task_type)? else {
371                return Ok(None);
372            };
373            let bytes = bytes.borrow().try_into()?;
374            let id = TaskId::try_from(u32::from_le_bytes(bytes)).unwrap();
375            Ok(Some(id))
376        }
377        if self.database.is_empty() {
378            // Checking if the database is empty is a performance optimization
379            // to avoid serializing the task type.
380            return None;
381        }
382        let id = self
383            .with_tx(tx, |tx| lookup(&self.database, tx, task_type))
384            .inspect_err(|err| println!("Looking up task id for {task_type:?} failed: {err:?}"))
385            .ok()??;
386        Some(id)
387    }
388
389    unsafe fn reverse_lookup_task_cache(
390        &self,
391        tx: Option<&T::ReadTransaction<'_>>,
392        task_id: TaskId,
393    ) -> Option<Arc<CachedTaskType>> {
394        fn lookup<D: KeyValueDatabase>(
395            database: &D,
396            tx: &D::ReadTransaction<'_>,
397            task_id: TaskId,
398        ) -> Result<Option<Arc<CachedTaskType>>> {
399            let Some(bytes) = database.get(
400                tx,
401                KeySpace::ReverseTaskCache,
402                IntKey::new(*task_id).as_ref(),
403            )?
404            else {
405                return Ok(None);
406            };
407            Ok(Some(POT_CONFIG.deserialize(bytes.borrow())?))
408        }
409        let result = self
410            .with_tx(tx, |tx| lookup(&self.database, tx, task_id))
411            .inspect_err(|err| println!("Looking up task type for {task_id} failed: {err:?}"))
412            .ok()??;
413        Some(result)
414    }
415
416    unsafe fn lookup_data(
417        &self,
418        tx: Option<&T::ReadTransaction<'_>>,
419        task_id: TaskId,
420        category: TaskDataCategory,
421    ) -> Vec<CachedDataItem> {
422        fn lookup<D: KeyValueDatabase>(
423            database: &D,
424            tx: &D::ReadTransaction<'_>,
425            task_id: TaskId,
426            category: TaskDataCategory,
427        ) -> Result<Vec<CachedDataItem>> {
428            let Some(bytes) = database.get(
429                tx,
430                match category {
431                    TaskDataCategory::Meta => KeySpace::TaskMeta,
432                    TaskDataCategory::Data => KeySpace::TaskData,
433                    TaskDataCategory::All => unreachable!(),
434                },
435                IntKey::new(*task_id).as_ref(),
436            )?
437            else {
438                return Ok(Vec::new());
439            };
440            let result: Vec<CachedDataItem> = POT_CONFIG.deserialize(bytes.borrow())?;
441            Ok(result)
442        }
443        self.with_tx(tx, |tx| lookup(&self.database, tx, task_id, category))
444            .inspect_err(|err| println!("Looking up data for {task_id} failed: {err:?}"))
445            .unwrap_or_default()
446    }
447
448    fn shutdown(&self) -> Result<()> {
449        self.database.shutdown()
450    }
451}
452
453fn get_next_free_task_id<'a, S, C>(
454    batch: &mut WriteBatchRef<'_, 'a, S, C>,
455) -> Result<u32, anyhow::Error>
456where
457    S: SerialWriteBatch<'a>,
458    C: ConcurrentWriteBatch<'a>,
459{
460    Ok(
461        match batch.get(
462            KeySpace::Infra,
463            IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref(),
464        )? {
465            Some(bytes) => u32::from_le_bytes(Borrow::<[u8]>::borrow(&bytes).try_into()?),
466            None => 1,
467        },
468    )
469}
470
471fn save_infra<'a, S, C>(
472    batch: &mut WriteBatchRef<'_, 'a, S, C>,
473    next_task_id: u32,
474    session_id: SessionId,
475    operations: Vec<Arc<AnyOperation>>,
476) -> Result<(), anyhow::Error>
477where
478    S: SerialWriteBatch<'a>,
479    C: ConcurrentWriteBatch<'a>,
480{
481    {
482        batch
483            .put(
484                KeySpace::Infra,
485                WriteBuffer::Borrowed(IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref()),
486                WriteBuffer::Borrowed(&next_task_id.to_le_bytes()),
487            )
488            .with_context(|| anyhow!("Unable to write next free task id"))?;
489    }
490    {
491        let _span = tracing::trace_span!("update session id", session_id = ?session_id).entered();
492        batch
493            .put(
494                KeySpace::Infra,
495                WriteBuffer::Borrowed(IntKey::new(META_KEY_SESSION_ID).as_ref()),
496                WriteBuffer::Borrowed(&session_id.to_le_bytes()),
497            )
498            .with_context(|| anyhow!("Unable to write next session id"))?;
499    }
500    {
501        let _span =
502            tracing::trace_span!("update operations", operations = operations.len()).entered();
503        let operations = pot_serialize_small_vec(&operations)
504            .with_context(|| anyhow!("Unable to serialize operations"))?;
505        batch
506            .put(
507                KeySpace::Infra,
508                WriteBuffer::Borrowed(IntKey::new(META_KEY_OPERATIONS).as_ref()),
509                WriteBuffer::SmallVec(operations),
510            )
511            .with_context(|| anyhow!("Unable to write operations"))?;
512    }
513    batch.flush(KeySpace::Infra)?;
514    Ok(())
515}
516
517fn serialize_task_type(
518    task_type: &Arc<CachedTaskType>,
519    mut task_type_bytes: &mut Vec<u8>,
520    task_id: u32,
521) -> Result<()> {
522    task_type_bytes.clear();
523    POT_CONFIG
524        .serialize_into(&**task_type, &mut task_type_bytes)
525        .with_context(|| anyhow!("Unable to serialize task {task_id} cache key {task_type:?}"))?;
526    #[cfg(feature = "verify_serialization")]
527    {
528        let deserialize: Result<CachedTaskType, _> = serde_path_to_error::deserialize(
529            &mut pot_de_symbol_list().deserializer_for_slice(&*task_type_bytes)?,
530        );
531        if let Err(err) = deserialize {
532            println!("Task type would not be deserializable {task_id}: {err:?}\n{task_type:#?}");
533            panic!("Task type would not be deserializable {task_id}: {err:?}");
534        }
535    }
536    Ok(())
537}
538
539type SerializedTasks = Vec<
540    Vec<(
541        TaskId,
542        Option<WriteBuffer<'static>>,
543        Option<WriteBuffer<'static>>,
544    )>,
545>;
546
547fn process_task_data<'a, B: ConcurrentWriteBatch<'a> + Send + Sync, I>(
548    tasks: Vec<I>,
549    batch: Option<&B>,
550) -> Result<SerializedTasks>
551where
552    I: Iterator<
553            Item = (
554                TaskId,
555                Option<SmallVec<[u8; 16]>>,
556                Option<SmallVec<[u8; 16]>>,
557            ),
558        > + Send
559        + Sync,
560{
561    let span = Span::current();
562    let turbo_tasks = turbo_tasks::turbo_tasks();
563    let handle = tokio::runtime::Handle::current();
564    tasks
565        .into_par_iter()
566        .map(|tasks| {
567            let _span = span.clone().entered();
568            let _guard = handle.clone().enter();
569            turbo_tasks_scope(turbo_tasks.clone(), || {
570                let mut result = Vec::new();
571                for (task_id, meta, data) in tasks {
572                    if let Some(batch) = batch {
573                        let key = IntKey::new(*task_id);
574                        let key = key.as_ref();
575                        if let Some(meta) = meta {
576                            batch.put(
577                                KeySpace::TaskMeta,
578                                WriteBuffer::Borrowed(key),
579                                WriteBuffer::SmallVec(meta),
580                            )?;
581                        }
582                        if let Some(data) = data {
583                            batch.put(
584                                KeySpace::TaskData,
585                                WriteBuffer::Borrowed(key),
586                                WriteBuffer::SmallVec(data),
587                            )?;
588                        }
589                    } else {
590                        // Store the new task data
591                        result.push((
592                            task_id,
593                            meta.map(WriteBuffer::SmallVec),
594                            data.map(WriteBuffer::SmallVec),
595                        ));
596                    }
597                }
598
599                Ok(result)
600            })
601        })
602        .collect::<Result<Vec<_>>>()
603}
604
605fn serialize(task: TaskId, data: &Vec<CachedDataItem>) -> Result<SmallVec<[u8; 16]>> {
606    Ok(match pot_serialize_small_vec(data) {
607        #[cfg(not(feature = "verify_serialization"))]
608        Ok(value) => value,
609        _ => {
610            let mut error = Ok(());
611            let mut data = data.clone();
612            data.retain(|item| {
613                let mut buf = Vec::<u8>::new();
614                let mut symbol_map = pot_ser_symbol_map();
615                let mut serializer = symbol_map.serializer_for(&mut buf).unwrap();
616                if let Err(err) = serde_path_to_error::serialize(&item, &mut serializer) {
617                    if item.is_optional() {
618                        #[cfg(feature = "verify_serialization")]
619                        println!("Skipping non-serializable optional item for {task}: {item:?}");
620                    } else {
621                        error = Err(err).context({
622                            anyhow!("Unable to serialize data item for {task}: {item:?}")
623                        });
624                    }
625                    false
626                } else {
627                    #[cfg(feature = "verify_serialization")]
628                    {
629                        let deserialize: Result<CachedDataItem, _> =
630                            serde_path_to_error::deserialize(
631                                &mut pot_de_symbol_list().deserializer_for_slice(&buf).unwrap(),
632                            );
633                        if let Err(err) = deserialize {
634                            println!(
635                                "Data item would not be deserializable {task}: {err:?}\n{item:?}"
636                            );
637                            return false;
638                        }
639                    }
640                    true
641                }
642            });
643            error?;
644
645            pot_serialize_small_vec(&data)
646                .with_context(|| anyhow!("Unable to serialize data items for {task}: {data:#?}"))?
647        }
648    })
649}