1use std::{
2 any::{Any, TypeId},
3 future::Future,
4 mem::replace,
5 panic,
6 pin::Pin,
7 sync::Arc,
8};
9
10use anyhow::Result;
11use auto_hash_map::AutoSet;
12use futures::{StreamExt, TryStreamExt};
13use parking_lot::Mutex;
14use rustc_hash::{FxHashMap, FxHashSet};
15use tokio::task_local;
16use tracing::Instrument;
17
18use crate::{
19 self as turbo_tasks, CollectiblesSource, NonLocalValue, ReadRef, ResolvedVc, TryJoinIterExt,
20 debug::ValueDebugFormat,
21 emit,
22 event::{Event, EventListener},
23 spawn,
24 trace::TraceRawVcs,
25 util::SharedError,
26};
27
28const APPLY_EFFECTS_CONCURRENCY_LIMIT: usize = 1024;
29
30#[turbo_tasks::value_trait]
35trait Effect {}
36
37type EffectFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + Sync + 'static>>;
40
41struct EffectInner {
43 future: EffectFuture,
44}
45
46enum EffectState {
47 NotStarted(EffectInner),
48 Started(Event),
49 Finished(Result<(), SharedError>),
50}
51
52#[turbo_tasks::value(serialization = "none", cell = "new", eq = "manual")]
54struct EffectInstance {
55 #[turbo_tasks(trace_ignore, debug_ignore)]
56 inner: Mutex<EffectState>,
57}
58
59impl EffectInstance {
60 fn new(future: impl Future<Output = Result<()>> + Send + Sync + 'static) -> Self {
61 Self {
62 inner: Mutex::new(EffectState::NotStarted(EffectInner {
63 future: Box::pin(future),
64 })),
65 }
66 }
67
68 async fn apply(&self) -> Result<()> {
69 loop {
70 enum State {
71 Started(EventListener),
72 NotStarted(EffectInner),
73 }
74 let state = {
75 let mut guard = self.inner.lock();
76 match &*guard {
77 EffectState::Started(event) => {
78 let listener = event.listen();
79 State::Started(listener)
80 }
81 EffectState::Finished(result) => {
82 return result.clone().map_err(Into::into);
83 }
84 EffectState::NotStarted(_) => {
85 let EffectState::NotStarted(inner) = std::mem::replace(
86 &mut *guard,
87 EffectState::Started(Event::new(|| || "Effect".to_string())),
88 ) else {
89 unreachable!();
90 };
91 State::NotStarted(inner)
92 }
93 }
94 };
95 match state {
96 State::Started(listener) => {
97 listener.await;
98 }
99 State::NotStarted(EffectInner { future }) => {
100 let join_handle = spawn(ApplyEffectsContext::in_current_scope(future));
101 let result = match join_handle.await {
102 Err(err) => Err(SharedError::new(err)),
103 Ok(()) => Ok(()),
104 };
105 let event = {
106 let mut guard = self.inner.lock();
107 let EffectState::Started(event) =
108 replace(&mut *guard, EffectState::Finished(result.clone()))
109 else {
110 unreachable!();
111 };
112 event
113 };
114 event.notify(usize::MAX);
115 return result.map_err(Into::into);
116 }
117 }
118 }
119 }
120}
121
122#[turbo_tasks::value_impl]
123impl Effect for EffectInstance {}
124
125pub fn effect(future: impl Future<Output = Result<()>> + Send + Sync + 'static) {
134 emit::<Box<dyn Effect>>(ResolvedVc::upcast(
135 EffectInstance::new(future).resolved_cell(),
136 ));
137}
138
139pub async fn apply_effects(source: impl CollectiblesSource) -> Result<()> {
157 let effects: AutoSet<ResolvedVc<Box<dyn Effect>>> = source.take_collectibles();
158 if effects.is_empty() {
159 return Ok(());
160 }
161 let span = tracing::info_span!("apply effects", count = effects.len());
162 APPLY_EFFECTS_CONTEXT
163 .scope(Default::default(), async move {
164 futures::stream::iter(effects)
166 .map(Ok)
167 .try_for_each_concurrent(APPLY_EFFECTS_CONCURRENCY_LIMIT, async |effect| {
168 let Some(effect) = ResolvedVc::try_downcast_type::<EffectInstance>(effect)
169 else {
170 panic!("Effect must only be implemented by EffectInstance");
171 };
172 effect.await?.apply().await
173 })
174 .await
175 })
176 .instrument(span)
177 .await
178}
179
180pub async fn get_effects(source: impl CollectiblesSource) -> Result<Effects> {
200 let effects: AutoSet<ResolvedVc<Box<dyn Effect>>> = source.take_collectibles();
201 let effects = effects
202 .into_iter()
203 .map(|effect| async move {
204 if let Some(effect) = ResolvedVc::try_downcast_type::<EffectInstance>(effect) {
205 Ok(effect.await?)
206 } else {
207 panic!("Effect must only be implemented by EffectInstance");
208 }
209 })
210 .try_join()
211 .await?;
212 Ok(Effects { effects })
213}
214
215#[derive(TraceRawVcs, Default, ValueDebugFormat, NonLocalValue)]
218pub struct Effects {
219 #[turbo_tasks(trace_ignore, debug_ignore)]
220 effects: Vec<ReadRef<EffectInstance>>,
221}
222
223impl PartialEq for Effects {
224 fn eq(&self, other: &Self) -> bool {
225 if self.effects.len() != other.effects.len() {
226 return false;
227 }
228 let effect_ptrs = self
229 .effects
230 .iter()
231 .map(ReadRef::ptr)
232 .collect::<FxHashSet<_>>();
233 other
234 .effects
235 .iter()
236 .all(|e| effect_ptrs.contains(&ReadRef::ptr(e)))
237 }
238}
239
240impl Eq for Effects {}
241
242impl Effects {
243 pub async fn apply(&self) -> Result<()> {
245 let span = tracing::info_span!("apply effects", count = self.effects.len());
246 APPLY_EFFECTS_CONTEXT
247 .scope(Default::default(), async move {
248 futures::stream::iter(self.effects.iter())
250 .map(Ok)
251 .try_for_each_concurrent(APPLY_EFFECTS_CONCURRENCY_LIMIT, async |effect| {
252 effect.apply().await
253 })
254 .await
255 })
256 .instrument(span)
257 .await
258 }
259}
260
261task_local! {
262 static APPLY_EFFECTS_CONTEXT: Arc<Mutex<ApplyEffectsContext>>;
264}
265
266#[derive(Default)]
267pub struct ApplyEffectsContext {
268 data: FxHashMap<TypeId, Box<dyn Any + Send + Sync>>,
269}
270
271impl ApplyEffectsContext {
272 fn in_current_scope<F: Future>(f: F) -> impl Future<Output = F::Output> {
273 let current = Self::current();
274 APPLY_EFFECTS_CONTEXT.scope(current, f)
275 }
276
277 fn current() -> Arc<Mutex<Self>> {
278 APPLY_EFFECTS_CONTEXT
279 .try_with(|mutex| mutex.clone())
280 .expect("No effect context found")
281 }
282
283 fn with_context<T, F: FnOnce(&mut Self) -> T>(f: F) -> T {
284 APPLY_EFFECTS_CONTEXT
285 .try_with(|mutex| f(&mut mutex.lock()))
286 .expect("No effect context found")
287 }
288
289 pub fn set<T: Any + Send + Sync>(value: T) {
290 Self::with_context(|this| {
291 this.data.insert(TypeId::of::<T>(), Box::new(value));
292 })
293 }
294
295 pub fn with<T: Any + Send + Sync, R>(f: impl FnOnce(&mut T) -> R) -> Option<R> {
296 Self::with_context(|this| {
297 this.data
298 .get_mut(&TypeId::of::<T>())
299 .map(|value| {
300 unsafe { value.downcast_mut_unchecked() }
302 })
303 .map(f)
304 })
305 }
306
307 pub fn with_or_insert_with<T: Any + Send + Sync, R>(
308 insert_with: impl FnOnce() -> T,
309 f: impl FnOnce(&mut T) -> R,
310 ) -> R {
311 Self::with_context(|this| {
312 let value = this.data.entry(TypeId::of::<T>()).or_insert_with(|| {
313 let value = insert_with();
314 Box::new(value)
315 });
316 f(
317 unsafe { value.downcast_mut_unchecked() },
319 )
320 })
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::{CollectiblesSource, apply_effects, get_effects};
327
328 #[test]
329 #[allow(dead_code)]
330 fn is_sync_and_send() {
331 fn assert_sync<T: Sync + Send>(_: T) {}
332 fn check_apply_effects<T: CollectiblesSource + Send + Sync>(t: T) {
333 assert_sync(apply_effects(t));
334 }
335 fn check_get_effects<T: CollectiblesSource + Send + Sync>(t: T) {
336 assert_sync(get_effects(t));
337 }
338 }
339}