Skip to main content

turbopack_core/module_graph/
module_batches.rs

1use std::{
2    collections::{VecDeque, hash_map::Entry},
3    hash::BuildHasherDefault,
4    mem::take,
5};
6
7use anyhow::{Context, Result, bail};
8use bincode::{Decode, Encode};
9use either::Either;
10use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
11use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
12use serde::{Deserialize, Serialize};
13use tracing::Instrument;
14use turbo_prehash::BuildHasherExt;
15use turbo_tasks::{
16    FxIndexMap, FxIndexSet, NonLocalValue, ResolvedVc, TryJoinIterExt, ValueToString, Vc,
17    trace::TraceRawVcs, turbobail,
18};
19
20use crate::{
21    chunk::{ChunkableModule, ChunkingType},
22    module::Module,
23    module_graph::{
24        GraphTraversalAction, ModuleGraph,
25        chunk_group_info::{ChunkGroupInfo, ChunkGroupKey, RoaringBitmapWrapper},
26        module_batch::{ModuleBatch, ModuleBatchGroup, ModuleOrBatch},
27        traced_di_graph::{TracedDiGraph, iter_neighbors_rev},
28    },
29};
30#[turbo_tasks::value(task_input)]
31#[derive(Debug, Clone, Default, Hash)]
32pub struct BatchingConfig {
33    /// Use a heuristic based on the module path to create batches. It aims for batches of a good
34    /// size.
35    pub use_heuristic: bool,
36}
37
38#[turbo_tasks::value_impl]
39impl BatchingConfig {
40    #[turbo_tasks::function]
41    pub fn new(config: BatchingConfig) -> Vc<Self> {
42        config.cell()
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, TraceRawVcs, NonLocalValue)]
47pub struct ModuleBatchesGraphEdge {
48    pub ty: ChunkingType,
49    pub module: Option<ResolvedVc<Box<dyn Module>>>,
50}
51
52#[derive(Debug, Clone, TraceRawVcs, NonLocalValue, Encode, Decode)]
53struct EntriesList(
54    #[bincode(with = "turbo_bincode::indexset")] pub FxIndexSet<ResolvedVc<Box<dyn Module>>>,
55);
56
57#[turbo_tasks::value(cell = "new", eq = "manual")]
58pub struct ModuleBatchesGraph {
59    graph: TracedDiGraph<ModuleOrBatch, ModuleBatchesGraphEdge>,
60
61    // NodeIndex isn't necessarily stable (because of swap_remove), but we never remove nodes.
62    //
63    // HashMaps have nondeterministic order, but this map is only used for lookups and not
64    // iteration.
65    //
66    // This contains Vcs, but they are already contained in the graph, so no need to trace this.
67    #[turbo_tasks(trace_ignore)]
68    #[bincode(with_serde)]
69    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, NodeIndex>,
70    batch_groups: FxHashMap<ModuleOrBatch, ResolvedVc<ModuleBatchGroup>>,
71
72    /// For chunk groups where the postorder of entries is different than the order of the
73    /// `ChunkGroup::entries()` this contains Some with the postorder list of entries of that chunk
74    /// group. The index in this list corresponds to the index in the
75    /// chunk_group_info.chunk_groups.
76    ordered_entries: Vec<Option<EntriesList>>,
77}
78
79impl ModuleBatchesGraph {
80    pub async fn get_entry_index(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<NodeIndex> {
81        let Some(entry) = self.entries.get(&entry) else {
82            if cfg!(debug_assertions) {
83                let possible_entries = format!(
84                    "{:#?}",
85                    self.entries
86                        .keys()
87                        .map(|e| e.ident().to_string())
88                        .try_join()
89                        .await?
90                );
91                turbobail!(
92                    "Entry {} is not in graph (possible entries: {})",
93                    entry.ident(),
94                    possible_entries
95                );
96            } else {
97                bail!("Entry is not in graph");
98            }
99        };
100        Ok(*entry)
101    }
102
103    pub fn get_ordered_entries<'l>(
104        &'l self,
105        chunk_group_info: &'l ChunkGroupInfo,
106        idx: usize,
107    ) -> impl Iterator<Item = ResolvedVc<Box<dyn Module>>> + 'l {
108        if let Some(EntriesList(ordered_entries)) = self
109            .ordered_entries
110            .get(idx)
111            .as_ref()
112            .and_then(|o| o.as_ref())
113        {
114            if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
115                debug_assert_eq!(ordered_entries.len(), chunk_group.entries_count());
116            }
117            Either::Left(Either::Left(ordered_entries.iter().copied()))
118        } else if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
119            Either::Right(chunk_group.entries())
120        } else {
121            Either::Left(Either::Right(std::iter::empty()))
122        }
123    }
124
125    pub fn get_batch_group(
126        &self,
127        module_or_batch: &ModuleOrBatch,
128    ) -> Option<ResolvedVc<ModuleBatchGroup>> {
129        self.batch_groups.get(module_or_batch).copied()
130    }
131
132    pub async fn get_entry(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<ModuleOrBatch> {
133        let entry = self.get_entry_index(entry).await?;
134        Ok(*self.graph.node_weight(entry).unwrap())
135    }
136
137    // Clippy complains but there's a type error without the bound
138    #[allow(clippy::implied_bounds_in_impls)]
139    /// Traverses all reachable edges in dfs order. The preorder visitor can be used to
140    /// forward state down the graph, and to skip subgraphs
141    ///
142    /// Use this to collect batches/modules in evaluation order.
143    ///
144    /// Target nodes can be revisited (once per incoming edge).
145    /// Edges are traversed in normal order, so should correspond to reference order.
146    ///
147    /// * `entries` - The entry modules to start the traversal from
148    /// * `state` - The state to be passed to the visitors
149    /// * `visit_preorder` - Called before visiting the children of a node.
150    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
151    ///      &ModuleBatchesGraphNode, state &S
152    ///    - Can return [GraphTraversalAction]s to control the traversal
153    /// * `visit_postorder` - Called after visiting the children of a node. Return
154    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
155    ///      &ModuleBatchesGraphNode, state &S
156    pub fn traverse_edges_from_entries_dfs<'a, S>(
157        &'a self,
158        entries: impl IntoIterator<
159            Item = NodeIndex,
160            IntoIter = impl Iterator<Item = NodeIndex> + DoubleEndedIterator,
161        >,
162        state: &mut S,
163        mut visit_preorder: impl FnMut(
164            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
165            &'a ModuleOrBatch,
166            &mut S,
167        ) -> Result<GraphTraversalAction>,
168        mut visit_postorder: impl FnMut(
169            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
170            &'a ModuleOrBatch,
171            &mut S,
172        ),
173    ) -> Result<()> {
174        let graph = &self.graph;
175
176        enum ReverseDFSPass {
177            Visit,
178            ExpandAndVisit,
179        }
180
181        let entries = entries.into_iter();
182        #[allow(clippy::type_complexity)] // This is a temporary internal structure
183        let mut stack: Vec<(ReverseDFSPass, Option<(NodeIndex, EdgeIndex)>, NodeIndex)> = entries
184            .rev()
185            .map(|e| (ReverseDFSPass::ExpandAndVisit, None, e))
186            .collect();
187        let mut expanded = FxHashSet::default();
188        while let Some((pass, parent, current)) = stack.pop() {
189            let parent_arg = parent.map(|(node, edge)| {
190                (
191                    graph.node_weight(node).unwrap(),
192                    graph.edge_weight(edge).unwrap(),
193                )
194            });
195            match pass {
196                ReverseDFSPass::Visit => {
197                    let current_node = graph.node_weight(current).unwrap();
198                    visit_postorder(parent_arg, current_node, state);
199                }
200                ReverseDFSPass::ExpandAndVisit => {
201                    let current_node = graph.node_weight(current).unwrap();
202                    let action = visit_preorder(parent_arg, current_node, state)?;
203                    if action == GraphTraversalAction::Exclude {
204                        continue;
205                    }
206                    stack.push((ReverseDFSPass::Visit, parent, current));
207                    if action == GraphTraversalAction::Continue && expanded.insert(current) {
208                        stack.extend(iter_neighbors_rev(graph, current).map(|(edge, child)| {
209                            (ReverseDFSPass::ExpandAndVisit, Some((current, edge)), child)
210                        }));
211                    }
212                }
213            }
214        }
215
216        Ok(())
217    }
218}
219
220type PreBatchIndex = usize;
221
222#[derive(Hash, PartialEq, Eq, Clone, Debug)]
223enum PreBatchItem {
224    ParallelModule(ResolvedVc<Box<dyn Module>>),
225    ParallelReference(PreBatchIndex),
226    NonParallelEdge(ChunkingType, ResolvedVc<Box<dyn Module>>),
227}
228
229struct PreBatch {
230    items: FxIndexSet<PreBatchItem>,
231    chunk_groups: RoaringBitmapWrapper,
232}
233
234impl PreBatch {
235    fn new(chunk_groups: RoaringBitmapWrapper) -> Self {
236        Self {
237            items: FxIndexSet::default(),
238            chunk_groups,
239        }
240    }
241}
242
243struct TraversalState<'l> {
244    items: Vec<PreBatchItem>,
245    this: &'l mut PreBatches,
246}
247
248struct PreBatches {
249    boundary_modules: FxHashSet<ResolvedVc<Box<dyn Module>>>,
250    batches: Vec<PreBatch>,
251    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, PreBatchIndex>,
252    single_module_entries: FxIndexSet<ResolvedVc<Box<dyn Module>>>,
253}
254
255impl PreBatches {
256    fn new() -> Self {
257        Self {
258            boundary_modules: FxHashSet::default(),
259            batches: Vec::new(),
260            entries: FxHashMap::default(),
261            single_module_entries: FxIndexSet::default(),
262        }
263    }
264
265    fn ensure_pre_batch_for_module(
266        &mut self,
267        module: ResolvedVc<Box<dyn Module>>,
268        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
269        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
270    ) -> Result<PreBatchIndex> {
271        Ok(match self.entries.entry(module) {
272            Entry::Vacant(e) => {
273                let index = self.batches.len();
274                queue.push_back((module, index));
275                let chunk_groups = module_chunk_groups
276                    .get(&module)
277                    .context("all modules need to have chunk group info")?;
278                let batch = PreBatch::new((*chunk_groups).clone());
279                self.batches.push(batch);
280                e.insert(index);
281                index
282            }
283            Entry::Occupied(e) => *e.get(),
284        })
285    }
286
287    async fn get_pre_batch_items(
288        &mut self,
289        entry: ResolvedVc<Box<dyn Module>>,
290        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
291        module_graph: &ModuleGraph,
292        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
293    ) -> Result<Vec<PreBatchItem>> {
294        let mut state = TraversalState {
295            items: Vec::new(),
296            this: self,
297        };
298        let mut visited = FxHashSet::default();
299        module_graph.traverse_edges_dfs(
300            std::iter::once(entry),
301            &mut state,
302            |parent_info, node, state| {
303                let ty = parent_info.map_or(
304                    &ChunkingType::Parallel {
305                        inherit_async: false,
306                        hoisted: false,
307                    },
308                    |(_, ty)| &ty.chunking_type,
309                );
310                let module = node;
311                if !ty.is_parallel() {
312                    state.items.push(PreBatchItem::NonParallelEdge(
313                        ty.without_inherit_async(),
314                        module,
315                    ));
316                    return Ok(GraphTraversalAction::Exclude);
317                }
318                if visited.insert(module) {
319                    if parent_info.is_some() && state.this.boundary_modules.contains(&module) {
320                        let idx = state.this.ensure_pre_batch_for_module(
321                            module,
322                            module_chunk_groups,
323                            queue,
324                        )?;
325                        state.items.push(PreBatchItem::ParallelReference(idx));
326                        return Ok(GraphTraversalAction::Exclude);
327                    }
328                    Ok(GraphTraversalAction::Continue)
329                } else {
330                    Ok(GraphTraversalAction::Exclude)
331                }
332            },
333            |_, node, state| {
334                let item = PreBatchItem::ParallelModule(node);
335                state.items.push(item);
336                Ok(())
337            },
338            false,
339        )?;
340        Ok(state.items)
341    }
342}
343
344pub async fn compute_module_batches(
345    module_graph: Vc<ModuleGraph>,
346    _config: &BatchingConfig,
347) -> Result<Vc<ModuleBatchesGraph>> {
348    let outer_span = tracing::info_span!(
349        "compute module batches",
350        initial_pre_batch_items = tracing::field::Empty,
351        initial_pre_batches = tracing::field::Empty,
352        extracted_shared_items = tracing::field::Empty,
353        batches = tracing::field::Empty,
354        modules = tracing::field::Empty,
355        edges = tracing::field::Empty
356    );
357    let span = outer_span.clone();
358    async move {
359        let chunk_group_info = module_graph.chunk_group_info().await?;
360        let module_chunk_groups = chunk_group_info.module_chunk_groups.await?;
361        let module_graph = module_graph.await?;
362
363        let mut pre_batches = PreBatches::new();
364
365        // Walk the module graph and mark all modules that are boundary modules (referenced from a
366        // different chunk group bitmap)
367        module_graph.traverse_edges_unordered(|parent, node| {
368            if let Some((parent, ty)) = parent {
369                let std::collections::hash_set::Entry::Vacant(entry) =
370                    pre_batches.boundary_modules.entry(node)
371                else {
372                    // Already a boundary module, can skip check
373                    return Ok(());
374                };
375                if ty.chunking_type.is_parallel() {
376                    let parent_chunk_groups = module_chunk_groups
377                        .get(&parent)
378                        .context("all modules need to have chunk group info")?;
379                    let chunk_groups = module_chunk_groups
380                        .get(&node)
381                        .context("all modules need to have chunk group info")?;
382                    if parent_chunk_groups != chunk_groups {
383                        // This is a boundary module
384                        entry.insert();
385                    }
386                } else {
387                    entry.insert();
388                }
389            }
390            Ok(())
391        })?;
392
393        // All entries are boundary modules too
394        for chunk_group in &chunk_group_info.chunk_groups {
395            for entry in chunk_group.entries() {
396                pre_batches.boundary_modules.insert(entry);
397            }
398        }
399
400        // Pre batches would be incorrect with cycles, so we need to opt-out of pre batches for
401        // cycles that include boundary modules
402        module_graph.traverse_cycles(
403            |ref_data| ref_data.chunking_type.is_parallel(),
404            |cycle| {
405                if cycle.len() > 1
406                    && cycle
407                        .iter()
408                        .any(|node| pre_batches.boundary_modules.contains(node))
409                {
410                    pre_batches
411                        .boundary_modules
412                        .extend(cycle.iter().map(|node| **node));
413                }
414                Ok(())
415            },
416        )?;
417
418        let mut queue: VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)> = VecDeque::new();
419
420        let mut chunk_group_indices_with_merged_children = FxHashSet::default();
421
422        // Start with the entries
423        for chunk_group in &chunk_group_info.chunk_groups {
424            for entry in chunk_group.entries() {
425                pre_batches.ensure_pre_batch_for_module(entry, &module_chunk_groups, &mut queue)?;
426            }
427            if let Some(parent) = chunk_group.get_merged_parent() {
428                chunk_group_indices_with_merged_children.insert(parent);
429            }
430        }
431
432        let mut initial_pre_batch_items = 0;
433        // Fill all pre batches
434        while let Some((chunkable_module, idx)) = queue.pop_front() {
435            let items = pre_batches
436                .get_pre_batch_items(
437                    chunkable_module,
438                    &module_chunk_groups,
439                    &module_graph,
440                    &mut queue,
441                )
442                .await?;
443            initial_pre_batch_items += items.len();
444            let batch = &mut pre_batches.batches[idx];
445            batch.items.extend(items);
446        }
447        span.record("initial_pre_batch_items", initial_pre_batch_items);
448        span.record("initial_pre_batches", pre_batches.batches.len());
449
450        // Figure out the order of all merged groups
451        let mut ordered_entries: Vec<Option<EntriesList>> =
452            vec![None; chunk_group_info.chunk_groups.len()];
453        for (i, chunk_group) in chunk_group_info.chunk_groups.iter().enumerate() {
454            if !chunk_group_indices_with_merged_children.contains(&i) {
455                continue;
456            }
457            let mut merged_modules: FxHashMap<ChunkingType, FxIndexSet<_>> = FxHashMap::default();
458            let mut stack = ordered_entries[i]
459                .as_ref()
460                .map_or_else(
461                    || Either::Left(chunk_group.entries()),
462                    |v| Either::Right(v.0.iter().copied()),
463                )
464                .map(|module| {
465                    let idx = *pre_batches
466                        .entries
467                        .get(&module)
468                        .context("could not prebatch for module")?;
469                    Ok((idx, 0))
470                })
471                .collect::<Result<Vec<_>>>()?;
472            stack.reverse();
473            let mut visited = FxHashSet::default();
474            while let Some((idx, mut pos)) = stack.pop() {
475                let batch = &pre_batches.batches[idx];
476                while let Some(item) = batch.items.get_index(pos) {
477                    match item {
478                        PreBatchItem::ParallelModule(_) => {}
479                        PreBatchItem::ParallelReference(other_idx) => {
480                            if visited.insert(*other_idx) {
481                                stack.push((idx, pos + 1));
482                                stack.push((*other_idx, 0));
483                                break;
484                            }
485                        }
486                        PreBatchItem::NonParallelEdge(chunking_type, module) => {
487                            if chunking_type.is_merged() {
488                                merged_modules
489                                    .entry(chunking_type.clone())
490                                    .or_default()
491                                    .insert(*module);
492                            }
493                        }
494                    }
495                    pos += 1;
496                }
497            }
498            if !merged_modules.is_empty() {
499                for (ty, merged_modules) in merged_modules {
500                    let chunk_group_key = match ty {
501                        ChunkingType::Isolated {
502                            merge_tag: Some(merge_tag),
503                            ..
504                        } => ChunkGroupKey::IsolatedMerged {
505                            parent: i.into(),
506                            merge_tag: merge_tag.clone(),
507                        },
508                        ChunkingType::Shared {
509                            merge_tag: Some(merge_tag),
510                            ..
511                        } => ChunkGroupKey::SharedMerged {
512                            parent: i.into(),
513                            merge_tag: merge_tag.clone(),
514                        },
515                        _ => unreachable!(),
516                    };
517                    let idx = chunk_group_info
518                        .chunk_group_keys
519                        .get_index_of(&chunk_group_key)
520                        .context("could not find chunk group key for merged chunk group")?;
521                    ordered_entries[idx] = Some(EntriesList(merged_modules));
522                }
523            }
524        }
525
526        // Create a map of parallel module to the batches they are contained in.
527        let mut parallel_module_to_pre_batch: FxIndexMap<_, Vec<PreBatchIndex>> =
528            FxIndexMap::default();
529
530        // Fill the map and also fill up the single_module_entries
531        for (idx, pre_batch) in pre_batches.batches.iter().enumerate() {
532            for item in &pre_batch.items {
533                match item {
534                    PreBatchItem::ParallelModule(module) => {
535                        parallel_module_to_pre_batch
536                            .entry(*module)
537                            .or_default()
538                            .push(idx);
539                    }
540                    PreBatchItem::NonParallelEdge(_, module) => {
541                        if !pre_batches.entries.contains_key(module) {
542                            pre_batches.single_module_entries.insert(*module);
543                        }
544                    }
545                    PreBatchItem::ParallelReference(_) => {}
546                }
547            }
548        }
549
550        // We never want a module to occur in multiple batches.
551
552        let mut extracted_shared_items = 0;
553        // Extract shared modules into separate batches
554        for i in 0..parallel_module_to_pre_batch.len() {
555            let (&module, batches) = parallel_module_to_pre_batch
556                .get_index(i)
557                .context("could not find parallel module to pre batch index")?;
558            if batches.len() > 1 {
559                // Create a new batch for the shared modules
560                let batches_with_item_index = batches
561                    .iter()
562                    .map(|&idx| {
563                        let batch_items = &pre_batches.batches[idx].items;
564                        let item_idx = batch_items
565                            .get_index_of(&PreBatchItem::ParallelModule(module))
566                            .context("could not find batch item index for parallel module")?;
567                        Ok((idx, item_idx))
568                    })
569                    .collect::<Result<Vec<_>>>()?;
570                let mut selected_items = 1;
571                fn get_item_at(
572                    pre_batches: &PreBatches,
573                    batch_idx: PreBatchIndex,
574                    item_idx: usize,
575                ) -> Option<&PreBatchItem> {
576                    pre_batches.batches[batch_idx].items.get_index(item_idx)
577                }
578                // Select more matching items that are equal in all batches that contain the shared
579                // module(s)
580                loop {
581                    if let Some(PreBatchItem::ParallelModule(next_module)) = get_item_at(
582                        &pre_batches,
583                        batches_with_item_index[0].0,
584                        batches_with_item_index[0].1 + selected_items,
585                    ) && parallel_module_to_pre_batch
586                        .get(next_module)
587                        .context("could not find pre batch for parallel module")?
588                        .len()
589                        == batches.len()
590                        && batches_with_item_index[1..]
591                            .iter()
592                            .all(|&(batch_idx, item_idx)| {
593                                get_item_at(&pre_batches, batch_idx, item_idx + selected_items)
594                                    == Some(&PreBatchItem::ParallelModule(*next_module))
595                            })
596                    {
597                        selected_items += 1;
598                        continue;
599                    }
600                    break;
601                }
602                extracted_shared_items += selected_items;
603
604                // Check if a batch is completely selected. In that case we can replace all other
605                // occurrences with a reference to that batch
606                let exact_match = batches_with_item_index
607                    .iter()
608                    .find(|&&(batch_idx, item_idx)| {
609                        item_idx == 0
610                            && pre_batches.batches[batch_idx].items.len() == selected_items
611                    });
612                if let Some(&(exact_match, _)) = exact_match {
613                    // Replace all other occurrences with a reference to the exact match
614                    for &(batch_index, item_start) in batches_with_item_index.iter() {
615                        if batch_index != exact_match {
616                            pre_batches.batches[batch_index].items.splice(
617                                item_start..item_start + selected_items,
618                                std::iter::once(PreBatchItem::ParallelReference(exact_match)),
619                            );
620                        }
621                    }
622                    for item in pre_batches.batches[exact_match].items.iter() {
623                        if let PreBatchItem::ParallelModule(module) = item {
624                            parallel_module_to_pre_batch
625                                .get_mut(module)
626                                .context("could not find pre batch for parallel module")?
627                                .clear();
628                        }
629                    }
630                } else {
631                    // Create a new batch of the shared part and replace all occurrences with a
632                    // reference to that batch
633                    let first_batch_index = batches_with_item_index[0].0;
634                    let first_batch_item_index = batches_with_item_index[0].1;
635                    let new_batch_index = pre_batches.batches.len();
636                    let mut new_batch =
637                        PreBatch::new(pre_batches.batches[first_batch_index].chunk_groups.clone());
638                    new_batch
639                        .items
640                        .extend(pre_batches.batches[first_batch_index].items.splice(
641                            first_batch_item_index..first_batch_item_index + selected_items,
642                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
643                        ));
644                    for item in new_batch.items.iter() {
645                        if let PreBatchItem::ParallelModule(module) = item {
646                            parallel_module_to_pre_batch
647                                .get_mut(module)
648                                .context("could not find pre batch for parallel module")?
649                                .clear();
650                        }
651                    }
652                    pre_batches.batches.push(new_batch);
653                    for &(batch_index, item_start) in batches_with_item_index[1..].iter() {
654                        pre_batches.batches[batch_index].items.splice(
655                            item_start..item_start + selected_items,
656                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
657                        );
658                    }
659                }
660            }
661        }
662        span.record("extracted_shared_items", extracted_shared_items);
663
664        // Now every module is only in one batch
665
666        let mut edges_count = 0;
667
668        // Since batches can only have references followed by a list of parallel chunkable modules,
669        // we need to split batches that have modules before references.
670        for i in 0..pre_batches.batches.len() {
671            let items = take(&mut pre_batches.batches[i].items);
672            let mut new_items =
673                FxIndexSet::with_capacity_and_hasher(items.len(), Default::default());
674            enum Mode {
675                ParallelChunkableModule,
676                Other,
677            }
678            let mut mode = Mode::Other;
679            for item in items {
680                let chunkable_module = if let PreBatchItem::ParallelModule(module) = &item {
681                    ResolvedVc::try_downcast::<Box<dyn ChunkableModule>>(*module)
682                } else {
683                    None
684                };
685                let item = if let PreBatchItem::ParallelModule(module) = item {
686                    if chunkable_module.is_some() {
687                        PreBatchItem::ParallelModule(module)
688                    } else {
689                        pre_batches.single_module_entries.insert(module);
690                        PreBatchItem::NonParallelEdge(
691                            ChunkingType::Parallel {
692                                inherit_async: false,
693                                hoisted: false,
694                            },
695                            module,
696                        )
697                    }
698                } else {
699                    item
700                };
701                match (&mode, chunkable_module) {
702                    (_, Some(_)) => {
703                        mode = Mode::ParallelChunkableModule;
704                        new_items.insert(item);
705                    }
706                    (Mode::Other, _) => {
707                        edges_count += 1;
708                        new_items.insert(item);
709                    }
710                    (Mode::ParallelChunkableModule, _) => {
711                        // Split the batch
712                        let idx = pre_batches.batches.len();
713                        let mut new_batch =
714                            PreBatch::new(pre_batches.batches[i].chunk_groups.clone());
715                        new_batch.items.extend(new_items.drain(..));
716                        pre_batches.batches.push(new_batch);
717                        edges_count += 1;
718                        new_items.insert(PreBatchItem::ParallelReference(idx));
719                        if chunkable_module.is_some() {
720                            new_items.insert(item);
721                        } else {
722                            edges_count += 1;
723                            mode = Mode::Other;
724                            new_items.insert(item);
725                        }
726                    }
727                }
728            }
729            pre_batches.batches[i].items = new_items;
730        }
731        span.record("pre_batches", pre_batches.batches.len());
732
733        // Now batches are in the correct shape. We can create the real batches and the graph.
734
735        // Create the graph
736        let mut graph: DiGraph<ModuleOrBatch, ModuleBatchesGraphEdge, u32> =
737            petgraph::graph::DiGraph::with_capacity(
738                pre_batches.batches.len() + pre_batches.single_module_entries.len(),
739                edges_count,
740            );
741
742        // Create the Vc<ModuleBatch> instances
743        let batches = pre_batches
744            .batches
745            .iter_mut()
746            .enumerate()
747            .map(async |(i, pre_batch)| {
748                let mut modules = pre_batch.items.iter().filter_map(|item| {
749                    if let PreBatchItem::ParallelModule(module) = item {
750                        ResolvedVc::try_downcast(*module)
751                    } else {
752                        None
753                    }
754                });
755                let Some(first) = modules.next() else {
756                    return Ok(ModuleOrBatch::None(i));
757                };
758                if let Some(second) = modules.next() {
759                    let batch = ModuleBatch::new(
760                        [first, second]
761                            .into_iter()
762                            .chain(modules)
763                            .map(|m| *m)
764                            .collect::<Vec<_>>(),
765                        Some(pre_batch.chunk_groups.clone()),
766                    );
767                    Ok(ModuleOrBatch::Batch(batch.to_resolved().await?))
768                } else {
769                    Ok(ModuleOrBatch::Module(ResolvedVc::upcast(first)))
770                }
771            })
772            .try_join()
773            .await?;
774
775        // Create the batch groups by grouping batches with the same chunk groups
776        let mut batch_groups: FxHashMap<_, Vec<_>> = FxHashMap::default();
777        for (i, pre_batch) in pre_batches.batches.iter().enumerate() {
778            let key = BuildHasherDefault::<FxHasher>::default().prehash(&pre_batch.chunk_groups);
779            let batch = batches[i];
780            batch_groups.entry(key).or_default().push(batch);
781        }
782        for &module in &pre_batches.single_module_entries {
783            let chunk_groups = module_chunk_groups
784                .get(&module)
785                .context("all modules need to have chunk group info")?;
786            let key = BuildHasherDefault::<FxHasher>::default().prehash(chunk_groups);
787            batch_groups
788                .entry(key)
789                .or_default()
790                .push(ModuleOrBatch::Module(module));
791        }
792
793        // Create the batch group instances
794        let batch_groups = batch_groups
795            .into_iter()
796            .map(async |(key, items)| {
797                if items.len() == 1 {
798                    Ok(Either::Left(std::iter::empty()))
799                } else {
800                    let batch_group = ModuleBatchGroup::new(items.clone(), (*key).clone())
801                        .to_resolved()
802                        .await?;
803                    Ok(Either::Right(
804                        items.into_iter().map(move |item| (item, batch_group)),
805                    ))
806                }
807            })
808            .try_join()
809            .await?
810            .into_iter()
811            .flatten()
812            .collect::<FxHashMap<_, _>>();
813
814        // Insert batches into the graph and store the NodeIndices
815        let mut batches_count = 0;
816        let mut modules_count = 0;
817        let batch_indices = batches
818            .into_iter()
819            .map(|batch| {
820                match &batch {
821                    ModuleOrBatch::Batch(_) => batches_count += 1,
822                    ModuleOrBatch::Module(_) => modules_count += 1,
823                    ModuleOrBatch::None(_) => {}
824                }
825                graph.add_node(batch)
826            })
827            .collect::<Vec<_>>();
828
829        // Also insert single modules into the graph and store the NodeIndices
830        let single_module_indices = pre_batches
831            .single_module_entries
832            .iter()
833            .map(|module| graph.add_node(ModuleOrBatch::Module(*module)))
834            .collect::<Vec<_>>();
835
836        span.record("batches", batches_count);
837        modules_count += pre_batches.single_module_entries.len();
838        span.record("modules", modules_count);
839        span.record("edges", edges_count);
840
841        // Add all the edges to the graph
842        for (i, pre_batch) in pre_batches.batches.into_iter().enumerate() {
843            let index = batch_indices[i];
844            let items = pre_batch.items;
845            for item in items {
846                match item {
847                    PreBatchItem::ParallelReference(idx) => {
848                        graph.add_edge(
849                            index,
850                            batch_indices[idx],
851                            ModuleBatchesGraphEdge {
852                                ty: ChunkingType::Parallel {
853                                    inherit_async: false,
854                                    hoisted: false,
855                                },
856                                module: None,
857                            },
858                        );
859                    }
860                    PreBatchItem::NonParallelEdge(ty, module) => {
861                        if let Some(batch) = pre_batches.entries.get(&module).copied() {
862                            graph.add_edge(
863                                index,
864                                batch_indices[batch],
865                                ModuleBatchesGraphEdge {
866                                    ty,
867                                    module: Some(module),
868                                },
869                            );
870                            continue;
871                        }
872                        let idx = pre_batches
873                            .single_module_entries
874                            .get_index_of(&module)
875                            .context("could not find single module entry index")?;
876                        let idx = single_module_indices[idx];
877                        graph.add_edge(
878                            index,
879                            idx,
880                            ModuleBatchesGraphEdge {
881                                ty,
882                                module: Some(module),
883                            },
884                        );
885                    }
886                    PreBatchItem::ParallelModule(_) => {}
887                }
888            }
889        }
890
891        debug_assert_eq!(graph.capacity().0, graph.node_count());
892        debug_assert_eq!(graph.capacity().1, graph.edge_count());
893
894        // Find the NodeIndices for our entries of the graph
895        let mut entries = FxHashMap::default();
896        for chunk_group in &chunk_group_info.chunk_groups {
897            for module in chunk_group.entries() {
898                if let Some(batch) = pre_batches.entries.get(&module).copied() {
899                    entries.insert(module, batch_indices[batch]);
900                    continue;
901                }
902                let idx = pre_batches
903                    .single_module_entries
904                    .get_index_of(&module)
905                    .context("could not find single module entry index")?;
906                let idx = single_module_indices[idx];
907                entries.insert(module, idx);
908            }
909        }
910
911        Ok(ModuleBatchesGraph {
912            graph: TracedDiGraph(graph),
913            entries,
914            batch_groups,
915            ordered_entries,
916        }
917        .cell())
918    }
919    .instrument(outer_span)
920    .await
921}