turbo_tasks/
util.rs

1use std::{
2    cell::SyncUnsafeCell,
3    error::Error as StdError,
4    fmt::{Debug, Display},
5    future::Future,
6    hash::{Hash, Hasher},
7    mem::ManuallyDrop,
8    ops::Deref,
9    pin::Pin,
10    sync::{Arc, LazyLock},
11    task::{Context, Poll},
12    thread::available_parallelism,
13    time::Duration,
14};
15
16use anyhow::{Error, anyhow};
17use pin_project_lite::pin_project;
18use serde::{Deserialize, Deserializer, Serialize, Serializer};
19
20pub use super::{
21    id_factory::{IdFactory, IdFactoryWithReuse},
22    once_map::*,
23};
24
25/// A error struct that is backed by an Arc to allow cloning errors
26#[derive(Debug, Clone)]
27pub struct SharedError {
28    inner: Arc<anyhow::Error>,
29}
30
31impl SharedError {
32    pub fn new(err: anyhow::Error) -> Self {
33        Self {
34            inner: Arc::new(err),
35        }
36    }
37}
38
39impl AsRef<dyn StdError> for SharedError {
40    fn as_ref(&self) -> &(dyn StdError + 'static) {
41        let err: &anyhow::Error = &self.inner;
42        err.as_ref()
43    }
44}
45
46impl StdError for SharedError {
47    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
48        self.inner.source()
49    }
50
51    fn provide<'a>(&'a self, req: &mut std::error::Request<'a>) {
52        self.inner.provide(req);
53    }
54}
55
56impl Display for SharedError {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        Display::fmt(&*self.inner, f)
59    }
60}
61
62impl From<Error> for SharedError {
63    fn from(e: Error) -> Self {
64        Self::new(e)
65    }
66}
67
68impl PartialEq for SharedError {
69    fn eq(&self, other: &Self) -> bool {
70        Arc::ptr_eq(&self.inner, &other.inner)
71    }
72}
73
74impl Eq for SharedError {}
75
76impl Serialize for SharedError {
77    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
78        let mut v = vec![self.to_string()];
79        let mut source = self.source();
80        while let Some(s) = source {
81            v.push(s.to_string());
82            source = s.source();
83        }
84        Serialize::serialize(&v, serializer)
85    }
86}
87
88impl<'de> Deserialize<'de> for SharedError {
89    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
90        use serde::de::Error;
91        let mut messages = <Vec<String>>::deserialize(deserializer)?;
92        let mut e = match messages.pop() {
93            Some(e) => anyhow!(e),
94            None => return Err(Error::custom("expected at least 1 error message")),
95        };
96        while let Some(message) = messages.pop() {
97            e = e.context(message);
98        }
99        Ok(SharedError::new(e))
100    }
101}
102
103impl Deref for SharedError {
104    type Target = Arc<Error>;
105    fn deref(&self) -> &Self::Target {
106        &self.inner
107    }
108}
109
110pub struct FormatDuration(pub Duration);
111
112impl Display for FormatDuration {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        let s = self.0.as_secs();
115        if s > 10 {
116            return write!(f, "{s}s");
117        }
118        let ms = self.0.as_millis();
119        if ms > 10 {
120            return write!(f, "{ms}ms");
121        }
122        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
123    }
124}
125
126impl Debug for FormatDuration {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        let s = self.0.as_secs();
129        if s > 100 {
130            return write!(f, "{s}s");
131        }
132        let ms = self.0.as_millis();
133        if ms > 10000 {
134            return write!(f, "{:.2}s", (ms as f32) / 1000.0);
135        }
136        if ms > 100 {
137            return write!(f, "{ms}ms");
138        }
139        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
140    }
141}
142
143pub struct FormatBytes(pub usize);
144
145impl Display for FormatBytes {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let b = self.0;
148        const KB: usize = 1_024;
149        const MB: usize = 1_024 * KB;
150        const GB: usize = 1_024 * MB;
151        if b > GB {
152            return write!(f, "{:.2}GiB", ((b / MB) as f32) / 1_024.0);
153        }
154        if b > MB {
155            return write!(f, "{:.2}MiB", ((b / KB) as f32) / 1_024.0);
156        }
157        if b > KB {
158            return write!(f, "{:.2}KiB", (b as f32) / 1_024.0);
159        }
160        write!(f, "{b}B")
161    }
162}
163
164/// Smart pointer that stores data either in an [Arc] or as a static reference.
165pub enum StaticOrArc<T: ?Sized + 'static> {
166    Static(&'static T),
167    Shared(Arc<T>),
168}
169
170impl<T: ?Sized + 'static> AsRef<T> for StaticOrArc<T> {
171    fn as_ref(&self) -> &T {
172        match self {
173            Self::Static(s) => s,
174            Self::Shared(b) => b,
175        }
176    }
177}
178
179impl<T: ?Sized + 'static> From<&'static T> for StaticOrArc<T> {
180    fn from(s: &'static T) -> Self {
181        Self::Static(s)
182    }
183}
184
185impl<T: ?Sized + 'static> From<Arc<T>> for StaticOrArc<T> {
186    fn from(b: Arc<T>) -> Self {
187        Self::Shared(b)
188    }
189}
190
191impl<T: 'static> From<T> for StaticOrArc<T> {
192    fn from(b: T) -> Self {
193        Self::Shared(Arc::new(b))
194    }
195}
196
197impl<T: ?Sized + 'static> Deref for StaticOrArc<T> {
198    type Target = T;
199
200    fn deref(&self) -> &Self::Target {
201        self.as_ref()
202    }
203}
204
205impl<T: ?Sized + 'static> Clone for StaticOrArc<T> {
206    fn clone(&self) -> Self {
207        match self {
208            Self::Static(s) => Self::Static(s),
209            Self::Shared(b) => Self::Shared(b.clone()),
210        }
211    }
212}
213
214impl<T: ?Sized + PartialEq + 'static> PartialEq for StaticOrArc<T> {
215    fn eq(&self, other: &Self) -> bool {
216        **self == **other
217    }
218}
219
220impl<T: ?Sized + PartialEq + Eq + 'static> Eq for StaticOrArc<T> {}
221
222impl<T: ?Sized + Hash + 'static> Hash for StaticOrArc<T> {
223    fn hash<H: Hasher>(&self, state: &mut H) {
224        (**self).hash(state);
225    }
226}
227
228impl<T: ?Sized + Display + 'static> Display for StaticOrArc<T> {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        (**self).fmt(f)
231    }
232}
233
234impl<T: ?Sized + Debug + 'static> Debug for StaticOrArc<T> {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        (**self).fmt(f)
237    }
238}
239
240pin_project! {
241    /// A future that wraps another future and applies a function on every poll call.
242    pub struct WrapFuture<F, W> {
243        wrapper: W,
244        #[pin]
245        future: F,
246    }
247}
248
249impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> WrapFuture<F, W> {
250    pub fn new(future: F, wrapper: W) -> Self {
251        Self { wrapper, future }
252    }
253}
254
255impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> Future
256    for WrapFuture<F, W>
257{
258    type Output = F::Output;
259
260    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261        let this = self.project();
262        (this.wrapper)(this.future, cx)
263    }
264}
265
266/// Calculates a good chunk size for parallel processing based on the number of available threads.
267/// This is used to ensure that the workload is evenly distributed across the threads.
268pub fn good_chunk_size(len: usize) -> usize {
269    static GOOD_CHUNK_COUNT: LazyLock<usize> =
270        LazyLock::new(|| available_parallelism().map_or(16, |c| c.get() * 4));
271    let min_chunk_count = *GOOD_CHUNK_COUNT;
272    len.div_ceil(min_chunk_count)
273}
274
275/// Similar to slice::chunks but for owned data. Chunks are Send and Sync to allow to use it for
276/// parallelism.
277pub fn into_chunks<T>(data: Vec<T>, chunk_size: usize) -> IntoChunks<T> {
278    let (ptr, length, capacity) = data.into_raw_parts();
279    // SAFETY: changing a pointer from T to SyncUnsafeCell<ManuallyDrop<..>> is safe as both types
280    // have repr(transparent).
281    let ptr = ptr as *mut SyncUnsafeCell<ManuallyDrop<T>>;
282    // SAFETY: The ptr, length and capacity were from into_raw_parts(). This is the only place where
283    // we use ptr.
284    let data =
285        unsafe { Vec::<SyncUnsafeCell<ManuallyDrop<T>>>::from_raw_parts(ptr, length, capacity) };
286    IntoChunks {
287        data: Arc::new(data),
288        index: 0,
289        chunk_size,
290    }
291}
292
293pub struct IntoChunks<T> {
294    data: Arc<Vec<SyncUnsafeCell<ManuallyDrop<T>>>>,
295    index: usize,
296    chunk_size: usize,
297}
298
299impl<T> Iterator for IntoChunks<T> {
300    type Item = Chunk<T>;
301
302    fn next(&mut self) -> Option<Self::Item> {
303        if self.index < self.data.len() {
304            let end = self.data.len().min(self.index + self.chunk_size);
305            let item = Chunk {
306                data: Arc::clone(&self.data),
307                index: self.index,
308                end,
309            };
310            self.index = end;
311            Some(item)
312        } else {
313            None
314        }
315    }
316}
317
318impl<T> IntoChunks<T> {
319    fn next_item(&mut self) -> Option<T> {
320        if self.index < self.data.len() {
321            // SAFETY: We are the only owner of this chunk of data and we make sure that this item
322            // is no longer dropped by moving the index
323            let item = unsafe { ManuallyDrop::take(&mut *self.data[self.index].get()) };
324            self.index += 1;
325            Some(item)
326        } else {
327            None
328        }
329    }
330}
331
332impl<T> Drop for IntoChunks<T> {
333    fn drop(&mut self) {
334        // To avoid leaking memory we need to drop the remaining items
335        while self.next_item().is_some() {}
336    }
337}
338
339pub struct Chunk<T> {
340    data: Arc<Vec<SyncUnsafeCell<ManuallyDrop<T>>>>,
341    index: usize,
342    end: usize,
343}
344
345impl<T> Iterator for Chunk<T> {
346    type Item = T;
347
348    fn next(&mut self) -> Option<Self::Item> {
349        if self.index < self.end {
350            // SAFETY: We are the only owner of this chunk of data and we make sure that this item
351            // is no longer dropped by moving the index
352            let item = unsafe { ManuallyDrop::take(&mut *self.data[self.index].get()) };
353            self.index += 1;
354            Some(item)
355        } else {
356            None
357        }
358    }
359}
360
361impl<T> Drop for Chunk<T> {
362    fn drop(&mut self) {
363        // To avoid leaking memory we need to drop the remaining items
364        while self.next().is_some() {}
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_chunk_iterator() {
374        let data = [(); 10]
375            .into_iter()
376            .enumerate()
377            .map(|(i, _)| Arc::new(i))
378            .collect::<Vec<_>>();
379        let mut chunks = into_chunks(data.clone(), 3);
380        let mut first_chunk = chunks.next().unwrap();
381        let second_chunk = chunks.next().unwrap();
382        drop(chunks);
383        assert_eq!(
384            second_chunk.into_iter().map(|a| *a).collect::<Vec<_>>(),
385            vec![3, 4, 5]
386        );
387        assert_eq!(*first_chunk.next().unwrap(), 0);
388        assert_eq!(*first_chunk.next().unwrap(), 1);
389        drop(first_chunk);
390        for arc in data {
391            assert_eq!(Arc::strong_count(&arc), 1);
392        }
393    }
394}