turbopack_core/module_graph/
module_batches.rs

1use std::{
2    collections::{VecDeque, hash_map::Entry},
3    hash::BuildHasherDefault,
4    mem::take,
5};
6
7use anyhow::{Context, Result, bail};
8use either::Either;
9use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
10use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
11use serde::{Deserialize, Serialize};
12use tracing::Instrument;
13use turbo_prehash::BuildHasherExt;
14use turbo_tasks::{
15    FxIndexMap, FxIndexSet, NonLocalValue, ResolvedVc, TaskInput, TryJoinIterExt, ValueToString,
16    Vc, trace::TraceRawVcs,
17};
18
19use crate::{
20    chunk::{ChunkableModule, ChunkingType},
21    module::Module,
22    module_graph::{
23        GraphTraversalAction, ModuleGraph,
24        chunk_group_info::{ChunkGroupInfo, ChunkGroupKey, RoaringBitmapWrapper},
25        module_batch::{ModuleBatch, ModuleBatchGroup, ModuleOrBatch},
26        traced_di_graph::{TracedDiGraph, iter_neighbors_rev},
27    },
28};
29#[turbo_tasks::value]
30#[derive(Debug, Clone, Default, TaskInput, Hash)]
31pub struct BatchingConfig {
32    /// Use a heuristic based on the module path to create batches. It aims for batches of a good
33    /// size.
34    pub use_heuristic: bool,
35}
36
37#[turbo_tasks::value_impl]
38impl BatchingConfig {
39    #[turbo_tasks::function]
40    pub fn new(config: BatchingConfig) -> Vc<Self> {
41        config.cell()
42    }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, TraceRawVcs, NonLocalValue)]
46pub struct ModuleBatchesGraphEdge {
47    pub ty: ChunkingType,
48    pub module: Option<ResolvedVc<Box<dyn Module>>>,
49}
50
51type EntriesList = FxIndexSet<ResolvedVc<Box<dyn Module>>>;
52
53#[turbo_tasks::value(cell = "new", eq = "manual", into = "new")]
54pub struct ModuleBatchesGraph {
55    graph: TracedDiGraph<ModuleOrBatch, ModuleBatchesGraphEdge>,
56
57    // NodeIndex isn't necessarily stable (because of swap_remove), but we never remove nodes.
58    //
59    // HashMaps have nondeterministic order, but this map is only used for lookups and not
60    // iteration.
61    //
62    // This contains Vcs, but they are already contained in the graph, so no need to trace this.
63    #[turbo_tasks(trace_ignore)]
64    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, NodeIndex>,
65    batch_groups: FxHashMap<ModuleOrBatch, ResolvedVc<ModuleBatchGroup>>,
66
67    /// For chunk groups where the postorder of entries is different than the order of the
68    /// `ChunkGroup::entries()` this contains Some with the postorder list of entries of that chunk
69    /// group. The index in this list corresponds to the index in the
70    /// chunk_group_info.chunk_groups.
71    ordered_entries: Vec<Option<EntriesList>>,
72}
73
74impl ModuleBatchesGraph {
75    pub async fn get_entry_index(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<NodeIndex> {
76        let Some(entry) = self.entries.get(&entry) else {
77            bail!(
78                "Entry {} is not in graph (possible entries: {:#?})",
79                entry.ident().to_string().await?,
80                self.entries
81                    .keys()
82                    .map(|e| e.ident().to_string())
83                    .try_join()
84                    .await?
85            );
86        };
87        Ok(*entry)
88    }
89
90    pub fn get_ordered_entries<'l>(
91        &'l self,
92        chunk_group_info: &'l ChunkGroupInfo,
93        idx: usize,
94    ) -> impl Iterator<Item = ResolvedVc<Box<dyn Module>>> + 'l {
95        if let Some(ordered_entries) = self
96            .ordered_entries
97            .get(idx)
98            .as_ref()
99            .and_then(|o| o.as_ref())
100        {
101            if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
102                debug_assert_eq!(ordered_entries.len(), chunk_group.entries_count());
103            }
104            Either::Left(Either::Left(ordered_entries.iter().copied()))
105        } else if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
106            Either::Right(chunk_group.entries())
107        } else {
108            Either::Left(Either::Right(std::iter::empty()))
109        }
110    }
111
112    pub fn get_batch_group(
113        &self,
114        module_or_batch: &ModuleOrBatch,
115    ) -> Option<ResolvedVc<ModuleBatchGroup>> {
116        self.batch_groups.get(module_or_batch).copied()
117    }
118
119    pub async fn get_entry(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<ModuleOrBatch> {
120        let entry = self.get_entry_index(entry).await?;
121        Ok(*self.graph.node_weight(entry).unwrap())
122    }
123
124    // Clippy complains but there's a type error without the bound
125    #[allow(clippy::implied_bounds_in_impls)]
126    /// Traverses all reachable edges in dfs order. The preorder visitor can be used to
127    /// forward state down the graph, and to skip subgraphs
128    ///
129    /// Use this to collect batches/modules in evaluation order.
130    ///
131    /// Target nodes can be revisited (once per incoming edge).
132    /// Edges are traversed in normal order, so should correspond to reference order.
133    ///
134    /// * `entries` - The entry modules to start the traversal from
135    /// * `state` - The state to be passed to the visitors
136    /// * `visit_preorder` - Called before visiting the children of a node.
137    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
138    ///      &ModuleBatchesGraphNode, state &S
139    ///    - Can return [GraphTraversalAction]s to control the traversal
140    /// * `visit_postorder` - Called after visiting the children of a node. Return
141    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
142    ///      &ModuleBatchesGraphNode, state &S
143    pub fn traverse_edges_from_entries_dfs<'a, S>(
144        &'a self,
145        entries: impl IntoIterator<
146            Item = NodeIndex,
147            IntoIter = impl Iterator<Item = NodeIndex> + DoubleEndedIterator,
148        >,
149        state: &mut S,
150        mut visit_preorder: impl FnMut(
151            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
152            &'a ModuleOrBatch,
153            &mut S,
154        ) -> Result<GraphTraversalAction>,
155        mut visit_postorder: impl FnMut(
156            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
157            &'a ModuleOrBatch,
158            &mut S,
159        ),
160    ) -> Result<()> {
161        let graph = &self.graph;
162
163        enum ReverseDFSPass {
164            Visit,
165            ExpandAndVisit,
166        }
167
168        let entries = entries.into_iter();
169        #[allow(clippy::type_complexity)] // This is a temporary internal structure
170        let mut stack: Vec<(ReverseDFSPass, Option<(NodeIndex, EdgeIndex)>, NodeIndex)> = entries
171            .rev()
172            .map(|e| (ReverseDFSPass::ExpandAndVisit, None, e))
173            .collect();
174        let mut expanded = FxHashSet::default();
175        while let Some((pass, parent, current)) = stack.pop() {
176            let parent_arg = parent.map(|(node, edge)| {
177                (
178                    graph.node_weight(node).unwrap(),
179                    graph.edge_weight(edge).unwrap(),
180                )
181            });
182            match pass {
183                ReverseDFSPass::Visit => {
184                    let current_node = graph.node_weight(current).unwrap();
185                    visit_postorder(parent_arg, current_node, state);
186                }
187                ReverseDFSPass::ExpandAndVisit => {
188                    let current_node = graph.node_weight(current).unwrap();
189                    let action = visit_preorder(parent_arg, current_node, state)?;
190                    if action == GraphTraversalAction::Exclude {
191                        continue;
192                    }
193                    stack.push((ReverseDFSPass::Visit, parent, current));
194                    if action == GraphTraversalAction::Continue && expanded.insert(current) {
195                        stack.extend(iter_neighbors_rev(graph, current).map(|(edge, child)| {
196                            (ReverseDFSPass::ExpandAndVisit, Some((current, edge)), child)
197                        }));
198                    }
199                }
200            }
201        }
202
203        Ok(())
204    }
205}
206
207type PreBatchIndex = usize;
208
209#[derive(Hash, PartialEq, Eq, Clone, Debug)]
210enum PreBatchItem {
211    ParallelModule(ResolvedVc<Box<dyn Module>>),
212    ParallelReference(PreBatchIndex),
213    NonParallelEdge(ChunkingType, ResolvedVc<Box<dyn Module>>),
214}
215
216struct PreBatch {
217    items: FxIndexSet<PreBatchItem>,
218    chunk_groups: RoaringBitmapWrapper,
219}
220
221impl PreBatch {
222    fn new(chunk_groups: RoaringBitmapWrapper) -> Self {
223        Self {
224            items: FxIndexSet::default(),
225            chunk_groups,
226        }
227    }
228}
229
230struct TraversalState<'l> {
231    items: Vec<PreBatchItem>,
232    this: &'l mut PreBatches,
233}
234
235struct PreBatches {
236    boundary_modules: FxHashSet<ResolvedVc<Box<dyn Module>>>,
237    batches: Vec<PreBatch>,
238    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, PreBatchIndex>,
239    single_module_entries: FxIndexSet<ResolvedVc<Box<dyn Module>>>,
240}
241
242impl PreBatches {
243    fn new() -> Self {
244        Self {
245            boundary_modules: FxHashSet::default(),
246            batches: Vec::new(),
247            entries: FxHashMap::default(),
248            single_module_entries: FxIndexSet::default(),
249        }
250    }
251
252    fn ensure_pre_batch_for_module(
253        &mut self,
254        module: ResolvedVc<Box<dyn Module>>,
255        chunk_group_info: &ChunkGroupInfo,
256        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
257    ) -> Result<PreBatchIndex> {
258        Ok(match self.entries.entry(module) {
259            Entry::Vacant(e) => {
260                let index = self.batches.len();
261                queue.push_back((module, index));
262                let chunk_groups = chunk_group_info
263                    .module_chunk_groups
264                    .get(&module)
265                    .context("all modules need to have chunk group info")?;
266                let batch = PreBatch::new(chunk_groups.clone());
267                self.batches.push(batch);
268                e.insert(index);
269                index
270            }
271            Entry::Occupied(e) => *e.get(),
272        })
273    }
274
275    async fn get_pre_batch_items(
276        &mut self,
277        entry: ResolvedVc<Box<dyn Module>>,
278        chunk_group_info: &ChunkGroupInfo,
279        module_graph: &ModuleGraph,
280        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
281    ) -> Result<Vec<PreBatchItem>> {
282        let mut state = TraversalState {
283            items: Vec::new(),
284            this: self,
285        };
286        let mut visited = FxHashSet::default();
287        module_graph
288            .traverse_edges_from_entries_dfs(
289                std::iter::once(ResolvedVc::upcast(entry)),
290                &mut state,
291                |parent_info, node, state| {
292                    let ty = parent_info.map_or(
293                        &ChunkingType::Parallel {
294                            inherit_async: false,
295                            hoisted: false,
296                        },
297                        |(_, ty)| &ty.chunking_type,
298                    );
299                    let module = node.module;
300                    if !ty.is_parallel() {
301                        state.items.push(PreBatchItem::NonParallelEdge(
302                            ty.without_inherit_async(),
303                            module,
304                        ));
305                        return Ok(GraphTraversalAction::Exclude);
306                    }
307                    if visited.insert(module) {
308                        if parent_info.is_some() && state.this.boundary_modules.contains(&module) {
309                            let idx = state.this.ensure_pre_batch_for_module(
310                                module,
311                                chunk_group_info,
312                                queue,
313                            )?;
314                            state.items.push(PreBatchItem::ParallelReference(idx));
315                            return Ok(GraphTraversalAction::Exclude);
316                        }
317                        Ok(GraphTraversalAction::Continue)
318                    } else {
319                        Ok(GraphTraversalAction::Exclude)
320                    }
321                },
322                |_, node, state| {
323                    let item = PreBatchItem::ParallelModule(node.module);
324                    state.items.push(item);
325                    Ok(())
326                },
327            )
328            .await?;
329        Ok(state.items)
330    }
331}
332
333pub async fn compute_module_batches(
334    module_graph: Vc<ModuleGraph>,
335    _config: &BatchingConfig,
336) -> Result<Vc<ModuleBatchesGraph>> {
337    let outer_span = tracing::info_span!(
338        "compute module batches",
339        initial_pre_batch_items = tracing::field::Empty,
340        initial_pre_batches = tracing::field::Empty,
341        extracted_shared_items = tracing::field::Empty,
342        batches = tracing::field::Empty,
343        modules = tracing::field::Empty,
344        edges = tracing::field::Empty
345    );
346    let span = outer_span.clone();
347    async move {
348        let chunk_group_info = module_graph.chunk_group_info().await?;
349        let module_graph = module_graph.await?;
350
351        let mut pre_batches = PreBatches::new();
352
353        // Walk the module graph and mark all modules that are boundary modules (referenced from a
354        // different chunk group bitmap)
355        module_graph
356            .traverse_all_edges_unordered(|(parent, ty), node| {
357                let std::collections::hash_set::Entry::Vacant(entry) =
358                    pre_batches.boundary_modules.entry(node.module)
359                else {
360                    // Already a boundary module, can skip check
361                    return Ok(());
362                };
363                if ty.chunking_type.is_parallel() {
364                    let parent_chunk_groups = chunk_group_info
365                        .module_chunk_groups
366                        .get(&parent.module)
367                        .context("all modules need to have chunk group info")?;
368                    let chunk_groups = chunk_group_info
369                        .module_chunk_groups
370                        .get(&node.module)
371                        .context("all modules need to have chunk group info")?;
372                    if parent_chunk_groups != chunk_groups {
373                        // This is a boundary module
374                        entry.insert();
375                    }
376                } else {
377                    entry.insert();
378                }
379                Ok(())
380            })
381            .await?;
382
383        // All entries are boundary modules too
384        for chunk_group in &chunk_group_info.chunk_groups {
385            for entry in chunk_group.entries() {
386                pre_batches.boundary_modules.insert(entry);
387            }
388        }
389
390        // Pre batches would be incorrect with cycles, so we need to opt-out of pre batches for
391        // cycles that include boundary modules
392        module_graph
393            .traverse_cycles(
394                |ref_data| ref_data.chunking_type.is_parallel(),
395                |cycle| {
396                    if cycle
397                        .iter()
398                        .any(|node| pre_batches.boundary_modules.contains(&node.module))
399                    {
400                        pre_batches
401                            .boundary_modules
402                            .extend(cycle.iter().map(|node| node.module));
403                    }
404                },
405            )
406            .await?;
407
408        let mut queue: VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)> = VecDeque::new();
409
410        let mut chunk_group_indices_with_merged_children = FxHashSet::default();
411
412        // Start with the entries
413        for chunk_group in &chunk_group_info.chunk_groups {
414            for entry in chunk_group.entries() {
415                if let Some(chunkable_module) = ResolvedVc::try_downcast(entry) {
416                    pre_batches.ensure_pre_batch_for_module(
417                        chunkable_module,
418                        &chunk_group_info,
419                        &mut queue,
420                    )?;
421                } else {
422                    pre_batches.single_module_entries.insert(entry);
423                }
424            }
425            if let Some(parent) = chunk_group.get_merged_parent() {
426                chunk_group_indices_with_merged_children.insert(parent);
427            }
428        }
429
430        let mut initial_pre_batch_items = 0;
431        // Fill all pre batches
432        while let Some((chunkable_module, idx)) = queue.pop_front() {
433            let items = pre_batches
434                .get_pre_batch_items(
435                    chunkable_module,
436                    &chunk_group_info,
437                    &module_graph,
438                    &mut queue,
439                )
440                .await?;
441            initial_pre_batch_items += items.len();
442            let batch = &mut pre_batches.batches[idx];
443            batch.items.extend(items);
444        }
445        span.record("initial_pre_batch_items", initial_pre_batch_items);
446        span.record("initial_pre_batches", pre_batches.batches.len());
447
448        // Figure out the order of all merged groups
449        let mut ordered_entries: Vec<Option<EntriesList>> =
450            vec![None; chunk_group_info.chunk_groups.len()];
451        for (i, chunk_group) in chunk_group_info.chunk_groups.iter().enumerate() {
452            if !chunk_group_indices_with_merged_children.contains(&i) {
453                continue;
454            }
455            let mut merged_modules: FxHashMap<ChunkingType, FxIndexSet<_>> = FxHashMap::default();
456            let mut stack = ordered_entries[i]
457                .as_ref()
458                .map_or_else(
459                    || Either::Left(chunk_group.entries()),
460                    |v| Either::Right(v.iter().copied()),
461                )
462                .filter_map(|module| {
463                    if let Some(chunkable_module) = ResolvedVc::try_downcast(module) {
464                        let idx = *pre_batches.entries.get(&chunkable_module).unwrap();
465                        Some((idx, 0))
466                    } else {
467                        None
468                    }
469                })
470                .collect::<Vec<_>>();
471            stack.reverse();
472            let mut visited = FxHashSet::default();
473            while let Some((idx, mut pos)) = stack.pop() {
474                let batch = &pre_batches.batches[idx];
475                while let Some(item) = batch.items.get_index(pos) {
476                    match item {
477                        PreBatchItem::ParallelModule(_) => {}
478                        PreBatchItem::ParallelReference(other_idx) => {
479                            if visited.insert(*other_idx) {
480                                stack.push((idx, pos + 1));
481                                stack.push((*other_idx, 0));
482                                break;
483                            }
484                        }
485                        PreBatchItem::NonParallelEdge(chunking_type, module) => {
486                            if chunking_type.is_merged() {
487                                merged_modules
488                                    .entry(chunking_type.clone())
489                                    .or_default()
490                                    .insert(*module);
491                            }
492                        }
493                    }
494                    pos += 1;
495                }
496            }
497            if !merged_modules.is_empty() {
498                for (ty, merged_modules) in merged_modules {
499                    let chunk_group_key = match ty {
500                        ChunkingType::Isolated {
501                            merge_tag: Some(merge_tag),
502                            ..
503                        } => ChunkGroupKey::IsolatedMerged {
504                            parent: i.into(),
505                            merge_tag: merge_tag.clone(),
506                        },
507                        ChunkingType::Shared {
508                            merge_tag: Some(merge_tag),
509                            ..
510                        } => ChunkGroupKey::SharedMerged {
511                            parent: i.into(),
512                            merge_tag: merge_tag.clone(),
513                        },
514                        _ => unreachable!(),
515                    };
516                    let idx = chunk_group_info
517                        .chunk_group_keys
518                        .get_index_of(&chunk_group_key)
519                        .unwrap();
520                    ordered_entries[idx] = Some(merged_modules);
521                }
522            }
523        }
524
525        // Create a map of parallel module to the batches they are contained in.
526        let mut parallel_module_to_pre_batch: FxIndexMap<_, Vec<PreBatchIndex>> =
527            FxIndexMap::default();
528
529        // Fill the map and also fill up the single_module_entries
530        for (idx, pre_batch) in pre_batches.batches.iter().enumerate() {
531            for item in &pre_batch.items {
532                match item {
533                    PreBatchItem::ParallelModule(module) => {
534                        parallel_module_to_pre_batch
535                            .entry(*module)
536                            .or_default()
537                            .push(idx);
538                    }
539                    PreBatchItem::NonParallelEdge(_, module) => {
540                        if let Some(chunkable_module) = ResolvedVc::try_downcast(*module) {
541                            if !pre_batches.entries.contains_key(&chunkable_module) {
542                                pre_batches.single_module_entries.insert(*module);
543                            }
544                        } else {
545                            pre_batches.single_module_entries.insert(*module);
546                        }
547                    }
548                    PreBatchItem::ParallelReference(_) => {}
549                }
550            }
551        }
552
553        // We never want a module to occur in multiple batches.
554
555        let mut extracted_shared_items = 0;
556        // Extract shared modules into separate batches
557        for i in 0..parallel_module_to_pre_batch.len() {
558            let (&module, batches) = parallel_module_to_pre_batch.get_index(i).unwrap();
559            if batches.len() > 1 {
560                // Create a new batch for the shared modules
561                let batches_with_item_index = batches
562                    .iter()
563                    .map(|&idx| {
564                        let batch_items = &pre_batches.batches[idx].items;
565                        let item_idx = batch_items
566                            .get_index_of(&PreBatchItem::ParallelModule(module))
567                            .unwrap();
568                        (idx, item_idx)
569                    })
570                    .collect::<Vec<_>>();
571                let mut selected_items = 1;
572                fn get_item_at(
573                    pre_batches: &PreBatches,
574                    batch_idx: PreBatchIndex,
575                    item_idx: usize,
576                ) -> Option<&PreBatchItem> {
577                    pre_batches.batches[batch_idx].items.get_index(item_idx)
578                }
579                // Select more matching items that are equal in all batches that contain the shared
580                // module(s)
581                loop {
582                    if let Some(PreBatchItem::ParallelModule(next_module)) = get_item_at(
583                        &pre_batches,
584                        batches_with_item_index[0].0,
585                        batches_with_item_index[0].1 + selected_items,
586                    ) && parallel_module_to_pre_batch.get(next_module).unwrap().len()
587                        == batches.len()
588                        && batches_with_item_index[1..]
589                            .iter()
590                            .all(|&(batch_idx, item_idx)| {
591                                get_item_at(&pre_batches, batch_idx, item_idx + selected_items)
592                                    == Some(&PreBatchItem::ParallelModule(*next_module))
593                            })
594                    {
595                        selected_items += 1;
596                        continue;
597                    }
598                    break;
599                }
600                extracted_shared_items += selected_items;
601
602                // Check if a batch is completely selected. In that case we can replace all other
603                // occurrences with a reference to that batch
604                let exact_match = batches_with_item_index
605                    .iter()
606                    .find(|&&(batch_idx, item_idx)| {
607                        item_idx == 0
608                            && pre_batches.batches[batch_idx].items.len() == selected_items
609                    });
610                if let Some(&(exact_match, _)) = exact_match {
611                    // Replace all other occurrences with a reference to the exact match
612                    for &(batch_index, item_start) in batches_with_item_index.iter() {
613                        if batch_index != exact_match {
614                            pre_batches.batches[batch_index].items.splice(
615                                item_start..item_start + selected_items,
616                                std::iter::once(PreBatchItem::ParallelReference(exact_match)),
617                            );
618                        }
619                    }
620                    for item in pre_batches.batches[exact_match].items.iter() {
621                        if let PreBatchItem::ParallelModule(module) = item {
622                            parallel_module_to_pre_batch
623                                .get_mut(module)
624                                .unwrap()
625                                .clear();
626                        }
627                    }
628                } else {
629                    // Create a new batch of the shared part and replace all occurrences with a
630                    // reference to that batch
631                    let first_batch_index = batches_with_item_index[0].0;
632                    let first_batch_item_index = batches_with_item_index[0].1;
633                    let new_batch_index = pre_batches.batches.len();
634                    let mut new_batch =
635                        PreBatch::new(pre_batches.batches[first_batch_index].chunk_groups.clone());
636                    new_batch
637                        .items
638                        .extend(pre_batches.batches[first_batch_index].items.splice(
639                            first_batch_item_index..first_batch_item_index + selected_items,
640                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
641                        ));
642                    for item in new_batch.items.iter() {
643                        if let PreBatchItem::ParallelModule(module) = item {
644                            parallel_module_to_pre_batch
645                                .get_mut(module)
646                                .unwrap()
647                                .clear();
648                        }
649                    }
650                    pre_batches.batches.push(new_batch);
651                    for &(batch_index, item_start) in batches_with_item_index[1..].iter() {
652                        pre_batches.batches[batch_index].items.splice(
653                            item_start..item_start + selected_items,
654                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
655                        );
656                    }
657                }
658            }
659        }
660        span.record("extracted_shared_items", extracted_shared_items);
661
662        // Now every module is only in one batch
663
664        let mut edges_count = 0;
665
666        // Since batches can only have references followed by a list of parallel chunkable modules,
667        // we need to split batches that have modules before references.
668        for i in 0..pre_batches.batches.len() {
669            let items = take(&mut pre_batches.batches[i].items);
670            let mut new_items =
671                FxIndexSet::with_capacity_and_hasher(items.len(), Default::default());
672            enum Mode {
673                ParallelChunkableModule,
674                Other,
675            }
676            let mut mode = Mode::Other;
677            for item in items {
678                let chunkable_module = if let PreBatchItem::ParallelModule(module) = &item {
679                    ResolvedVc::try_downcast::<Box<dyn ChunkableModule>>(*module)
680                } else {
681                    None
682                };
683                let item = if let PreBatchItem::ParallelModule(module) = item {
684                    if chunkable_module.is_some() {
685                        PreBatchItem::ParallelModule(module)
686                    } else {
687                        pre_batches.single_module_entries.insert(module);
688                        PreBatchItem::NonParallelEdge(
689                            ChunkingType::Parallel {
690                                inherit_async: false,
691                                hoisted: false,
692                            },
693                            module,
694                        )
695                    }
696                } else {
697                    item
698                };
699                match (&mode, chunkable_module) {
700                    (_, Some(_)) => {
701                        mode = Mode::ParallelChunkableModule;
702                        new_items.insert(item);
703                    }
704                    (Mode::Other, _) => {
705                        edges_count += 1;
706                        new_items.insert(item);
707                    }
708                    (Mode::ParallelChunkableModule, _) => {
709                        // Split the batch
710                        let idx = pre_batches.batches.len();
711                        let mut new_batch =
712                            PreBatch::new(pre_batches.batches[i].chunk_groups.clone());
713                        new_batch.items.extend(new_items.drain(..));
714                        pre_batches.batches.push(new_batch);
715                        edges_count += 1;
716                        new_items.insert(PreBatchItem::ParallelReference(idx));
717                        if chunkable_module.is_some() {
718                            new_items.insert(item);
719                        } else {
720                            edges_count += 1;
721                            mode = Mode::Other;
722                            new_items.insert(item);
723                        }
724                    }
725                }
726            }
727            pre_batches.batches[i].items = new_items;
728        }
729        span.record("pre_batches", pre_batches.batches.len());
730
731        // Now batches are in the correct shape. We can create the real batches and the graph.
732
733        // Create the graph
734        let mut graph: DiGraph<ModuleOrBatch, ModuleBatchesGraphEdge, u32> =
735            petgraph::graph::DiGraph::with_capacity(
736                pre_batches.batches.len() + pre_batches.single_module_entries.len(),
737                edges_count,
738            );
739
740        // Create the Vc<ModuleBatch> instances
741        let batches = pre_batches
742            .batches
743            .iter_mut()
744            .enumerate()
745            .map(async |(i, pre_batch)| {
746                let mut modules = pre_batch.items.iter().filter_map(|item| {
747                    if let PreBatchItem::ParallelModule(module) = item {
748                        ResolvedVc::try_downcast(*module)
749                    } else {
750                        None
751                    }
752                });
753                let Some(first) = modules.next() else {
754                    return Ok(ModuleOrBatch::None(i));
755                };
756                if let Some(second) = modules.next() {
757                    let batch = ModuleBatch::new(
758                        [first, second]
759                            .into_iter()
760                            .chain(modules)
761                            .map(|m| *m)
762                            .collect::<Vec<_>>(),
763                        Some(pre_batch.chunk_groups.clone()),
764                    );
765                    Ok(ModuleOrBatch::Batch(batch.to_resolved().await?))
766                } else {
767                    Ok(ModuleOrBatch::Module(ResolvedVc::upcast(first)))
768                }
769            })
770            .try_join()
771            .await?;
772
773        // Create the batch groups by grouping batches with the same chunk groups
774        let mut batch_groups: FxHashMap<_, Vec<_>> = FxHashMap::default();
775        for (i, pre_batch) in pre_batches.batches.iter().enumerate() {
776            let key =
777                BuildHasherDefault::<FxHasher>::default().prehash(pre_batch.chunk_groups.clone());
778            let batch = batches[i];
779            batch_groups.entry(key).or_default().push(batch);
780        }
781        for &module in &pre_batches.single_module_entries {
782            let chunk_groups = chunk_group_info
783                .module_chunk_groups
784                .get(&module)
785                .context("all modules need to have chunk group info")?;
786            let key = BuildHasherDefault::<FxHasher>::default().prehash(chunk_groups.clone());
787            batch_groups
788                .entry(key)
789                .or_default()
790                .push(ModuleOrBatch::Module(module));
791        }
792
793        // Create the batch group instances
794        let batch_groups = batch_groups
795            .into_iter()
796            .map(async |(key, items)| {
797                if items.len() == 1 {
798                    Ok(Either::Left(std::iter::empty()))
799                } else {
800                    let batch_group = ModuleBatchGroup::new(items.clone(), (*key).clone())
801                        .to_resolved()
802                        .await?;
803                    Ok(Either::Right(
804                        items.into_iter().map(move |item| (item, batch_group)),
805                    ))
806                }
807            })
808            .try_join()
809            .await?
810            .into_iter()
811            .flatten()
812            .collect::<FxHashMap<_, _>>();
813
814        // Insert batches into the graph and store the NodeIndices
815        let mut batches_count = 0;
816        let mut modules_count = 0;
817        let batch_indices = batches
818            .into_iter()
819            .map(|batch| {
820                match &batch {
821                    ModuleOrBatch::Batch(_) => batches_count += 1,
822                    ModuleOrBatch::Module(_) => modules_count += 1,
823                    ModuleOrBatch::None(_) => {}
824                }
825                graph.add_node(batch)
826            })
827            .collect::<Vec<_>>();
828
829        // Also insert single modules into the graph and store the NodeIndices
830        let single_module_indices = pre_batches
831            .single_module_entries
832            .iter()
833            .map(|module| graph.add_node(ModuleOrBatch::Module(*module)))
834            .collect::<Vec<_>>();
835
836        span.record("batches", batches_count);
837        modules_count += pre_batches.single_module_entries.len();
838        span.record("modules", modules_count);
839        span.record("edges", edges_count);
840
841        // Add all the edges to the graph
842        for (i, pre_batch) in pre_batches.batches.into_iter().enumerate() {
843            let index = batch_indices[i];
844            let items = pre_batch.items;
845            for item in items {
846                match item {
847                    PreBatchItem::ParallelReference(idx) => {
848                        graph.add_edge(
849                            index,
850                            batch_indices[idx],
851                            ModuleBatchesGraphEdge {
852                                ty: ChunkingType::Parallel {
853                                    inherit_async: false,
854                                    hoisted: false,
855                                },
856                                module: None,
857                            },
858                        );
859                    }
860                    PreBatchItem::NonParallelEdge(ty, module) => {
861                        if let Some(chunkable_module) = ResolvedVc::try_downcast(module)
862                            && let Some(batch) = pre_batches.entries.get(&chunkable_module).copied()
863                        {
864                            graph.add_edge(
865                                index,
866                                batch_indices[batch],
867                                ModuleBatchesGraphEdge {
868                                    ty,
869                                    module: Some(module),
870                                },
871                            );
872                            continue;
873                        }
874                        let idx = pre_batches
875                            .single_module_entries
876                            .get_index_of(&module)
877                            .unwrap();
878                        let idx = single_module_indices[idx];
879                        graph.add_edge(
880                            index,
881                            idx,
882                            ModuleBatchesGraphEdge {
883                                ty,
884                                module: Some(module),
885                            },
886                        );
887                    }
888                    PreBatchItem::ParallelModule(_) => {}
889                }
890            }
891        }
892
893        debug_assert_eq!(graph.capacity().0, graph.node_count());
894        debug_assert_eq!(graph.capacity().1, graph.edge_count());
895
896        // Find the NodeIndices for our entries of the graph
897        let mut entries = FxHashMap::default();
898        for chunk_group in &chunk_group_info.chunk_groups {
899            for module in chunk_group.entries() {
900                if let Some(chunkable_module) = ResolvedVc::try_downcast(module)
901                    && let Some(batch) = pre_batches.entries.get(&chunkable_module).copied()
902                {
903                    entries.insert(module, batch_indices[batch]);
904                    continue;
905                }
906                let idx = pre_batches
907                    .single_module_entries
908                    .get_index_of(&module)
909                    .unwrap();
910                let idx = single_module_indices[idx];
911                entries.insert(module, idx);
912            }
913        }
914
915        Ok(ModuleBatchesGraph {
916            graph: TracedDiGraph(graph),
917            entries,
918            batch_groups,
919            ordered_entries,
920        }
921        .cell())
922    }
923    .instrument(outer_span)
924    .await
925}