1use std::{
2 cell::SyncUnsafeCell,
3 fs::File,
4 io::Write,
5 mem::{replace, take},
6 path::PathBuf,
7 sync::atomic::{AtomicU32, AtomicU64, Ordering},
8};
9
10use anyhow::{Context, Result};
11use byteorder::{BE, WriteBytesExt};
12use either::Either;
13use parking_lot::Mutex;
14use smallvec::SmallVec;
15use thread_local::ThreadLocal;
16
17use crate::{
18 ValueBuffer,
19 collector::Collector,
20 collector_entry::CollectorEntry,
21 compression::compress_into_buffer,
22 constants::{MAX_MEDIUM_VALUE_SIZE, THREAD_LOCAL_SIZE_SHIFT},
23 key::StoreKey,
24 meta_file_builder::MetaFileBuilder,
25 parallel_scheduler::ParallelScheduler,
26 static_sorted_file_builder::{StaticSortedFileBuilderMeta, write_static_stored_file},
27};
28
29struct ThreadLocalState<K: StoreKey + Send, const FAMILIES: usize> {
35 collectors: [Option<Collector<K, THREAD_LOCAL_SIZE_SHIFT>>; FAMILIES],
37 new_blob_files: Vec<(u32, File)>,
40}
41
42const COLLECTOR_SHARDS: usize = 4;
43const COLLECTOR_SHARD_SHIFT: usize =
44 u64::BITS as usize - COLLECTOR_SHARDS.trailing_zeros() as usize;
45
46pub(crate) struct FinishResult {
48 pub(crate) sequence_number: u32,
49 pub(crate) new_meta_files: Vec<(u32, File)>,
51 pub(crate) new_sst_files: Vec<(u32, File)>,
53 pub(crate) new_blob_files: Vec<(u32, File)>,
55 pub(crate) keys_written: u64,
57}
58
59enum GlobalCollectorState<K: StoreKey + Send> {
60 Unsharded(Collector<K>),
62 Sharded([Collector<K>; COLLECTOR_SHARDS]),
65}
66
67pub struct WriteBatch<K: StoreKey + Send, S: ParallelScheduler, const FAMILIES: usize> {
69 parallel_scheduler: S,
71 db_path: PathBuf,
73 current_sequence_number: AtomicU32,
75 thread_locals: ThreadLocal<SyncUnsafeCell<ThreadLocalState<K, FAMILIES>>>,
77 collectors: [Mutex<GlobalCollectorState<K>>; FAMILIES],
79 meta_collectors: [Mutex<Vec<(u32, StaticSortedFileBuilderMeta<'static>)>>; FAMILIES],
81 new_sst_files: Mutex<Vec<(u32, File)>>,
84}
85
86impl<K: StoreKey + Send + Sync, S: ParallelScheduler, const FAMILIES: usize>
87 WriteBatch<K, S, FAMILIES>
88{
89 pub(crate) fn new(path: PathBuf, current: u32, parallel_scheduler: S) -> Self {
91 const {
92 assert!(FAMILIES <= usize_from_u32(u32::MAX));
93 };
94 Self {
95 parallel_scheduler,
96 db_path: path,
97 current_sequence_number: AtomicU32::new(current),
98 thread_locals: ThreadLocal::new(),
99 collectors: [(); FAMILIES]
100 .map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new()))),
101 meta_collectors: [(); FAMILIES].map(|_| Mutex::new(Vec::new())),
102 new_sst_files: Mutex::new(Vec::new()),
103 }
104 }
105
106 #[allow(clippy::mut_from_ref)]
108 fn thread_local_state(&self) -> &mut ThreadLocalState<K, FAMILIES> {
109 let cell = self.thread_locals.get_or(|| {
110 SyncUnsafeCell::new(ThreadLocalState {
111 collectors: [const { None }; FAMILIES],
112 new_blob_files: Vec::new(),
113 })
114 });
115 unsafe { &mut *cell.get() }
117 }
118
119 fn thread_local_collector_mut<'l>(
121 &self,
122 state: &'l mut ThreadLocalState<K, FAMILIES>,
123 family: u32,
124 ) -> Result<&'l mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>> {
125 debug_assert!(usize_from_u32(family) < FAMILIES);
126 let collector =
127 state.collectors[usize_from_u32(family)].get_or_insert_with(|| Collector::new());
128 if collector.is_full() {
129 self.flush_thread_local_collector(family, collector)?;
130 }
131 Ok(collector)
132 }
133
134 #[tracing::instrument(level = "trace", skip(self, collector))]
135 fn flush_thread_local_collector(
136 &self,
137 family: u32,
138 collector: &mut Collector<K, THREAD_LOCAL_SIZE_SHIFT>,
139 ) -> Result<()> {
140 let mut full_collectors = SmallVec::<[_; 2]>::new();
141 {
142 let mut global_collector_state = self.collectors[usize_from_u32(family)].lock();
143 for entry in collector.drain() {
144 match &mut *global_collector_state {
145 GlobalCollectorState::Unsharded(collector) => {
146 collector.add_entry(entry);
147 if collector.is_full() {
148 let mut shards: [Collector<K>; 4] =
150 [(); COLLECTOR_SHARDS].map(|_| Collector::new());
151 for entry in collector.drain() {
152 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
153 shards[shard].add_entry(entry);
154 }
155 for collector in shards.iter_mut() {
158 if collector.is_full() {
159 full_collectors
160 .push(replace(&mut *collector, Collector::new()));
161 }
162 }
163 *global_collector_state = GlobalCollectorState::Sharded(shards);
164 }
165 }
166 GlobalCollectorState::Sharded(shards) => {
167 let shard = (entry.key.hash >> COLLECTOR_SHARD_SHIFT) as usize;
168 let collector = &mut shards[shard];
169 collector.add_entry(entry);
170 if collector.is_full() {
171 full_collectors.push(replace(&mut *collector, Collector::new()));
172 }
173 }
174 }
175 }
176 }
177 for mut global_collector in full_collectors {
178 let sst = self.create_sst_file(family, global_collector.sorted())?;
180 self.new_sst_files.lock().push(sst);
181 drop(global_collector);
182 }
183 Ok(())
184 }
185
186 pub fn put(&self, family: u32, key: K, value: ValueBuffer<'_>) -> Result<()> {
188 let state = self.thread_local_state();
189 let collector = self.thread_local_collector_mut(state, family)?;
190 if value.len() <= MAX_MEDIUM_VALUE_SIZE {
191 collector.put(key, value);
192 } else {
193 let (blob, file) = self.create_blob(&value)?;
194 collector.put_blob(key, blob);
195 state.new_blob_files.push((blob, file));
196 }
197 Ok(())
198 }
199
200 pub fn delete(&self, family: u32, key: K) -> Result<()> {
202 let state = self.thread_local_state();
203 let collector = self.thread_local_collector_mut(state, family)?;
204 collector.delete(key);
205 Ok(())
206 }
207
208 #[tracing::instrument(level = "trace", skip(self))]
216 pub unsafe fn flush(&self, family: u32) -> Result<()> {
217 let mut collectors = Vec::new();
219 for cell in self.thread_locals.iter() {
220 let state = unsafe { &mut *cell.get() };
221 if let Some(collector) = state.collectors[usize_from_u32(family)].take()
222 && !collector.is_empty()
223 {
224 collectors.push(collector);
225 }
226 }
227
228 self.parallel_scheduler
229 .try_parallel_for_each_owned(collectors, |mut collector| {
230 self.flush_thread_local_collector(family, &mut collector)?;
231 drop(collector);
232 anyhow::Ok(())
233 })?;
234
235 let mut collector_state = self.collectors[usize_from_u32(family)].lock();
237 match &mut *collector_state {
238 GlobalCollectorState::Unsharded(collector) => {
239 if !collector.is_empty() {
240 let sst = self.create_sst_file(family, collector.sorted())?;
241 collector.clear();
242 self.new_sst_files.lock().push(sst);
243 }
244 }
245 GlobalCollectorState::Sharded(_) => {
246 let GlobalCollectorState::Sharded(mut shards) = replace(
247 &mut *collector_state,
248 GlobalCollectorState::Unsharded(Collector::new()),
249 ) else {
250 unreachable!();
251 };
252 self.parallel_scheduler
253 .try_parallel_for_each_mut(&mut shards, |collector| {
254 if !collector.is_empty() {
255 let sst = self.create_sst_file(family, collector.sorted())?;
256 collector.clear();
257 self.new_sst_files.lock().push(sst);
258 collector.drop_contents();
259 }
260 anyhow::Ok(())
261 })?;
262 }
263 }
264
265 Ok(())
266 }
267
268 #[tracing::instrument(level = "trace", skip(self))]
271 pub(crate) fn finish(&mut self) -> Result<FinishResult> {
272 let mut new_blob_files = Vec::new();
273
274 {
276 let _span = tracing::trace_span!("flush thread local collectors").entered();
277 let mut collectors = [const { Vec::new() }; FAMILIES];
278 for cell in self.thread_locals.iter_mut() {
279 let state = cell.get_mut();
280 new_blob_files.append(&mut state.new_blob_files);
281 for (family, thread_local_collector) in state.collectors.iter_mut().enumerate() {
282 if let Some(collector) = thread_local_collector.take()
283 && !collector.is_empty()
284 {
285 collectors[family].push(collector);
286 }
287 }
288 }
289 let to_flush = collectors
290 .into_iter()
291 .enumerate()
292 .flat_map(|(family, collector)| {
293 collector
294 .into_iter()
295 .map(move |collector| (family as u32, collector))
296 })
297 .collect::<Vec<_>>();
298 self.parallel_scheduler.try_parallel_for_each_owned(
299 to_flush,
300 |(family, mut collector)| {
301 self.flush_thread_local_collector(family, &mut collector)?;
302 drop(collector);
303 anyhow::Ok(())
304 },
305 )?;
306 }
307
308 let _span = tracing::trace_span!("flush collectors").entered();
309
310 let mut new_sst_files = take(self.new_sst_files.get_mut());
312 let shared_new_sst_files = Mutex::new(&mut new_sst_files);
313
314 let new_collectors =
315 [(); FAMILIES].map(|_| Mutex::new(GlobalCollectorState::Unsharded(Collector::new())));
316 let collectors = replace(&mut self.collectors, new_collectors);
317 let collectors = collectors
318 .into_iter()
319 .enumerate()
320 .flat_map(|(family, state)| {
321 let collector = state.into_inner();
322 match collector {
323 GlobalCollectorState::Unsharded(collector) => {
324 Either::Left([(family, collector)].into_iter())
325 }
326 GlobalCollectorState::Sharded(shards) => {
327 Either::Right(shards.into_iter().map(move |collector| (family, collector)))
328 }
329 }
330 })
331 .collect::<Vec<_>>();
332 self.parallel_scheduler.try_parallel_for_each_owned(
333 collectors,
334 |(family, mut collector)| {
335 let family = family as u32;
336 if !collector.is_empty() {
337 let sst = self.create_sst_file(family, collector.sorted())?;
338 collector.clear();
339 drop(collector);
340 shared_new_sst_files.lock().push(sst);
341 }
342 anyhow::Ok(())
343 },
344 )?;
345
346 let new_meta_collectors = [(); FAMILIES].map(|_| Mutex::new(Vec::new()));
348 let meta_collectors = replace(&mut self.meta_collectors, new_meta_collectors);
349 let keys_written = AtomicU64::new(0);
350 let file_to_write = meta_collectors
351 .into_iter()
352 .map(|mutex| mutex.into_inner())
353 .enumerate()
354 .filter(|(_, sst_files)| !sst_files.is_empty())
355 .collect::<Vec<_>>();
356 let new_meta_files = self
357 .parallel_scheduler
358 .parallel_map_collect_owned::<_, _, Result<Vec<_>>>(
359 file_to_write,
360 |(family, sst_files)| {
361 let family = family as u32;
362 let mut entries = 0;
363 let mut builder = MetaFileBuilder::new(family);
364 for (seq, sst) in sst_files {
365 entries += sst.entries;
366 builder.add(seq, sst);
367 }
368 keys_written.fetch_add(entries, Ordering::Relaxed);
369 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
370 let file = builder.write(&self.db_path, seq)?;
371 Ok((seq, file))
372 },
373 )?;
374
375 let seq = self.current_sequence_number.load(Ordering::SeqCst);
377 Ok(FinishResult {
378 sequence_number: seq,
379 new_meta_files,
380 new_sst_files,
381 new_blob_files,
382 keys_written: keys_written.into_inner(),
383 })
384 }
385
386 #[tracing::instrument(level = "trace", skip(self, value), fields(value_len = value.len()))]
389 fn create_blob(&self, value: &[u8]) -> Result<(u32, File)> {
390 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
391 let mut buffer = Vec::new();
392 buffer.write_u32::<BE>(value.len() as u32)?;
393 compress_into_buffer(value, None, true, &mut buffer)
394 .context("Compression of value for blob file failed")?;
395
396 let file = self.db_path.join(format!("{seq:08}.blob"));
397 let mut file = File::create(&file).context("Unable to create blob file")?;
398 file.write_all(&buffer)
399 .context("Unable to write blob file")?;
400 file.flush().context("Unable to flush blob file")?;
401 Ok((seq, file))
402 }
403
404 #[tracing::instrument(level = "trace", skip(self, collector_data))]
407 fn create_sst_file(
408 &self,
409 family: u32,
410 collector_data: (&[CollectorEntry<K>], usize),
411 ) -> Result<(u32, File)> {
412 let (entries, total_key_size) = collector_data;
413 let seq = self.current_sequence_number.fetch_add(1, Ordering::SeqCst) + 1;
414
415 let path = self.db_path.join(format!("{seq:08}.sst"));
416 let (meta, file) = self
417 .parallel_scheduler
418 .block_in_place(|| write_static_stored_file(entries, total_key_size, &path))
419 .with_context(|| format!("Unable to write SST file {seq:08}.sst"))?;
420
421 #[cfg(feature = "verify_sst_content")]
422 {
423 use core::panic;
424
425 use crate::{
426 collector_entry::CollectorEntryValue,
427 key::hash_key,
428 lookup_entry::LookupValue,
429 static_sorted_file::{
430 BlockCache, SstLookupResult, StaticSortedFile, StaticSortedFileMetaData,
431 },
432 static_sorted_file_builder::Entry,
433 };
434
435 file.sync_all()?;
436 let sst = StaticSortedFile::open(
437 &self.db_path,
438 StaticSortedFileMetaData {
439 sequence_number: seq,
440 key_compression_dictionary_length: meta.key_compression_dictionary_length,
441 block_count: meta.block_count,
442 },
443 )?;
444 let cache2 = BlockCache::with(
445 10,
446 u64::MAX,
447 Default::default(),
448 Default::default(),
449 Default::default(),
450 );
451 let cache3 = BlockCache::with(
452 10,
453 u64::MAX,
454 Default::default(),
455 Default::default(),
456 Default::default(),
457 );
458 let mut key_buf = Vec::new();
459 for entry in entries {
460 entry.write_key_to(&mut key_buf);
461 let result = sst
462 .lookup(hash_key(&key_buf), &key_buf, &cache2, &cache3)
463 .expect("key found");
464 key_buf.clear();
465 match result {
466 SstLookupResult::Found(LookupValue::Deleted) => {}
467 SstLookupResult::Found(LookupValue::Slice {
468 value: lookup_value,
469 }) => {
470 let expected_value_slice = match &entry.value {
471 CollectorEntryValue::Small { value } => &**value,
472 CollectorEntryValue::Medium { value } => &**value,
473 _ => panic!("Unexpected value"),
474 };
475 assert_eq!(*lookup_value, *expected_value_slice);
476 }
477 SstLookupResult::Found(LookupValue::Blob { sequence_number: _ }) => {}
478 SstLookupResult::NotFound => panic!("All keys must exist"),
479 }
480 }
481 }
482
483 self.meta_collectors[usize_from_u32(family)]
484 .lock()
485 .push((seq, meta));
486
487 Ok((seq, file))
488 }
489}
490
491#[inline(always)]
492const fn usize_from_u32(value: u32) -> usize {
493 const {
496 assert!(u32::BITS < usize::BITS);
497 };
498 value as usize
499}