1use std::{
2 cell::SyncUnsafeCell,
3 fs::File,
4 io::Write,
5 mem::{replace, take},
6 path::PathBuf,
7 sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use lzzzz::lz4::{self, ACC_LEVEL_DEFAULT};
13use parking_lot::Mutex;
14use rayon::{
15 iter::{Either, IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
16 scope,
17};
18use smallvec::SmallVec;
19use thread_local::ThreadLocal;
20use tracing::Span;
21
22use crate::{
23 ValueBuffer,
24 collector::Collector,
25 collector_entry::CollectorEntry,
26 constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
27 key::StoreKey,
28 meta_file_builder::MetaFileBuilder,
29 static_sorted_file_builder::{StaticSortedFileBuilder, StaticSortedFileBuilderMeta},
30};
31
32struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
38 collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
40 new_blob_files: Vec<(u32, File)>,
43}
44
45const COLLECTOR_SHARDS: usize = 4;
46const COLLECTOR_SHARD_SHIFT: usize =
47 u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
48
49pub(crate) struct FinishResult {
51 pub(crate) sequence_number: u32,
52 pub(crate) new_meta_files: Vec<(u32, File)>,
54 pub(crate) new_sst_files: Vec<(u32, File)>,
56 pub(crate) new_blob_files: Vec<(u32, File)>,
58 pub(crate) keys_written: u64,
60}
61
62enum GlobalCollectorState<K: StoreKey + Send> {
63 Unsharded(Collector<K>),
65 Sharded([Collector<K>; COLLECTOR_SHARDS]),
68}
69
70pub struct WriteBatch<K: StoreKey + Send, const FAMILIES: usize> {
72 db_path: PathBuf,
74 current_sequence_number: AtomicU32,
76 thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
78 collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
80 meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
82 new_sst_files: Mutex<Vec<(u32, File)>>,
85 idle_collectors: Mutex<Vec<Collector<K>>>,
87 idle_thread_local_collectors: Mutex<Vec<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>>,
89}
90
91impl<K: StoreKey + Send + Sync, const FAMILIES: usize> WriteBatch<K, FAMILIES> {
92 pub(crate) fn new(path: PathBuf, current: u32) -> Self {
94 const {
95 assert!(FAMILIES <= usize_from_u32(u32::MAX));
96 };
97 Self {
98 db_path: path,
99 current_sequence_number: AtomicU32::new(current),
100 thread_locals: ThreadLocal::new(),
101 collectors: [(); FAMILIES]
102 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
103 meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
104 new_sst_files: Mutex::new(Vec::new()),
105 idle_collectors: Mutex::new(Vec::new()),
106 idle_thread_local_collectors: Mutex::new(Vec::new()),
107 }
108 }
109
110 pub(crate) fn reset(&mut self, current: u32) {
113 self.current_sequence_number
114 .store(current, Ordering::SeqCst);
115 }
116
117 #[allow(clippy::mut_from_ref)]
119 fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
120 let cell = self.thread_locals.get_or(|| {
121 SyncUnsafeCell::new(ThreadLocalState {
122 collectors: [const { None }; FAMILIES],
123 new_blob_files: Vec::new(),
124 })
125 });
126 unsafe { &mut *cell.get() }
128 }
129
130 fn thread_local_collector_mut<'l>(
132 &self,
133 state: &'l mut ThreadLocalState<K, FAMILIES>,
134 family: u32,
135 ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
136 debug_assert!(usize_from_u32(family) < FAMILIES);
137 let collector = state.collectors[usize_from_u32(family)].get_or_insert_with(|| {
138 self.idle_thread_local_collectors
139 .lock()
140 .pop()
141 .unwrap_or_else(|| Collector::new())
142 });
143 if collector.is_full() {
144 self.flush_thread_local_collector(family, collector)?;
145 }
146 Ok(collector)
147 }
148
149 #[tracing::instrument(level = "trace", skip(self, collector))]
150 fn flush_thread_local_collector(
151 &self,
152 family: u32,
153 collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
154 ) -> Result<()> {
155 let mut full_collectors = SmallVec::<[_; 2]>::new();
156 {
157 let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
158 for entry in collector.drain() {
159 match &mut *global_collector_state {
160 GlobalCollectorState::Unsharded(collector) => {
161 collector.add_entry(entry);
162 if collector.is_full() {
163 let mut shards: [Collector<K>; 4] =
165 [(); COLLECTOR_SHARDS].map(|_| Collector::new());
166 for entry in collector.drain() {
167 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
168 shards[shard].add_entry(entry);
169 }
170 for collector in shards.iter_mut() {
173 if collector.is_full() {
174 full_collectors
175 .push(replace(&mut *collector, self.get_new_collector()));
176 }
177 }
178 *global_collector_state = GlobalCollectorState::Sharded(shards);
179 }
180 }
181 GlobalCollectorState::Sharded(shards) => {
182 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
183 let collector = &mut shards[shard];
184 collector.add_entry(entry);
185 if collector.is_full() {
186 full_collectors
187 .push(replace(&mut *collector, self.get_new_collector()));
188 }
189 }
190 }
191 }
192 }
193 for mut global_collector in full_collectors {
194 let sst = self.create_sst_file(family, global_collector.sorted())?;
196 global_collector.clear();
197 self.new_sst_files.lock().push(sst);
198 self.dispose_collector(global_collector);
199 }
200 Ok(())
201 }
202
203 fn get_new_collector(&self) -> Collector<K> {
204 self.idle_collectors
205 .lock()
206 .pop()
207 .unwrap_or_else(|| Collector::new())
208 }
209
210 fn dispose_collector(&self, collector: Collector<K>) {
211 self.idle_collectors.lock().push(collector);
212 }
213
214 fn dispose_thread_local_collector(&self, collector: Collector<K, THREAD_LOCAL_SIZE_SHIFT>) {
215 self.idle_thread_local_collectors.lock().push(collector);
216 }
217
218 pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
220 let state = self.thread_local_state();
221 let collector = self.thread_local_collector_mut(state, family)?;
222 if value.len() <= MAX_MEDIUM_VALUE_SIZE {
223 collector.put(key, value);
224 } else {
225 let (blob, file) = self.create_blob(&value)?;
226 collector.put_blob(key, blob);
227 state.new_blob_files.push((blob, file));
228 }
229 Ok(())
230 }
231
232 pub fn delete(&self, family: u32, key: K) -> Result<()> {
234 let state = self.thread_local_state();
235 let collector = self.thread_local_collector_mut(state, family)?;
236 collector.delete(key);
237 Ok(())
238 }
239
240 #[tracing::instrument(level = "trace", skip(self))]
248 pub unsafe fn flush(&self, family: u32) -> Result<()> {
249 let mut collectors = Vec::new();
251 for cell in self.thread_locals.iter() {
252 let state = unsafe { &mut *cell.get() };
253 if let Some(collector) = state.collectors[usize_from_u32(family)].take()
254 && !collector.is_empty()
255 {
256 collectors.push(collector);
257 }
258 }
259
260 let span = Span::current();
261 collectors.into_par_iter().try_for_each(|mut collector| {
262 let _span = span.clone().entered();
263 self.flush_thread_local_collector(family, &mut collector)?;
264 self.dispose_thread_local_collector(collector);
265 anyhow::Ok(())
266 })?;
267
268 let mut collector_state = self.collectors[usize_from_u32(family)].lock();
270 match &mut *collector_state {
271 GlobalCollectorState::Unsharded(collector) => {
272 if !collector.is_empty() {
273 let sst = self.create_sst_file(family, collector.sorted())?;
274 collector.clear();
275 self.new_sst_files.lock().push(sst);
276 }
277 }
278 GlobalCollectorState::Sharded(_) => {
279 let GlobalCollectorState::Sharded(shards) = replace(
280 &mut *collector_state,
281 GlobalCollectorState::Unsharded(self.get_new_collector()),
282 ) else {
283 unreachable!();
284 };
285 shards.into_par_iter().try_for_each(|mut collector| {
286 let _span = span.clone().entered();
287 if !collector.is_empty() {
288 let sst = self.create_sst_file(family, collector.sorted())?;
289 collector.clear();
290 self.new_sst_files.lock().push(sst);
291 self.dispose_collector(collector);
292 }
293 anyhow::Ok(())
294 })?;
295 }
296 }
297
298 Ok(())
299 }
300
301 #[tracing::instrument(level = "trace", skip(self))]
304 pub(crate) fn finish(&mut self) -> Result<FinishResult> {
305 let mut new_blob_files = Vec::new();
306 let shared_error = Mutex::new(Ok(()));
307
308 scope(|scope| {
310 let _span = tracing::trace_span!("flush thread local collectors").entered();
311 let mut collectors = [const { Vec::new() }; FAMILIES];
312 for cell in self.thread_locals.iter_mut() {
313 let state = cell.get_mut();
314 new_blob_files.append(&mut state.new_blob_files);
315 for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
316 if let Some(collector) = thread_local_collector.take()
317 && !collector.is_empty()
318 {
319 collectors[family].push(collector);
320 }
321 }
322 }
323 for (family, thread_local_collectors) in collectors.into_iter().enumerate() {
324 for mut collector in thread_local_collectors {
325 let this = &self;
326 let shared_error = &shared_error;
327 let span = Span::current();
328 scope.spawn(move |_| {
329 let _span = span.entered();
330 if let Err(err) =
331 this.flush_thread_local_collector(family as u32, &mut collector)
332 {
333 *shared_error.lock() = Err(err);
334 }
335 this.dispose_thread_local_collector(collector);
336 });
337 }
338 }
339 });
340
341 let _span = tracing::trace_span!("flush collectors").entered();
342
343 let mut new_sst_files = take(self.new_sst_files.get_mut());
345 let shared_new_sst_files = Mutex::new(&mut new_sst_files);
346
347 let new_collectors = [(); FAMILIES]
348 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(self.get_new_collector())));
349 let collectors = replace(&mut self.collectors, new_collectors);
350 let span = Span::current();
351 collectors
352 .into_par_iter()
353 .enumerate()
354 .flat_map(|(family, state)| {
355 let collector = state.into_inner();
356 match collector {
357 GlobalCollectorState::Unsharded(collector) => {
358 Either::Left([(family, collector)].into_par_iter())
359 }
360 GlobalCollectorState::Sharded(shards) => Either::Right(
361 shards
362 .into_par_iter()
363 .map(move |collector| (family, collector)),
364 ),
365 }
366 })
367 .try_for_each(|(family, mut collector)| {
368 let _span = span.clone().entered();
369 let family = family as u32;
370 if !collector.is_empty() {
371 let sst = self.create_sst_file(family, collector.sorted())?;
372 collector.clear();
373 self.dispose_collector(collector);
374 shared_new_sst_files.lock().push(sst);
375 }
376 anyhow::Ok(())
377 })?;
378
379 shared_error.into_inner()?;
380
381 let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
383 let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
384 let keys_written = AtomicU64::new(0);
385 let new_meta_files = meta_collectors
386 .into_par_iter()
387 .map(|mutex| mutex.into_inner())
388 .enumerate()
389 .filter(|(_, sst_files)| !sst_files.is_empty())
390 .map(|(family, sst_files)| {
391 let family = family as u32;
392 let mut entries = 0;
393 let mut builder = MetaFileBuilder::new(family);
394 for (seq, sst) in sst_files {
395 entries += sst.entries;
396 builder.add(seq, sst);
397 }
398 keys_written.fetch_add(entries, Ordering::Relaxed);
399 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
400 let file = builder.write(&self.db_path, seq)?;
401 Ok((seq, file))
402 })
403 .collect::<Result<Vec<_>>>()?;
404
405 let seq = self.current_sequence_number.load(Ordering::SeqCst);
407 Ok(FinishResult {
408 sequence_number: seq,
409 new_meta_files,
410 new_sst_files,
411 new_blob_files,
412 keys_written: keys_written.into_inner(),
413 })
414 }
415
416 #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
419 fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
420 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
421 let mut buffer = Vec::new();
422 buffer.write_u32::<BE>(value.len() as u32)?;
423 lz4::compress_to_vec(value, &mut buffer, ACC_LEVEL_DEFAULT)
424 .context("Compression of value for blob file failed")?;
425
426 let file = self.db_path.join(format!("{seq:08}.blob"));
427 let mut file = File::create(&file).context("Unable to create blob file")?;
428 file.write_all(&buffer)
429 .context("Unable to write blob file")?;
430 file.flush().context("Unable to flush blob file")?;
431 Ok((seq, file))
432 }
433
434 #[tracing::instrument(level = "trace", skip(self, collector_data))]
437 fn create_sst_file(
438 &self,
439 family: u32,
440 collector_data: (&[CollectorEntry<K>], usize, usize),
441 ) -> Result<(u32, File)> {
442 let (entries, total_key_size, total_value_size) = collector_data;
443 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
444
445 let builder = StaticSortedFileBuilder::new(entries, total_key_size, total_value_size)?;
446
447 let path = self.db_path.join(format!("{seq:08}.sst"));
448 let (meta, file) = builder
449 .write(&path)
450 .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
451
452 #[cfg(feature = "verify_sst_content")]
453 {
454 use core::panic;
455
456 use crate::{
457 collector_entry::CollectorEntryValue,
458 key::hash_key,
459 lookup_entry::LookupValue,
460 static_sorted_file::{
461 BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
462 },
463 static_sorted_file_builder::Entry,
464 };
465
466 file.sync_all()?;
467 let sst = StaticSortedFile::open(
468 &self.db_path,
469 StaticSortedFileMetaData {
470 sequence_number: seq,
471 key_compression_dictionary_length: meta.key_compression_dictionary_length,
472 value_compression_dictionary_length: meta.value_compression_dictionary_length,
473 block_count: meta.block_count,
474 },
475 )?;
476 let cache2 = BlockCache::with(
477 10,
478 u64::MAX,
479 Default::default(),
480 Default::default(),
481 Default::default(),
482 );
483 let cache3 = BlockCache::with(
484 10,
485 u64::MAX,
486 Default::default(),
487 Default::default(),
488 Default::default(),
489 );
490 let mut key_buf = Vec::new();
491 for entry in entries {
492 entry.write_key_to(&mut key_buf);
493 let result = sst
494 .lookup(hash_key(&key_buf), &key_buf, &cache2, &cache3)
495 .expect("key found");
496 key_buf.clear();
497 match result {
498 SstLookupResult::Found(LookupValue::Deleted) => {}
499 SstLookupResult::Found(LookupValue::Slice {
500 value: lookup_value,
501 }) => {
502 let expected_value_slice = match &entry.value {
503 CollectorEntryValue::Small { value } => &**value,
504 CollectorEntryValue::Medium { value } => &**value,
505 _ => panic!("Unexpected value"),
506 };
507 assert_eq!(*lookup_value, *expected_value_slice);
508 }
509 SstLookupResult::Found(LookupValue::Blob { sequence_number: _ }) => {}
510 SstLookupResult::NotFound => panic!("All keys must exist"),
511 }
512 }
513 }
514
515 self.meta_collectors[usize_from_u32(family)]
516 .lock()
517 .push((seq, meta));
518
519 Ok((seq, file))
520 }
521}
522
523#[inline(always)]
524const fn usize_from_u32(value: u32) -> usize {
525 const {
528 assert!(u32::BITS < usize::BITS);
529 };
530 value as usize
531}