1use std::{
2 cmp::max,
3 fmt::{Display, Formatter},
4 hash::Hash,
5 sync::LazyLock,
6 time::{Duration, SystemTime},
7};
8
9use anyhow::Result;
10use quick_cache::sync::Cache;
11use turbo_rcstr::RcStr;
12use turbo_tasks::{
13 Completion, FxIndexSet, InvalidationReason, InvalidationReasonKind, Invalidator, ReadRef,
14 ResolvedVc, Vc, duration_span, util::StaticOrArc,
15};
16
17use crate::{FetchError, FetchResult, HttpResponse, HttpResponseBody};
18
19const MAX_CLIENTS: usize = 16;
20static CLIENT_CACHE: LazyLock<Cache<ReadRef<FetchClientConfig>, reqwest::Client>> =
21 LazyLock::new(|| Cache::new(MAX_CLIENTS));
22
23#[turbo_tasks::value(shared)]
31#[derive(Hash)]
32pub struct FetchClientConfig {
33 pub min_cache_control: Duration,
37}
38
39impl Default for FetchClientConfig {
40 fn default() -> Self {
41 Self {
42 min_cache_control: Duration::from_secs(60 * 60),
43 }
44 }
45}
46
47impl FetchClientConfig {
48 pub fn try_get_cached_reqwest_client(
61 self: ReadRef<FetchClientConfig>,
62 ) -> reqwest::Result<reqwest::Client> {
63 CLIENT_CACHE.get_or_insert_with(&self, {
64 let this = ReadRef::clone(&self);
65 move || this.try_build_uncached_reqwest_client()
66 })
67 }
68
69 fn try_build_uncached_reqwest_client(&self) -> reqwest::Result<reqwest::Client> {
70 #[allow(unused_mut)]
71 let mut builder = reqwest::Client::builder();
72 #[cfg(any(target_os = "linux", all(windows, not(target_arch = "aarch64"))))]
73 {
74 use std::sync::Once;
75 static ONCE: Once = Once::new();
76 ONCE.call_once(|| {
77 rustls::crypto::ring::default_provider()
78 .install_default()
79 .unwrap()
80 });
81 builder = builder.tls_backend_rustls();
82 }
83 #[cfg(all(windows, target_arch = "aarch64"))]
84 {
85 builder = builder.tls_backend_native();
86 }
87 #[cfg(target_os = "linux")]
88 {
89 builder = builder.tls_certs_merge(webpki_root_certs::TLS_SERVER_ROOT_CERTS.iter().map(
93 |der| {
94 reqwest::Certificate::from_der(der)
95 .expect("webpki_root_certs should parse correctly")
96 },
97 ))
98 }
99 builder.build()
100 }
101}
102
103#[derive(PartialEq, Eq, Hash)]
105pub(crate) struct HttpTimeout;
106
107impl InvalidationReason for HttpTimeout {
108 fn kind(&self) -> Option<StaticOrArc<dyn InvalidationReasonKind>> {
109 Some(StaticOrArc::Static(&HTTP_TIMEOUT_KIND))
110 }
111}
112
113impl Display for HttpTimeout {
114 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
115 write!(f, "http max-age timeout")
116 }
117}
118
119#[derive(PartialEq, Eq, Hash)]
121struct HttpTimeoutKind;
122
123static HTTP_TIMEOUT_KIND: HttpTimeoutKind = HttpTimeoutKind;
124
125impl InvalidationReasonKind for HttpTimeoutKind {
126 fn fmt(
127 &self,
128 reasons: &FxIndexSet<StaticOrArc<dyn InvalidationReason>>,
129 f: &mut Formatter<'_>,
130 ) -> std::fmt::Result {
131 write!(f, "{} fetches timed out", reasons.len())
132 }
133}
134
135#[turbo_tasks::value(shared)]
137struct FetchInnerResult {
138 result: ResolvedVc<FetchResult>,
139 invalidator: Option<Invalidator>,
142 deadline_secs: Option<u64>,
146}
147
148#[turbo_tasks::value_impl]
149impl FetchClientConfig {
150 #[turbo_tasks::function(network)]
154 async fn fetch_inner(
155 self: Vc<FetchClientConfig>,
156 url: RcStr,
157 user_agent: Option<RcStr>,
158 ) -> Result<Vc<FetchInnerResult>> {
159 let url_ref = &*url;
160 let this = self.await?;
161 let min_cache_control_secs = this.min_cache_control;
162 let response_result: reqwest::Result<(HttpResponse, Option<u64>)> = async move {
163 let reqwest_client = this.try_get_cached_reqwest_client()?;
164
165 let mut builder = reqwest_client.get(url_ref);
166 if let Some(user_agent) = user_agent {
167 builder = builder.header("User-Agent", user_agent.as_str());
168 }
169
170 let response = {
171 let _span = duration_span!("fetch request", url = url_ref);
172 builder.send().await
173 }
174 .and_then(|r| r.error_for_status())?;
175
176 let status = response.status().as_u16();
177 let max_age = parse_cache_control(response.headers());
178
179 let body = {
180 let _span = duration_span!("fetch response", url = url_ref);
181 response.bytes().await?
182 }
183 .to_vec();
184
185 Ok((
186 HttpResponse {
187 status,
188 body: HttpResponseBody(body).resolved_cell(),
189 },
190 max_age,
191 ))
192 }
193 .await;
194
195 match response_result {
196 Ok((resp, max_age_secs)) => {
197 if let Some(max_age_secs) = max_age_secs {
198 let max_age_secs = max(max_age_secs, min_cache_control_secs.as_secs());
199 let deadline_secs = SystemTime::now()
202 .duration_since(SystemTime::UNIX_EPOCH)
203 .ok()
205 .map(|d| d.as_secs() + max_age_secs);
206 let invalidator = turbo_tasks::get_invalidator();
207 Ok(FetchInnerResult {
208 result: ResolvedVc::cell(Ok(resp.resolved_cell())),
209 invalidator,
210 deadline_secs,
211 }
212 .cell())
213 } else {
214 Completion::session_dependent().await?;
215 Ok(FetchInnerResult {
216 result: ResolvedVc::cell(Ok(resp.resolved_cell())),
217 invalidator: None,
218 deadline_secs: None,
219 }
220 .cell())
221 }
222 }
223 Err(err) => {
224 Completion::session_dependent().await?;
228 Ok(FetchInnerResult {
229 result: ResolvedVc::cell(Err(
230 FetchError::from_reqwest_error(&err, &url).resolved_cell()
231 )),
232 invalidator: None,
233 deadline_secs: None,
234 }
235 .cell())
236 }
237 }
238 }
239
240 #[turbo_tasks::function(network, session_dependent)]
250 pub async fn fetch(
251 self: Vc<FetchClientConfig>,
252 url: RcStr,
253 user_agent: Option<RcStr>,
254 ) -> Result<Vc<FetchResult>> {
255 let FetchInnerResult {
256 result,
257 deadline_secs,
258 invalidator,
259 } = *self.fetch_inner(url, user_agent).await?;
260
261 if turbo_tasks::turbo_tasks().is_tracking_dependencies()
268 && let (Some(deadline_secs), Some(invalidator)) = (deadline_secs, invalidator)
269 {
270 if let Ok(now) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
273 let remaining = Duration::from_secs(deadline_secs.saturating_sub(now.as_secs()));
274 turbo_tasks::spawn(async move {
287 tokio::time::sleep(remaining).await;
288 invalidator
289 .invalidate_with_reason(&*turbo_tasks::turbo_tasks(), HttpTimeout {});
290 });
291 }
292 }
293
294 Ok(*result)
295 }
296}
297
298fn parse_cache_control(headers: &reqwest::header::HeaderMap) -> Option<u64> {
302 let value = headers.get(reqwest::header::CACHE_CONTROL)?.to_str().ok()?;
303 let mut max_age = None;
304 for directive in value.split(',') {
305 let (key, val) = {
306 if let Some(index) = directive.find('=') {
307 (directive[0..index].trim(), Some(&directive[index + 1..]))
308 } else {
309 (directive.trim(), None)
310 }
311 };
312 if key.eq_ignore_ascii_case("max-age")
313 && let Some(val) = val
314 {
315 max_age = val.trim().parse().ok();
316 } else if key.eq_ignore_ascii_case("no-cache") || key.eq_ignore_ascii_case("no-store") {
317 return None;
318 }
319 }
320 max_age
321}
322
323#[doc(hidden)]
324pub fn __test_only_reqwest_client_cache_clear() {
325 CLIENT_CACHE.clear()
326}
327
328#[doc(hidden)]
329pub fn __test_only_reqwest_client_cache_len() -> usize {
330 CLIENT_CACHE.len()
331}
332
333#[cfg(test)]
334mod tests {
335 use reqwest::header::{CACHE_CONTROL, HeaderMap, HeaderValue};
336
337 use super::parse_cache_control;
338
339 fn headers(value: &str) -> HeaderMap {
340 let mut h = HeaderMap::new();
341 h.insert(CACHE_CONTROL, HeaderValue::from_str(value).unwrap());
342 h
343 }
344
345 #[test]
346 fn max_age() {
347 assert_eq!(parse_cache_control(&headers("max-age=300")), Some(300));
348 assert_eq!(parse_cache_control(&headers("MAX-AGE = 300")), Some(300));
349 assert_eq!(
350 parse_cache_control(&headers("public, max-age=3600, must-revalidate")),
351 Some(3600)
352 );
353 }
354
355 #[test]
356 fn no_cache_headers() {
357 assert_eq!(parse_cache_control(&headers("NO-CACHE")), None);
358 assert_eq!(parse_cache_control(&headers("no-cache")), None);
359 assert_eq!(parse_cache_control(&headers("no-store")), None);
360 assert_eq!(parse_cache_control(&headers("max-age=300, no-store")), None);
361 assert_eq!(parse_cache_control(&HeaderMap::new()), None);
362 }
363}