turbo_tasks/
util.rs

1use std::{
2    cell::SyncUnsafeCell,
3    error::Error as StdError,
4    fmt::{Debug, Display},
5    future::Future,
6    hash::{Hash, Hasher},
7    mem::ManuallyDrop,
8    ops::Deref,
9    pin::Pin,
10    sync::{Arc, LazyLock},
11    task::{Context, Poll},
12    thread::available_parallelism,
13    time::Duration,
14};
15
16use anyhow::{Error, anyhow};
17use bincode::{
18    Decode, Encode,
19    de::Decoder,
20    enc::Encoder,
21    error::{DecodeError, EncodeError},
22};
23use pin_project_lite::pin_project;
24
25pub use crate::{
26    id_factory::{IdFactory, IdFactoryWithReuse},
27    once_map::*,
28};
29
30/// A error struct that is backed by an Arc to allow cloning errors
31#[derive(Debug, Clone)]
32pub struct SharedError {
33    inner: Arc<anyhow::Error>,
34}
35
36impl SharedError {
37    pub fn new(err: anyhow::Error) -> Self {
38        Self {
39            inner: Arc::new(err),
40        }
41    }
42}
43
44impl AsRef<dyn StdError> for SharedError {
45    fn as_ref(&self) -> &(dyn StdError + 'static) {
46        let err: &anyhow::Error = &self.inner;
47        err.as_ref()
48    }
49}
50
51impl StdError for SharedError {
52    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
53        self.inner.source()
54    }
55
56    fn provide<'a>(&'a self, req: &mut std::error::Request<'a>) {
57        self.inner.provide(req);
58    }
59}
60
61impl Display for SharedError {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        Display::fmt(&*self.inner, f)
64    }
65}
66
67impl From<Error> for SharedError {
68    fn from(e: Error) -> Self {
69        Self::new(e)
70    }
71}
72
73impl PartialEq for SharedError {
74    fn eq(&self, other: &Self) -> bool {
75        Arc::ptr_eq(&self.inner, &other.inner)
76    }
77}
78
79impl Eq for SharedError {}
80
81impl Encode for SharedError {
82    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
83        let mut v = vec![self.to_string()];
84        let mut source = self.source();
85        while let Some(s) = source {
86            v.push(s.to_string());
87            source = s.source();
88        }
89        Encode::encode(&v, encoder)
90    }
91}
92
93impl<Context> Decode<Context> for SharedError {
94    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
95        let mut messages = <Vec<String>>::decode(decoder)?;
96        let msg = messages
97            .pop()
98            .ok_or(DecodeError::Other("expected at least 1 error message"))?;
99        let mut e = anyhow!(msg);
100        while let Some(message) = messages.pop() {
101            e = e.context(message);
102        }
103        Ok(SharedError::new(e))
104    }
105}
106
107bincode::impl_borrow_decode!(SharedError);
108
109impl Deref for SharedError {
110    type Target = Arc<Error>;
111    fn deref(&self) -> &Self::Target {
112        &self.inner
113    }
114}
115
116pub struct FormatDuration(pub Duration);
117
118impl Display for FormatDuration {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        let s = self.0.as_secs();
121        if s > 10 {
122            return write!(f, "{s}s");
123        }
124        let ms = self.0.as_millis();
125        if ms > 10 {
126            return write!(f, "{ms}ms");
127        }
128        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
129    }
130}
131
132impl Debug for FormatDuration {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        let s = self.0.as_secs();
135        if s > 100 {
136            return write!(f, "{s}s");
137        }
138        let ms = self.0.as_millis();
139        if ms > 10000 {
140            return write!(f, "{:.2}s", (ms as f32) / 1000.0);
141        }
142        if ms > 100 {
143            return write!(f, "{ms}ms");
144        }
145        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
146    }
147}
148
149pub struct FormatBytes(pub usize);
150
151impl Display for FormatBytes {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        let b = self.0;
154        const KB: usize = 1_024;
155        const MB: usize = 1_024 * KB;
156        const GB: usize = 1_024 * MB;
157        if b > GB {
158            return write!(f, "{:.2}GiB", ((b / MB) as f32) / 1_024.0);
159        }
160        if b > MB {
161            return write!(f, "{:.2}MiB", ((b / KB) as f32) / 1_024.0);
162        }
163        if b > KB {
164            return write!(f, "{:.2}KiB", (b as f32) / 1_024.0);
165        }
166        write!(f, "{b}B")
167    }
168}
169
170/// Smart pointer that stores data either in an [Arc] or as a static reference.
171pub enum StaticOrArc<T: ?Sized + 'static> {
172    Static(&'static T),
173    Shared(Arc<T>),
174}
175
176impl<T: ?Sized + 'static> AsRef<T> for StaticOrArc<T> {
177    fn as_ref(&self) -> &T {
178        match self {
179            Self::Static(s) => s,
180            Self::Shared(b) => b,
181        }
182    }
183}
184
185impl<T: ?Sized + 'static> From<&'static T> for StaticOrArc<T> {
186    fn from(s: &'static T) -> Self {
187        Self::Static(s)
188    }
189}
190
191impl<T: ?Sized + 'static> From<Arc<T>> for StaticOrArc<T> {
192    fn from(b: Arc<T>) -> Self {
193        Self::Shared(b)
194    }
195}
196
197impl<T: 'static> From<T> for StaticOrArc<T> {
198    fn from(b: T) -> Self {
199        Self::Shared(Arc::new(b))
200    }
201}
202
203impl<T: ?Sized + 'static> Deref for StaticOrArc<T> {
204    type Target = T;
205
206    fn deref(&self) -> &Self::Target {
207        self.as_ref()
208    }
209}
210
211impl<T: ?Sized + 'static> Clone for StaticOrArc<T> {
212    fn clone(&self) -> Self {
213        match self {
214            Self::Static(s) => Self::Static(s),
215            Self::Shared(b) => Self::Shared(b.clone()),
216        }
217    }
218}
219
220impl<T: ?Sized + PartialEq + 'static> PartialEq for StaticOrArc<T> {
221    fn eq(&self, other: &Self) -> bool {
222        **self == **other
223    }
224}
225
226impl<T: ?Sized + PartialEq + Eq + 'static> Eq for StaticOrArc<T> {}
227
228impl<T: ?Sized + Hash + 'static> Hash for StaticOrArc<T> {
229    fn hash<H: Hasher>(&self, state: &mut H) {
230        (**self).hash(state);
231    }
232}
233
234impl<T: ?Sized + Display + 'static> Display for StaticOrArc<T> {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        (**self).fmt(f)
237    }
238}
239
240impl<T: ?Sized + Debug + 'static> Debug for StaticOrArc<T> {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        (**self).fmt(f)
243    }
244}
245
246pin_project! {
247    /// A future that wraps another future and applies a function on every poll call.
248    pub struct WrapFuture<F, W> {
249        wrapper: W,
250        #[pin]
251        future: F,
252    }
253}
254
255impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> WrapFuture<F, W> {
256    pub fn new(future: F, wrapper: W) -> Self {
257        Self { wrapper, future }
258    }
259}
260
261impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> Future
262    for WrapFuture<F, W>
263{
264    type Output = F::Output;
265
266    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        let this = self.project();
268        (this.wrapper)(this.future, cx)
269    }
270}
271
272/// Calculates a good chunk size for parallel processing based on the number of available threads.
273/// This is used to ensure that the workload is evenly distributed across the threads.
274pub fn good_chunk_size(len: usize) -> usize {
275    static GOOD_CHUNK_COUNT: LazyLock<usize> =
276        LazyLock::new(|| available_parallelism().map_or(16, |c| c.get() * 4));
277    let min_chunk_count = *GOOD_CHUNK_COUNT;
278    len.div_ceil(min_chunk_count)
279}
280
281/// Similar to slice::chunks but for owned data. Chunks are Send and Sync to allow to use it for
282/// parallelism.
283pub fn into_chunks<T>(data: Vec<T>, chunk_size: usize) -> IntoChunks<T> {
284    let (ptr, length, capacity) = data.into_raw_parts();
285    // SAFETY: changing a pointer from T to SyncUnsafeCell<ManuallyDrop<..>> is safe as both types
286    // have repr(transparent).
287    let ptr = ptr as *mut SyncUnsafeCell<ManuallyDrop<T>>;
288    // SAFETY: The ptr, length and capacity were from into_raw_parts(). This is the only place where
289    // we use ptr.
290    let data =
291        unsafe { Vec::<SyncUnsafeCell<ManuallyDrop<T>>>::from_raw_parts(ptr, length, capacity) };
292    IntoChunks {
293        data: Arc::new(data),
294        index: 0,
295        chunk_size,
296    }
297}
298
299pub struct IntoChunks<T> {
300    data: Arc<Vec<SyncUnsafeCell<ManuallyDrop<T>>>>,
301    index: usize,
302    chunk_size: usize,
303}
304
305impl<T> Iterator for IntoChunks<T> {
306    type Item = Chunk<T>;
307
308    fn next(&mut self) -> Option<Self::Item> {
309        if self.index < self.data.len() {
310            let end = self.data.len().min(self.index + self.chunk_size);
311            let item = Chunk {
312                data: Arc::clone(&self.data),
313                index: self.index,
314                end,
315            };
316            self.index = end;
317            Some(item)
318        } else {
319            None
320        }
321    }
322}
323
324impl<T> IntoChunks<T> {
325    fn next_item(&mut self) -> Option<T> {
326        if self.index < self.data.len() {
327            // SAFETY: We are the only owner of this chunk of data and we make sure that this item
328            // is no longer dropped by moving the index
329            let item = unsafe { ManuallyDrop::take(&mut *self.data[self.index].get()) };
330            self.index += 1;
331            Some(item)
332        } else {
333            None
334        }
335    }
336}
337
338impl<T> Drop for IntoChunks<T> {
339    fn drop(&mut self) {
340        // To avoid leaking memory we need to drop the remaining items
341        while self.next_item().is_some() {}
342    }
343}
344
345pub struct Chunk<T> {
346    data: Arc<Vec<SyncUnsafeCell<ManuallyDrop<T>>>>,
347    index: usize,
348    end: usize,
349}
350
351impl<T> Iterator for Chunk<T> {
352    type Item = T;
353
354    fn next(&mut self) -> Option<Self::Item> {
355        if self.index < self.end {
356            // SAFETY: We are the only owner of this chunk of data and we make sure that this item
357            // is no longer dropped by moving the index
358            let item = unsafe { ManuallyDrop::take(&mut *self.data[self.index].get()) };
359            self.index += 1;
360            Some(item)
361        } else {
362            None
363        }
364    }
365
366    fn size_hint(&self) -> (usize, Option<usize>) {
367        let len = self.len();
368        (len, Some(len))
369    }
370}
371
372impl<T> ExactSizeIterator for Chunk<T> {
373    fn len(&self) -> usize {
374        self.end - self.index
375    }
376}
377
378impl<T> Drop for Chunk<T> {
379    fn drop(&mut self) {
380        // To avoid leaking memory we need to drop the remaining items
381        while self.next().is_some() {}
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn test_chunk_iterator() {
391        let data = [(); 10]
392            .into_iter()
393            .enumerate()
394            .map(|(i, _)| Arc::new(i))
395            .collect::<Vec<_>>();
396        let mut chunks = into_chunks(data.clone(), 3);
397        let mut first_chunk = chunks.next().unwrap();
398        let second_chunk = chunks.next().unwrap();
399        drop(chunks);
400        assert_eq!(
401            second_chunk.into_iter().map(|a| *a).collect::<Vec<_>>(),
402            vec![3, 4, 5]
403        );
404        assert_eq!(*first_chunk.next().unwrap(), 0);
405        assert_eq!(*first_chunk.next().unwrap(), 1);
406        drop(first_chunk);
407        for arc in data {
408            assert_eq!(Arc::strong_count(&arc), 1);
409        }
410    }
411}