turbo_tasks/task/
shared_reference.rs

1use std::{
2    any::Any,
3    fmt::{Debug, Display},
4    hash::Hash,
5    ops::Deref,
6};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize, ser::SerializeTuple};
10use unsize::CoerceUnsize;
11
12use crate::{
13    ValueTypeId, registry,
14    triomphe_utils::{coerce_to_any_send_sync, downcast_triomphe_arc},
15};
16
17/// A reference to a piece of data
18#[derive(Clone)]
19pub struct SharedReference(pub triomphe::Arc<dyn Any + Send + Sync>);
20
21impl SharedReference {
22    pub fn new(data: triomphe::Arc<impl Any + Send + Sync>) -> Self {
23        Self(data.unsize(coerce_to_any_send_sync()))
24    }
25}
26
27/// A reference to a piece of data with type information
28#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
29pub struct TypedSharedReference(pub ValueTypeId, pub SharedReference);
30
31impl SharedReference {
32    pub fn downcast<T: Any + Send + Sync>(self) -> Result<triomphe::Arc<T>, Self> {
33        match downcast_triomphe_arc(self.0) {
34            Ok(data) => Ok(data),
35            Err(data) => Err(Self(data)),
36        }
37    }
38
39    pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
40        self.0.downcast_ref()
41    }
42
43    pub fn into_typed(self, type_id: ValueTypeId) -> TypedSharedReference {
44        TypedSharedReference(type_id, self)
45    }
46}
47
48impl TypedSharedReference {
49    pub fn into_untyped(self) -> SharedReference {
50        self.1
51    }
52}
53
54impl Deref for TypedSharedReference {
55    type Target = SharedReference;
56
57    fn deref(&self) -> &Self::Target {
58        &self.1
59    }
60}
61
62impl Hash for SharedReference {
63    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
64        Hash::hash(&(&*self.0 as *const (dyn Any + Send + Sync)), state)
65    }
66}
67impl PartialEq for SharedReference {
68    // Must compare with PartialEq rather than std::ptr::addr_eq since the latter
69    // only compares their addresses.
70    #[allow(ambiguous_wide_pointer_comparisons)]
71    fn eq(&self, other: &Self) -> bool {
72        triomphe::Arc::ptr_eq(&self.0, &other.0)
73    }
74}
75impl Eq for SharedReference {}
76impl PartialOrd for SharedReference {
77    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
78        Some(self.cmp(other))
79    }
80}
81impl Ord for SharedReference {
82    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
83        Ord::cmp(
84            &(&*self.0 as *const (dyn Any + Send + Sync)).cast::<()>(),
85            &(&*other.0 as *const (dyn Any + Send + Sync)).cast::<()>(),
86        )
87    }
88}
89impl Debug for SharedReference {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_tuple("SharedReference").field(&self.0).finish()
92    }
93}
94
95impl Serialize for TypedSharedReference {
96    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
97    where
98        S: serde::Serializer,
99    {
100        let TypedSharedReference(ty, SharedReference(arc)) = self;
101        let value_type = registry::get_value_type(*ty);
102        if let Some(serializable) = value_type.any_as_serializable(arc) {
103            let mut t = serializer.serialize_tuple(2)?;
104            t.serialize_element(registry::get_value_type_global_name(*ty))?;
105            t.serialize_element(serializable)?;
106            t.end()
107        } else {
108            Err(serde::ser::Error::custom(format!(
109                "{:?} is not serializable",
110                registry::get_value_type_global_name(*ty)
111            )))
112        }
113    }
114}
115
116impl Display for SharedReference {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "untyped value")
119    }
120}
121
122impl Display for TypedSharedReference {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        write!(f, "value of type {}", registry::get_value_type(self.0).name)
125    }
126}
127
128impl<'de> Deserialize<'de> for TypedSharedReference {
129    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
130    where
131        D: serde::Deserializer<'de>,
132    {
133        struct Visitor;
134
135        impl<'de> serde::de::Visitor<'de> for Visitor {
136            type Value = TypedSharedReference;
137
138            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
139                formatter.write_str("a serializable shared reference")
140            }
141
142            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
143            where
144                A: serde::de::SeqAccess<'de>,
145            {
146                if let Some(global_name) = seq.next_element()? {
147                    if let Some(ty) = registry::get_value_type_id_by_global_name(global_name) {
148                        if let Some(seed) = registry::get_value_type(ty).get_any_deserialize_seed()
149                        {
150                            if let Some(value) = seq.next_element_seed(seed)? {
151                                let arc = triomphe::Arc::<dyn Any + Send + Sync>::from(value);
152                                Ok(TypedSharedReference(ty, SharedReference(arc)))
153                            } else {
154                                Err(serde::de::Error::invalid_length(
155                                    1,
156                                    &"tuple with type and value",
157                                ))
158                            }
159                        } else {
160                            Err(serde::de::Error::custom(format!(
161                                "{ty} is not deserializable"
162                            )))
163                        }
164                    } else {
165                        Err(serde::de::Error::unknown_variant(global_name, &[]))
166                    }
167                } else {
168                    Err(serde::de::Error::invalid_length(
169                        0,
170                        &"tuple with type and value",
171                    ))
172                }
173            }
174        }
175
176        deserializer.deserialize_tuple(2, Visitor)
177    }
178}