turbo_tasks/task/
shared_reference.rs

1use std::{
2    any::Any,
3    fmt::{Debug, Display},
4    hash::Hash,
5    ops::Deref,
6};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize, ser::SerializeTuple};
10use unsize::CoerceUnsize;
11
12use crate::{
13    ValueTypeId, registry,
14    triomphe_utils::{coerce_to_any_send_sync, downcast_triomphe_arc},
15};
16
17/// A reference to a piece of data
18#[derive(Clone)]
19pub struct SharedReference(pub triomphe::Arc<dyn Any + Send + Sync>);
20
21impl SharedReference {
22    pub fn new(data: triomphe::Arc<impl Any + Send + Sync>) -> Self {
23        Self(data.unsize(coerce_to_any_send_sync()))
24    }
25}
26
27/// A reference to a piece of data with type information
28#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
29pub struct TypedSharedReference {
30    pub type_id: ValueTypeId,
31    pub reference: SharedReference,
32}
33
34impl SharedReference {
35    pub fn downcast<T: Any + Send + Sync>(self) -> Result<triomphe::Arc<T>, Self> {
36        match downcast_triomphe_arc(self.0) {
37            Ok(data) => Ok(data),
38            Err(data) => Err(Self(data)),
39        }
40    }
41
42    pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
43        self.0.downcast_ref()
44    }
45
46    pub fn into_typed(self, type_id: ValueTypeId) -> TypedSharedReference {
47        TypedSharedReference {
48            type_id,
49            reference: self,
50        }
51    }
52}
53
54impl TypedSharedReference {
55    pub fn into_untyped(self) -> SharedReference {
56        self.reference
57    }
58}
59
60impl Deref for TypedSharedReference {
61    type Target = SharedReference;
62
63    fn deref(&self) -> &Self::Target {
64        &self.reference
65    }
66}
67
68impl Hash for SharedReference {
69    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
70        Hash::hash(&(&*self.0 as *const (dyn Any + Send + Sync)), state)
71    }
72}
73impl PartialEq for SharedReference {
74    // Must compare with PartialEq rather than std::ptr::addr_eq since the latter
75    // only compares their addresses.
76    #[allow(ambiguous_wide_pointer_comparisons)]
77    fn eq(&self, other: &Self) -> bool {
78        triomphe::Arc::ptr_eq(&self.0, &other.0)
79    }
80}
81impl Eq for SharedReference {}
82impl PartialOrd for SharedReference {
83    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
84        Some(self.cmp(other))
85    }
86}
87impl Ord for SharedReference {
88    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
89        Ord::cmp(
90            &(&*self.0 as *const (dyn Any + Send + Sync)).cast::<()>(),
91            &(&*other.0 as *const (dyn Any + Send + Sync)).cast::<()>(),
92        )
93    }
94}
95impl Debug for SharedReference {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.debug_tuple("SharedReference").field(&self.0).finish()
98    }
99}
100
101impl Serialize for TypedSharedReference {
102    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103    where
104        S: serde::Serializer,
105    {
106        let TypedSharedReference {
107            type_id: ty,
108            reference: SharedReference(arc),
109        } = self;
110        let value_type = registry::get_value_type(*ty);
111        if let Some(serializable) = value_type.any_as_serializable(arc) {
112            let mut t = serializer.serialize_tuple(2)?;
113            t.serialize_element(registry::get_value_type_global_name(*ty))?;
114            t.serialize_element(serializable)?;
115            t.end()
116        } else {
117            Err(serde::ser::Error::custom(format!(
118                "{:?} is not serializable",
119                registry::get_value_type_global_name(*ty)
120            )))
121        }
122    }
123}
124
125impl Display for SharedReference {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        write!(f, "untyped value")
128    }
129}
130
131impl Display for TypedSharedReference {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(
134            f,
135            "value of type {}",
136            registry::get_value_type(self.type_id).name
137        )
138    }
139}
140
141impl<'de> Deserialize<'de> for TypedSharedReference {
142    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
143    where
144        D: serde::Deserializer<'de>,
145    {
146        struct Visitor;
147
148        impl<'de> serde::de::Visitor<'de> for Visitor {
149            type Value = TypedSharedReference;
150
151            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
152                formatter.write_str("a serializable shared reference")
153            }
154
155            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
156            where
157                A: serde::de::SeqAccess<'de>,
158            {
159                if let Some(global_name) = seq.next_element()? {
160                    if let Some(ty) = registry::get_value_type_id_by_global_name(global_name) {
161                        if let Some(seed) = registry::get_value_type(ty).get_any_deserialize_seed()
162                        {
163                            if let Some(value) = seq.next_element_seed(seed)? {
164                                let arc = triomphe::Arc::<dyn Any + Send + Sync>::from(value);
165                                Ok(TypedSharedReference {
166                                    type_id: ty,
167                                    reference: SharedReference(arc),
168                                })
169                            } else {
170                                Err(serde::de::Error::invalid_length(
171                                    1,
172                                    &"tuple with type and value",
173                                ))
174                            }
175                        } else {
176                            Err(serde::de::Error::custom(format!(
177                                "{ty} is not deserializable"
178                            )))
179                        }
180                    } else {
181                        Err(serde::de::Error::unknown_variant(global_name, &[]))
182                    }
183                } else {
184                    Err(serde::de::Error::invalid_length(
185                        0,
186                        &"tuple with type and value",
187                    ))
188                }
189            }
190        }
191
192        deserializer.deserialize_tuple(2, Visitor)
193    }
194}