turbopack_dev_server/update/
server.rs

1use std::{
2    ops::ControlFlow,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use anyhow::{Context as _, Error, Result};
8use futures::{SinkExt, prelude::*, ready, stream::FusedStream};
9use hyper::{HeaderMap, Uri, upgrade::Upgraded};
10use hyper_tungstenite::{HyperWebsocket, WebSocketStream, tungstenite::Message};
11use pin_project_lite::pin_project;
12use tokio::select;
13use tokio_stream::StreamMap;
14use tracing::{Level, instrument};
15use turbo_tasks::{
16    NonLocalValue, OperationVc, ReadRef, TransientInstance, TurboTasksApi, Vc, trace::TraceRawVcs,
17};
18use turbo_tasks_fs::json::parse_json_with_source_context;
19use turbopack_core::{error::PrettyPrintError, issue::IssueReporter, version::Update};
20use turbopack_ecmascript_hmr_protocol::{
21    ClientMessage, ClientUpdateInstruction, Issue, ResourceIdentifier,
22};
23
24use crate::{
25    SourceProvider,
26    source::{
27        Body,
28        request::SourceRequest,
29        resolve::{ResolveSourceRequestResult, resolve_source_request},
30    },
31    update::stream::{GetContentFn, UpdateStream, UpdateStreamItem},
32};
33
34/// A server that listens for updates and sends them to connected clients.
35pub(crate) struct UpdateServer<P: SourceProvider> {
36    source_provider: P,
37    #[allow(dead_code)]
38    issue_reporter: Vc<Box<dyn IssueReporter>>,
39}
40
41impl<P> UpdateServer<P>
42where
43    P: SourceProvider + NonLocalValue + TraceRawVcs + Clone + Send + Sync,
44{
45    /// Create a new update server with the given websocket and content source.
46    pub fn new(source_provider: P, issue_reporter: Vc<Box<dyn IssueReporter>>) -> Self {
47        Self {
48            source_provider,
49            issue_reporter,
50        }
51    }
52
53    /// Run the update server loop.
54    pub fn run(self, tt: &dyn TurboTasksApi, ws: HyperWebsocket) {
55        tt.start_once_process(Box::pin(async move {
56            if let Err(err) = self.run_internal(ws).await {
57                println!("[UpdateServer]: error {err:#}");
58            }
59        }));
60    }
61
62    #[instrument(level = Level::TRACE, skip_all, name = "UpdateServer::run_internal")]
63    async fn run_internal(self, ws: HyperWebsocket) -> Result<()> {
64        let mut client: UpdateClient = ws.await?.into();
65
66        let mut streams = StreamMap::new();
67
68        loop {
69            // most logic is in helper functions as rustfmt cannot format code inside the macro
70            select! {
71                message = client.try_next() => {
72                    if Self::on_message(
73                        &mut client,
74                        &mut streams,
75                        &self.source_provider,
76                        message?,
77                    ).await?.is_break() {
78                        break;
79                    }
80                }
81                Some((resource, update_result)) = streams.next() => {
82                    Self::on_stream(
83                        &mut client,
84                        &mut streams,
85                        resource,
86                        update_result,
87                    ).await?
88                }
89                else => break
90            }
91        }
92
93        Ok(())
94    }
95
96    /// Helper for `on_message` used to construct a `GetContentFn`. Argument must match
97    /// `get_content_capture`.
98    fn get_content(
99        (source_provider, request): &(P, SourceRequest),
100    ) -> OperationVc<ResolveSourceRequestResult> {
101        let request = request.clone();
102        let source = source_provider.get_source();
103        resolve_source_request(source, TransientInstance::new(request))
104    }
105
106    /// receives ClientMessages and passes subscriptions to `on_stream` via the `streams` map.
107    async fn on_message(
108        client: &mut UpdateClient,
109        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
110        source_provider: &P,
111        message: Option<ClientMessage>,
112    ) -> Result<ControlFlow<()>> {
113        match message {
114            Some(ClientMessage::Subscribe { resource }) => {
115                let get_content_capture =
116                    (source_provider.clone(), resource_to_request(&resource)?);
117                match UpdateStream::new(
118                    resource.to_string().into(),
119                    TransientInstance::new(GetContentFn::new(
120                        get_content_capture,
121                        Self::get_content,
122                    )),
123                )
124                .await
125                {
126                    Ok(stream) => {
127                        streams.insert(resource, stream);
128                    }
129                    Err(err) => {
130                        eprintln!(
131                            "Failed to create update stream for {resource}: {}",
132                            PrettyPrintError(&err),
133                        );
134                        client
135                            .send(ClientUpdateInstruction::not_found(&resource))
136                            .await?;
137                    }
138                }
139            }
140            Some(ClientMessage::Unsubscribe { resource }) => {
141                streams.remove(&resource);
142            }
143            None => {
144                // WebSocket was closed, stop sending updates
145                return Ok(ControlFlow::Break(()));
146            }
147        }
148        Ok(ControlFlow::Continue(()))
149    }
150
151    async fn on_stream(
152        client: &mut UpdateClient,
153        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
154        resource: ResourceIdentifier,
155        update_result: Result<ReadRef<UpdateStreamItem>>,
156    ) -> Result<()> {
157        match update_result {
158            Ok(update_item) => Self::send_update(client, streams, resource, &update_item).await,
159            Err(err) => {
160                eprintln!(
161                    "Failed to get update for {resource}: {}",
162                    PrettyPrintError(&err)
163                );
164                Ok(())
165            }
166        }
167    }
168
169    async fn send_update(
170        client: &mut UpdateClient,
171        streams: &mut StreamMap<ResourceIdentifier, UpdateStream>,
172        resource: ResourceIdentifier,
173        update_item: &UpdateStreamItem,
174    ) -> Result<()> {
175        match update_item {
176            UpdateStreamItem::NotFound => {
177                // If the resource was not found, we remove the stream and indicate that to the
178                // client.
179                streams.remove(&resource);
180                client
181                    .send(ClientUpdateInstruction::not_found(&resource))
182                    .await?;
183            }
184            UpdateStreamItem::Found { update, issues } => {
185                let issues = issues
186                    .iter()
187                    .map(|p| Issue::from(&**p))
188                    .collect::<Vec<Issue<'_>>>();
189                match &**update {
190                    Update::Partial(partial) => {
191                        let partial_instruction = &partial.instruction;
192                        client
193                            .send(ClientUpdateInstruction::partial(
194                                &resource,
195                                partial_instruction,
196                                &issues,
197                            ))
198                            .await?;
199                    }
200                    Update::Missing | Update::Total(_) => {
201                        client
202                            .send(ClientUpdateInstruction::restart(&resource, &issues))
203                            .await?;
204                    }
205                    Update::None => {
206                        client
207                            .send(ClientUpdateInstruction::issues(&resource, &issues))
208                            .await?;
209                    }
210                }
211            }
212        }
213
214        Ok(())
215    }
216}
217
218fn resource_to_request(resource: &ResourceIdentifier) -> Result<SourceRequest> {
219    let mut headers = HeaderMap::new();
220
221    if let Some(res_headers) = &resource.headers {
222        for (name, value) in res_headers {
223            headers.append(
224                hyper::header::HeaderName::from_bytes(name.as_bytes()).unwrap(),
225                hyper::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
226            );
227        }
228    }
229
230    Ok(SourceRequest {
231        uri: Uri::try_from(format!("/{}", resource.path))?,
232        headers,
233        method: "GET".to_string(),
234        body: Body::new(vec![]),
235    })
236}
237
238pin_project! {
239    struct UpdateClient {
240        #[pin]
241        ws: WebSocketStream<Upgraded>,
242        ended: bool,
243    }
244}
245
246impl Stream for UpdateClient {
247    type Item = Result<ClientMessage>;
248
249    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
250        if self.ended {
251            return Poll::Ready(None);
252        }
253
254        let this = self.project();
255        let item = ready!(this.ws.poll_next(cx));
256
257        let msg = match item {
258            Some(Ok(Message::Text(msg))) => msg,
259            Some(Err(err)) => {
260                *this.ended = true;
261
262                let err = Error::new(err).context("reading from websocket");
263                return Poll::Ready(Some(Err(err)));
264            }
265            _ => {
266                *this.ended = true;
267                return Poll::Ready(None);
268            }
269        };
270
271        match parse_json_with_source_context(&msg).context("deserializing websocket message") {
272            Ok(msg) => Poll::Ready(Some(Ok(msg))),
273            Err(err) => {
274                *this.ended = true;
275
276                Poll::Ready(Some(Err(err)))
277            }
278        }
279    }
280}
281
282impl FusedStream for UpdateClient {
283    fn is_terminated(&self) -> bool {
284        self.ended || self.ws.is_terminated()
285    }
286}
287
288impl<'a> Sink<ClientUpdateInstruction<'a>> for UpdateClient {
289    type Error = Error;
290
291    fn poll_ready(
292        self: Pin<&mut Self>,
293        cx: &mut Context<'_>,
294    ) -> Poll<std::result::Result<(), Self::Error>> {
295        self.project()
296            .ws
297            .poll_ready(cx)
298            .map(|res| res.context("polling WebSocket ready"))
299    }
300
301    fn start_send(
302        self: Pin<&mut Self>,
303        item: ClientUpdateInstruction<'a>,
304    ) -> std::result::Result<(), Self::Error> {
305        let msg = Message::text(serde_json::to_string(&item)?);
306
307        self.project()
308            .ws
309            .start_send(msg)
310            .context("sending to WebSocket")
311    }
312
313    fn poll_flush(
314        self: Pin<&mut Self>,
315        cx: &mut Context<'_>,
316    ) -> Poll<std::result::Result<(), Self::Error>> {
317        self.project()
318            .ws
319            .poll_flush(cx)
320            .map(|res| res.context("flushing WebSocket"))
321    }
322
323    fn poll_close(
324        self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326    ) -> Poll<std::result::Result<(), Self::Error>> {
327        self.project()
328            .ws
329            .poll_close(cx)
330            .map(|res| res.context("closing WebSocket"))
331    }
332}
333
334impl From<WebSocketStream<Upgraded>> for UpdateClient {
335    fn from(ws: WebSocketStream<Upgraded>) -> Self {
336        Self { ws, ended: false }
337    }
338}