turbo_tasks/
util.rs

1use std::{
2    error::Error as StdError,
3    fmt::{Debug, Display},
4    future::Future,
5    hash::{Hash, Hasher},
6    ops::Deref,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use anyhow::{Error, anyhow};
14use pin_project_lite::pin_project;
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16
17pub use super::{
18    id_factory::{IdFactory, IdFactoryWithReuse},
19    no_move_vec::NoMoveVec,
20    once_map::*,
21};
22
23/// A error struct that is backed by an Arc to allow cloning errors
24#[derive(Debug, Clone)]
25pub struct SharedError {
26    inner: Arc<anyhow::Error>,
27}
28
29impl SharedError {
30    pub fn new(err: anyhow::Error) -> Self {
31        Self {
32            inner: Arc::new(err),
33        }
34    }
35}
36
37impl StdError for SharedError {
38    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
39        self.inner.source()
40    }
41
42    fn provide<'a>(&'a self, req: &mut std::error::Request<'a>) {
43        self.inner.provide(req);
44    }
45}
46
47impl Display for SharedError {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        Display::fmt(&*self.inner, f)
50    }
51}
52
53impl From<Error> for SharedError {
54    fn from(e: Error) -> Self {
55        Self::new(e)
56    }
57}
58
59impl PartialEq for SharedError {
60    fn eq(&self, other: &Self) -> bool {
61        Arc::ptr_eq(&self.inner, &other.inner)
62    }
63}
64
65impl Eq for SharedError {}
66
67impl Serialize for SharedError {
68    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
69        let mut v = vec![self.to_string()];
70        let mut source = self.source();
71        while let Some(s) = source {
72            v.push(s.to_string());
73            source = s.source();
74        }
75        Serialize::serialize(&v, serializer)
76    }
77}
78
79impl<'de> Deserialize<'de> for SharedError {
80    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
81        use serde::de::Error;
82        let mut messages = <Vec<String>>::deserialize(deserializer)?;
83        let mut e = match messages.pop() {
84            Some(e) => anyhow!(e),
85            None => return Err(Error::custom("expected at least 1 error message")),
86        };
87        while let Some(message) = messages.pop() {
88            e = e.context(message);
89        }
90        Ok(SharedError::new(e))
91    }
92}
93
94impl Deref for SharedError {
95    type Target = Arc<Error>;
96    fn deref(&self) -> &Self::Target {
97        &self.inner
98    }
99}
100
101pub struct FormatDuration(pub Duration);
102
103impl Display for FormatDuration {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        let s = self.0.as_secs();
106        if s > 10 {
107            return write!(f, "{s}s");
108        }
109        let ms = self.0.as_millis();
110        if ms > 10 {
111            return write!(f, "{ms}ms");
112        }
113        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
114    }
115}
116
117impl Debug for FormatDuration {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        let s = self.0.as_secs();
120        if s > 100 {
121            return write!(f, "{s}s");
122        }
123        let ms = self.0.as_millis();
124        if ms > 10000 {
125            return write!(f, "{:.2}s", (ms as f32) / 1000.0);
126        }
127        if ms > 100 {
128            return write!(f, "{ms}ms");
129        }
130        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
131    }
132}
133
134pub struct FormatBytes(pub usize);
135
136impl Display for FormatBytes {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        let b = self.0;
139        const KB: usize = 1_024;
140        const MB: usize = 1_024 * KB;
141        const GB: usize = 1_024 * MB;
142        if b > GB {
143            return write!(f, "{:.2}GiB", ((b / MB) as f32) / 1_024.0);
144        }
145        if b > MB {
146            return write!(f, "{:.2}MiB", ((b / KB) as f32) / 1_024.0);
147        }
148        if b > KB {
149            return write!(f, "{:.2}KiB", (b as f32) / 1_024.0);
150        }
151        write!(f, "{b}B")
152    }
153}
154
155/// Smart pointer that stores data either in an [Arc] or as a static reference.
156pub enum StaticOrArc<T: ?Sized + 'static> {
157    Static(&'static T),
158    Shared(Arc<T>),
159}
160
161impl<T: ?Sized + 'static> AsRef<T> for StaticOrArc<T> {
162    fn as_ref(&self) -> &T {
163        match self {
164            Self::Static(s) => s,
165            Self::Shared(b) => b,
166        }
167    }
168}
169
170impl<T: ?Sized + 'static> From<&'static T> for StaticOrArc<T> {
171    fn from(s: &'static T) -> Self {
172        Self::Static(s)
173    }
174}
175
176impl<T: ?Sized + 'static> From<Arc<T>> for StaticOrArc<T> {
177    fn from(b: Arc<T>) -> Self {
178        Self::Shared(b)
179    }
180}
181
182impl<T: 'static> From<T> for StaticOrArc<T> {
183    fn from(b: T) -> Self {
184        Self::Shared(Arc::new(b))
185    }
186}
187
188impl<T: ?Sized + 'static> Deref for StaticOrArc<T> {
189    type Target = T;
190
191    fn deref(&self) -> &Self::Target {
192        self.as_ref()
193    }
194}
195
196impl<T: ?Sized + 'static> Clone for StaticOrArc<T> {
197    fn clone(&self) -> Self {
198        match self {
199            Self::Static(s) => Self::Static(s),
200            Self::Shared(b) => Self::Shared(b.clone()),
201        }
202    }
203}
204
205impl<T: ?Sized + PartialEq + 'static> PartialEq for StaticOrArc<T> {
206    fn eq(&self, other: &Self) -> bool {
207        **self == **other
208    }
209}
210
211impl<T: ?Sized + PartialEq + Eq + 'static> Eq for StaticOrArc<T> {}
212
213impl<T: ?Sized + Hash + 'static> Hash for StaticOrArc<T> {
214    fn hash<H: Hasher>(&self, state: &mut H) {
215        (**self).hash(state);
216    }
217}
218
219impl<T: ?Sized + Display + 'static> Display for StaticOrArc<T> {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        (**self).fmt(f)
222    }
223}
224
225impl<T: ?Sized + Debug + 'static> Debug for StaticOrArc<T> {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        (**self).fmt(f)
228    }
229}
230
231pin_project! {
232    /// A future that wraps another future and applies a function on every poll call.
233    pub struct WrapFuture<F, W> {
234        wrapper: W,
235        #[pin]
236        future: F,
237    }
238}
239
240impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> WrapFuture<F, W> {
241    pub fn new(future: F, wrapper: W) -> Self {
242        Self { wrapper, future }
243    }
244}
245
246impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> Future
247    for WrapFuture<F, W>
248{
249    type Output = F::Output;
250
251    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
252        let this = self.project();
253        (this.wrapper)(this.future, cx)
254    }
255}