1use std::mem::take;
2
3use crate::timestamp::Timestamp;
4
5const SPLIT_COUNT: usize = 128;
6const BALANCE_THRESHOLD: usize = 3;
8
9pub struct SelfTimeTree<T> {
10 entries: Vec<SelfTimeEntry<T>>,
11 children: Option<Box<SelfTimeChildren<T>>>,
12 count: usize,
13}
14
15struct SelfTimeEntry<T> {
16 start: Timestamp,
17 end: Timestamp,
18 item: T,
19}
20
21struct SelfTimeChildren<T> {
22 left: SelfTimeTree<T>,
24 split_point: Timestamp,
25 right: SelfTimeTree<T>,
27 spanning_entries: usize,
29}
30
31impl<T> Default for SelfTimeTree<T> {
32 fn default() -> Self {
33 Self {
34 entries: Vec::new(),
35 children: None,
36 count: 0,
37 }
38 }
39}
40
41impl<T> SelfTimeTree<T> {
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn len(&self) -> usize {
47 self.count
48 }
49
50 pub fn insert(&mut self, start: Timestamp, end: Timestamp, item: T) {
51 self.count += 1;
52 self.entries.push(SelfTimeEntry { start, end, item });
53 self.check_for_split();
54 }
55
56 fn check_for_split(&mut self) {
57 if self.entries.len() >= SPLIT_COUNT {
58 let spanning_entries = if let Some(children) = &mut self.children {
59 children.spanning_entries
60 } else {
61 0
62 };
63 if self.entries.len() - spanning_entries >= SPLIT_COUNT {
64 self.split();
65 }
66 }
67 }
68
69 fn split(&mut self) {
70 debug_assert!(!self.entries.is_empty());
71 self.distribute_entries();
72 self.rebalance();
73 }
74
75 fn distribute_entries(&mut self) {
76 if self.children.is_none() {
77 let start = self.entries.iter().min_by_key(|e| e.start).unwrap().start;
78 let end = self.entries.iter().max_by_key(|e| e.end).unwrap().end;
79 let middle = (start + end) / 2;
80 self.children = Some(Box::new(SelfTimeChildren {
81 left: SelfTimeTree::new(),
82 split_point: middle,
83 right: SelfTimeTree::new(),
84 spanning_entries: 0,
85 }));
86 }
87 let Some(children) = &mut self.children else {
88 unreachable!();
89 };
90 let mut i = children.spanning_entries;
91 while i < self.entries.len() {
92 let SelfTimeEntry { start, end, .. } = self.entries[i];
93 if end <= children.split_point {
94 let SelfTimeEntry { start, end, item } = self.entries.swap_remove(i);
95 children.left.insert(start, end, item);
96 } else if start >= children.split_point {
97 let SelfTimeEntry { start, end, item } = self.entries.swap_remove(i);
98 children.right.insert(start, end, item);
99 } else {
100 self.entries.swap(i, children.spanning_entries);
101 children.spanning_entries += 1;
102 i += 1;
103 }
104 }
105 }
106
107 fn rebalance(&mut self) {
108 if let Some(box SelfTimeChildren {
109 left,
110 split_point,
111 right,
112 spanning_entries,
113 }) = &mut self.children
114 {
115 let SelfTimeTree {
116 count: left_count,
117 children: left_children,
118 entries: left_entries,
119 } = left;
120 let SelfTimeTree {
121 count: right_count,
122 children: right_children,
123 entries: right_entries,
124 } = right;
125 if *left_count > *right_count * BALANCE_THRESHOLD + *spanning_entries {
126 if let Some(box SelfTimeChildren {
133 left: left_left,
134 split_point: left_split_point,
135 right: left_right,
136 spanning_entries: _,
137 }) = left_children
138 {
139 *right = Self {
140 count: left_right.count + right.count,
141 entries: Vec::new(),
142 children: Some(Box::new(SelfTimeChildren {
143 left: take(left_right),
144 split_point: *split_point,
145 right: take(right),
146 spanning_entries: 0,
147 })),
148 };
149 *split_point = *left_split_point;
150 self.entries.append(left_entries);
151 *left = take(left_left);
152 *spanning_entries = 0;
153 self.distribute_entries();
154 }
155 } else if *right_count > *left_count * BALANCE_THRESHOLD + *spanning_entries {
156 if let Some(box SelfTimeChildren {
163 left: right_left,
164 split_point: right_split_point,
165 right: right_right,
166 spanning_entries: _,
167 }) = right_children
168 {
169 *left = Self {
170 count: left.count + right_left.count,
171 entries: Vec::new(),
172 children: Some(Box::new(SelfTimeChildren {
173 left: take(left),
174 split_point: *split_point,
175 right: take(right_left),
176 spanning_entries: 0,
177 })),
178 };
179 *split_point = *right_split_point;
180 self.entries.append(right_entries);
181 *right = take(right_right);
182 *spanning_entries = 0;
183 self.check_for_split();
184 }
185 }
186 }
187 }
188
189 #[cfg(test)]
190 pub fn lookup_range_count(&self, start: Timestamp, end: Timestamp) -> Timestamp {
191 let mut total_count = Timestamp::ZERO;
192 for entry in &self.entries {
193 if entry.start < end && entry.end > start {
194 let start = std::cmp::max(entry.start, start);
195 let end = std::cmp::min(entry.end, end);
196 let span = end - start;
197 total_count += span;
198 }
199 }
200 if let Some(children) = &self.children {
201 if start < children.split_point {
202 total_count += children.left.lookup_range_count(start, end);
203 }
204 if end > children.split_point {
205 total_count += children.right.lookup_range_count(start, end);
206 }
207 }
208 total_count
209 }
210
211 pub fn lookup_range_corrected_time(&self, start: Timestamp, end: Timestamp) -> Timestamp {
212 let mut factor_times_1000 = 0u64;
213 #[derive(PartialEq, Eq, PartialOrd, Ord)]
214 enum Change {
215 Start,
216 End,
217 }
218 let mut current_count = 0;
219 let mut changes = Vec::new();
220 self.for_each_in_range(start, end, |s, e, _| {
221 if s <= start {
222 current_count += 1;
223 } else {
224 changes.push((s, Change::Start));
225 }
226 if e < end {
227 changes.push((e, Change::End));
228 }
229 });
230 changes.sort_unstable();
231 let mut current_ts = start;
232 for (ts, change) in changes {
233 if current_ts < ts {
234 let time_diff = ts - current_ts;
236 factor_times_1000 += *time_diff * 1000 / current_count;
237 current_ts = ts;
238 }
239 match change {
240 Change::Start => current_count += 1,
241 Change::End => current_count -= 1,
242 }
243 }
244 if current_ts < end {
245 let time_diff = end - current_ts;
246 factor_times_1000 += *time_diff * 1000 / current_count;
247 }
248 Timestamp::from_value(factor_times_1000 / 1000)
249 }
250
251 pub fn for_each_in_range(
252 &self,
253 start: Timestamp,
254 end: Timestamp,
255 mut f: impl FnMut(Timestamp, Timestamp, &T),
256 ) {
257 self.for_each_in_range_ref(start, end, &mut f);
258 }
259
260 fn for_each_in_range_ref(
261 &self,
262 start: Timestamp,
263 end: Timestamp,
264 f: &mut impl FnMut(Timestamp, Timestamp, &T),
265 ) {
266 for entry in &self.entries {
267 if entry.start < end && entry.end > start {
268 f(entry.start, entry.end, &entry.item);
269 }
270 }
271 if let Some(children) = &self.children {
272 if start < children.split_point {
273 children.left.for_each_in_range_ref(start, end, f);
274 }
275 if end > children.split_point {
276 children.right.for_each_in_range_ref(start, end, f);
277 }
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn print_tree<T>(tree: &SelfTimeTree<T>, indent: usize) {
287 if let Some(children) = &tree.children {
288 println!(
289 "{}{} items (split at {}, {} overlapping, {} total)",
290 " ".repeat(indent),
291 tree.entries.len(),
292 children.split_point,
293 children.spanning_entries,
294 tree.count
295 );
296 print_tree(&children.left, indent + 2);
297 print_tree(&children.right, indent + 2);
298 } else {
299 println!(
300 "{}{} items ({} total)",
301 " ".repeat(indent),
302 tree.entries.len(),
303 tree.count
304 );
305 }
306 }
307
308 fn assert_balanced<T>(tree: &SelfTimeTree<T>) {
309 if let Some(children) = &tree.children {
310 let l = children.left.count;
311 let r = children.right.count;
312 let s = children.spanning_entries;
313 if (l > SPLIT_COUNT || r > SPLIT_COUNT)
314 && ((l > r * BALANCE_THRESHOLD + s) || (r > l * BALANCE_THRESHOLD + s))
315 {
316 print_tree(tree, 0);
317 panic!("Tree is not balanced");
318 }
319 assert_balanced(&children.left);
320 assert_balanced(&children.right);
321 }
322 }
323
324 #[test]
325 fn test_simple() {
326 let mut tree = SelfTimeTree::new();
327 let count = 10000;
328 for i in 0..count {
329 tree.insert(Timestamp::from_micros(i), Timestamp::from_micros(i + 1), i);
330 assert_eq!(tree.count, (i + 1) as usize);
331 assert_balanced(&tree);
332 }
333 assert_eq!(
334 tree.lookup_range_count(Timestamp::ZERO, Timestamp::from_micros(count)),
335 Timestamp::from_micros(count)
336 );
337 print_tree(&tree, 0);
338 assert_balanced(&tree);
339 }
340
341 #[test]
342 fn test_evenly() {
343 let mut tree = SelfTimeTree::new();
344 let count = 10000;
345 for a in 0..10 {
346 for b in 0..10 {
347 for c in 0..10 {
348 for d in 0..10 {
349 let i = d * 1000 + c * 100 + b * 10 + a;
350 tree.insert(Timestamp::from_micros(i), Timestamp::from_micros(i + 1), i);
351 assert_balanced(&tree);
352 }
353 }
354 }
355 }
356 assert_eq!(
357 tree.lookup_range_count(Timestamp::ZERO, Timestamp::from_micros(count)),
358 Timestamp::from_micros(count)
359 );
360 print_tree(&tree, 0);
361 assert_balanced(&tree);
362 }
363
364 #[test]
365 fn test_overlapping() {
366 let mut tree = SelfTimeTree::new();
367 let count = 10000;
368 for i in 0..count {
369 tree.insert(
370 Timestamp::from_micros(i),
371 Timestamp::from_micros(i + 100),
372 i,
373 );
374 assert_eq!(tree.count, (i + 1) as usize);
375 assert_balanced(&tree);
376 }
377 assert_eq!(
378 tree.lookup_range_count(Timestamp::ZERO, Timestamp::from_micros(count + 100)),
379 Timestamp::from_micros(count * 100)
380 );
381 print_tree(&tree, 0);
382 assert_balanced(&tree);
383 }
384
385 #[test]
386 fn test_overlapping_heavy() {
387 let mut tree = SelfTimeTree::new();
388 let count = 10000;
389 for i in 0..count {
390 tree.insert(
391 Timestamp::from_micros(i),
392 Timestamp::from_micros(i + 500),
393 i,
394 );
395 assert_eq!(tree.count, (i + 1) as usize);
396 }
397 assert_eq!(
398 tree.lookup_range_count(Timestamp::ZERO, Timestamp::from_micros(count + 500)),
399 Timestamp::from_micros(count * 500)
400 );
401 print_tree(&tree, 0);
402 assert_balanced(&tree);
403 }
404}