turbo_tasks/
util.rs

1use std::{
2    error::Error as StdError,
3    fmt::{Debug, Display},
4    future::Future,
5    hash::{Hash, Hasher},
6    ops::Deref,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10    time::Duration,
11};
12
13use anyhow::{Error, anyhow};
14use pin_project_lite::pin_project;
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16
17pub use super::{
18    id_factory::{IdFactory, IdFactoryWithReuse},
19    once_map::*,
20};
21
22/// A error struct that is backed by an Arc to allow cloning errors
23#[derive(Debug, Clone)]
24pub struct SharedError {
25    inner: Arc<anyhow::Error>,
26}
27
28impl SharedError {
29    pub fn new(err: anyhow::Error) -> Self {
30        Self {
31            inner: Arc::new(err),
32        }
33    }
34}
35
36impl AsRef<dyn StdError> for SharedError {
37    fn as_ref(&self) -> &(dyn StdError + 'static) {
38        let err: &anyhow::Error = &self.inner;
39        err.as_ref()
40    }
41}
42
43impl StdError for SharedError {
44    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
45        self.inner.source()
46    }
47
48    fn provide<'a>(&'a self, req: &mut std::error::Request<'a>) {
49        self.inner.provide(req);
50    }
51}
52
53impl Display for SharedError {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        Display::fmt(&*self.inner, f)
56    }
57}
58
59impl From<Error> for SharedError {
60    fn from(e: Error) -> Self {
61        Self::new(e)
62    }
63}
64
65impl PartialEq for SharedError {
66    fn eq(&self, other: &Self) -> bool {
67        Arc::ptr_eq(&self.inner, &other.inner)
68    }
69}
70
71impl Eq for SharedError {}
72
73impl Serialize for SharedError {
74    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
75        let mut v = vec![self.to_string()];
76        let mut source = self.source();
77        while let Some(s) = source {
78            v.push(s.to_string());
79            source = s.source();
80        }
81        Serialize::serialize(&v, serializer)
82    }
83}
84
85impl<'de> Deserialize<'de> for SharedError {
86    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
87        use serde::de::Error;
88        let mut messages = <Vec<String>>::deserialize(deserializer)?;
89        let mut e = match messages.pop() {
90            Some(e) => anyhow!(e),
91            None => return Err(Error::custom("expected at least 1 error message")),
92        };
93        while let Some(message) = messages.pop() {
94            e = e.context(message);
95        }
96        Ok(SharedError::new(e))
97    }
98}
99
100impl Deref for SharedError {
101    type Target = Arc<Error>;
102    fn deref(&self) -> &Self::Target {
103        &self.inner
104    }
105}
106
107pub struct FormatDuration(pub Duration);
108
109impl Display for FormatDuration {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        let s = self.0.as_secs();
112        if s > 10 {
113            return write!(f, "{s}s");
114        }
115        let ms = self.0.as_millis();
116        if ms > 10 {
117            return write!(f, "{ms}ms");
118        }
119        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
120    }
121}
122
123impl Debug for FormatDuration {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        let s = self.0.as_secs();
126        if s > 100 {
127            return write!(f, "{s}s");
128        }
129        let ms = self.0.as_millis();
130        if ms > 10000 {
131            return write!(f, "{:.2}s", (ms as f32) / 1000.0);
132        }
133        if ms > 100 {
134            return write!(f, "{ms}ms");
135        }
136        write!(f, "{}ms", (self.0.as_micros() as f32) / 1000.0)
137    }
138}
139
140pub struct FormatBytes(pub usize);
141
142impl Display for FormatBytes {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        let b = self.0;
145        const KB: usize = 1_024;
146        const MB: usize = 1_024 * KB;
147        const GB: usize = 1_024 * MB;
148        if b > GB {
149            return write!(f, "{:.2}GiB", ((b / MB) as f32) / 1_024.0);
150        }
151        if b > MB {
152            return write!(f, "{:.2}MiB", ((b / KB) as f32) / 1_024.0);
153        }
154        if b > KB {
155            return write!(f, "{:.2}KiB", (b as f32) / 1_024.0);
156        }
157        write!(f, "{b}B")
158    }
159}
160
161/// Smart pointer that stores data either in an [Arc] or as a static reference.
162pub enum StaticOrArc<T: ?Sized + 'static> {
163    Static(&'static T),
164    Shared(Arc<T>),
165}
166
167impl<T: ?Sized + 'static> AsRef<T> for StaticOrArc<T> {
168    fn as_ref(&self) -> &T {
169        match self {
170            Self::Static(s) => s,
171            Self::Shared(b) => b,
172        }
173    }
174}
175
176impl<T: ?Sized + 'static> From<&'static T> for StaticOrArc<T> {
177    fn from(s: &'static T) -> Self {
178        Self::Static(s)
179    }
180}
181
182impl<T: ?Sized + 'static> From<Arc<T>> for StaticOrArc<T> {
183    fn from(b: Arc<T>) -> Self {
184        Self::Shared(b)
185    }
186}
187
188impl<T: 'static> From<T> for StaticOrArc<T> {
189    fn from(b: T) -> Self {
190        Self::Shared(Arc::new(b))
191    }
192}
193
194impl<T: ?Sized + 'static> Deref for StaticOrArc<T> {
195    type Target = T;
196
197    fn deref(&self) -> &Self::Target {
198        self.as_ref()
199    }
200}
201
202impl<T: ?Sized + 'static> Clone for StaticOrArc<T> {
203    fn clone(&self) -> Self {
204        match self {
205            Self::Static(s) => Self::Static(s),
206            Self::Shared(b) => Self::Shared(b.clone()),
207        }
208    }
209}
210
211impl<T: ?Sized + PartialEq + 'static> PartialEq for StaticOrArc<T> {
212    fn eq(&self, other: &Self) -> bool {
213        **self == **other
214    }
215}
216
217impl<T: ?Sized + PartialEq + Eq + 'static> Eq for StaticOrArc<T> {}
218
219impl<T: ?Sized + Hash + 'static> Hash for StaticOrArc<T> {
220    fn hash<H: Hasher>(&self, state: &mut H) {
221        (**self).hash(state);
222    }
223}
224
225impl<T: ?Sized + Display + 'static> Display for StaticOrArc<T> {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        (**self).fmt(f)
228    }
229}
230
231impl<T: ?Sized + Debug + 'static> Debug for StaticOrArc<T> {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        (**self).fmt(f)
234    }
235}
236
237pin_project! {
238    /// A future that wraps another future and applies a function on every poll call.
239    pub struct WrapFuture<F, W> {
240        wrapper: W,
241        #[pin]
242        future: F,
243    }
244}
245
246impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> WrapFuture<F, W> {
247    pub fn new(future: F, wrapper: W) -> Self {
248        Self { wrapper, future }
249    }
250}
251
252impl<F: Future, W: for<'a> Fn(Pin<&mut F>, &mut Context<'a>) -> Poll<F::Output>> Future
253    for WrapFuture<F, W>
254{
255    type Output = F::Output;
256
257    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
258        let this = self.project();
259        (this.wrapper)(this.future, cx)
260    }
261}