1use std::{
2 cell::SyncUnsafeCell,
3 fs::File,
4 io::Write,
5 mem::{replace, take},
6 path::PathBuf,
7 sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use either::Either;
13use parking_lot::Mutex;
14use smallvec::SmallVec;
15use thread_local::ThreadLocal;
16
17use crate::{
18 FamilyConfig, ValueBuffer,
19 collector::Collector,
20 collector_entry::CollectorEntry,
21 compression::{checksum_block, compress_into_buffer},
22 constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
23 key::StoreKey,
24 meta_file::MetaEntryFlags,
25 meta_file_builder::MetaFileBuilder,
26 parallel_scheduler::ParallelScheduler,
27 static_sorted_file_builder::{StaticSortedFileBuilderMeta, write_static_stored_file},
28};
29
30struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
36 collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
38 new_blob_files: Vec<(u32, File)>,
41}
42
43const COLLECTOR_SHARDS: usize = 4;
44const COLLECTOR_SHARD_SHIFT: usize =
45 u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
46
47pub(crate) struct FinishResult {
49 pub(crate) sequence_number: u32,
50 pub(crate) new_meta_files: Vec<(u32, File)>,
52 pub(crate) new_sst_files: Vec<(u32, File)>,
54 pub(crate) new_blob_files: Vec<(u32, File)>,
56 pub(crate) keys_written: u64,
58}
59
60enum GlobalCollectorState<K: StoreKey + Send> {
61 Unsharded(Collector<K>),
63 Sharded([Collector<K>; COLLECTOR_SHARDS]),
66}
67
68pub struct WriteBatch<K: StoreKey + Send, S: ParallelScheduler, const FAMILIES: usize> {
70 parallel_scheduler: S,
72 db_path: PathBuf,
74 #[cfg_attr(not(feature = "verify_sst_content"), allow(dead_code))]
76 family_configs: [FamilyConfig; FAMILIES],
77 current_sequence_number: AtomicU32,
79 thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
81 collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
83 meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
85 new_sst_files: Mutex<Vec<(u32, File)>>,
88}
89
90impl<K: StoreKey + Send + Sync, S: ParallelScheduler, const FAMILIES: usize>
91 WriteBatch<K, S, FAMILIES>
92{
93 pub(crate) fn new(
95 path: PathBuf,
96 current: u32,
97 parallel_scheduler: S,
98 family_configs: [FamilyConfig; FAMILIES],
99 ) -> Self {
100 const {
101 assert!(FAMILIES <= usize_from_u32(u32::MAX));
102 };
103 Self {
104 parallel_scheduler,
105 db_path: path,
106 family_configs,
107 current_sequence_number: AtomicU32::new(current),
108 thread_locals: ThreadLocal::new(),
109 collectors: [(); FAMILIES]
110 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
111 meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
112 new_sst_files: Mutex::new(Vec::new()),
113 }
114 }
115
116 #[allow(clippy::mut_from_ref)]
118 fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
119 let cell = self.thread_locals.get_or(|| {
120 SyncUnsafeCell::new(ThreadLocalState {
121 collectors: [const { None }; FAMILIES],
122 new_blob_files: Vec::new(),
123 })
124 });
125 unsafe { &mut *cell.get() }
127 }
128
129 fn thread_local_collector_mut<'l>(
131 &self,
132 state: &'l mut ThreadLocalState<K, FAMILIES>,
133 family: u32,
134 ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
135 debug_assert!(usize_from_u32(family) < FAMILIES);
136 let collector =
137 state.collectors[usize_from_u32(family)].get_or_insert_with(|| Collector::new());
138 if collector.is_full() {
139 self.flush_thread_local_collector(family, collector)?;
140 }
141 Ok(collector)
142 }
143
144 #[tracing::instrument(level = "trace", skip(self, collector))]
145 fn flush_thread_local_collector(
146 &self,
147 family: u32,
148 collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
149 ) -> Result<()> {
150 let mut full_collectors = SmallVec::<[_; 2]>::new();
151 {
152 let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
153 for entry in collector.drain() {
154 match &mut *global_collector_state {
155 GlobalCollectorState::Unsharded(collector) => {
156 collector.add_entry(entry);
157 if collector.is_full() {
158 let mut shards: [Collector<K>; 4] =
160 [(); COLLECTOR_SHARDS].map(|_| Collector::new());
161 for entry in collector.drain() {
162 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
163 shards[shard].add_entry(entry);
164 }
165 for collector in shards.iter_mut() {
168 if collector.is_full() {
169 full_collectors
170 .push(replace(&mut *collector, Collector::new()));
171 }
172 }
173 *global_collector_state = GlobalCollectorState::Sharded(shards);
174 }
175 }
176 GlobalCollectorState::Sharded(shards) => {
177 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
178 let collector = &mut shards[shard];
179 collector.add_entry(entry);
180 if collector.is_full() {
181 full_collectors.push(replace(&mut *collector, Collector::new()));
182 }
183 }
184 }
185 }
186 }
187 for mut global_collector in full_collectors {
206 let sst = self.create_sst_file(
208 family,
209 global_collector.sorted(self.family_configs[usize_from_u32(family)].kind),
210 )?;
211 self.new_sst_files.lock().push(sst);
212 drop(global_collector);
213 }
214 Ok(())
215 }
216
217 pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
219 let state = self.thread_local_state();
220 let collector = self.thread_local_collector_mut(state, family)?;
221 if value.len() <= MAX_MEDIUM_VALUE_SIZE {
222 collector.put(key, value);
223 } else {
224 let (blob, file) = self.create_blob(&value)?;
225 collector.put_blob(key, blob);
226 state.new_blob_files.push((blob, file));
227 }
228 Ok(())
229 }
230
231 pub fn delete(&self, family: u32, key: K) -> Result<()> {
233 let state = self.thread_local_state();
234 let collector = self.thread_local_collector_mut(state, family)?;
235 collector.delete(key);
236 Ok(())
237 }
238
239 #[tracing::instrument(level = "trace", skip(self))]
247 pub unsafe fn flush(&self, family: u32) -> Result<()> {
248 let mut collectors = Vec::new();
250 for cell in self.thread_locals.iter() {
251 let state = unsafe { &mut *cell.get() };
252 if let Some(collector) = state.collectors[usize_from_u32(family)].take()
253 && !collector.is_empty()
254 {
255 collectors.push(collector);
256 }
257 }
258
259 self.parallel_scheduler
260 .try_parallel_for_each_owned(collectors, |mut collector| {
261 self.flush_thread_local_collector(family, &mut collector)?;
262 drop(collector);
263 anyhow::Ok(())
264 })?;
265
266 let mut collector_state = self.collectors[usize_from_u32(family)].lock();
268 match &mut *collector_state {
269 GlobalCollectorState::Unsharded(collector) => {
270 if !collector.is_empty() {
271 let sst = self.create_sst_file(
272 family,
273 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
274 )?;
275 collector.clear();
276 self.new_sst_files.lock().push(sst);
277 }
278 }
279 GlobalCollectorState::Sharded(_) => {
280 let GlobalCollectorState::Sharded(mut shards) = replace(
281 &mut *collector_state,
282 GlobalCollectorState::Unsharded(Collector::new()),
283 ) else {
284 unreachable!();
285 };
286 self.parallel_scheduler
287 .try_parallel_for_each_mut(&mut shards, |collector| {
288 if !collector.is_empty() {
289 let sst = self.create_sst_file(
290 family,
291 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
292 )?;
293 collector.clear();
294 self.new_sst_files.lock().push(sst);
295 collector.drop_contents();
296 }
297 anyhow::Ok(())
298 })?;
299 }
300 }
301
302 Ok(())
303 }
304
305 #[tracing::instrument(level = "trace", skip_all)]
308 pub(crate) fn finish(
309 &mut self,
310 get_accessed_key_hashes: impl Fn(u32) -> qfilter::Filter + Send + Sync,
311 ) -> Result<FinishResult> {
312 let mut new_blob_files = Vec::new();
313
314 {
316 let _span = tracing::trace_span!("flush thread local collectors").entered();
317 let mut collectors = [const { Vec::new() }; FAMILIES];
318 for cell in self.thread_locals.iter_mut() {
319 let state = cell.get_mut();
320 new_blob_files.append(&mut state.new_blob_files);
321 for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
322 if let Some(collector) = thread_local_collector.take()
323 && !collector.is_empty()
324 {
325 collectors[family].push(collector);
326 }
327 }
328 }
329 let to_flush = collectors
330 .into_iter()
331 .enumerate()
332 .flat_map(|(family, collector)| {
333 collector
334 .into_iter()
335 .map(move |collector| (family as u32, collector))
336 })
337 .collect::<Vec<_>>();
338 self.parallel_scheduler.try_parallel_for_each_owned(
339 to_flush,
340 |(family, mut collector)| {
341 self.flush_thread_local_collector(family, &mut collector)?;
342 drop(collector);
343 anyhow::Ok(())
344 },
345 )?;
346 }
347
348 let _span = tracing::trace_span!("flush collectors").entered();
349
350 let mut new_sst_files = take(self.new_sst_files.get_mut());
352 let shared_new_sst_files = Mutex::new(&mut new_sst_files);
353
354 let new_collectors =
355 [(); FAMILIES].map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new())));
356 let collectors = replace(&mut self.collectors, new_collectors);
357 let collectors = collectors
358 .into_iter()
359 .enumerate()
360 .flat_map(|(family, state)| {
361 let collector = state.into_inner();
362 match collector {
363 GlobalCollectorState::Unsharded(collector) => {
364 Either::Left([(family, collector)].into_iter())
365 }
366 GlobalCollectorState::Sharded(shards) => {
367 Either::Right(shards.into_iter().map(move |collector| (family, collector)))
368 }
369 }
370 })
371 .collect::<Vec<_>>();
372 self.parallel_scheduler.try_parallel_for_each_owned(
373 collectors,
374 |(family, mut collector)| {
375 let family = family as u32;
376 if !collector.is_empty() {
377 let sst = self.create_sst_file(
378 family,
379 collector.sorted(self.family_configs[usize_from_u32(family)].kind),
380 )?;
381 collector.clear();
382 drop(collector);
383 shared_new_sst_files.lock().push(sst);
384 }
385 anyhow::Ok(())
386 },
387 )?;
388
389 let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
391 let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
392 let keys_written = AtomicU64::new(0);
393 let file_to_write = meta_collectors
394 .into_iter()
395 .map(|mutex| mutex.into_inner())
396 .enumerate()
397 .filter(|(_, sst_files)| !sst_files.is_empty())
398 .collect::<Vec<_>>();
399 let new_meta_files = self
400 .parallel_scheduler
401 .parallel_map_collect_owned::<_, _, Result<Vec<_>>>(
402 file_to_write,
403 |(family, sst_files)| {
404 let family = family as u32;
405 let mut entries = 0;
406 let mut builder = MetaFileBuilder::new(family);
407 for (seq, sst) in sst_files {
408 entries += sst.entries;
409 builder.add(seq, sst);
410 }
411 keys_written.fetch_add(entries, Ordering::Relaxed);
412 let accessed_key_hashes = get_accessed_key_hashes(family);
413 builder.set_used_key_hashes_amqf(accessed_key_hashes);
414 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
415 let file = builder.write(&self.db_path, seq)?;
416 Ok((seq, file))
417 },
418 )?;
419
420 let seq = self.current_sequence_number.load(Ordering::SeqCst);
422 Ok(FinishResult {
423 sequence_number: seq,
424 new_meta_files,
425 new_sst_files,
426 new_blob_files,
427 keys_written: keys_written.into_inner(),
428 })
429 }
430
431 #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
434 fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
435 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
436 let mut compressed = Vec::new();
437 compress_into_buffer(value, &mut compressed)
438 .context("Compression of value for blob file failed")?;
439
440 let mut buffer = Vec::with_capacity(8 + compressed.len());
441 buffer.write_u32::<BE>(value.len() as u32)?;
442 buffer.write_u32::<BE>(checksum_block(&compressed))?;
443 buffer.extend_from_slice(&compressed);
444
445 let file = self.db_path.join(format!("{seq:08}.blob"));
446 let mut file = File::create(&file).context("Unable to create blob file")?;
447 file.write_all(&buffer)
448 .context("Unable to write blob file")?;
449 file.flush().context("Unable to flush blob file")?;
450 Ok((seq, file))
451 }
452
453 #[tracing::instrument(level = "trace", skip(self, collector_data))]
456 fn create_sst_file(
457 &self,
458 family: u32,
459 collector_data: (&[CollectorEntry<K>], usize),
460 ) -> Result<(u32, File)> {
461 let (entries, _total_key_size) = collector_data;
462 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
463
464 let path = self.db_path.join(format!("{seq:08}.sst"));
465 let (meta, file) = self
466 .parallel_scheduler
467 .block_in_place(|| write_static_stored_file(entries, &path, MetaEntryFlags::FRESH))
468 .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
469
470 #[cfg(feature = "verify_sst_content")]
471 {
472 use core::panic;
473
474 use crate::{
475 collector_entry::CollectorEntryValue,
476 key::hash_key,
477 lookup_entry::LookupValue,
478 static_sorted_file::{
479 BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
480 },
481 static_sorted_file_builder::Entry,
482 };
483
484 file.sync_all()?;
485 let sst = StaticSortedFile::open(
486 &self.db_path,
487 StaticSortedFileMetaData {
488 sequence_number: seq,
489 block_count: meta.block_count,
490 },
491 )?;
492 let cache2 = BlockCache::with(
493 10,
494 u64::MAX,
495 Default::default(),
496 Default::default(),
497 Default::default(),
498 );
499 let cache3 = BlockCache::with(
500 10,
501 u64::MAX,
502 Default::default(),
503 Default::default(),
504 Default::default(),
505 );
506 let mut key_buf = Vec::new();
507 let family_config = self.family_configs[usize_from_u32(family)].kind;
508 for entry in entries {
509 entry.write_key_to(&mut key_buf);
510 let result = sst
511 .lookup::<_, true>(hash_key(&key_buf), &key_buf, &cache2, &cache3)
512 .expect("key found");
513 key_buf.clear();
514 match result {
515 SstLookupResult::Found(values) => {
516 if values.len() > 1 {
517 use crate::FamilyKind;
518
519 assert!(
520 values.len() == 1 || family_config == FamilyKind::MultiValue,
521 "only multi-value tables can have more than one value, got {} \
522 values",
523 values.len()
524 )
525 }
526 match &entry.value {
527 CollectorEntryValue::Large { blob } => {
528 assert!(
529 values.contains(&LookupValue::Blob {
530 sequence_number: *blob
531 }),
532 "we wrote a blob but did not read it"
533 );
534 }
535 CollectorEntryValue::Deleted => assert!(
536 values.first() == Some(&LookupValue::Deleted),
537 "we wrote a deleted tombstone but it was not first in results"
538 ),
539 v => {
540 assert!(
541 values.into_iter().any(|lv| {
542 if let LookupValue::Slice { value } = lv {
543 &*value == v.as_bytes().unwrap()
544 } else {
545 false
546 }
547 }),
548 "we wrote a slice of bytes but did not read it"
549 )
550 }
551 }
552 }
553 SstLookupResult::NotFound => panic!("All keys must exist"),
554 }
555 }
556 }
557
558 self.meta_collectors[usize_from_u32(family)]
559 .lock()
560 .push((seq, meta));
561
562 Ok((seq, file))
563 }
564}
565
566#[inline(always)]
567const fn usize_from_u32(value: u32) -> usize {
568 const {
571 assert!(u32::BITS < usize::BITS);
572 };
573 value as usize
574}