turbo_tasks/
message_queue.rs

1use std::{any::Any, collections::VecDeque, fmt::Display, sync::Arc, time::Duration};
2
3use dashmap::DashMap;
4use tokio::sync::{Mutex, mpsc};
5
6pub trait CompilationEvent: Sync + Send + Any {
7    fn type_name(&self) -> &'static str;
8    fn severity(&self) -> Severity;
9    fn message(&self) -> String;
10    fn to_json(&self) -> String;
11}
12
13const MAX_QUEUE_SIZE: usize = 256;
14
15type ArcMx<T> = Arc<Mutex<T>>;
16type CompilationEventChannel = mpsc::Sender<Arc<dyn CompilationEvent>>;
17
18#[derive(Debug, Clone, Eq, PartialEq, Hash)]
19enum EventChannelType {
20    Global,
21    Type(String),
22}
23
24pub struct CompilationEventQueue {
25    event_history: ArcMx<VecDeque<Arc<dyn CompilationEvent>>>,
26    subscribers: Arc<DashMap<EventChannelType, Vec<CompilationEventChannel>>>,
27}
28
29impl Default for CompilationEventQueue {
30    fn default() -> Self {
31        let subscribers = DashMap::new();
32        subscribers.insert(
33            EventChannelType::Global,
34            Vec::<CompilationEventChannel>::new(),
35        );
36
37        Self {
38            event_history: Arc::new(Mutex::new(VecDeque::with_capacity(MAX_QUEUE_SIZE))),
39            subscribers: Arc::new(subscribers),
40        }
41    }
42}
43
44impl CompilationEventQueue {
45    pub fn send(
46        &self,
47        message: Arc<dyn CompilationEvent>,
48    ) -> Result<(), mpsc::error::SendError<Arc<dyn CompilationEvent>>> {
49        let event_history = self.event_history.clone();
50        let subscribers = self.subscribers.clone();
51        let message_clone = message.clone();
52
53        // Spawn a task to handle the async operations
54        tokio::spawn(async move {
55            // Store the message in history
56            let mut history = event_history.lock().await;
57            if history.len() >= MAX_QUEUE_SIZE {
58                history.pop_front();
59            }
60            history.push_back(message_clone.clone());
61
62            // Send to all active receivers of the same message type
63            if let Some(mut type_subscribers) = subscribers.get_mut(&EventChannelType::Type(
64                message_clone.type_name().to_owned(),
65            )) {
66                let mut removal_indices = Vec::new();
67                for (ix, sender) in type_subscribers.iter().enumerate() {
68                    if sender.send(message_clone.clone()).await.is_err() {
69                        removal_indices.push(ix);
70                    }
71                }
72
73                for ix in removal_indices.iter().rev() {
74                    type_subscribers.remove(*ix);
75                }
76            }
77
78            // Send to all global message subscribers
79            let mut all_channel = subscribers.get_mut(&EventChannelType::Global).unwrap();
80            let mut removal_indices = Vec::new();
81            for (ix, sender) in all_channel.iter_mut().enumerate() {
82                if sender.send(message_clone.clone()).await.is_err() {
83                    removal_indices.push(ix);
84                }
85            }
86
87            for ix in removal_indices.iter().rev() {
88                all_channel.remove(*ix);
89            }
90        });
91
92        Ok(())
93    }
94
95    pub fn subscribe(
96        &self,
97        event_types: Option<Vec<String>>,
98    ) -> mpsc::Receiver<Arc<dyn CompilationEvent>> {
99        let (tx, rx) = mpsc::channel(MAX_QUEUE_SIZE);
100        let subscribers = self.subscribers.clone();
101        let event_history = self.event_history.clone();
102        let tx_clone = tx.clone();
103
104        // Spawn a task to handle the async operations
105        tokio::spawn(async move {
106            // Store the sender
107            if let Some(event_types) = event_types {
108                for event_type in event_types.iter() {
109                    let mut type_subscribers = subscribers
110                        .entry(EventChannelType::Type(event_type.clone()))
111                        .or_default();
112                    type_subscribers.push(tx_clone.clone());
113                }
114
115                for event in event_history.lock().await.iter() {
116                    if event_types.contains(&event.type_name().to_string()) {
117                        let _ = tx_clone.send(event.clone()).await;
118                    }
119                }
120            } else {
121                let mut global_subscribers =
122                    subscribers.entry(EventChannelType::Global).or_default();
123                global_subscribers.push(tx_clone.clone());
124
125                for event in event_history.lock().await.iter() {
126                    let _ = tx_clone.send(event.clone()).await;
127                }
128            }
129        });
130
131        rx
132    }
133}
134
135#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
136pub enum Severity {
137    Info,
138    Trace,
139    Warning,
140    Error,
141    Fatal,
142    Event,
143}
144
145impl Display for Severity {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        match self {
148            Severity::Info => write!(f, "INFO"),
149            Severity::Trace => write!(f, "TRACE"),
150            Severity::Warning => write!(f, "WARNING"),
151            Severity::Error => write!(f, "ERROR"),
152            Severity::Fatal => write!(f, "FATAL"),
153            Severity::Event => write!(f, "EVENT"),
154        }
155    }
156}
157
158#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
159/// Compilation event that is used to log the duration of a task
160pub struct TimingEvent {
161    /// Message of the event without the timing information
162    ///
163    /// Example:
164    /// ```rust
165    /// let event = TimingEvent::new("Compiled successfully".to_string(), Duration::from_millis(100));
166    /// let message = event.message();
167    /// assert_eq!(message, "Compiled successfully in 100ms");
168    /// ```
169    pub message: String,
170    /// Duration in milliseconds
171    pub duration: Duration,
172}
173
174impl TimingEvent {
175    pub fn new(message: String, duration: Duration) -> Self {
176        Self { message, duration }
177    }
178}
179
180impl CompilationEvent for TimingEvent {
181    fn type_name(&self) -> &'static str {
182        "TimingEvent"
183    }
184
185    fn severity(&self) -> Severity {
186        Severity::Event
187    }
188
189    fn message(&self) -> String {
190        let duration_secs = self.duration.as_secs_f64();
191        let duration_string = if duration_secs > 120.0 {
192            format!("{:.1}min", duration_secs / 60.0)
193        } else if duration_secs > 40.0 {
194            format!("{duration_secs:.0}s")
195        } else if duration_secs > 2.0 {
196            format!("{duration_secs:.1}s")
197        } else {
198            format!("{:.0}ms", duration_secs * 1000.0)
199        };
200        format!("{} in {}", self.message, duration_string)
201    }
202
203    fn to_json(&self) -> String {
204        serde_json::to_string(self).unwrap()
205    }
206}
207
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct DiagnosticEvent {
210    pub message: String,
211    pub severity: Severity,
212}
213
214impl DiagnosticEvent {
215    pub fn new(severity: Severity, message: String) -> Self {
216        Self { message, severity }
217    }
218}
219
220impl CompilationEvent for DiagnosticEvent {
221    fn type_name(&self) -> &'static str {
222        "DiagnosticEvent"
223    }
224
225    fn severity(&self) -> Severity {
226        self.severity
227    }
228
229    fn message(&self) -> String {
230        self.message.clone()
231    }
232
233    fn to_json(&self) -> String {
234        serde_json::to_string(self).unwrap()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_timing_event_string_formatting() {
244        let tests = vec![
245            (Duration::from_nanos(1588), "0ms"),
246            (Duration::from_nanos(1022616), "1ms"),
247            (Duration::from_millis(100), "100ms"),
248            (Duration::from_millis(1000), "1000ms"),
249            (Duration::from_millis(10000), "10.0s"),
250            (Duration::from_millis(20381), "20.4s"),
251            (Duration::from_secs(60), "60s"),
252            (Duration::from_secs(100), "100s"),
253            (Duration::from_secs(125), "2.1min"),
254        ];
255
256        for (duration, expected) in tests {
257            let event = TimingEvent::new("Compiled successfully".to_string(), duration);
258            assert_eq!(
259                event.message(),
260                format!("Compiled successfully in {expected}")
261            );
262        }
263    }
264}