turbo_tasks/
serialization_invalidation.rs

1use std::{
2    hash::Hash,
3    sync::{Arc, Weak},
4};
5
6use bincode::{
7    Decode, Encode,
8    de::Decoder,
9    enc::Encoder,
10    error::{DecodeError, EncodeError},
11    impl_borrow_decode,
12};
13use serde::{Deserialize, Serialize, de::Visitor};
14use tokio::runtime::Handle;
15
16use crate::{TaskId, TurboTasksApi, manager::with_turbo_tasks, trace::TraceRawVcs};
17
18#[derive(Clone)]
19pub struct SerializationInvalidator {
20    task: TaskId,
21    turbo_tasks: Weak<dyn TurboTasksApi>,
22    handle: Handle,
23}
24
25impl Hash for SerializationInvalidator {
26    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
27        self.task.hash(state);
28    }
29}
30
31impl PartialEq for SerializationInvalidator {
32    fn eq(&self, other: &Self) -> bool {
33        self.task == other.task
34    }
35}
36
37impl Eq for SerializationInvalidator {}
38
39impl SerializationInvalidator {
40    pub fn invalidate(&self) {
41        let SerializationInvalidator {
42            task,
43            turbo_tasks,
44            handle,
45        } = self;
46        let _guard = handle.enter();
47        if let Some(turbo_tasks) = turbo_tasks.upgrade() {
48            turbo_tasks.invalidate_serialization(*task);
49        }
50    }
51
52    pub(crate) fn new(task_id: TaskId) -> Self {
53        Self {
54            task: task_id,
55            turbo_tasks: with_turbo_tasks(Arc::downgrade),
56            handle: Handle::current(),
57        }
58    }
59}
60
61impl TraceRawVcs for SerializationInvalidator {
62    fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) {
63        // nothing here
64    }
65}
66
67impl Serialize for SerializationInvalidator {
68    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
69    where
70        S: serde::Serializer,
71    {
72        serializer.serialize_newtype_struct("SerializationInvalidator", &self.task)
73    }
74}
75
76impl<'de> Deserialize<'de> for SerializationInvalidator {
77    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78    where
79        D: serde::Deserializer<'de>,
80    {
81        struct V;
82
83        impl<'de> Visitor<'de> for V {
84            type Value = SerializationInvalidator;
85
86            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
87                write!(f, "an SerializationInvalidator")
88            }
89
90            fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
91            where
92                D: serde::Deserializer<'de>,
93            {
94                Ok(SerializationInvalidator {
95                    task: TaskId::deserialize(deserializer)?,
96                    turbo_tasks: with_turbo_tasks(Arc::downgrade),
97                    handle: tokio::runtime::Handle::current(),
98                })
99            }
100        }
101        deserializer.deserialize_newtype_struct("SerializationInvalidator", V)
102    }
103}
104
105impl Encode for SerializationInvalidator {
106    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
107        Encode::encode(&self.task, encoder)
108    }
109}
110
111impl<Context> Decode<Context> for SerializationInvalidator {
112    fn decode<D: Decoder<Context = Context>>(decoder: &mut D) -> Result<Self, DecodeError> {
113        Ok(SerializationInvalidator {
114            task: Decode::decode(decoder)?,
115            turbo_tasks: with_turbo_tasks(Arc::downgrade),
116            handle: tokio::runtime::Handle::current(),
117        })
118    }
119}
120
121impl_borrow_decode!(SerializationInvalidator);