1use std::{
2 any::{Any, TypeId},
3 error::Error as StdError,
4 future::Future,
5 mem::replace,
6 panic,
7 pin::Pin,
8 sync::Arc,
9};
10
11use anyhow::Result;
12use auto_hash_map::AutoSet;
13use futures::{StreamExt, TryStreamExt};
14use parking_lot::Mutex;
15use rustc_hash::{FxHashMap, FxHashSet};
16use tokio::task_local;
17use tracing::Instrument;
18
19use crate::{
20 self as turbo_tasks, CollectiblesSource, NonLocalValue, ReadRef, ResolvedVc, TryJoinIterExt,
21 emit,
22 event::{Event, EventListener},
23 spawn,
24 trace::TraceRawVcs,
25};
26
27const APPLY_EFFECTS_CONCURRENCY_LIMIT: usize = 1024;
28
29pub trait Effect: TraceRawVcs + NonLocalValue + Send + Sync + 'static {
30 type Error: EffectError;
31
32 fn apply(&self) -> impl Future<Output = Result<(), Self::Error>> + Send;
39}
40
41pub trait EffectError: StdError + TraceRawVcs + NonLocalValue + Send + Sync + 'static {}
55impl<T> EffectError for T where T: StdError + TraceRawVcs + NonLocalValue + Send + Sync + 'static {}
56
57trait DynEffect: TraceRawVcs + NonLocalValue + Send + Sync + 'static {
60 fn dyn_apply<'a>(&'a self) -> DynEffectApplyFuture<'a>;
61}
62
63impl<T> DynEffect for T
64where
65 T: Effect,
66{
67 fn dyn_apply<'a>(&'a self) -> DynEffectApplyFuture<'a> {
68 Box::pin(async move {
69 self.apply()
70 .await
71 .map_err(|err| Arc::new(err) as Arc<dyn EffectError>)
72 })
73 }
74}
75
76type DynEffectApplyFuture<'a> =
77 Pin<Box<dyn Future<Output = Result<(), Arc<dyn EffectError>>> + Send + 'a>>;
78
79#[turbo_tasks::value_trait]
83trait EffectCollectible {}
84
85#[derive(TraceRawVcs, NonLocalValue)]
86enum EffectState {
87 NotStarted(Box<dyn DynEffect>),
88 Started(Arc<dyn DynEffect>, Event),
93 Finished(Result<(), Arc<dyn EffectError>>),
94
95 Invalid,
97}
98
99#[turbo_tasks::value(serialization = "none", cell = "new", eq = "manual")]
101struct EffectInstance {
102 #[turbo_tasks(debug_ignore)]
105 inner: Mutex<EffectState>,
106}
107
108impl EffectInstance {
109 fn new(effect: impl Effect) -> Self {
110 Self {
111 inner: Mutex::new(EffectState::NotStarted(
112 Box::new(effect) as Box<dyn DynEffect>
113 )),
114 }
115 }
116
117 async fn apply(&self) -> Result<()> {
118 loop {
119 enum State {
120 Started(EventListener),
121 NotStarted(Arc<dyn DynEffect>),
122 }
123 let state = {
124 let mut guard = self.inner.lock();
125 match &*guard {
126 EffectState::Started(_, event) => {
127 let listener = event.listen();
128 State::Started(listener)
129 }
130 EffectState::Finished(result) => {
131 return result.clone().map_err(Into::into);
132 }
133 EffectState::NotStarted(_) => {
134 let EffectState::NotStarted(effect) =
135 std::mem::replace(&mut *guard, EffectState::Invalid)
136 else {
137 unreachable!()
138 };
139 let effect: Arc<dyn DynEffect> = Arc::from(effect);
140 *guard = EffectState::Started(
141 effect.clone(),
142 Event::new(|| || "Effect".to_string()),
143 );
144 State::NotStarted(effect)
145 }
146 EffectState::Invalid => unreachable!(),
147 }
148 };
149 match state {
150 State::Started(listener) => listener.await,
151 State::NotStarted(effect) => {
152 let join_handle = spawn(ApplyEffectsContext::in_current_scope(async move {
156 effect.dyn_apply().await
157 }));
158 let result = match join_handle.await {
159 Err(err) => Err(err),
160 Ok(()) => Ok(()),
161 };
162 let event = {
163 let mut guard = self.inner.lock();
164 let EffectState::Started(_, event) =
165 replace(&mut *guard, EffectState::Finished(result.clone()))
166 else {
167 unreachable!();
168 };
169 event
170 };
171 event.notify(usize::MAX);
172 return result.map_err(Into::into);
173 }
174 }
175 }
176 }
177}
178
179#[turbo_tasks::value_impl]
180impl EffectCollectible for EffectInstance {}
181
182pub fn emit_effect(effect: impl Effect) {
191 emit::<Box<dyn EffectCollectible>>(ResolvedVc::upcast(
192 EffectInstance::new(effect).resolved_cell(),
193 ));
194}
195
196pub async fn apply_effects(source: impl CollectiblesSource) -> Result<()> {
214 let effects: AutoSet<ResolvedVc<Box<dyn EffectCollectible>>> = source.take_collectibles();
215 if effects.is_empty() {
216 return Ok(());
217 }
218 let span = tracing::info_span!("apply effects", count = effects.len());
219 APPLY_EFFECTS_CONTEXT
220 .scope(Default::default(), async move {
221 futures::stream::iter(effects)
223 .map(Ok)
224 .try_for_each_concurrent(APPLY_EFFECTS_CONCURRENCY_LIMIT, async |effect| {
225 let Some(effect) = ResolvedVc::try_downcast_type::<EffectInstance>(effect)
226 else {
227 panic!("Effect must only be implemented by EffectInstance");
228 };
229 effect.await?.apply().await
230 })
231 .await
232 })
233 .instrument(span)
234 .await
235}
236
237pub async fn get_effects(source: impl CollectiblesSource) -> Result<Effects> {
257 let effects: AutoSet<ResolvedVc<Box<dyn EffectCollectible>>> = source.take_collectibles();
258 let effects = effects
259 .into_iter()
260 .map(|effect| async move {
261 if let Some(effect) = ResolvedVc::try_downcast_type::<EffectInstance>(effect) {
262 Ok(effect.await?)
263 } else {
264 panic!("Effect must only be implemented by EffectInstance");
265 }
266 })
267 .try_join()
268 .await?;
269 Ok(Effects { effects })
270}
271
272#[derive(Default)]
275#[turbo_tasks::value(shared, eq = "manual", serialization = "none")]
276pub struct Effects {
277 #[turbo_tasks(debug_ignore)]
278 effects: Vec<ReadRef<EffectInstance>>,
279}
280
281impl PartialEq for Effects {
282 fn eq(&self, other: &Self) -> bool {
283 if self.effects.len() != other.effects.len() {
284 return false;
285 }
286 let effect_ptrs = self
287 .effects
288 .iter()
289 .map(ReadRef::ptr)
290 .collect::<FxHashSet<_>>();
291 other
292 .effects
293 .iter()
294 .all(|e| effect_ptrs.contains(&ReadRef::ptr(e)))
295 }
296}
297
298impl Eq for Effects {}
299
300impl Effects {
301 pub async fn apply(&self) -> Result<()> {
303 let span = tracing::info_span!("apply effects", count = self.effects.len());
304 APPLY_EFFECTS_CONTEXT
305 .scope(Default::default(), async move {
306 futures::stream::iter(self.effects.iter())
308 .map(Ok)
309 .try_for_each_concurrent(APPLY_EFFECTS_CONCURRENCY_LIMIT, async |effect| {
310 effect.apply().await
311 })
312 .await
313 })
314 .instrument(span)
315 .await
316 }
317}
318
319task_local! {
320 static APPLY_EFFECTS_CONTEXT: Arc<Mutex<ApplyEffectsContext>>;
322}
323
324#[derive(Default)]
325pub struct ApplyEffectsContext {
326 data: FxHashMap<TypeId, Box<dyn Any + Send + Sync>>,
327}
328
329impl ApplyEffectsContext {
330 fn in_current_scope<F: Future>(f: F) -> impl Future<Output = F::Output> {
331 let current = Self::current();
332 APPLY_EFFECTS_CONTEXT.scope(current, f)
333 }
334
335 fn current() -> Arc<Mutex<Self>> {
336 APPLY_EFFECTS_CONTEXT
337 .try_with(|mutex| mutex.clone())
338 .expect("No effect context found")
339 }
340
341 fn with_context<T, F: FnOnce(&mut Self) -> T>(f: F) -> T {
342 APPLY_EFFECTS_CONTEXT
343 .try_with(|mutex| f(&mut mutex.lock()))
344 .expect("No effect context found")
345 }
346
347 pub fn set<T: Any + Send + Sync>(value: T) {
348 Self::with_context(|this| {
349 this.data.insert(TypeId::of::<T>(), Box::new(value));
350 })
351 }
352
353 pub fn with<T: Any + Send + Sync, R>(f: impl FnOnce(&mut T) -> R) -> Option<R> {
354 Self::with_context(|this| {
355 this.data
356 .get_mut(&TypeId::of::<T>())
357 .map(|value| {
358 unsafe { value.downcast_unchecked_mut() }
360 })
361 .map(f)
362 })
363 }
364
365 pub fn with_or_insert_with<T: Any + Send + Sync, R>(
366 insert_with: impl FnOnce() -> T,
367 f: impl FnOnce(&mut T) -> R,
368 ) -> R {
369 Self::with_context(|this| {
370 let value = this.data.entry(TypeId::of::<T>()).or_insert_with(|| {
371 let value = insert_with();
372 Box::new(value)
373 });
374 f(
375 unsafe { value.downcast_unchecked_mut() },
377 )
378 })
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use crate::{CollectiblesSource, apply_effects, get_effects};
385
386 #[test]
387 #[allow(dead_code)]
388 fn is_sync_and_send() {
389 fn assert_sync<T: Sync + Send>(_: T) {}
390 fn check_apply_effects<T: CollectiblesSource + Send + Sync>(t: T) {
391 assert_sync(apply_effects(t));
392 }
393 fn check_get_effects<T: CollectiblesSource + Send + Sync>(t: T) {
394 assert_sync(get_effects(t));
395 }
396 }
397}