1use std::{
2 cell::SyncUnsafeCell,
3 fs::File,
4 io::Write,
5 mem::{replace, take},
6 path::PathBuf,
7 sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use either::Either;
13use parking_lot::Mutex;
14use smallvec::SmallVec;
15use thread_local::ThreadLocal;
16
17use crate::{
18 FamilyConfig, ValueBuffer,
19 collector::Collector,
20 collector_entry::CollectorEntry,
21 compression::{checksum_block, compress_into_buffer},
22 constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
23 db::WriteOperationGuard,
24 key::StoreKey,
25 meta_file::MetaEntryFlags,
26 meta_file_builder::MetaFileBuilder,
27 parallel_scheduler::ParallelScheduler,
28 static_sorted_file_builder::{StaticSortedFileBuilderMeta, write_static_stored_file},
29};
30
31struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
37 collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
39 new_blob_files: Vec<(u32, File)>,
42}
43
44const COLLECTOR_SHARDS: usize = 4;
45const COLLECTOR_SHARD_SHIFT: usize =
46 u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
47
48pub(crate) struct FinishResult {
50 pub(crate) sequence_number: u32,
51 pub(crate) new_meta_files: Vec<(u32, File)>,
53 pub(crate) new_sst_files: Vec<(u32, File)>,
55 pub(crate) new_blob_files: Vec<(u32, File)>,
57 pub(crate) keys_written: u64,
59}
60
61enum GlobalCollectorState<K: StoreKey + Send> {
62 Unsharded(Collector<K>),
64 Sharded([Collector<K>; COLLECTOR_SHARDS]),
67}
68
69pub struct WriteBatch<'db, K: StoreKey + Send, S: ParallelScheduler, const FAMILIES: usize> {
71 _write_guard: WriteOperationGuard<'db>,
73 parallel_scheduler: S,
75 db_path: PathBuf,
77 #[cfg_attr(not(feature = "verify_sst_content"), allow(dead_code))]
79 family_configs: [FamilyConfig; FAMILIES],
80 current_sequence_number: AtomicU32,
82 thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
84 collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
86 meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
88 new_sst_files: Mutex<Vec<(u32, File)>>,
91}
92
93impl<'db, K: StoreKey + Send + Sync, S: ParallelScheduler, const FAMILIES: usize>
94 WriteBatch<'db, K, S, FAMILIES>
95{
96 pub(crate) fn new(
98 write_guard: WriteOperationGuard<'db>,
99 path: PathBuf,
100 current: u32,
101 parallel_scheduler: S,
102 family_configs: [FamilyConfig; FAMILIES],
103 ) -> Self {
104 const {
105 assert!(FAMILIES <= usize_from_u32(u32::MAX));
106 };
107 Self {
108 _write_guard: write_guard,
109 parallel_scheduler,
110 db_path: path,
111 family_configs,
112 current_sequence_number: AtomicU32::new(current),
113 thread_locals: ThreadLocal::new(),
114 collectors: [(); FAMILIES]
115 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
116 meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
117 new_sst_files: Mutex::new(Vec::new()),
118 }
119 }
120
121 pub(crate) fn mark_succeeded(&mut self) {
126 self._write_guard.success();
127 }
128
129 #[allow(clippy::mut_from_ref)]
131 fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
132 let cell = self.thread_locals.get_or(|| {
133 SyncUnsafeCell::new(ThreadLocalState {
134 collectors: [const { None }; FAMILIES],
135 new_blob_files: Vec::new(),
136 })
137 });
138 unsafe { &mut *cell.get() }
140 }
141
142 fn thread_local_collector_mut<'l>(
144 &self,
145 state: &'l mut ThreadLocalState<K, FAMILIES>,
146 family: u32,
147 ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
148 debug_assert!(usize_from_u32(family) < FAMILIES);
149 let collector =
150 state.collectors[usize_from_u32(family)].get_or_insert_with(|| Collector::new());
151 if collector.is_full() {
152 self.flush_thread_local_collector(family, collector)?;
153 }
154 Ok(collector)
155 }
156
157 #[tracing::instrument(level = "trace", skip(self, collector), fields(family_name = self.family_configs[usize_from_u32(family)].name))]
158 fn flush_thread_local_collector(
159 &self,
160 family: u32,
161 collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
162 ) -> Result<()> {
163 let mut full_collectors = SmallVec::<[_; 2]>::new();
164 {
165 let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
166 for entry in collector.drain() {
167 match &mut *global_collector_state {
168 GlobalCollectorState::Unsharded(collector) => {
169 collector.add_entry(entry);
170 if collector.is_full() {
171 let mut shards: [Collector<K>; 4] =
173 [(); COLLECTOR_SHARDS].map(|_| Collector::new());
174 for entry in collector.drain() {
175 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
176 shards[shard].add_entry(entry);
177 }
178 for collector in shards.iter_mut() {
181 if collector.is_full() {
182 full_collectors
183 .push(replace(&mut *collector, Collector::new()));
184 }
185 }
186 *global_collector_state = GlobalCollectorState::Sharded(shards);
187 }
188 }
189 GlobalCollectorState::Sharded(shards) => {
190 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
191 let collector = &mut shards[shard];
192 collector.add_entry(entry);
193 if collector.is_full() {
194 full_collectors.push(replace(&mut *collector, Collector::new()));
195 }
196 }
197 }
198 }
199 }
200 for mut global_collector in full_collectors {
219 let sst = self.create_sst_file(
221 family,
222 global_collector.sorted(self.family_configs[usize_from_u32(family)].kind),
223 )?;
224 self.new_sst_files.lock().push(sst);
225 drop(global_collector);
226 }
227 Ok(())
228 }
229
230 pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
232 let state = self.thread_local_state();
233 let collector = self.thread_local_collector_mut(state, family)?;
234 if value.len() <= MAX_MEDIUM_VALUE_SIZE {
235 collector.put(key, value);
236 } else {
237 let (blob, file) = self.create_blob(&value)?;
238 collector.put_blob(key, blob);
239 state.new_blob_files.push((blob, file));
240 }
241 Ok(())
242 }
243
244 pub fn delete(&self, family: u32, key: K) -> Result<()> {
246 let state = self.thread_local_state();
247 let collector = self.thread_local_collector_mut(state, family)?;
248 collector.delete(key);
249 Ok(())
250 }
251
252 #[tracing::instrument(level = "trace", skip(self), fields(family_name = self.family_configs[usize_from_u32(family)].name))]
260 pub unsafe fn flush(&self, family: u32) -> Result<()> {
261 let mut collectors = Vec::new();
263 for cell in self.thread_locals.iter() {
264 let state = unsafe { &mut *cell.get() };
265 if let Some(collector) = state.collectors[usize_from_u32(family)].take()
266 && !collector.is_empty()
267 {
268 collectors.push(collector);
269 }
270 }
271
272 self.parallel_scheduler
273 .try_parallel_for_each_owned(collectors, |mut collector| {
274 self.flush_thread_local_collector(family, &mut collector)?;
275 drop(collector);
276 anyhow::Ok(())
277 })?;
278
279 let mut collector_state = self.collectors[usize_from_u32(family)].lock();
281 match &mut *collector_state {
282 GlobalCollectorState::Unsharded(collector) => {
283 if !collector.is_empty() {
284 let sst = self.create_sst_file(
285 family,
286 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
287 )?;
288 collector.clear();
289 self.new_sst_files.lock().push(sst);
290 }
291 }
292 GlobalCollectorState::Sharded(_) => {
293 let GlobalCollectorState::Sharded(mut shards) = replace(
294 &mut *collector_state,
295 GlobalCollectorState::Unsharded(Collector::new()),
296 ) else {
297 unreachable!();
298 };
299 self.parallel_scheduler
300 .try_parallel_for_each_mut(&mut shards, |collector| {
301 if !collector.is_empty() {
302 let sst = self.create_sst_file(
303 family,
304 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
305 )?;
306 collector.clear();
307 self.new_sst_files.lock().push(sst);
308 collector.drop_contents();
309 }
310 anyhow::Ok(())
311 })?;
312 }
313 }
314
315 Ok(())
316 }
317
318 #[tracing::instrument(level = "trace", skip_all)]
321 pub(crate) fn finish(
322 &mut self,
323 get_accessed_key_hashes: impl Fn(u32) -> qfilter::Filter + Send + Sync,
324 ) -> Result<FinishResult> {
325 let mut new_blob_files = Vec::new();
326
327 {
329 let _span = tracing::trace_span!("flush thread local collectors").entered();
330 let mut collectors = [const { Vec::new() }; FAMILIES];
331 for cell in self.thread_locals.iter_mut() {
332 let state = cell.get_mut();
333 new_blob_files.append(&mut state.new_blob_files);
334 for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
335 if let Some(collector) = thread_local_collector.take()
336 && !collector.is_empty()
337 {
338 collectors[family].push(collector);
339 }
340 }
341 }
342 let to_flush = collectors
343 .into_iter()
344 .enumerate()
345 .flat_map(|(family, collector)| {
346 collector
347 .into_iter()
348 .map(move |collector| (family as u32, collector))
349 })
350 .collect::<Vec<_>>();
351 self.parallel_scheduler.try_parallel_for_each_owned(
352 to_flush,
353 |(family, mut collector)| {
354 self.flush_thread_local_collector(family, &mut collector)?;
355 drop(collector);
356 anyhow::Ok(())
357 },
358 )?;
359 }
360
361 let _span = tracing::trace_span!("flush collectors").entered();
362
363 let mut new_sst_files = take(self.new_sst_files.get_mut());
365 let shared_new_sst_files = Mutex::new(&mut new_sst_files);
366
367 let new_collectors =
368 [(); FAMILIES].map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new())));
369 let collectors = replace(&mut self.collectors, new_collectors);
370 let collectors = collectors
371 .into_iter()
372 .enumerate()
373 .flat_map(|(family, state)| {
374 let collector = state.into_inner();
375 match collector {
376 GlobalCollectorState::Unsharded(collector) => {
377 Either::Left([(family, collector)].into_iter())
378 }
379 GlobalCollectorState::Sharded(shards) => {
380 Either::Right(shards.into_iter().map(move |collector| (family, collector)))
381 }
382 }
383 })
384 .collect::<Vec<_>>();
385 self.parallel_scheduler.try_parallel_for_each_owned(
386 collectors,
387 |(family, mut collector)| {
388 let family = family as u32;
389 if !collector.is_empty() {
390 let sst = self.create_sst_file(
391 family,
392 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
393 )?;
394 collector.clear();
395 drop(collector);
396 shared_new_sst_files.lock().push(sst);
397 }
398 anyhow::Ok(())
399 },
400 )?;
401
402 let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
404 let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
405 let keys_written = AtomicU64::new(0);
406 let file_to_write = meta_collectors
407 .into_iter()
408 .map(|mutex| mutex.into_inner())
409 .enumerate()
410 .filter(|(_, sst_files)| !sst_files.is_empty())
411 .collect::<Vec<_>>();
412 let new_meta_files = self
413 .parallel_scheduler
414 .parallel_map_collect_owned::<_, _, Result<Vec<_>>>(
415 file_to_write,
416 |(family, sst_files)| {
417 let family = family as u32;
418 let mut entries = 0;
419 let mut builder = MetaFileBuilder::new(family);
420 for (seq, sst) in sst_files {
421 entries += sst.entries;
422 builder.add(seq, sst);
423 }
424 keys_written.fetch_add(entries, Ordering::Relaxed);
425 let accessed_key_hashes = get_accessed_key_hashes(family);
426 builder.set_used_key_hashes_amqf(accessed_key_hashes);
427 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
428 let file = builder.write(&self.db_path, seq)?;
429 Ok((seq, file))
430 },
431 )?;
432
433 let seq = self.current_sequence_number.load(Ordering::SeqCst);
435 Ok(FinishResult {
436 sequence_number: seq,
437 new_meta_files,
438 new_sst_files,
439 new_blob_files,
440 keys_written: keys_written.into_inner(),
441 })
442 }
443
444 #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
447 fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
448 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
449 let mut compressed = Vec::new();
450 compress_into_buffer(value, &mut compressed)
451 .context("Compression of value for blob file failed")?;
452
453 let mut buffer = Vec::with_capacity(8 + compressed.len());
454 buffer.write_u32::<BE>(value.len() as u32)?;
455 buffer.write_u32::<BE>(checksum_block(&compressed))?;
456 buffer.extend_from_slice(&compressed);
457
458 let file = self.db_path.join(format!("{seq:08}.blob"));
459 let mut file = File::create(&file).context("Unable to create blob file")?;
460 file.write_all(&buffer)
461 .context("Unable to write blob file")?;
462 file.flush().context("Unable to flush blob file")?;
463 Ok((seq, file))
464 }
465
466 #[tracing::instrument(level = "trace", skip(self, collector_data), fields(family_name = self.family_configs[usize_from_u32(family)].name))]
469 fn create_sst_file(
470 &self,
471 family: u32,
472 collector_data: (&[CollectorEntry<K>], usize),
473 ) -> Result<(u32, File)> {
474 let (entries, _total_key_size) = collector_data;
475 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
476
477 let path = self.db_path.join(format!("{seq:08}.sst"));
478 let (meta, file) = self
479 .parallel_scheduler
480 .block_in_place(|| write_static_stored_file(entries, &path, MetaEntryFlags::FRESH))
481 .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
482
483 #[cfg(feature = "verify_sst_content")]
484 {
485 use core::panic;
486
487 use crate::{
488 collector_entry::CollectorEntryValue,
489 key::hash_key,
490 lookup_entry::LookupValue,
491 static_sorted_file::{
492 BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
493 },
494 static_sorted_file_builder::Entry,
495 };
496
497 file.sync_all()?;
498 let sst = StaticSortedFile::open(
499 &self.db_path,
500 StaticSortedFileMetaData {
501 sequence_number: seq,
502 block_count: meta.block_count,
503 },
504 )?;
505 let cache2 = BlockCache::with(
506 10,
507 u64::MAX,
508 Default::default(),
509 Default::default(),
510 Default::default(),
511 );
512 let cache3 = BlockCache::with(
513 10,
514 u64::MAX,
515 Default::default(),
516 Default::default(),
517 Default::default(),
518 );
519 let mut key_buf = Vec::new();
520 let family_config = self.family_configs[usize_from_u32(family)].kind;
521 for entry in entries {
522 entry.write_key_to(&mut key_buf);
523 let result = sst
524 .lookup::<_, true>(hash_key(&key_buf), &key_buf, &cache2, &cache3)
525 .expect("key found");
526 key_buf.clear();
527 match result {
528 SstLookupResult::Found(values) => {
529 if values.len() > 1 {
530 use crate::FamilyKind;
531
532 assert!(
533 values.len() == 1 || family_config == FamilyKind::MultiValue,
534 "only multi-value tables can have more than one value, got {} \
535 values",
536 values.len()
537 )
538 }
539 match &entry.value {
540 CollectorEntryValue::Large { blob } => {
541 assert!(
542 values.contains(&LookupValue::Blob {
543 sequence_number: *blob
544 }),
545 "we wrote a blob but did not read it"
546 );
547 }
548 CollectorEntryValue::Deleted => assert!(
549 values.first() == Some(&LookupValue::Deleted),
550 "we wrote a deleted tombstone but it was not first in results"
551 ),
552 v => {
553 assert!(
554 values.into_iter().any(|lv| {
555 if let LookupValue::Slice { value } = lv {
556 &*value == v.as_bytes().unwrap()
557 } else {
558 false
559 }
560 }),
561 "we wrote a slice of bytes but did not read it"
562 )
563 }
564 }
565 }
566 SstLookupResult::NotFound => panic!("All keys must exist"),
567 }
568 }
569 }
570
571 self.meta_collectors[usize_from_u32(family)]
572 .lock()
573 .push((seq, meta));
574
575 Ok((seq, file))
576 }
577}
578
579#[inline(always)]
580const fn usize_from_u32(value: u32) -> usize {
581 const {
584 assert!(u32::BITS < usize::BITS);
585 };
586 value as usize
587}