turbopack_dev_server/update/
server.rs

1use std::{
2    ops::ControlFlow,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use anyhow::{Context as _, Error, Result};
8use futures::{SinkExt, prelude::*, ready, stream::FusedStream};
9use hyper::{HeaderMap, Uri, upgrade::Upgraded};
10use hyper_tungstenite::{HyperWebsocket, WebSocketStream, tungstenite::Message};
11use pin_project_lite::pin_project;
12use tokio::select;
13use tokio_stream::StreamMap;
14use tracing::{Level, instrument};
15use turbo_tasks::{
16    NonLocalValue, OperationVc, ReadRef, TransientInstance, TurboTasksApi, Vc, trace::TraceRawVcs,
17};
18use turbo_tasks_fs::json::parse_json_with_source_context;
19use turbopack_core::{error::PrettyPrintError, issue::IssueReporter, version::Update};
20use turbopack_ecmascript_hmr_protocol::{
21    ClientMessage, ClientUpdateInstruction, Issue, ResourceIdentifier,
22};
23
24use crate::{
25    SourceProvider,
26    source::{
27        Body,
28        request::SourceRequest,
29        resolve::{ResolveSourceRequestResult, resolve_source_request},
30    },
31    update::stream::{GetContentFn, UpdateStream, UpdateStreamItem},
32};
33
34/// A server that listens for updates and sends them to connected clients.
35pub(crate) struct UpdateServer<P: SourceProvider> {
36    source_provider: P,
37    #[allow(dead_code)]
38    issue_reporter: Vc<Box<dyn IssueReporter>>,
39}
40
41impl<P> UpdateServer<P>
42where
43    P: SourceProvider + NonLocalValue + TraceRawVcs + Clone + Send + Sync,
44{
45    /// Create a new update server with the given websocket and content source.
46    pub fn new(source_provider: P, issue_reporter: Vc<Box<dyn IssueReporter>>) -> Self {
47        Self {
48            source_provider,
49            issue_reporter,
50        }
51    }
52
53    /// Run the update server loop.
54    pub fn run(self, tt: &dyn TurboTasksApi, ws: HyperWebsocket) {
55        tt.run_once_process(Box::pin(async move {
56            if let Err(err) = self.run_internal(ws).await {
57                println!("[UpdateServer]: error {err:#}");
58            }
59            Ok(())
60        }));
61    }
62
63    #[instrument(level = Level::TRACE, skip_all, name = "UpdateServer::run_internal")]
64    async fn run_internal(self, ws: HyperWebsocket) -> Result<()> {
65        let mut client: UpdateClient = ws.await?.into();
66
67        let mut streams = StreamMap::new();
68
69        loop {
70            // most logic is in helper functions as rustfmt cannot format code inside the macro
71            select! {
72                message = client.try_next() => {
73                    if Self::on_message(
74                        &mut client,
75                        &mut streams,
76                        &self.source_provider,
77                        message?,
78                    ).await?.is_break() {
79                        break;
80                    }
81                }
82                Some((resource, update_result)) = streams.next() => {
83                    Self::on_stream(
84                        &mut client,
85                        &mut streams,
86                        resource,
87                        update_result,
88                    ).await?
89                }
90                else => break
91            }
92        }
93
94        Ok(())
95    }
96
97    /// Helper for `on_message` used to construct a `GetContentFn`. Argument must match
98    /// `get_content_capture`.
99    fn get_content(
100        (source_provider, request): &(P, SourceRequest),
101    ) -> OperationVc<ResolveSourceRequestResult> {
102        let request = request.clone();
103        let source = source_provider.get_source();
104        resolve_source_request(source, TransientInstance::new(request))
105    }
106
107    /// receives ClientMessages and passes subscriptions to `on_stream` via the `streams` map.
108    async fn on_message(
109        client: &mut UpdateClient,
110        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
111        source_provider: &P,
112        message: Option<ClientMessage>,
113    ) -> Result<ControlFlow<()>> {
114        match message {
115            Some(ClientMessage::Subscribe { resource }) => {
116                let get_content_capture =
117                    (source_provider.clone(), resource_to_request(&resource)?);
118                match UpdateStream::new(
119                    resource.to_string().into(),
120                    TransientInstance::new(GetContentFn::new(
121                        get_content_capture,
122                        Self::get_content,
123                    )),
124                )
125                .await
126                {
127                    Ok(stream) => {
128                        streams.insert(resource, stream);
129                    }
130                    Err(err) => {
131                        eprintln!(
132                            "Failed to create update stream for {resource}: {}",
133                            PrettyPrintError(&err),
134                        );
135                        client
136                            .send(ClientUpdateInstruction::not_found(&resource))
137                            .await?;
138                    }
139                }
140            }
141            Some(ClientMessage::Unsubscribe { resource }) => {
142                streams.remove(&resource);
143            }
144            None => {
145                // WebSocket was closed, stop sending updates
146                return Ok(ControlFlow::Break(()));
147            }
148        }
149        Ok(ControlFlow::Continue(()))
150    }
151
152    async fn on_stream(
153        client: &mut UpdateClient,
154        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
155        resource: ResourceIdentifier,
156        update_result: Result<ReadRef<UpdateStreamItem>>,
157    ) -> Result<()> {
158        match update_result {
159            Ok(update_item) => Self::send_update(client, streams, resource, &update_item).await,
160            Err(err) => {
161                eprintln!(
162                    "Failed to get update for {resource}: {}",
163                    PrettyPrintError(&err)
164                );
165                Ok(())
166            }
167        }
168    }
169
170    async fn send_update(
171        client: &mut UpdateClient,
172        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
173        resource: ResourceIdentifier,
174        update_item: &UpdateStreamItem,
175    ) -> Result<()> {
176        match update_item {
177            UpdateStreamItem::NotFound => {
178                // If the resource was not found, we remove the stream and indicate that to the
179                // client.
180                streams.remove(&resource);
181                client
182                    .send(ClientUpdateInstruction::not_found(&resource))
183                    .await?;
184            }
185            UpdateStreamItem::Found { update, issues } => {
186                let issues = issues
187                    .iter()
188                    .map(|p| Issue::from(&**p))
189                    .collect::<Vec<Issue<'_>>>();
190                match &**update {
191                    Update::Partial(partial) => {
192                        let partial_instruction = &partial.instruction;
193                        client
194                            .send(ClientUpdateInstruction::partial(
195                                &resource,
196                                partial_instruction,
197                                &issues,
198                            ))
199                            .await?;
200                    }
201                    Update::Missing | Update::Total(_) => {
202                        client
203                            .send(ClientUpdateInstruction::restart(&resource, &issues))
204                            .await?;
205                    }
206                    Update::None => {
207                        client
208                            .send(ClientUpdateInstruction::issues(&resource, &issues))
209                            .await?;
210                    }
211                }
212            }
213        }
214
215        Ok(())
216    }
217}
218
219fn resource_to_request(resource: &ResourceIdentifier) -> Result<SourceRequest> {
220    let mut headers = HeaderMap::new();
221
222    if let Some(res_headers) = &resource.headers {
223        for (name, value) in res_headers {
224            headers.append(
225                hyper::header::HeaderName::from_bytes(name.as_bytes()).unwrap(),
226                hyper::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
227            );
228        }
229    }
230
231    Ok(SourceRequest {
232        uri: Uri::try_from(format!("/{}", resource.path))?,
233        headers,
234        method: "GET".to_string(),
235        body: Body::new(vec![]),
236    })
237}
238
239pin_project! {
240    struct UpdateClient {
241        #[pin]
242        ws: WebSocketStream<Upgraded>,
243        ended: bool,
244    }
245}
246
247impl Stream for UpdateClient {
248    type Item = Result<ClientMessage>;
249
250    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
251        if self.ended {
252            return Poll::Ready(None);
253        }
254
255        let this = self.project();
256        let item = ready!(this.ws.poll_next(cx));
257
258        let msg = match item {
259            Some(Ok(Message::Text(msg))) => msg,
260            Some(Err(err)) => {
261                *this.ended = true;
262
263                let err = Error::new(err).context("reading from websocket");
264                return Poll::Ready(Some(Err(err)));
265            }
266            _ => {
267                *this.ended = true;
268                return Poll::Ready(None);
269            }
270        };
271
272        match parse_json_with_source_context(&msg).context("deserializing websocket message") {
273            Ok(msg) => Poll::Ready(Some(Ok(msg))),
274            Err(err) => {
275                *this.ended = true;
276
277                Poll::Ready(Some(Err(err)))
278            }
279        }
280    }
281}
282
283impl FusedStream for UpdateClient {
284    fn is_terminated(&self) -> bool {
285        self.ended || self.ws.is_terminated()
286    }
287}
288
289impl<'a> Sink<ClientUpdateInstruction<'a>> for UpdateClient {
290    type Error = Error;
291
292    fn poll_ready(
293        self: Pin<&mut Self>,
294        cx: &mut Context<'_>,
295    ) -> Poll<std::result::Result<(), Self::Error>> {
296        self.project()
297            .ws
298            .poll_ready(cx)
299            .map(|res| res.context("polling WebSocket ready"))
300    }
301
302    fn start_send(
303        self: Pin<&mut Self>,
304        item: ClientUpdateInstruction<'a>,
305    ) -> std::result::Result<(), Self::Error> {
306        let msg = Message::text(serde_json::to_string(&item)?);
307
308        self.project()
309            .ws
310            .start_send(msg)
311            .context("sending to WebSocket")
312    }
313
314    fn poll_flush(
315        self: Pin<&mut Self>,
316        cx: &mut Context<'_>,
317    ) -> Poll<std::result::Result<(), Self::Error>> {
318        self.project()
319            .ws
320            .poll_flush(cx)
321            .map(|res| res.context("flushing WebSocket"))
322    }
323
324    fn poll_close(
325        self: Pin<&mut Self>,
326        cx: &mut Context<'_>,
327    ) -> Poll<std::result::Result<(), Self::Error>> {
328        self.project()
329            .ws
330            .poll_close(cx)
331            .map(|res| res.context("closing WebSocket"))
332    }
333}
334
335impl From<WebSocketStream<Upgraded>> for UpdateClient {
336    fn from(ws: WebSocketStream<Upgraded>) -> Self {
337        Self { ws, ended: false }
338    }
339}