1use std::{borrow::Borrow, cmp::max, sync::Arc};
2
3use anyhow::{Context, Result, anyhow};
4use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
5use serde::Serialize;
6use smallvec::SmallVec;
7use tracing::Span;
8use turbo_tasks::{SessionId, TaskId, backend::CachedTaskType, turbo_tasks_scope};
9
10use crate::{
11 backend::{AnyOperation, TaskDataCategory},
12 backing_storage::BackingStorage,
13 data::CachedDataItem,
14 database::{
15 key_value_database::{KeySpace, KeyValueDatabase},
16 write_batch::{
17 BaseWriteBatch, ConcurrentWriteBatch, SerialWriteBatch, WriteBatch, WriteBatchRef,
18 WriteBuffer,
19 },
20 },
21 utils::chunked_vec::ChunkedVec,
22};
23
24const POT_CONFIG: pot::Config = pot::Config::new().compatibility(pot::Compatibility::V4);
25
26fn pot_serialize_small_vec<T: Serialize>(value: &T) -> pot::Result<SmallVec<[u8; 16]>> {
27 struct SmallVecWrite<'l>(&'l mut SmallVec<[u8; 16]>);
28 impl std::io::Write for SmallVecWrite<'_> {
29 #[inline]
30 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
31 self.0.extend_from_slice(buf);
32 Ok(buf.len())
33 }
34
35 #[inline]
36 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
37 self.0.extend_from_slice(buf);
38 Ok(())
39 }
40
41 #[inline]
42 fn flush(&mut self) -> std::io::Result<()> {
43 Ok(())
44 }
45 }
46
47 let mut output = SmallVec::new();
48 POT_CONFIG.serialize_into(value, SmallVecWrite(&mut output))?;
49 Ok(output)
50}
51
52fn pot_ser_symbol_map() -> pot::ser::SymbolMap {
53 pot::ser::SymbolMap::new().with_compatibility(pot::Compatibility::V4)
54}
55
56#[cfg(feature = "verify_serialization")]
57fn pot_de_symbol_list<'l>() -> pot::de::SymbolList<'l> {
58 pot::de::SymbolList::new()
59}
60
61const META_KEY_OPERATIONS: u32 = 0;
62const META_KEY_NEXT_FREE_TASK_ID: u32 = 1;
63const META_KEY_SESSION_ID: u32 = 2;
64
65struct IntKey([u8; 4]);
66
67impl IntKey {
68 fn new(value: u32) -> Self {
69 Self(value.to_le_bytes())
70 }
71}
72
73impl AsRef<[u8]> for IntKey {
74 fn as_ref(&self) -> &[u8] {
75 &self.0
76 }
77}
78
79fn as_u32(bytes: impl Borrow<[u8]>) -> Result<u32> {
80 let n = u32::from_le_bytes(bytes.borrow().try_into()?);
81 Ok(n)
82}
83
84pub struct KeyValueDatabaseBackingStorage<T: KeyValueDatabase> {
85 database: T,
86}
87
88impl<T: KeyValueDatabase> KeyValueDatabaseBackingStorage<T> {
89 pub fn new(database: T) -> Self {
90 Self { database }
91 }
92
93 fn with_tx<R>(
94 &self,
95 tx: Option<&T::ReadTransaction<'_>>,
96 f: impl FnOnce(&T::ReadTransaction<'_>) -> Result<R>,
97 ) -> Result<R> {
98 if let Some(tx) = tx {
99 f(tx)
100 } else {
101 let tx = self.database.begin_read_transaction()?;
102 let r = f(&tx)?;
103 drop(tx);
104 Ok(r)
105 }
106 }
107}
108
109fn get_infra_u32(database: &impl KeyValueDatabase, key: u32) -> Option<u32> {
110 let tx = database.begin_read_transaction().ok()?;
111 let value = database
112 .get(&tx, KeySpace::Infra, IntKey::new(key).as_ref())
113 .ok()?
114 .map(as_u32)?
115 .ok()?;
116 Some(value)
117}
118
119impl<T: KeyValueDatabase + Send + Sync + 'static> BackingStorage
120 for KeyValueDatabaseBackingStorage<T>
121{
122 type ReadTransaction<'l> = T::ReadTransaction<'l>;
123
124 fn lower_read_transaction<'l: 'i + 'r, 'i: 'r, 'r>(
125 tx: &'r Self::ReadTransaction<'l>,
126 ) -> &'r Self::ReadTransaction<'i> {
127 T::lower_read_transaction(tx)
128 }
129
130 fn next_free_task_id(&self) -> TaskId {
131 TaskId::try_from(get_infra_u32(&self.database, META_KEY_NEXT_FREE_TASK_ID).unwrap_or(1))
132 .unwrap()
133 }
134
135 fn next_session_id(&self) -> SessionId {
136 SessionId::try_from(get_infra_u32(&self.database, META_KEY_SESSION_ID).unwrap_or(0) + 1)
137 .unwrap()
138 }
139
140 fn uncompleted_operations(&self) -> Vec<AnyOperation> {
141 fn get(database: &impl KeyValueDatabase) -> Result<Vec<AnyOperation>> {
142 let tx = database.begin_read_transaction()?;
143 let Some(operations) = database.get(
144 &tx,
145 KeySpace::Infra,
146 IntKey::new(META_KEY_OPERATIONS).as_ref(),
147 )?
148 else {
149 return Ok(Vec::new());
150 };
151 let operations = POT_CONFIG.deserialize(operations.borrow())?;
152 Ok(operations)
153 }
154 get(&self.database).unwrap_or_default()
155 }
156
157 fn serialize(task: TaskId, data: &Vec<CachedDataItem>) -> Result<SmallVec<[u8; 16]>> {
158 serialize(task, data)
159 }
160
161 fn save_snapshot<I>(
162 &self,
163 session_id: SessionId,
164 operations: Vec<Arc<AnyOperation>>,
165 task_cache_updates: Vec<ChunkedVec<(Arc<CachedTaskType>, TaskId)>>,
166 snapshots: Vec<I>,
167 ) -> Result<()>
168 where
169 I: Iterator<
170 Item = (
171 TaskId,
172 Option<SmallVec<[u8; 16]>>,
173 Option<SmallVec<[u8; 16]>>,
174 ),
175 > + Send
176 + Sync,
177 {
178 let _span = tracing::trace_span!("save snapshot", session_id = ?session_id, operations = operations.len());
179 let mut batch = self.database.write_batch()?;
180
181 match &mut batch {
183 &mut WriteBatch::Concurrent(ref batch, _) => {
184 {
185 let _span = tracing::trace_span!("update task data").entered();
186 process_task_data(snapshots, Some(batch))?;
187 let span = tracing::trace_span!("flush task data").entered();
188 [KeySpace::TaskMeta, KeySpace::TaskData]
189 .into_par_iter()
190 .try_for_each(|key_space| {
191 let _span = span.clone().entered();
192 unsafe { batch.flush(key_space) }
195 })?;
196 }
197
198 let mut next_task_id = get_next_free_task_id::<
199 T::SerialWriteBatch<'_>,
200 T::ConcurrentWriteBatch<'_>,
201 >(&mut WriteBatchRef::concurrent(batch))?;
202
203 {
204 let _span = tracing::trace_span!(
205 "update task cache",
206 items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
207 )
208 .entered();
209 let result = task_cache_updates
210 .into_par_iter()
211 .with_max_len(1)
212 .map(|updates| {
213 let _span = _span.clone().entered();
214 let mut max_task_id = 0;
215
216 let mut task_type_bytes = Vec::new();
217 for (task_type, task_id) in updates {
218 let task_id: u32 = *task_id;
219 serialize_task_type(&task_type, &mut task_type_bytes, task_id)?;
220
221 batch
222 .put(
223 KeySpace::ForwardTaskCache,
224 WriteBuffer::Borrowed(&task_type_bytes),
225 WriteBuffer::Borrowed(&task_id.to_le_bytes()),
226 )
227 .with_context(|| {
228 anyhow!(
229 "Unable to write task cache {task_type:?} => {task_id}"
230 )
231 })?;
232 batch
233 .put(
234 KeySpace::ReverseTaskCache,
235 WriteBuffer::Borrowed(IntKey::new(task_id).as_ref()),
236 WriteBuffer::Borrowed(&task_type_bytes),
237 )
238 .with_context(|| {
239 anyhow!(
240 "Unable to write task cache {task_id} => {task_type:?}"
241 )
242 })?;
243 max_task_id = max_task_id.max(task_id + 1);
244 }
245
246 Ok(max_task_id)
247 })
248 .reduce(
249 || Ok(0),
250 |a, b| -> anyhow::Result<_> {
251 let a_max = a?;
252 let b_max = b?;
253 Ok(max(a_max, b_max))
254 },
255 )?;
256 next_task_id = next_task_id.max(result);
257 }
258
259 save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
260 &mut WriteBatchRef::concurrent(batch),
261 next_task_id,
262 session_id,
263 operations,
264 )?;
265 }
266 WriteBatch::Serial(batch) => {
267 let mut task_items_result = Ok(Vec::new());
268 turbo_tasks::scope(|s| {
269 s.spawn(|_| {
270 task_items_result =
271 process_task_data(snapshots, None::<&T::ConcurrentWriteBatch<'_>>);
272 });
273
274 let mut next_task_id =
275 get_next_free_task_id::<
276 T::SerialWriteBatch<'_>,
277 T::ConcurrentWriteBatch<'_>,
278 >(&mut WriteBatchRef::serial(batch))?;
279
280 {
281 let _span = tracing::trace_span!(
282 "update task cache",
283 items = task_cache_updates.iter().map(|m| m.len()).sum::<usize>()
284 )
285 .entered();
286 let mut task_type_bytes = Vec::new();
287 for (task_type, task_id) in task_cache_updates.into_iter().flatten() {
288 let task_id = *task_id;
289 serialize_task_type(&task_type, &mut task_type_bytes, task_id)?;
290
291 batch
292 .put(
293 KeySpace::ForwardTaskCache,
294 WriteBuffer::Borrowed(&task_type_bytes),
295 WriteBuffer::Borrowed(&task_id.to_le_bytes()),
296 )
297 .with_context(|| {
298 anyhow!("Unable to write task cache {task_type:?} => {task_id}")
299 })?;
300 batch
301 .put(
302 KeySpace::ReverseTaskCache,
303 WriteBuffer::Borrowed(IntKey::new(task_id).as_ref()),
304 WriteBuffer::Borrowed(&task_type_bytes),
305 )
306 .with_context(|| {
307 anyhow!("Unable to write task cache {task_id} => {task_type:?}")
308 })?;
309 next_task_id = next_task_id.max(task_id + 1);
310 }
311 }
312
313 save_infra::<T::SerialWriteBatch<'_>, T::ConcurrentWriteBatch<'_>>(
314 &mut WriteBatchRef::serial(batch),
315 next_task_id,
316 session_id,
317 operations,
318 )?;
319 anyhow::Ok(())
320 })?;
321
322 {
323 let _span = tracing::trace_span!("update tasks").entered();
324 for (task_id, meta, data) in task_items_result?.into_iter().flatten() {
325 let key = IntKey::new(*task_id);
326 let key = key.as_ref();
327 if let Some(meta) = meta {
328 batch
329 .put(KeySpace::TaskMeta, WriteBuffer::Borrowed(key), meta)
330 .with_context(|| {
331 anyhow!("Unable to write meta items for {task_id}")
332 })?;
333 }
334 if let Some(data) = data {
335 batch
336 .put(KeySpace::TaskData, WriteBuffer::Borrowed(key), data)
337 .with_context(|| {
338 anyhow!("Unable to write data items for {task_id}")
339 })?;
340 }
341 }
342 }
343 }
344 }
345
346 {
347 let _span = tracing::trace_span!("commit").entered();
348 batch
349 .commit()
350 .with_context(|| anyhow!("Unable to commit operations"))?;
351 }
352 Ok(())
353 }
354
355 fn start_read_transaction(&self) -> Option<Self::ReadTransaction<'_>> {
356 self.database.begin_read_transaction().ok()
357 }
358
359 unsafe fn forward_lookup_task_cache(
360 &self,
361 tx: Option<&T::ReadTransaction<'_>>,
362 task_type: &CachedTaskType,
363 ) -> Option<TaskId> {
364 fn lookup<D: KeyValueDatabase>(
365 database: &D,
366 tx: &D::ReadTransaction<'_>,
367 task_type: &CachedTaskType,
368 ) -> Result<Option<TaskId>> {
369 let task_type = POT_CONFIG.serialize(task_type)?;
370 let Some(bytes) = database.get(tx, KeySpace::ForwardTaskCache, &task_type)? else {
371 return Ok(None);
372 };
373 let bytes = bytes.borrow().try_into()?;
374 let id = TaskId::try_from(u32::from_le_bytes(bytes)).unwrap();
375 Ok(Some(id))
376 }
377 if self.database.is_empty() {
378 return None;
381 }
382 let id = self
383 .with_tx(tx, |tx| lookup(&self.database, tx, task_type))
384 .inspect_err(|err| println!("Looking up task id for {task_type:?} failed: {err:?}"))
385 .ok()??;
386 Some(id)
387 }
388
389 unsafe fn reverse_lookup_task_cache(
390 &self,
391 tx: Option<&T::ReadTransaction<'_>>,
392 task_id: TaskId,
393 ) -> Option<Arc<CachedTaskType>> {
394 fn lookup<D: KeyValueDatabase>(
395 database: &D,
396 tx: &D::ReadTransaction<'_>,
397 task_id: TaskId,
398 ) -> Result<Option<Arc<CachedTaskType>>> {
399 let Some(bytes) = database.get(
400 tx,
401 KeySpace::ReverseTaskCache,
402 IntKey::new(*task_id).as_ref(),
403 )?
404 else {
405 return Ok(None);
406 };
407 Ok(Some(POT_CONFIG.deserialize(bytes.borrow())?))
408 }
409 let result = self
410 .with_tx(tx, |tx| lookup(&self.database, tx, task_id))
411 .inspect_err(|err| println!("Looking up task type for {task_id} failed: {err:?}"))
412 .ok()??;
413 Some(result)
414 }
415
416 unsafe fn lookup_data(
417 &self,
418 tx: Option<&T::ReadTransaction<'_>>,
419 task_id: TaskId,
420 category: TaskDataCategory,
421 ) -> Vec<CachedDataItem> {
422 fn lookup<D: KeyValueDatabase>(
423 database: &D,
424 tx: &D::ReadTransaction<'_>,
425 task_id: TaskId,
426 category: TaskDataCategory,
427 ) -> Result<Vec<CachedDataItem>> {
428 let Some(bytes) = database.get(
429 tx,
430 match category {
431 TaskDataCategory::Meta => KeySpace::TaskMeta,
432 TaskDataCategory::Data => KeySpace::TaskData,
433 TaskDataCategory::All => unreachable!(),
434 },
435 IntKey::new(*task_id).as_ref(),
436 )?
437 else {
438 return Ok(Vec::new());
439 };
440 let result: Vec<CachedDataItem> = POT_CONFIG.deserialize(bytes.borrow())?;
441 Ok(result)
442 }
443 self.with_tx(tx, |tx| lookup(&self.database, tx, task_id, category))
444 .inspect_err(|err| println!("Looking up data for {task_id} failed: {err:?}"))
445 .unwrap_or_default()
446 }
447
448 fn shutdown(&self) -> Result<()> {
449 self.database.shutdown()
450 }
451}
452
453fn get_next_free_task_id<'a, S, C>(
454 batch: &mut WriteBatchRef<'_, 'a, S, C>,
455) -> Result<u32, anyhow::Error>
456where
457 S: SerialWriteBatch<'a>,
458 C: ConcurrentWriteBatch<'a>,
459{
460 Ok(
461 match batch.get(
462 KeySpace::Infra,
463 IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref(),
464 )? {
465 Some(bytes) => u32::from_le_bytes(Borrow::<[u8]>::borrow(&bytes).try_into()?),
466 None => 1,
467 },
468 )
469}
470
471fn save_infra<'a, S, C>(
472 batch: &mut WriteBatchRef<'_, 'a, S, C>,
473 next_task_id: u32,
474 session_id: SessionId,
475 operations: Vec<Arc<AnyOperation>>,
476) -> Result<(), anyhow::Error>
477where
478 S: SerialWriteBatch<'a>,
479 C: ConcurrentWriteBatch<'a>,
480{
481 {
482 batch
483 .put(
484 KeySpace::Infra,
485 WriteBuffer::Borrowed(IntKey::new(META_KEY_NEXT_FREE_TASK_ID).as_ref()),
486 WriteBuffer::Borrowed(&next_task_id.to_le_bytes()),
487 )
488 .with_context(|| anyhow!("Unable to write next free task id"))?;
489 }
490 {
491 let _span = tracing::trace_span!("update session id", session_id = ?session_id).entered();
492 batch
493 .put(
494 KeySpace::Infra,
495 WriteBuffer::Borrowed(IntKey::new(META_KEY_SESSION_ID).as_ref()),
496 WriteBuffer::Borrowed(&session_id.to_le_bytes()),
497 )
498 .with_context(|| anyhow!("Unable to write next session id"))?;
499 }
500 {
501 let _span =
502 tracing::trace_span!("update operations", operations = operations.len()).entered();
503 let operations = pot_serialize_small_vec(&operations)
504 .with_context(|| anyhow!("Unable to serialize operations"))?;
505 batch
506 .put(
507 KeySpace::Infra,
508 WriteBuffer::Borrowed(IntKey::new(META_KEY_OPERATIONS).as_ref()),
509 WriteBuffer::SmallVec(operations),
510 )
511 .with_context(|| anyhow!("Unable to write operations"))?;
512 }
513 batch.flush(KeySpace::Infra)?;
514 Ok(())
515}
516
517fn serialize_task_type(
518 task_type: &Arc<CachedTaskType>,
519 mut task_type_bytes: &mut Vec<u8>,
520 task_id: u32,
521) -> Result<()> {
522 task_type_bytes.clear();
523 POT_CONFIG
524 .serialize_into(&**task_type, &mut task_type_bytes)
525 .with_context(|| anyhow!("Unable to serialize task {task_id} cache key {task_type:?}"))?;
526 #[cfg(feature = "verify_serialization")]
527 {
528 let deserialize: Result<CachedTaskType, _> = serde_path_to_error::deserialize(
529 &mut pot_de_symbol_list().deserializer_for_slice(&*task_type_bytes)?,
530 );
531 if let Err(err) = deserialize {
532 println!("Task type would not be deserializable {task_id}: {err:?}\n{task_type:#?}");
533 panic!("Task type would not be deserializable {task_id}: {err:?}");
534 }
535 }
536 Ok(())
537}
538
539type SerializedTasks = Vec<
540 Vec<(
541 TaskId,
542 Option<WriteBuffer<'static>>,
543 Option<WriteBuffer<'static>>,
544 )>,
545>;
546
547fn process_task_data<'a, B: ConcurrentWriteBatch<'a> + Send + Sync, I>(
548 tasks: Vec<I>,
549 batch: Option<&B>,
550) -> Result<SerializedTasks>
551where
552 I: Iterator<
553 Item = (
554 TaskId,
555 Option<SmallVec<[u8; 16]>>,
556 Option<SmallVec<[u8; 16]>>,
557 ),
558 > + Send
559 + Sync,
560{
561 let span = Span::current();
562 let turbo_tasks = turbo_tasks::turbo_tasks();
563 let handle = tokio::runtime::Handle::current();
564 tasks
565 .into_par_iter()
566 .map(|tasks| {
567 let _span = span.clone().entered();
568 let _guard = handle.clone().enter();
569 turbo_tasks_scope(turbo_tasks.clone(), || {
570 let mut result = Vec::new();
571 for (task_id, meta, data) in tasks {
572 if let Some(batch) = batch {
573 let key = IntKey::new(*task_id);
574 let key = key.as_ref();
575 if let Some(meta) = meta {
576 batch.put(
577 KeySpace::TaskMeta,
578 WriteBuffer::Borrowed(key),
579 WriteBuffer::SmallVec(meta),
580 )?;
581 }
582 if let Some(data) = data {
583 batch.put(
584 KeySpace::TaskData,
585 WriteBuffer::Borrowed(key),
586 WriteBuffer::SmallVec(data),
587 )?;
588 }
589 } else {
590 result.push((
592 task_id,
593 meta.map(WriteBuffer::SmallVec),
594 data.map(WriteBuffer::SmallVec),
595 ));
596 }
597 }
598
599 Ok(result)
600 })
601 })
602 .collect::<Result<Vec<_>>>()
603}
604
605fn serialize(task: TaskId, data: &Vec<CachedDataItem>) -> Result<SmallVec<[u8; 16]>> {
606 Ok(match pot_serialize_small_vec(data) {
607 #[cfg(not(feature = "verify_serialization"))]
608 Ok(value) => value,
609 _ => {
610 let mut error = Ok(());
611 let mut data = data.clone();
612 data.retain(|item| {
613 let mut buf = Vec::<u8>::new();
614 let mut symbol_map = pot_ser_symbol_map();
615 let mut serializer = symbol_map.serializer_for(&mut buf).unwrap();
616 if let Err(err) = serde_path_to_error::serialize(&item, &mut serializer) {
617 if item.is_optional() {
618 #[cfg(feature = "verify_serialization")]
619 println!("Skipping non-serializable optional item for {task}: {item:?}");
620 } else {
621 error = Err(err).context({
622 anyhow!("Unable to serialize data item for {task}: {item:?}")
623 });
624 }
625 false
626 } else {
627 #[cfg(feature = "verify_serialization")]
628 {
629 let deserialize: Result<CachedDataItem, _> =
630 serde_path_to_error::deserialize(
631 &mut pot_de_symbol_list().deserializer_for_slice(&buf).unwrap(),
632 );
633 if let Err(err) = deserialize {
634 println!(
635 "Data item would not be deserializable {task}: {err:?}\n{item:?}"
636 );
637 return false;
638 }
639 }
640 true
641 }
642 });
643 error?;
644
645 pot_serialize_small_vec(&data)
646 .with_context(|| anyhow!("Unable to serialize data items for {task}: {data:#?}"))?
647 }
648 })
649}