turbopack_core/module_graph/
module_batches.rs

1use std::{
2    collections::{VecDeque, hash_map::Entry},
3    hash::BuildHasherDefault,
4    mem::take,
5};
6
7use anyhow::{Context, Result, bail};
8use bincode::{Decode, Encode};
9use either::Either;
10use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
11use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
12use serde::{Deserialize, Serialize};
13use tracing::Instrument;
14use turbo_prehash::BuildHasherExt;
15use turbo_tasks::{
16    FxIndexMap, FxIndexSet, NonLocalValue, ResolvedVc, TaskInput, TryJoinIterExt, ValueToString,
17    Vc, trace::TraceRawVcs,
18};
19
20use crate::{
21    chunk::{ChunkableModule, ChunkingType},
22    module::Module,
23    module_graph::{
24        GraphTraversalAction, ModuleGraph,
25        chunk_group_info::{ChunkGroupInfo, ChunkGroupKey, RoaringBitmapWrapper},
26        module_batch::{ModuleBatch, ModuleBatchGroup, ModuleOrBatch},
27        traced_di_graph::{TracedDiGraph, iter_neighbors_rev},
28    },
29};
30#[turbo_tasks::value]
31#[derive(Debug, Clone, Default, TaskInput, Hash)]
32pub struct BatchingConfig {
33    /// Use a heuristic based on the module path to create batches. It aims for batches of a good
34    /// size.
35    pub use_heuristic: bool,
36}
37
38#[turbo_tasks::value_impl]
39impl BatchingConfig {
40    #[turbo_tasks::function]
41    pub fn new(config: BatchingConfig) -> Vc<Self> {
42        config.cell()
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, TraceRawVcs, NonLocalValue)]
47pub struct ModuleBatchesGraphEdge {
48    pub ty: ChunkingType,
49    pub module: Option<ResolvedVc<Box<dyn Module>>>,
50}
51
52#[derive(Debug, Clone, TraceRawVcs, NonLocalValue, Encode, Decode)]
53struct EntriesList(
54    #[bincode(with = "turbo_bincode::indexset")] pub FxIndexSet<ResolvedVc<Box<dyn Module>>>,
55);
56
57#[turbo_tasks::value(cell = "new", eq = "manual")]
58pub struct ModuleBatchesGraph {
59    graph: TracedDiGraph<ModuleOrBatch, ModuleBatchesGraphEdge>,
60
61    // NodeIndex isn't necessarily stable (because of swap_remove), but we never remove nodes.
62    //
63    // HashMaps have nondeterministic order, but this map is only used for lookups and not
64    // iteration.
65    //
66    // This contains Vcs, but they are already contained in the graph, so no need to trace this.
67    #[turbo_tasks(trace_ignore)]
68    #[bincode(with_serde)]
69    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, NodeIndex>,
70    batch_groups: FxHashMap<ModuleOrBatch, ResolvedVc<ModuleBatchGroup>>,
71
72    /// For chunk groups where the postorder of entries is different than the order of the
73    /// `ChunkGroup::entries()` this contains Some with the postorder list of entries of that chunk
74    /// group. The index in this list corresponds to the index in the
75    /// chunk_group_info.chunk_groups.
76    ordered_entries: Vec<Option<EntriesList>>,
77}
78
79impl ModuleBatchesGraph {
80    pub async fn get_entry_index(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<NodeIndex> {
81        let Some(entry) = self.entries.get(&entry) else {
82            bail!(
83                "Entry {} is not in graph (possible entries: {:#?})",
84                entry.ident().to_string().await?,
85                self.entries
86                    .keys()
87                    .map(|e| e.ident().to_string())
88                    .try_join()
89                    .await?
90            );
91        };
92        Ok(*entry)
93    }
94
95    pub fn get_ordered_entries<'l>(
96        &'l self,
97        chunk_group_info: &'l ChunkGroupInfo,
98        idx: usize,
99    ) -> impl Iterator<Item = ResolvedVc<Box<dyn Module>>> + 'l {
100        if let Some(EntriesList(ordered_entries)) = self
101            .ordered_entries
102            .get(idx)
103            .as_ref()
104            .and_then(|o| o.as_ref())
105        {
106            if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
107                debug_assert_eq!(ordered_entries.len(), chunk_group.entries_count());
108            }
109            Either::Left(Either::Left(ordered_entries.iter().copied()))
110        } else if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
111            Either::Right(chunk_group.entries())
112        } else {
113            Either::Left(Either::Right(std::iter::empty()))
114        }
115    }
116
117    pub fn get_batch_group(
118        &self,
119        module_or_batch: &ModuleOrBatch,
120    ) -> Option<ResolvedVc<ModuleBatchGroup>> {
121        self.batch_groups.get(module_or_batch).copied()
122    }
123
124    pub async fn get_entry(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<ModuleOrBatch> {
125        let entry = self.get_entry_index(entry).await?;
126        Ok(*self.graph.node_weight(entry).unwrap())
127    }
128
129    // Clippy complains but there's a type error without the bound
130    #[allow(clippy::implied_bounds_in_impls)]
131    /// Traverses all reachable edges in dfs order. The preorder visitor can be used to
132    /// forward state down the graph, and to skip subgraphs
133    ///
134    /// Use this to collect batches/modules in evaluation order.
135    ///
136    /// Target nodes can be revisited (once per incoming edge).
137    /// Edges are traversed in normal order, so should correspond to reference order.
138    ///
139    /// * `entries` - The entry modules to start the traversal from
140    /// * `state` - The state to be passed to the visitors
141    /// * `visit_preorder` - Called before visiting the children of a node.
142    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
143    ///      &ModuleBatchesGraphNode, state &S
144    ///    - Can return [GraphTraversalAction]s to control the traversal
145    /// * `visit_postorder` - Called after visiting the children of a node. Return
146    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
147    ///      &ModuleBatchesGraphNode, state &S
148    pub fn traverse_edges_from_entries_dfs<'a, S>(
149        &'a self,
150        entries: impl IntoIterator<
151            Item = NodeIndex,
152            IntoIter = impl Iterator<Item = NodeIndex> + DoubleEndedIterator,
153        >,
154        state: &mut S,
155        mut visit_preorder: impl FnMut(
156            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
157            &'a ModuleOrBatch,
158            &mut S,
159        ) -> Result<GraphTraversalAction>,
160        mut visit_postorder: impl FnMut(
161            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
162            &'a ModuleOrBatch,
163            &mut S,
164        ),
165    ) -> Result<()> {
166        let graph = &self.graph;
167
168        enum ReverseDFSPass {
169            Visit,
170            ExpandAndVisit,
171        }
172
173        let entries = entries.into_iter();
174        #[allow(clippy::type_complexity)] // This is a temporary internal structure
175        let mut stack: Vec<(ReverseDFSPass, Option<(NodeIndex, EdgeIndex)>, NodeIndex)> = entries
176            .rev()
177            .map(|e| (ReverseDFSPass::ExpandAndVisit, None, e))
178            .collect();
179        let mut expanded = FxHashSet::default();
180        while let Some((pass, parent, current)) = stack.pop() {
181            let parent_arg = parent.map(|(node, edge)| {
182                (
183                    graph.node_weight(node).unwrap(),
184                    graph.edge_weight(edge).unwrap(),
185                )
186            });
187            match pass {
188                ReverseDFSPass::Visit => {
189                    let current_node = graph.node_weight(current).unwrap();
190                    visit_postorder(parent_arg, current_node, state);
191                }
192                ReverseDFSPass::ExpandAndVisit => {
193                    let current_node = graph.node_weight(current).unwrap();
194                    let action = visit_preorder(parent_arg, current_node, state)?;
195                    if action == GraphTraversalAction::Exclude {
196                        continue;
197                    }
198                    stack.push((ReverseDFSPass::Visit, parent, current));
199                    if action == GraphTraversalAction::Continue && expanded.insert(current) {
200                        stack.extend(iter_neighbors_rev(graph, current).map(|(edge, child)| {
201                            (ReverseDFSPass::ExpandAndVisit, Some((current, edge)), child)
202                        }));
203                    }
204                }
205            }
206        }
207
208        Ok(())
209    }
210}
211
212type PreBatchIndex = usize;
213
214#[derive(Hash, PartialEq, Eq, Clone, Debug)]
215enum PreBatchItem {
216    ParallelModule(ResolvedVc<Box<dyn Module>>),
217    ParallelReference(PreBatchIndex),
218    NonParallelEdge(ChunkingType, ResolvedVc<Box<dyn Module>>),
219}
220
221struct PreBatch {
222    items: FxIndexSet<PreBatchItem>,
223    chunk_groups: RoaringBitmapWrapper,
224}
225
226impl PreBatch {
227    fn new(chunk_groups: RoaringBitmapWrapper) -> Self {
228        Self {
229            items: FxIndexSet::default(),
230            chunk_groups,
231        }
232    }
233}
234
235struct TraversalState<'l> {
236    items: Vec<PreBatchItem>,
237    this: &'l mut PreBatches,
238}
239
240struct PreBatches {
241    boundary_modules: FxHashSet<ResolvedVc<Box<dyn Module>>>,
242    batches: Vec<PreBatch>,
243    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, PreBatchIndex>,
244    single_module_entries: FxIndexSet<ResolvedVc<Box<dyn Module>>>,
245}
246
247impl PreBatches {
248    fn new() -> Self {
249        Self {
250            boundary_modules: FxHashSet::default(),
251            batches: Vec::new(),
252            entries: FxHashMap::default(),
253            single_module_entries: FxIndexSet::default(),
254        }
255    }
256
257    fn ensure_pre_batch_for_module(
258        &mut self,
259        module: ResolvedVc<Box<dyn Module>>,
260        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
261        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
262    ) -> Result<PreBatchIndex> {
263        Ok(match self.entries.entry(module) {
264            Entry::Vacant(e) => {
265                let index = self.batches.len();
266                queue.push_back((module, index));
267                let chunk_groups = module_chunk_groups
268                    .get(&module)
269                    .context("all modules need to have chunk group info")?;
270                let batch = PreBatch::new((*chunk_groups).clone());
271                self.batches.push(batch);
272                e.insert(index);
273                index
274            }
275            Entry::Occupied(e) => *e.get(),
276        })
277    }
278
279    async fn get_pre_batch_items(
280        &mut self,
281        entry: ResolvedVc<Box<dyn Module>>,
282        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
283        module_graph: &ModuleGraph,
284        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
285    ) -> Result<Vec<PreBatchItem>> {
286        let mut state = TraversalState {
287            items: Vec::new(),
288            this: self,
289        };
290        let mut visited = FxHashSet::default();
291        module_graph.traverse_edges_dfs(
292            std::iter::once(entry),
293            &mut state,
294            |parent_info, node, state| {
295                let ty = parent_info.map_or(
296                    &ChunkingType::Parallel {
297                        inherit_async: false,
298                        hoisted: false,
299                    },
300                    |(_, ty)| &ty.chunking_type,
301                );
302                let module = node;
303                if !ty.is_parallel() {
304                    state.items.push(PreBatchItem::NonParallelEdge(
305                        ty.without_inherit_async(),
306                        module,
307                    ));
308                    return Ok(GraphTraversalAction::Exclude);
309                }
310                if visited.insert(module) {
311                    if parent_info.is_some() && state.this.boundary_modules.contains(&module) {
312                        let idx = state.this.ensure_pre_batch_for_module(
313                            module,
314                            module_chunk_groups,
315                            queue,
316                        )?;
317                        state.items.push(PreBatchItem::ParallelReference(idx));
318                        return Ok(GraphTraversalAction::Exclude);
319                    }
320                    Ok(GraphTraversalAction::Continue)
321                } else {
322                    Ok(GraphTraversalAction::Exclude)
323                }
324            },
325            |_, node, state| {
326                let item = PreBatchItem::ParallelModule(node);
327                state.items.push(item);
328                Ok(())
329            },
330        )?;
331        Ok(state.items)
332    }
333}
334
335pub async fn compute_module_batches(
336    module_graph: Vc<ModuleGraph>,
337    _config: &BatchingConfig,
338) -> Result<Vc<ModuleBatchesGraph>> {
339    let outer_span = tracing::info_span!(
340        "compute module batches",
341        initial_pre_batch_items = tracing::field::Empty,
342        initial_pre_batches = tracing::field::Empty,
343        extracted_shared_items = tracing::field::Empty,
344        batches = tracing::field::Empty,
345        modules = tracing::field::Empty,
346        edges = tracing::field::Empty
347    );
348    let span = outer_span.clone();
349    async move {
350        let chunk_group_info = module_graph.chunk_group_info().await?;
351        let module_chunk_groups = chunk_group_info.module_chunk_groups.await?;
352        let module_graph = module_graph.await?;
353
354        let mut pre_batches = PreBatches::new();
355
356        // Walk the module graph and mark all modules that are boundary modules (referenced from a
357        // different chunk group bitmap)
358        module_graph.traverse_edges_unordered(|parent, node| {
359            if let Some((parent, ty)) = parent {
360                let std::collections::hash_set::Entry::Vacant(entry) =
361                    pre_batches.boundary_modules.entry(node)
362                else {
363                    // Already a boundary module, can skip check
364                    return Ok(());
365                };
366                if ty.chunking_type.is_parallel() {
367                    let parent_chunk_groups = module_chunk_groups
368                        .get(&parent)
369                        .context("all modules need to have chunk group info")?;
370                    let chunk_groups = module_chunk_groups
371                        .get(&node)
372                        .context("all modules need to have chunk group info")?;
373                    if parent_chunk_groups != chunk_groups {
374                        // This is a boundary module
375                        entry.insert();
376                    }
377                } else {
378                    entry.insert();
379                }
380            }
381            Ok(())
382        })?;
383
384        // All entries are boundary modules too
385        for chunk_group in &chunk_group_info.chunk_groups {
386            for entry in chunk_group.entries() {
387                pre_batches.boundary_modules.insert(entry);
388            }
389        }
390
391        // Pre batches would be incorrect with cycles, so we need to opt-out of pre batches for
392        // cycles that include boundary modules
393        module_graph.traverse_cycles(
394            |ref_data| ref_data.chunking_type.is_parallel(),
395            |cycle| {
396                if cycle.len() > 1
397                    && cycle
398                        .iter()
399                        .any(|node| pre_batches.boundary_modules.contains(node))
400                {
401                    pre_batches
402                        .boundary_modules
403                        .extend(cycle.iter().map(|node| **node));
404                }
405                Ok(())
406            },
407        )?;
408
409        let mut queue: VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)> = VecDeque::new();
410
411        let mut chunk_group_indices_with_merged_children = FxHashSet::default();
412
413        // Start with the entries
414        for chunk_group in &chunk_group_info.chunk_groups {
415            for entry in chunk_group.entries() {
416                pre_batches.ensure_pre_batch_for_module(entry, &module_chunk_groups, &mut queue)?;
417            }
418            if let Some(parent) = chunk_group.get_merged_parent() {
419                chunk_group_indices_with_merged_children.insert(parent);
420            }
421        }
422
423        let mut initial_pre_batch_items = 0;
424        // Fill all pre batches
425        while let Some((chunkable_module, idx)) = queue.pop_front() {
426            let items = pre_batches
427                .get_pre_batch_items(
428                    chunkable_module,
429                    &module_chunk_groups,
430                    &module_graph,
431                    &mut queue,
432                )
433                .await?;
434            initial_pre_batch_items += items.len();
435            let batch = &mut pre_batches.batches[idx];
436            batch.items.extend(items);
437        }
438        span.record("initial_pre_batch_items", initial_pre_batch_items);
439        span.record("initial_pre_batches", pre_batches.batches.len());
440
441        // Figure out the order of all merged groups
442        let mut ordered_entries: Vec<Option<EntriesList>> =
443            vec![None; chunk_group_info.chunk_groups.len()];
444        for (i, chunk_group) in chunk_group_info.chunk_groups.iter().enumerate() {
445            if !chunk_group_indices_with_merged_children.contains(&i) {
446                continue;
447            }
448            let mut merged_modules: FxHashMap<ChunkingType, FxIndexSet<_>> = FxHashMap::default();
449            let mut stack = ordered_entries[i]
450                .as_ref()
451                .map_or_else(
452                    || Either::Left(chunk_group.entries()),
453                    |v| Either::Right(v.0.iter().copied()),
454                )
455                .map(|module| {
456                    let idx = *pre_batches
457                        .entries
458                        .get(&module)
459                        .context("could not prebatch for module")?;
460                    Ok((idx, 0))
461                })
462                .collect::<Result<Vec<_>>>()?;
463            stack.reverse();
464            let mut visited = FxHashSet::default();
465            while let Some((idx, mut pos)) = stack.pop() {
466                let batch = &pre_batches.batches[idx];
467                while let Some(item) = batch.items.get_index(pos) {
468                    match item {
469                        PreBatchItem::ParallelModule(_) => {}
470                        PreBatchItem::ParallelReference(other_idx) => {
471                            if visited.insert(*other_idx) {
472                                stack.push((idx, pos + 1));
473                                stack.push((*other_idx, 0));
474                                break;
475                            }
476                        }
477                        PreBatchItem::NonParallelEdge(chunking_type, module) => {
478                            if chunking_type.is_merged() {
479                                merged_modules
480                                    .entry(chunking_type.clone())
481                                    .or_default()
482                                    .insert(*module);
483                            }
484                        }
485                    }
486                    pos += 1;
487                }
488            }
489            if !merged_modules.is_empty() {
490                for (ty, merged_modules) in merged_modules {
491                    let chunk_group_key = match ty {
492                        ChunkingType::Isolated {
493                            merge_tag: Some(merge_tag),
494                            ..
495                        } => ChunkGroupKey::IsolatedMerged {
496                            parent: i.into(),
497                            merge_tag: merge_tag.clone(),
498                        },
499                        ChunkingType::Shared {
500                            merge_tag: Some(merge_tag),
501                            ..
502                        } => ChunkGroupKey::SharedMerged {
503                            parent: i.into(),
504                            merge_tag: merge_tag.clone(),
505                        },
506                        _ => unreachable!(),
507                    };
508                    let idx = chunk_group_info
509                        .chunk_group_keys
510                        .get_index_of(&chunk_group_key)
511                        .context("could not find chunk group key for merged chunk group")?;
512                    ordered_entries[idx] = Some(EntriesList(merged_modules));
513                }
514            }
515        }
516
517        // Create a map of parallel module to the batches they are contained in.
518        let mut parallel_module_to_pre_batch: FxIndexMap<_, Vec<PreBatchIndex>> =
519            FxIndexMap::default();
520
521        // Fill the map and also fill up the single_module_entries
522        for (idx, pre_batch) in pre_batches.batches.iter().enumerate() {
523            for item in &pre_batch.items {
524                match item {
525                    PreBatchItem::ParallelModule(module) => {
526                        parallel_module_to_pre_batch
527                            .entry(*module)
528                            .or_default()
529                            .push(idx);
530                    }
531                    PreBatchItem::NonParallelEdge(_, module) => {
532                        if !pre_batches.entries.contains_key(module) {
533                            pre_batches.single_module_entries.insert(*module);
534                        }
535                    }
536                    PreBatchItem::ParallelReference(_) => {}
537                }
538            }
539        }
540
541        // We never want a module to occur in multiple batches.
542
543        let mut extracted_shared_items = 0;
544        // Extract shared modules into separate batches
545        for i in 0..parallel_module_to_pre_batch.len() {
546            let (&module, batches) = parallel_module_to_pre_batch
547                .get_index(i)
548                .context("could not find parallel module to pre batch index")?;
549            if batches.len() > 1 {
550                // Create a new batch for the shared modules
551                let batches_with_item_index = batches
552                    .iter()
553                    .map(|&idx| {
554                        let batch_items = &pre_batches.batches[idx].items;
555                        let item_idx = batch_items
556                            .get_index_of(&PreBatchItem::ParallelModule(module))
557                            .context("could not find batch item index for parallel module")?;
558                        Ok((idx, item_idx))
559                    })
560                    .collect::<Result<Vec<_>>>()?;
561                let mut selected_items = 1;
562                fn get_item_at(
563                    pre_batches: &PreBatches,
564                    batch_idx: PreBatchIndex,
565                    item_idx: usize,
566                ) -> Option<&PreBatchItem> {
567                    pre_batches.batches[batch_idx].items.get_index(item_idx)
568                }
569                // Select more matching items that are equal in all batches that contain the shared
570                // module(s)
571                loop {
572                    if let Some(PreBatchItem::ParallelModule(next_module)) = get_item_at(
573                        &pre_batches,
574                        batches_with_item_index[0].0,
575                        batches_with_item_index[0].1 + selected_items,
576                    ) && parallel_module_to_pre_batch
577                        .get(next_module)
578                        .context("could not find pre batch for parallel module")?
579                        .len()
580                        == batches.len()
581                        && batches_with_item_index[1..]
582                            .iter()
583                            .all(|&(batch_idx, item_idx)| {
584                                get_item_at(&pre_batches, batch_idx, item_idx + selected_items)
585                                    == Some(&PreBatchItem::ParallelModule(*next_module))
586                            })
587                    {
588                        selected_items += 1;
589                        continue;
590                    }
591                    break;
592                }
593                extracted_shared_items += selected_items;
594
595                // Check if a batch is completely selected. In that case we can replace all other
596                // occurrences with a reference to that batch
597                let exact_match = batches_with_item_index
598                    .iter()
599                    .find(|&&(batch_idx, item_idx)| {
600                        item_idx == 0
601                            && pre_batches.batches[batch_idx].items.len() == selected_items
602                    });
603                if let Some(&(exact_match, _)) = exact_match {
604                    // Replace all other occurrences with a reference to the exact match
605                    for &(batch_index, item_start) in batches_with_item_index.iter() {
606                        if batch_index != exact_match {
607                            pre_batches.batches[batch_index].items.splice(
608                                item_start..item_start + selected_items,
609                                std::iter::once(PreBatchItem::ParallelReference(exact_match)),
610                            );
611                        }
612                    }
613                    for item in pre_batches.batches[exact_match].items.iter() {
614                        if let PreBatchItem::ParallelModule(module) = item {
615                            parallel_module_to_pre_batch
616                                .get_mut(module)
617                                .context("could not find pre batch for parallel module")?
618                                .clear();
619                        }
620                    }
621                } else {
622                    // Create a new batch of the shared part and replace all occurrences with a
623                    // reference to that batch
624                    let first_batch_index = batches_with_item_index[0].0;
625                    let first_batch_item_index = batches_with_item_index[0].1;
626                    let new_batch_index = pre_batches.batches.len();
627                    let mut new_batch =
628                        PreBatch::new(pre_batches.batches[first_batch_index].chunk_groups.clone());
629                    new_batch
630                        .items
631                        .extend(pre_batches.batches[first_batch_index].items.splice(
632                            first_batch_item_index..first_batch_item_index + selected_items,
633                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
634                        ));
635                    for item in new_batch.items.iter() {
636                        if let PreBatchItem::ParallelModule(module) = item {
637                            parallel_module_to_pre_batch
638                                .get_mut(module)
639                                .context("could not find pre batch for parallel module")?
640                                .clear();
641                        }
642                    }
643                    pre_batches.batches.push(new_batch);
644                    for &(batch_index, item_start) in batches_with_item_index[1..].iter() {
645                        pre_batches.batches[batch_index].items.splice(
646                            item_start..item_start + selected_items,
647                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
648                        );
649                    }
650                }
651            }
652        }
653        span.record("extracted_shared_items", extracted_shared_items);
654
655        // Now every module is only in one batch
656
657        let mut edges_count = 0;
658
659        // Since batches can only have references followed by a list of parallel chunkable modules,
660        // we need to split batches that have modules before references.
661        for i in 0..pre_batches.batches.len() {
662            let items = take(&mut pre_batches.batches[i].items);
663            let mut new_items =
664                FxIndexSet::with_capacity_and_hasher(items.len(), Default::default());
665            enum Mode {
666                ParallelChunkableModule,
667                Other,
668            }
669            let mut mode = Mode::Other;
670            for item in items {
671                let chunkable_module = if let PreBatchItem::ParallelModule(module) = &item {
672                    ResolvedVc::try_downcast::<Box<dyn ChunkableModule>>(*module)
673                } else {
674                    None
675                };
676                let item = if let PreBatchItem::ParallelModule(module) = item {
677                    if chunkable_module.is_some() {
678                        PreBatchItem::ParallelModule(module)
679                    } else {
680                        pre_batches.single_module_entries.insert(module);
681                        PreBatchItem::NonParallelEdge(
682                            ChunkingType::Parallel {
683                                inherit_async: false,
684                                hoisted: false,
685                            },
686                            module,
687                        )
688                    }
689                } else {
690                    item
691                };
692                match (&mode, chunkable_module) {
693                    (_, Some(_)) => {
694                        mode = Mode::ParallelChunkableModule;
695                        new_items.insert(item);
696                    }
697                    (Mode::Other, _) => {
698                        edges_count += 1;
699                        new_items.insert(item);
700                    }
701                    (Mode::ParallelChunkableModule, _) => {
702                        // Split the batch
703                        let idx = pre_batches.batches.len();
704                        let mut new_batch =
705                            PreBatch::new(pre_batches.batches[i].chunk_groups.clone());
706                        new_batch.items.extend(new_items.drain(..));
707                        pre_batches.batches.push(new_batch);
708                        edges_count += 1;
709                        new_items.insert(PreBatchItem::ParallelReference(idx));
710                        if chunkable_module.is_some() {
711                            new_items.insert(item);
712                        } else {
713                            edges_count += 1;
714                            mode = Mode::Other;
715                            new_items.insert(item);
716                        }
717                    }
718                }
719            }
720            pre_batches.batches[i].items = new_items;
721        }
722        span.record("pre_batches", pre_batches.batches.len());
723
724        // Now batches are in the correct shape. We can create the real batches and the graph.
725
726        // Create the graph
727        let mut graph: DiGraph<ModuleOrBatch, ModuleBatchesGraphEdge, u32> =
728            petgraph::graph::DiGraph::with_capacity(
729                pre_batches.batches.len() + pre_batches.single_module_entries.len(),
730                edges_count,
731            );
732
733        // Create the Vc<ModuleBatch> instances
734        let batches = pre_batches
735            .batches
736            .iter_mut()
737            .enumerate()
738            .map(async |(i, pre_batch)| {
739                let mut modules = pre_batch.items.iter().filter_map(|item| {
740                    if let PreBatchItem::ParallelModule(module) = item {
741                        ResolvedVc::try_downcast(*module)
742                    } else {
743                        None
744                    }
745                });
746                let Some(first) = modules.next() else {
747                    return Ok(ModuleOrBatch::None(i));
748                };
749                if let Some(second) = modules.next() {
750                    let batch = ModuleBatch::new(
751                        [first, second]
752                            .into_iter()
753                            .chain(modules)
754                            .map(|m| *m)
755                            .collect::<Vec<_>>(),
756                        Some(pre_batch.chunk_groups.clone()),
757                    );
758                    Ok(ModuleOrBatch::Batch(batch.to_resolved().await?))
759                } else {
760                    Ok(ModuleOrBatch::Module(ResolvedVc::upcast(first)))
761                }
762            })
763            .try_join()
764            .await?;
765
766        // Create the batch groups by grouping batches with the same chunk groups
767        let mut batch_groups: FxHashMap<_, Vec<_>> = FxHashMap::default();
768        for (i, pre_batch) in pre_batches.batches.iter().enumerate() {
769            let key = BuildHasherDefault::<FxHasher>::default().prehash(&pre_batch.chunk_groups);
770            let batch = batches[i];
771            batch_groups.entry(key).or_default().push(batch);
772        }
773        for &module in &pre_batches.single_module_entries {
774            let chunk_groups = module_chunk_groups
775                .get(&module)
776                .context("all modules need to have chunk group info")?;
777            let key = BuildHasherDefault::<FxHasher>::default().prehash(chunk_groups);
778            batch_groups
779                .entry(key)
780                .or_default()
781                .push(ModuleOrBatch::Module(module));
782        }
783
784        // Create the batch group instances
785        let batch_groups = batch_groups
786            .into_iter()
787            .map(async |(key, items)| {
788                if items.len() == 1 {
789                    Ok(Either::Left(std::iter::empty()))
790                } else {
791                    let batch_group = ModuleBatchGroup::new(items.clone(), (*key).clone())
792                        .to_resolved()
793                        .await?;
794                    Ok(Either::Right(
795                        items.into_iter().map(move |item| (item, batch_group)),
796                    ))
797                }
798            })
799            .try_join()
800            .await?
801            .into_iter()
802            .flatten()
803            .collect::<FxHashMap<_, _>>();
804
805        // Insert batches into the graph and store the NodeIndices
806        let mut batches_count = 0;
807        let mut modules_count = 0;
808        let batch_indices = batches
809            .into_iter()
810            .map(|batch| {
811                match &batch {
812                    ModuleOrBatch::Batch(_) => batches_count += 1,
813                    ModuleOrBatch::Module(_) => modules_count += 1,
814                    ModuleOrBatch::None(_) => {}
815                }
816                graph.add_node(batch)
817            })
818            .collect::<Vec<_>>();
819
820        // Also insert single modules into the graph and store the NodeIndices
821        let single_module_indices = pre_batches
822            .single_module_entries
823            .iter()
824            .map(|module| graph.add_node(ModuleOrBatch::Module(*module)))
825            .collect::<Vec<_>>();
826
827        span.record("batches", batches_count);
828        modules_count += pre_batches.single_module_entries.len();
829        span.record("modules", modules_count);
830        span.record("edges", edges_count);
831
832        // Add all the edges to the graph
833        for (i, pre_batch) in pre_batches.batches.into_iter().enumerate() {
834            let index = batch_indices[i];
835            let items = pre_batch.items;
836            for item in items {
837                match item {
838                    PreBatchItem::ParallelReference(idx) => {
839                        graph.add_edge(
840                            index,
841                            batch_indices[idx],
842                            ModuleBatchesGraphEdge {
843                                ty: ChunkingType::Parallel {
844                                    inherit_async: false,
845                                    hoisted: false,
846                                },
847                                module: None,
848                            },
849                        );
850                    }
851                    PreBatchItem::NonParallelEdge(ty, module) => {
852                        if let Some(batch) = pre_batches.entries.get(&module).copied() {
853                            graph.add_edge(
854                                index,
855                                batch_indices[batch],
856                                ModuleBatchesGraphEdge {
857                                    ty,
858                                    module: Some(module),
859                                },
860                            );
861                            continue;
862                        }
863                        let idx = pre_batches
864                            .single_module_entries
865                            .get_index_of(&module)
866                            .context("could not find single module entry index")?;
867                        let idx = single_module_indices[idx];
868                        graph.add_edge(
869                            index,
870                            idx,
871                            ModuleBatchesGraphEdge {
872                                ty,
873                                module: Some(module),
874                            },
875                        );
876                    }
877                    PreBatchItem::ParallelModule(_) => {}
878                }
879            }
880        }
881
882        debug_assert_eq!(graph.capacity().0, graph.node_count());
883        debug_assert_eq!(graph.capacity().1, graph.edge_count());
884
885        // Find the NodeIndices for our entries of the graph
886        let mut entries = FxHashMap::default();
887        for chunk_group in &chunk_group_info.chunk_groups {
888            for module in chunk_group.entries() {
889                if let Some(batch) = pre_batches.entries.get(&module).copied() {
890                    entries.insert(module, batch_indices[batch]);
891                    continue;
892                }
893                let idx = pre_batches
894                    .single_module_entries
895                    .get_index_of(&module)
896                    .context("could not find single module entry index")?;
897                let idx = single_module_indices[idx];
898                entries.insert(module, idx);
899            }
900        }
901
902        Ok(ModuleBatchesGraph {
903            graph: TracedDiGraph(graph),
904            entries,
905            batch_groups,
906            ordered_entries,
907        }
908        .cell())
909    }
910    .instrument(outer_span)
911    .await
912}