turbo_persistence/
write_batch.rs

1use std::{
2    cell::SyncUnsafeCell,
3    fs::File,
4    io::Write,
5    mem::{replace, take},
6    path::PathBuf,
7    sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use lzzzz::lz4::{self, ACC_LEVEL_DEFAULT};
13use parking_lot::Mutex;
14use rayon::{
15    iter::{Either, IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
16    scope,
17};
18use smallvec::SmallVec;
19use thread_local::ThreadLocal;
20use tracing::Span;
21
22use crate::{
23    ValueBuffer,
24    collector::Collector,
25    collector_entry::CollectorEntry,
26    constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
27    key::StoreKey,
28    meta_file_builder::MetaFileBuilder,
29    static_sorted_file_builder::{StaticSortedFileBuilder, StaticSortedFileBuilderMeta},
30};
31
32/// The thread local state of a `WriteBatch`. `FAMILIES` should fit within a `u32`.
33//
34// NOTE: This type *must* use `usize`, even though the real type used in storage is `u32` because
35// there's no way to cast a `u32` to `usize` when declaring an array without the nightly
36// `min_generic_const_args` feature.
37struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
38    /// The collectors for each family.
39    collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
40    /// The list of new blob files that have been created.
41    /// Tuple of (sequence number, file).
42    new_blob_files: Vec<(u32, File)>,
43}
44
45const COLLECTOR_SHARDS: usize = 4;
46const COLLECTOR_SHARD_SHIFT: usize =
47    u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
48
49/// The result of a `WriteBatch::finish` operation.
50pub(crate) struct FinishResult {
51    pub(crate) sequence_number: u32,
52    /// Tuple of (sequence number, file).
53    pub(crate) new_meta_files: Vec<(u32, File)>,
54    /// Tuple of (sequence number, file).
55    pub(crate) new_sst_files: Vec<(u32, File)>,
56    /// Tuple of (sequence number, file).
57    pub(crate) new_blob_files: Vec<(u32, File)>,
58    /// Number of keys written in this batch.
59    pub(crate) keys_written: u64,
60}
61
62enum GlobalCollectorState<K: StoreKey + Send> {
63    /// Initial state. Single collector. Once the collector is full, we switch to sharded mode.
64    Unsharded(Collector<K>),
65    /// Sharded mode.
66    /// We use multiple collectors, and select one based on the first bits of the key hash.
67    Sharded([Collector<K>; COLLECTOR_SHARDS]),
68}
69
70/// A write batch.
71pub struct WriteBatch<K: StoreKey + Send, const FAMILIES: usize> {
72    /// The database path
73    db_path: PathBuf,
74    /// The current sequence number counter. Increased for every new SST file or blob file.
75    current_sequence_number: AtomicU32,
76    /// The thread local state.
77    thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
78    /// Collectors in use. The thread local collectors flush into these when they are full.
79    collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
80    /// Meta file builders for each family.
81    meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
82    /// The list of new SST files that have been created.
83    /// Tuple of (sequence number, file).
84    new_sst_files: Mutex<Vec<(u32, File)>>,
85    /// Collectors that are currently unused, but have memory preallocated.
86    idle_collectors: Mutex<Vec<Collector<K>>>,
87    /// Collectors that are currently unused, but have memory preallocated.
88    idle_thread_local_collectors: Mutex<Vec<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>>,
89}
90
91impl<K: StoreKey + Send + Sync, const FAMILIES: usize> WriteBatch<K, FAMILIES> {
92    /// Creates a new write batch for a database.
93    pub(crate) fn new(path: PathBuf, current: u32) -> Self {
94        const {
95            assert!(FAMILIES <= usize_from_u32(u32::MAX));
96        };
97        Self {
98            db_path: path,
99            current_sequence_number: AtomicU32::new(current),
100            thread_locals: ThreadLocal::new(),
101            collectors: [(); FAMILIES]
102                .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
103            meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
104            new_sst_files: Mutex::new(Vec::new()),
105            idle_collectors: Mutex::new(Vec::new()),
106            idle_thread_local_collectors: Mutex::new(Vec::new()),
107        }
108    }
109
110    /// Resets the write batch to a new sequence number. This is called when the WriteBatch is
111    /// reused.
112    pub(crate) fn reset(&mut self, current: u32) {
113        self.current_sequence_number
114            .store(current, Ordering::SeqCst);
115    }
116
117    /// Returns the thread local state for the current thread.
118    #[allow(clippy::mut_from_ref)]
119    fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
120        let cell = self.thread_locals.get_or(|| {
121            SyncUnsafeCell::new(ThreadLocalState {
122                collectors: [const { None }; FAMILIES],
123                new_blob_files: Vec::new(),
124            })
125        });
126        // Safety: We know that the cell is only accessed from the current thread.
127        unsafe { &mut *cell.get() }
128    }
129
130    /// Returns the collector for a family for the current thread.
131    fn thread_local_collector_mut<'l>(
132        &self,
133        state: &'l mut ThreadLocalState<K, FAMILIES>,
134        family: u32,
135    ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
136        debug_assert!(usize_from_u32(family) < FAMILIES);
137        let collector = state.collectors[usize_from_u32(family)].get_or_insert_with(|| {
138            self.idle_thread_local_collectors
139                .lock()
140                .pop()
141                .unwrap_or_else(|| Collector::new())
142        });
143        if collector.is_full() {
144            self.flush_thread_local_collector(family, collector)?;
145        }
146        Ok(collector)
147    }
148
149    #[tracing::instrument(level = "trace", skip(self, collector))]
150    fn flush_thread_local_collector(
151        &self,
152        family: u32,
153        collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
154    ) -> Result<()> {
155        let mut full_collectors = SmallVec::<[_; 2]>::new();
156        {
157            let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
158            for entry in collector.drain() {
159                match &mut *global_collector_state {
160                    GlobalCollectorState::Unsharded(collector) => {
161                        collector.add_entry(entry);
162                        if collector.is_full() {
163                            // When full, split the entries into shards.
164                            let mut shards: [Collector<K>; 4] =
165                                [(); COLLECTOR_SHARDS].map(|_| Collector::new());
166                            for entry in collector.drain() {
167                                let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
168                                shards[shard].add_entry(entry);
169                            }
170                            // There is a rare edge case where all entries are in the same shard,
171                            // and the collector is full after the split.
172                            for collector in shards.iter_mut() {
173                                if collector.is_full() {
174                                    full_collectors
175                                        .push(replace(&mut *collector, self.get_new_collector()));
176                                }
177                            }
178                            *global_collector_state = GlobalCollectorState::Sharded(shards);
179                        }
180                    }
181                    GlobalCollectorState::Sharded(shards) => {
182                        let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
183                        let collector = &mut shards[shard];
184                        collector.add_entry(entry);
185                        if collector.is_full() {
186                            full_collectors
187                                .push(replace(&mut *collector, self.get_new_collector()));
188                        }
189                    }
190                }
191            }
192        }
193        for mut global_collector in full_collectors {
194            // When the global collector is full, we create a new SST file.
195            let sst = self.create_sst_file(family, global_collector.sorted())?;
196            global_collector.clear();
197            self.new_sst_files.lock().push(sst);
198            self.dispose_collector(global_collector);
199        }
200        Ok(())
201    }
202
203    fn get_new_collector(&self) -> Collector<K> {
204        self.idle_collectors
205            .lock()
206            .pop()
207            .unwrap_or_else(|| Collector::new())
208    }
209
210    fn dispose_collector(&self, collector: Collector<K>) {
211        self.idle_collectors.lock().push(collector);
212    }
213
214    fn dispose_thread_local_collector(&self, collector: Collector<K, THREAD_LOCAL_SIZE_SHIFT>) {
215        self.idle_thread_local_collectors.lock().push(collector);
216    }
217
218    /// Puts a key-value pair into the write batch.
219    pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
220        let state = self.thread_local_state();
221        let collector = self.thread_local_collector_mut(state, family)?;
222        if value.len() <= MAX_MEDIUM_VALUE_SIZE {
223            collector.put(key, value);
224        } else {
225            let (blob, file) = self.create_blob(&value)?;
226            collector.put_blob(key, blob);
227            state.new_blob_files.push((blob, file));
228        }
229        Ok(())
230    }
231
232    /// Puts a delete operation into the write batch.
233    pub fn delete(&self, family: u32, key: K) -> Result<()> {
234        let state = self.thread_local_state();
235        let collector = self.thread_local_collector_mut(state, family)?;
236        collector.delete(key);
237        Ok(())
238    }
239
240    /// Flushes a family of the write batch, reducing the amount of buffered memory used.
241    /// Does not commit any data persistently.
242    ///
243    /// # Safety
244    ///
245    /// Caller must ensure that no concurrent put or delete operation is happening on the flushed
246    /// family.
247    #[tracing::instrument(level = "trace", skip(self))]
248    pub unsafe fn flush(&self, family: u32) -> Result<()> {
249        // Flush the thread local collectors to the global collector.
250        let mut collectors = Vec::new();
251        for cell in self.thread_locals.iter() {
252            let state = unsafe { &mut *cell.get() };
253            if let Some(collector) = state.collectors[usize_from_u32(family)].take()
254                && !collector.is_empty()
255            {
256                collectors.push(collector);
257            }
258        }
259
260        let span = Span::current();
261        collectors.into_par_iter().try_for_each(|mut collector| {
262            let _span = span.clone().entered();
263            self.flush_thread_local_collector(family, &mut collector)?;
264            self.dispose_thread_local_collector(collector);
265            anyhow::Ok(())
266        })?;
267
268        // Now we flush the global collector(s).
269        let mut collector_state = self.collectors[usize_from_u32(family)].lock();
270        match &mut *collector_state {
271            GlobalCollectorState::Unsharded(collector) => {
272                if !collector.is_empty() {
273                    let sst = self.create_sst_file(family, collector.sorted())?;
274                    collector.clear();
275                    self.new_sst_files.lock().push(sst);
276                }
277            }
278            GlobalCollectorState::Sharded(_) => {
279                let GlobalCollectorState::Sharded(shards) = replace(
280                    &mut *collector_state,
281                    GlobalCollectorState::Unsharded(self.get_new_collector()),
282                ) else {
283                    unreachable!();
284                };
285                shards.into_par_iter().try_for_each(|mut collector| {
286                    let _span = span.clone().entered();
287                    if !collector.is_empty() {
288                        let sst = self.create_sst_file(family, collector.sorted())?;
289                        collector.clear();
290                        self.new_sst_files.lock().push(sst);
291                        self.dispose_collector(collector);
292                    }
293                    anyhow::Ok(())
294                })?;
295            }
296        }
297
298        Ok(())
299    }
300
301    /// Finishes the write batch by returning the new sequence number and the new SST files. This
302    /// writes all outstanding thread local data to disk.
303    #[tracing::instrument(level = "trace", skip(self))]
304    pub(crate) fn finish(&mut self) -> Result<FinishResult> {
305        let mut new_blob_files = Vec::new();
306        let shared_error = Mutex::new(Ok(()));
307
308        // First, we flush all thread local collectors to the global collectors.
309        scope(|scope| {
310            let _span = tracing::trace_span!("flush thread local collectors").entered();
311            let mut collectors = [const { Vec::new() }; FAMILIES];
312            for cell in self.thread_locals.iter_mut() {
313                let state = cell.get_mut();
314                new_blob_files.append(&mut state.new_blob_files);
315                for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
316                    if let Some(collector) = thread_local_collector.take()
317                        && !collector.is_empty()
318                    {
319                        collectors[family].push(collector);
320                    }
321                }
322            }
323            for (family, thread_local_collectors) in collectors.into_iter().enumerate() {
324                for mut collector in thread_local_collectors {
325                    let this = &self;
326                    let shared_error = &shared_error;
327                    let span = Span::current();
328                    scope.spawn(move |_| {
329                        let _span = span.entered();
330                        if let Err(err) =
331                            this.flush_thread_local_collector(family as u32, &mut collector)
332                        {
333                            *shared_error.lock() = Err(err);
334                        }
335                        this.dispose_thread_local_collector(collector);
336                    });
337                }
338            }
339        });
340
341        let _span = tracing::trace_span!("flush collectors").entered();
342
343        // Now we reduce the global collectors in parallel
344        let mut new_sst_files = take(self.new_sst_files.get_mut());
345        let shared_new_sst_files = Mutex::new(&mut new_sst_files);
346
347        let new_collectors = [(); FAMILIES]
348            .map(|_| Mutex::new(GlobalCollectorState::Unsharded(self.get_new_collector())));
349        let collectors = replace(&mut self.collectors, new_collectors);
350        let span = Span::current();
351        collectors
352            .into_par_iter()
353            .enumerate()
354            .flat_map(|(family, state)| {
355                let collector = state.into_inner();
356                match collector {
357                    GlobalCollectorState::Unsharded(collector) => {
358                        Either::Left([(family, collector)].into_par_iter())
359                    }
360                    GlobalCollectorState::Sharded(shards) => Either::Right(
361                        shards
362                            .into_par_iter()
363                            .map(move |collector| (family, collector)),
364                    ),
365                }
366            })
367            .try_for_each(|(family, mut collector)| {
368                let _span = span.clone().entered();
369                let family = family as u32;
370                if !collector.is_empty() {
371                    let sst = self.create_sst_file(family, collector.sorted())?;
372                    collector.clear();
373                    self.dispose_collector(collector);
374                    shared_new_sst_files.lock().push(sst);
375                }
376                anyhow::Ok(())
377            })?;
378
379        shared_error.into_inner()?;
380
381        // Not we need to write the new meta files.
382        let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
383        let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
384        let keys_written = AtomicU64::new(0);
385        let new_meta_files = meta_collectors
386            .into_par_iter()
387            .map(|mutex| mutex.into_inner())
388            .enumerate()
389            .filter(|(_, sst_files)| !sst_files.is_empty())
390            .map(|(family, sst_files)| {
391                let family = family as u32;
392                let mut entries = 0;
393                let mut builder = MetaFileBuilder::new(family);
394                for (seq, sst) in sst_files {
395                    entries += sst.entries;
396                    builder.add(seq, sst);
397                }
398                keys_written.fetch_add(entries, Ordering::Relaxed);
399                let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
400                let file = builder.write(&self.db_path, seq)?;
401                Ok((seq, file))
402            })
403            .collect::<Result<Vec<_>>>()?;
404
405        // Finally we return the new files and sequence number.
406        let seq = self.current_sequence_number.load(Ordering::SeqCst);
407        Ok(FinishResult {
408            sequence_number: seq,
409            new_meta_files,
410            new_sst_files,
411            new_blob_files,
412            keys_written: keys_written.into_inner(),
413        })
414    }
415
416    /// Creates a new blob file with the given value.
417    /// Returns a tuple of (sequence number, file).
418    #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
419    fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
420        let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
421        let mut buffer = Vec::new();
422        buffer.write_u32::<BE>(value.len() as u32)?;
423        lz4::compress_to_vec(value, &mut buffer, ACC_LEVEL_DEFAULT)
424            .context("Compression of value for blob file failed")?;
425
426        let file = self.db_path.join(format!("{seq:08}.blob"));
427        let mut file = File::create(&file).context("Unable to create blob file")?;
428        file.write_all(&buffer)
429            .context("Unable to write blob file")?;
430        file.flush().context("Unable to flush blob file")?;
431        Ok((seq, file))
432    }
433
434    /// Creates a new SST file with the given collector data.
435    /// Returns a tuple of (sequence number, file).
436    #[tracing::instrument(level = "trace", skip(self, collector_data))]
437    fn create_sst_file(
438        &self,
439        family: u32,
440        collector_data: (&[CollectorEntry<K>], usize, usize),
441    ) -> Result<(u32, File)> {
442        let (entries, total_key_size, total_value_size) = collector_data;
443        let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
444
445        let builder = StaticSortedFileBuilder::new(entries, total_key_size, total_value_size)?;
446
447        let path = self.db_path.join(format!("{seq:08}.sst"));
448        let (meta, file) = builder
449            .write(&path)
450            .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
451
452        #[cfg(feature = "verify_sst_content")]
453        {
454            use core::panic;
455
456            use crate::{
457                collector_entry::CollectorEntryValue,
458                key::hash_key,
459                lookup_entry::LookupValue,
460                static_sorted_file::{
461                    BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
462                },
463                static_sorted_file_builder::Entry,
464            };
465
466            file.sync_all()?;
467            let sst = StaticSortedFile::open(
468                &self.db_path,
469                StaticSortedFileMetaData {
470                    sequence_number: seq,
471                    key_compression_dictionary_length: meta.key_compression_dictionary_length,
472                    value_compression_dictionary_length: meta.value_compression_dictionary_length,
473                    block_count: meta.block_count,
474                },
475            )?;
476            let cache2 = BlockCache::with(
477                10,
478                u64::MAX,
479                Default::default(),
480                Default::default(),
481                Default::default(),
482            );
483            let cache3 = BlockCache::with(
484                10,
485                u64::MAX,
486                Default::default(),
487                Default::default(),
488                Default::default(),
489            );
490            let mut key_buf = Vec::new();
491            for entry in entries {
492                entry.write_key_to(&mut key_buf);
493                let result = sst
494                    .lookup(hash_key(&key_buf), &key_buf, &cache2, &cache3)
495                    .expect("key found");
496                key_buf.clear();
497                match result {
498                    SstLookupResult::Found(LookupValue::Deleted) => {}
499                    SstLookupResult::Found(LookupValue::Slice {
500                        value: lookup_value,
501                    }) => {
502                        let expected_value_slice = match &entry.value {
503                            CollectorEntryValue::Small { value } => &**value,
504                            CollectorEntryValue::Medium { value } => &**value,
505                            _ => panic!("Unexpected value"),
506                        };
507                        assert_eq!(*lookup_value, *expected_value_slice);
508                    }
509                    SstLookupResult::Found(LookupValue::Blob { sequence_number: _ }) => {}
510                    SstLookupResult::NotFound => panic!("All keys must exist"),
511                }
512            }
513        }
514
515        self.meta_collectors[usize_from_u32(family)]
516            .lock()
517            .push((seq, meta));
518
519        Ok((seq, file))
520    }
521}
522
523#[inline(always)]
524const fn usize_from_u32(value: u32) -> usize {
525    // This should always be true, as we assume at least a 32-bit width architecture for Turbopack.
526    // Since this is a const expression, we expect it to be compiled away.
527    const {
528        assert!(u32::BITS < usize::BITS);
529    };
530    value as usize
531}