1use std::{
2 cell::SyncUnsafeCell,
3 fs::File,
4 io::Write,
5 mem::{replace, take},
6 path::PathBuf,
7 sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use either::Either;
13use parking_lot::Mutex;
14use smallvec::SmallVec;
15use thread_local::ThreadLocal;
16
17use crate::{
18 ValueBuffer,
19 collector::Collector,
20 collector_entry::CollectorEntry,
21 compression::compress_into_buffer,
22 constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
23 key::StoreKey,
24 meta_file::MetaEntryFlags,
25 meta_file_builder::MetaFileBuilder,
26 parallel_scheduler::ParallelScheduler,
27 static_sorted_file_builder::{StaticSortedFileBuilderMeta, write_static_stored_file},
28};
29
30struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
36 collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
38 new_blob_files: Vec<(u32, File)>,
41}
42
43const COLLECTOR_SHARDS: usize = 4;
44const COLLECTOR_SHARD_SHIFT: usize =
45 u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
46
47pub(crate) struct FinishResult {
49 pub(crate) sequence_number: u32,
50 pub(crate) new_meta_files: Vec<(u32, File)>,
52 pub(crate) new_sst_files: Vec<(u32, File)>,
54 pub(crate) new_blob_files: Vec<(u32, File)>,
56 pub(crate) keys_written: u64,
58}
59
60enum GlobalCollectorState<K: StoreKey + Send> {
61 Unsharded(Collector<K>),
63 Sharded([Collector<K>; COLLECTOR_SHARDS]),
66}
67
68pub struct WriteBatch<K: StoreKey + Send, S: ParallelScheduler, const FAMILIES: usize> {
70 parallel_scheduler: S,
72 db_path: PathBuf,
74 current_sequence_number: AtomicU32,
76 thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
78 collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
80 meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
82 new_sst_files: Mutex<Vec<(u32, File)>>,
85}
86
87impl<K: StoreKey + Send + Sync, S: ParallelScheduler, const FAMILIES: usize>
88 WriteBatch<K, S, FAMILIES>
89{
90 pub(crate) fn new(path: PathBuf, current: u32, parallel_scheduler: S) -> Self {
92 const {
93 assert!(FAMILIES <= usize_from_u32(u32::MAX));
94 };
95 Self {
96 parallel_scheduler,
97 db_path: path,
98 current_sequence_number: AtomicU32::new(current),
99 thread_locals: ThreadLocal::new(),
100 collectors: [(); FAMILIES]
101 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
102 meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
103 new_sst_files: Mutex::new(Vec::new()),
104 }
105 }
106
107 #[allow(clippy::mut_from_ref)]
109 fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
110 let cell = self.thread_locals.get_or(|| {
111 SyncUnsafeCell::new(ThreadLocalState {
112 collectors: [const { None }; FAMILIES],
113 new_blob_files: Vec::new(),
114 })
115 });
116 unsafe { &mut *cell.get() }
118 }
119
120 fn thread_local_collector_mut<'l>(
122 &self,
123 state: &'l mut ThreadLocalState<K, FAMILIES>,
124 family: u32,
125 ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
126 debug_assert!(usize_from_u32(family) < FAMILIES);
127 let collector =
128 state.collectors[usize_from_u32(family)].get_or_insert_with(|| Collector::new());
129 if collector.is_full() {
130 self.flush_thread_local_collector(family, collector)?;
131 }
132 Ok(collector)
133 }
134
135 #[tracing::instrument(level = "trace", skip(self, collector))]
136 fn flush_thread_local_collector(
137 &self,
138 family: u32,
139 collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
140 ) -> Result<()> {
141 let mut full_collectors = SmallVec::<[_; 2]>::new();
142 {
143 let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
144 for entry in collector.drain() {
145 match &mut *global_collector_state {
146 GlobalCollectorState::Unsharded(collector) => {
147 collector.add_entry(entry);
148 if collector.is_full() {
149 let mut shards: [Collector<K>; 4] =
151 [(); COLLECTOR_SHARDS].map(|_| Collector::new());
152 for entry in collector.drain() {
153 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
154 shards[shard].add_entry(entry);
155 }
156 for collector in shards.iter_mut() {
159 if collector.is_full() {
160 full_collectors
161 .push(replace(&mut *collector, Collector::new()));
162 }
163 }
164 *global_collector_state = GlobalCollectorState::Sharded(shards);
165 }
166 }
167 GlobalCollectorState::Sharded(shards) => {
168 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
169 let collector = &mut shards[shard];
170 collector.add_entry(entry);
171 if collector.is_full() {
172 full_collectors.push(replace(&mut *collector, Collector::new()));
173 }
174 }
175 }
176 }
177 }
178 for mut global_collector in full_collectors {
179 let sst = self.create_sst_file(family, global_collector.sorted())?;
181 self.new_sst_files.lock().push(sst);
182 drop(global_collector);
183 }
184 Ok(())
185 }
186
187 pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
189 let state = self.thread_local_state();
190 let collector = self.thread_local_collector_mut(state, family)?;
191 if value.len() <= MAX_MEDIUM_VALUE_SIZE {
192 collector.put(key, value);
193 } else {
194 let (blob, file) = self.create_blob(&value)?;
195 collector.put_blob(key, blob);
196 state.new_blob_files.push((blob, file));
197 }
198 Ok(())
199 }
200
201 pub fn delete(&self, family: u32, key: K) -> Result<()> {
203 let state = self.thread_local_state();
204 let collector = self.thread_local_collector_mut(state, family)?;
205 collector.delete(key);
206 Ok(())
207 }
208
209 #[tracing::instrument(level = "trace", skip(self))]
217 pub unsafe fn flush(&self, family: u32) -> Result<()> {
218 let mut collectors = Vec::new();
220 for cell in self.thread_locals.iter() {
221 let state = unsafe { &mut *cell.get() };
222 if let Some(collector) = state.collectors[usize_from_u32(family)].take()
223 && !collector.is_empty()
224 {
225 collectors.push(collector);
226 }
227 }
228
229 self.parallel_scheduler
230 .try_parallel_for_each_owned(collectors, |mut collector| {
231 self.flush_thread_local_collector(family, &mut collector)?;
232 drop(collector);
233 anyhow::Ok(())
234 })?;
235
236 let mut collector_state = self.collectors[usize_from_u32(family)].lock();
238 match &mut *collector_state {
239 GlobalCollectorState::Unsharded(collector) => {
240 if !collector.is_empty() {
241 let sst = self.create_sst_file(family, collector.sorted())?;
242 collector.clear();
243 self.new_sst_files.lock().push(sst);
244 }
245 }
246 GlobalCollectorState::Sharded(_) => {
247 let GlobalCollectorState::Sharded(mut shards) = replace(
248 &mut *collector_state,
249 GlobalCollectorState::Unsharded(Collector::new()),
250 ) else {
251 unreachable!();
252 };
253 self.parallel_scheduler
254 .try_parallel_for_each_mut(&mut shards, |collector| {
255 if !collector.is_empty() {
256 let sst = self.create_sst_file(family, collector.sorted())?;
257 collector.clear();
258 self.new_sst_files.lock().push(sst);
259 collector.drop_contents();
260 }
261 anyhow::Ok(())
262 })?;
263 }
264 }
265
266 Ok(())
267 }
268
269 #[tracing::instrument(level = "trace", skip_all)]
272 pub(crate) fn finish(
273 &mut self,
274 get_accessed_key_hashes: impl Fn(u32) -> qfilter::Filter + Send + Sync,
275 ) -> Result<FinishResult> {
276 let mut new_blob_files = Vec::new();
277
278 {
280 let _span = tracing::trace_span!("flush thread local collectors").entered();
281 let mut collectors = [const { Vec::new() }; FAMILIES];
282 for cell in self.thread_locals.iter_mut() {
283 let state = cell.get_mut();
284 new_blob_files.append(&mut state.new_blob_files);
285 for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
286 if let Some(collector) = thread_local_collector.take()
287 && !collector.is_empty()
288 {
289 collectors[family].push(collector);
290 }
291 }
292 }
293 let to_flush = collectors
294 .into_iter()
295 .enumerate()
296 .flat_map(|(family, collector)| {
297 collector
298 .into_iter()
299 .map(move |collector| (family as u32, collector))
300 })
301 .collect::<Vec<_>>();
302 self.parallel_scheduler.try_parallel_for_each_owned(
303 to_flush,
304 |(family, mut collector)| {
305 self.flush_thread_local_collector(family, &mut collector)?;
306 drop(collector);
307 anyhow::Ok(())
308 },
309 )?;
310 }
311
312 let _span = tracing::trace_span!("flush collectors").entered();
313
314 let mut new_sst_files = take(self.new_sst_files.get_mut());
316 let shared_new_sst_files = Mutex::new(&mut new_sst_files);
317
318 let new_collectors =
319 [(); FAMILIES].map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new())));
320 let collectors = replace(&mut self.collectors, new_collectors);
321 let collectors = collectors
322 .into_iter()
323 .enumerate()
324 .flat_map(|(family, state)| {
325 let collector = state.into_inner();
326 match collector {
327 GlobalCollectorState::Unsharded(collector) => {
328 Either::Left([(family, collector)].into_iter())
329 }
330 GlobalCollectorState::Sharded(shards) => {
331 Either::Right(shards.into_iter().map(move |collector| (family, collector)))
332 }
333 }
334 })
335 .collect::<Vec<_>>();
336 self.parallel_scheduler.try_parallel_for_each_owned(
337 collectors,
338 |(family, mut collector)| {
339 let family = family as u32;
340 if !collector.is_empty() {
341 let sst = self.create_sst_file(family, collector.sorted())?;
342 collector.clear();
343 drop(collector);
344 shared_new_sst_files.lock().push(sst);
345 }
346 anyhow::Ok(())
347 },
348 )?;
349
350 let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
352 let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
353 let keys_written = AtomicU64::new(0);
354 let file_to_write = meta_collectors
355 .into_iter()
356 .map(|mutex| mutex.into_inner())
357 .enumerate()
358 .filter(|(_, sst_files)| !sst_files.is_empty())
359 .collect::<Vec<_>>();
360 let new_meta_files = self
361 .parallel_scheduler
362 .parallel_map_collect_owned::<_, _, Result<Vec<_>>>(
363 file_to_write,
364 |(family, sst_files)| {
365 let family = family as u32;
366 let mut entries = 0;
367 let mut builder = MetaFileBuilder::new(family);
368 for (seq, sst) in sst_files {
369 entries += sst.entries;
370 builder.add(seq, sst);
371 }
372 keys_written.fetch_add(entries, Ordering::Relaxed);
373 let accessed_key_hashes = get_accessed_key_hashes(family);
374 builder.set_used_key_hashes_amqf(accessed_key_hashes);
375 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
376 let file = builder.write(&self.db_path, seq)?;
377 Ok((seq, file))
378 },
379 )?;
380
381 let seq = self.current_sequence_number.load(Ordering::SeqCst);
383 Ok(FinishResult {
384 sequence_number: seq,
385 new_meta_files,
386 new_sst_files,
387 new_blob_files,
388 keys_written: keys_written.into_inner(),
389 })
390 }
391
392 #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
395 fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
396 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
397 let mut buffer = Vec::new();
398 buffer.write_u32::<BE>(value.len() as u32)?;
399 compress_into_buffer(value, None, true, &mut buffer)
400 .context("Compression of value for blob file failed")?;
401
402 let file = self.db_path.join(format!("{seq:08}.blob"));
403 let mut file = File::create(&file).context("Unable to create blob file")?;
404 file.write_all(&buffer)
405 .context("Unable to write blob file")?;
406 file.flush().context("Unable to flush blob file")?;
407 Ok((seq, file))
408 }
409
410 #[tracing::instrument(level = "trace", skip(self, collector_data))]
413 fn create_sst_file(
414 &self,
415 family: u32,
416 collector_data: (&[CollectorEntry<K>], usize),
417 ) -> Result<(u32, File)> {
418 let (entries, total_key_size) = collector_data;
419 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
420
421 let path = self.db_path.join(format!("{seq:08}.sst"));
422 let (meta, file) = self
423 .parallel_scheduler
424 .block_in_place(|| {
425 write_static_stored_file(entries, total_key_size, &path, MetaEntryFlags::FRESH)
426 })
427 .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
428
429 #[cfg(feature = "verify_sst_content")]
430 {
431 use core::panic;
432
433 use crate::{
434 collector_entry::CollectorEntryValue,
435 key::hash_key,
436 lookup_entry::LookupValue,
437 static_sorted_file::{
438 BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
439 },
440 static_sorted_file_builder::Entry,
441 };
442
443 file.sync_all()?;
444 let sst = StaticSortedFile::open(
445 &self.db_path,
446 StaticSortedFileMetaData {
447 sequence_number: seq,
448 key_compression_dictionary_length: meta.key_compression_dictionary_length,
449 block_count: meta.block_count,
450 },
451 )?;
452 let cache2 = BlockCache::with(
453 10,
454 u64::MAX,
455 Default::default(),
456 Default::default(),
457 Default::default(),
458 );
459 let cache3 = BlockCache::with(
460 10,
461 u64::MAX,
462 Default::default(),
463 Default::default(),
464 Default::default(),
465 );
466 let mut key_buf = Vec::new();
467 for entry in entries {
468 entry.write_key_to(&mut key_buf);
469 let result = sst
470 .lookup(hash_key(&key_buf), &key_buf, &cache2, &cache3)
471 .expect("key found");
472 key_buf.clear();
473 match result {
474 SstLookupResult::Found(LookupValue::Deleted) => {}
475 SstLookupResult::Found(LookupValue::Slice {
476 value: lookup_value,
477 }) => {
478 let expected_value_slice = match &entry.value {
479 CollectorEntryValue::Small { value } => &**value,
480 CollectorEntryValue::Medium { value } => &**value,
481 _ => panic!("Unexpected value"),
482 };
483 assert_eq!(*lookup_value, *expected_value_slice);
484 }
485 SstLookupResult::Found(LookupValue::Blob { sequence_number: _ }) => {}
486 SstLookupResult::NotFound => panic!("All keys must exist"),
487 }
488 }
489 }
490
491 self.meta_collectors[usize_from_u32(family)]
492 .lock()
493 .push((seq, meta));
494
495 Ok((seq, file))
496 }
497}
498
499#[inline(always)]
500const fn usize_from_u32(value: u32) -> usize {
501 const {
504 assert!(u32::BITS < usize::BITS);
505 };
506 value as usize
507}