Skip to main content

turbopack_core/module_graph/
module_batches.rs

1use std::{
2    collections::{VecDeque, hash_map::Entry},
3    hash::BuildHasherDefault,
4    mem::take,
5};
6
7use anyhow::{Context, Result};
8use bincode::{Decode, Encode};
9use either::Either;
10use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
11use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
12use serde::{Deserialize, Serialize};
13use tracing::Instrument;
14use turbo_prehash::BuildHasherExt;
15use turbo_tasks::{
16    FxIndexMap, FxIndexSet, NonLocalValue, ResolvedVc, TaskInput, TryJoinIterExt, ValueToString,
17    Vc, trace::TraceRawVcs, turbobail,
18};
19
20use crate::{
21    chunk::{ChunkableModule, ChunkingType},
22    module::Module,
23    module_graph::{
24        GraphTraversalAction, ModuleGraph,
25        chunk_group_info::{ChunkGroupInfo, ChunkGroupKey, RoaringBitmapWrapper},
26        module_batch::{ModuleBatch, ModuleBatchGroup, ModuleOrBatch},
27        traced_di_graph::{TracedDiGraph, iter_neighbors_rev},
28    },
29};
30#[turbo_tasks::value]
31#[derive(Debug, Clone, Default, TaskInput, Hash)]
32pub struct BatchingConfig {
33    /// Use a heuristic based on the module path to create batches. It aims for batches of a good
34    /// size.
35    pub use_heuristic: bool,
36}
37
38#[turbo_tasks::value_impl]
39impl BatchingConfig {
40    #[turbo_tasks::function]
41    pub fn new(config: BatchingConfig) -> Vc<Self> {
42        config.cell()
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, TraceRawVcs, NonLocalValue)]
47pub struct ModuleBatchesGraphEdge {
48    pub ty: ChunkingType,
49    pub module: Option<ResolvedVc<Box<dyn Module>>>,
50}
51
52#[derive(Debug, Clone, TraceRawVcs, NonLocalValue, Encode, Decode)]
53struct EntriesList(
54    #[bincode(with = "turbo_bincode::indexset")] pub FxIndexSet<ResolvedVc<Box<dyn Module>>>,
55);
56
57#[turbo_tasks::value(cell = "new", eq = "manual")]
58pub struct ModuleBatchesGraph {
59    graph: TracedDiGraph<ModuleOrBatch, ModuleBatchesGraphEdge>,
60
61    // NodeIndex isn't necessarily stable (because of swap_remove), but we never remove nodes.
62    //
63    // HashMaps have nondeterministic order, but this map is only used for lookups and not
64    // iteration.
65    //
66    // This contains Vcs, but they are already contained in the graph, so no need to trace this.
67    #[turbo_tasks(trace_ignore)]
68    #[bincode(with_serde)]
69    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, NodeIndex>,
70    batch_groups: FxHashMap<ModuleOrBatch, ResolvedVc<ModuleBatchGroup>>,
71
72    /// For chunk groups where the postorder of entries is different than the order of the
73    /// `ChunkGroup::entries()` this contains Some with the postorder list of entries of that chunk
74    /// group. The index in this list corresponds to the index in the
75    /// chunk_group_info.chunk_groups.
76    ordered_entries: Vec<Option<EntriesList>>,
77}
78
79impl ModuleBatchesGraph {
80    pub async fn get_entry_index(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<NodeIndex> {
81        let Some(entry) = self.entries.get(&entry) else {
82            let possible_entries = format!(
83                "{:#?}",
84                self.entries
85                    .keys()
86                    .map(|e| e.ident().to_string())
87                    .try_join()
88                    .await?
89            );
90            turbobail!(
91                "Entry {} is not in graph (possible entries: {})",
92                entry.ident(),
93                possible_entries
94            );
95        };
96        Ok(*entry)
97    }
98
99    pub fn get_ordered_entries<'l>(
100        &'l self,
101        chunk_group_info: &'l ChunkGroupInfo,
102        idx: usize,
103    ) -> impl Iterator<Item = ResolvedVc<Box<dyn Module>>> + 'l {
104        if let Some(EntriesList(ordered_entries)) = self
105            .ordered_entries
106            .get(idx)
107            .as_ref()
108            .and_then(|o| o.as_ref())
109        {
110            if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
111                debug_assert_eq!(ordered_entries.len(), chunk_group.entries_count());
112            }
113            Either::Left(Either::Left(ordered_entries.iter().copied()))
114        } else if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
115            Either::Right(chunk_group.entries())
116        } else {
117            Either::Left(Either::Right(std::iter::empty()))
118        }
119    }
120
121    pub fn get_batch_group(
122        &self,
123        module_or_batch: &ModuleOrBatch,
124    ) -> Option<ResolvedVc<ModuleBatchGroup>> {
125        self.batch_groups.get(module_or_batch).copied()
126    }
127
128    pub async fn get_entry(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<ModuleOrBatch> {
129        let entry = self.get_entry_index(entry).await?;
130        Ok(*self.graph.node_weight(entry).unwrap())
131    }
132
133    // Clippy complains but there's a type error without the bound
134    #[allow(clippy::implied_bounds_in_impls)]
135    /// Traverses all reachable edges in dfs order. The preorder visitor can be used to
136    /// forward state down the graph, and to skip subgraphs
137    ///
138    /// Use this to collect batches/modules in evaluation order.
139    ///
140    /// Target nodes can be revisited (once per incoming edge).
141    /// Edges are traversed in normal order, so should correspond to reference order.
142    ///
143    /// * `entries` - The entry modules to start the traversal from
144    /// * `state` - The state to be passed to the visitors
145    /// * `visit_preorder` - Called before visiting the children of a node.
146    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
147    ///      &ModuleBatchesGraphNode, state &S
148    ///    - Can return [GraphTraversalAction]s to control the traversal
149    /// * `visit_postorder` - Called after visiting the children of a node. Return
150    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
151    ///      &ModuleBatchesGraphNode, state &S
152    pub fn traverse_edges_from_entries_dfs<'a, S>(
153        &'a self,
154        entries: impl IntoIterator<
155            Item = NodeIndex,
156            IntoIter = impl Iterator<Item = NodeIndex> + DoubleEndedIterator,
157        >,
158        state: &mut S,
159        mut visit_preorder: impl FnMut(
160            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
161            &'a ModuleOrBatch,
162            &mut S,
163        ) -> Result<GraphTraversalAction>,
164        mut visit_postorder: impl FnMut(
165            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
166            &'a ModuleOrBatch,
167            &mut S,
168        ),
169    ) -> Result<()> {
170        let graph = &self.graph;
171
172        enum ReverseDFSPass {
173            Visit,
174            ExpandAndVisit,
175        }
176
177        let entries = entries.into_iter();
178        #[allow(clippy::type_complexity)] // This is a temporary internal structure
179        let mut stack: Vec<(ReverseDFSPass, Option<(NodeIndex, EdgeIndex)>, NodeIndex)> = entries
180            .rev()
181            .map(|e| (ReverseDFSPass::ExpandAndVisit, None, e))
182            .collect();
183        let mut expanded = FxHashSet::default();
184        while let Some((pass, parent, current)) = stack.pop() {
185            let parent_arg = parent.map(|(node, edge)| {
186                (
187                    graph.node_weight(node).unwrap(),
188                    graph.edge_weight(edge).unwrap(),
189                )
190            });
191            match pass {
192                ReverseDFSPass::Visit => {
193                    let current_node = graph.node_weight(current).unwrap();
194                    visit_postorder(parent_arg, current_node, state);
195                }
196                ReverseDFSPass::ExpandAndVisit => {
197                    let current_node = graph.node_weight(current).unwrap();
198                    let action = visit_preorder(parent_arg, current_node, state)?;
199                    if action == GraphTraversalAction::Exclude {
200                        continue;
201                    }
202                    stack.push((ReverseDFSPass::Visit, parent, current));
203                    if action == GraphTraversalAction::Continue && expanded.insert(current) {
204                        stack.extend(iter_neighbors_rev(graph, current).map(|(edge, child)| {
205                            (ReverseDFSPass::ExpandAndVisit, Some((current, edge)), child)
206                        }));
207                    }
208                }
209            }
210        }
211
212        Ok(())
213    }
214}
215
216type PreBatchIndex = usize;
217
218#[derive(Hash, PartialEq, Eq, Clone, Debug)]
219enum PreBatchItem {
220    ParallelModule(ResolvedVc<Box<dyn Module>>),
221    ParallelReference(PreBatchIndex),
222    NonParallelEdge(ChunkingType, ResolvedVc<Box<dyn Module>>),
223}
224
225struct PreBatch {
226    items: FxIndexSet<PreBatchItem>,
227    chunk_groups: RoaringBitmapWrapper,
228}
229
230impl PreBatch {
231    fn new(chunk_groups: RoaringBitmapWrapper) -> Self {
232        Self {
233            items: FxIndexSet::default(),
234            chunk_groups,
235        }
236    }
237}
238
239struct TraversalState<'l> {
240    items: Vec<PreBatchItem>,
241    this: &'l mut PreBatches,
242}
243
244struct PreBatches {
245    boundary_modules: FxHashSet<ResolvedVc<Box<dyn Module>>>,
246    batches: Vec<PreBatch>,
247    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, PreBatchIndex>,
248    single_module_entries: FxIndexSet<ResolvedVc<Box<dyn Module>>>,
249}
250
251impl PreBatches {
252    fn new() -> Self {
253        Self {
254            boundary_modules: FxHashSet::default(),
255            batches: Vec::new(),
256            entries: FxHashMap::default(),
257            single_module_entries: FxIndexSet::default(),
258        }
259    }
260
261    fn ensure_pre_batch_for_module(
262        &mut self,
263        module: ResolvedVc<Box<dyn Module>>,
264        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
265        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
266    ) -> Result<PreBatchIndex> {
267        Ok(match self.entries.entry(module) {
268            Entry::Vacant(e) => {
269                let index = self.batches.len();
270                queue.push_back((module, index));
271                let chunk_groups = module_chunk_groups
272                    .get(&module)
273                    .context("all modules need to have chunk group info")?;
274                let batch = PreBatch::new((*chunk_groups).clone());
275                self.batches.push(batch);
276                e.insert(index);
277                index
278            }
279            Entry::Occupied(e) => *e.get(),
280        })
281    }
282
283    async fn get_pre_batch_items(
284        &mut self,
285        entry: ResolvedVc<Box<dyn Module>>,
286        module_chunk_groups: &FxHashMap<ResolvedVc<Box<dyn Module>>, RoaringBitmapWrapper>,
287        module_graph: &ModuleGraph,
288        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
289    ) -> Result<Vec<PreBatchItem>> {
290        let mut state = TraversalState {
291            items: Vec::new(),
292            this: self,
293        };
294        let mut visited = FxHashSet::default();
295        module_graph.traverse_edges_dfs(
296            std::iter::once(entry),
297            &mut state,
298            |parent_info, node, state| {
299                let ty = parent_info.map_or(
300                    &ChunkingType::Parallel {
301                        inherit_async: false,
302                        hoisted: false,
303                    },
304                    |(_, ty)| &ty.chunking_type,
305                );
306                let module = node;
307                if !ty.is_parallel() {
308                    state.items.push(PreBatchItem::NonParallelEdge(
309                        ty.without_inherit_async(),
310                        module,
311                    ));
312                    return Ok(GraphTraversalAction::Exclude);
313                }
314                if visited.insert(module) {
315                    if parent_info.is_some() && state.this.boundary_modules.contains(&module) {
316                        let idx = state.this.ensure_pre_batch_for_module(
317                            module,
318                            module_chunk_groups,
319                            queue,
320                        )?;
321                        state.items.push(PreBatchItem::ParallelReference(idx));
322                        return Ok(GraphTraversalAction::Exclude);
323                    }
324                    Ok(GraphTraversalAction::Continue)
325                } else {
326                    Ok(GraphTraversalAction::Exclude)
327                }
328            },
329            |_, node, state| {
330                let item = PreBatchItem::ParallelModule(node);
331                state.items.push(item);
332                Ok(())
333            },
334        )?;
335        Ok(state.items)
336    }
337}
338
339pub async fn compute_module_batches(
340    module_graph: Vc<ModuleGraph>,
341    _config: &BatchingConfig,
342) -> Result<Vc<ModuleBatchesGraph>> {
343    let outer_span = tracing::info_span!(
344        "compute module batches",
345        initial_pre_batch_items = tracing::field::Empty,
346        initial_pre_batches = tracing::field::Empty,
347        extracted_shared_items = tracing::field::Empty,
348        batches = tracing::field::Empty,
349        modules = tracing::field::Empty,
350        edges = tracing::field::Empty
351    );
352    let span = outer_span.clone();
353    async move {
354        let chunk_group_info = module_graph.chunk_group_info().await?;
355        let module_chunk_groups = chunk_group_info.module_chunk_groups.await?;
356        let module_graph = module_graph.await?;
357
358        let mut pre_batches = PreBatches::new();
359
360        // Walk the module graph and mark all modules that are boundary modules (referenced from a
361        // different chunk group bitmap)
362        module_graph.traverse_edges_unordered(|parent, node| {
363            if let Some((parent, ty)) = parent {
364                let std::collections::hash_set::Entry::Vacant(entry) =
365                    pre_batches.boundary_modules.entry(node)
366                else {
367                    // Already a boundary module, can skip check
368                    return Ok(());
369                };
370                if ty.chunking_type.is_parallel() {
371                    let parent_chunk_groups = module_chunk_groups
372                        .get(&parent)
373                        .context("all modules need to have chunk group info")?;
374                    let chunk_groups = module_chunk_groups
375                        .get(&node)
376                        .context("all modules need to have chunk group info")?;
377                    if parent_chunk_groups != chunk_groups {
378                        // This is a boundary module
379                        entry.insert();
380                    }
381                } else {
382                    entry.insert();
383                }
384            }
385            Ok(())
386        })?;
387
388        // All entries are boundary modules too
389        for chunk_group in &chunk_group_info.chunk_groups {
390            for entry in chunk_group.entries() {
391                pre_batches.boundary_modules.insert(entry);
392            }
393        }
394
395        // Pre batches would be incorrect with cycles, so we need to opt-out of pre batches for
396        // cycles that include boundary modules
397        module_graph.traverse_cycles(
398            |ref_data| ref_data.chunking_type.is_parallel(),
399            |cycle| {
400                if cycle.len() > 1
401                    && cycle
402                        .iter()
403                        .any(|node| pre_batches.boundary_modules.contains(node))
404                {
405                    pre_batches
406                        .boundary_modules
407                        .extend(cycle.iter().map(|node| **node));
408                }
409                Ok(())
410            },
411        )?;
412
413        let mut queue: VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)> = VecDeque::new();
414
415        let mut chunk_group_indices_with_merged_children = FxHashSet::default();
416
417        // Start with the entries
418        for chunk_group in &chunk_group_info.chunk_groups {
419            for entry in chunk_group.entries() {
420                pre_batches.ensure_pre_batch_for_module(entry, &module_chunk_groups, &mut queue)?;
421            }
422            if let Some(parent) = chunk_group.get_merged_parent() {
423                chunk_group_indices_with_merged_children.insert(parent);
424            }
425        }
426
427        let mut initial_pre_batch_items = 0;
428        // Fill all pre batches
429        while let Some((chunkable_module, idx)) = queue.pop_front() {
430            let items = pre_batches
431                .get_pre_batch_items(
432                    chunkable_module,
433                    &module_chunk_groups,
434                    &module_graph,
435                    &mut queue,
436                )
437                .await?;
438            initial_pre_batch_items += items.len();
439            let batch = &mut pre_batches.batches[idx];
440            batch.items.extend(items);
441        }
442        span.record("initial_pre_batch_items", initial_pre_batch_items);
443        span.record("initial_pre_batches", pre_batches.batches.len());
444
445        // Figure out the order of all merged groups
446        let mut ordered_entries: Vec<Option<EntriesList>> =
447            vec![None; chunk_group_info.chunk_groups.len()];
448        for (i, chunk_group) in chunk_group_info.chunk_groups.iter().enumerate() {
449            if !chunk_group_indices_with_merged_children.contains(&i) {
450                continue;
451            }
452            let mut merged_modules: FxHashMap<ChunkingType, FxIndexSet<_>> = FxHashMap::default();
453            let mut stack = ordered_entries[i]
454                .as_ref()
455                .map_or_else(
456                    || Either::Left(chunk_group.entries()),
457                    |v| Either::Right(v.0.iter().copied()),
458                )
459                .map(|module| {
460                    let idx = *pre_batches
461                        .entries
462                        .get(&module)
463                        .context("could not prebatch for module")?;
464                    Ok((idx, 0))
465                })
466                .collect::<Result<Vec<_>>>()?;
467            stack.reverse();
468            let mut visited = FxHashSet::default();
469            while let Some((idx, mut pos)) = stack.pop() {
470                let batch = &pre_batches.batches[idx];
471                while let Some(item) = batch.items.get_index(pos) {
472                    match item {
473                        PreBatchItem::ParallelModule(_) => {}
474                        PreBatchItem::ParallelReference(other_idx) => {
475                            if visited.insert(*other_idx) {
476                                stack.push((idx, pos + 1));
477                                stack.push((*other_idx, 0));
478                                break;
479                            }
480                        }
481                        PreBatchItem::NonParallelEdge(chunking_type, module) => {
482                            if chunking_type.is_merged() {
483                                merged_modules
484                                    .entry(chunking_type.clone())
485                                    .or_default()
486                                    .insert(*module);
487                            }
488                        }
489                    }
490                    pos += 1;
491                }
492            }
493            if !merged_modules.is_empty() {
494                for (ty, merged_modules) in merged_modules {
495                    let chunk_group_key = match ty {
496                        ChunkingType::Isolated {
497                            merge_tag: Some(merge_tag),
498                            ..
499                        } => ChunkGroupKey::IsolatedMerged {
500                            parent: i.into(),
501                            merge_tag: merge_tag.clone(),
502                        },
503                        ChunkingType::Shared {
504                            merge_tag: Some(merge_tag),
505                            ..
506                        } => ChunkGroupKey::SharedMerged {
507                            parent: i.into(),
508                            merge_tag: merge_tag.clone(),
509                        },
510                        _ => unreachable!(),
511                    };
512                    let idx = chunk_group_info
513                        .chunk_group_keys
514                        .get_index_of(&chunk_group_key)
515                        .context("could not find chunk group key for merged chunk group")?;
516                    ordered_entries[idx] = Some(EntriesList(merged_modules));
517                }
518            }
519        }
520
521        // Create a map of parallel module to the batches they are contained in.
522        let mut parallel_module_to_pre_batch: FxIndexMap<_, Vec<PreBatchIndex>> =
523            FxIndexMap::default();
524
525        // Fill the map and also fill up the single_module_entries
526        for (idx, pre_batch) in pre_batches.batches.iter().enumerate() {
527            for item in &pre_batch.items {
528                match item {
529                    PreBatchItem::ParallelModule(module) => {
530                        parallel_module_to_pre_batch
531                            .entry(*module)
532                            .or_default()
533                            .push(idx);
534                    }
535                    PreBatchItem::NonParallelEdge(_, module) => {
536                        if !pre_batches.entries.contains_key(module) {
537                            pre_batches.single_module_entries.insert(*module);
538                        }
539                    }
540                    PreBatchItem::ParallelReference(_) => {}
541                }
542            }
543        }
544
545        // We never want a module to occur in multiple batches.
546
547        let mut extracted_shared_items = 0;
548        // Extract shared modules into separate batches
549        for i in 0..parallel_module_to_pre_batch.len() {
550            let (&module, batches) = parallel_module_to_pre_batch
551                .get_index(i)
552                .context("could not find parallel module to pre batch index")?;
553            if batches.len() > 1 {
554                // Create a new batch for the shared modules
555                let batches_with_item_index = batches
556                    .iter()
557                    .map(|&idx| {
558                        let batch_items = &pre_batches.batches[idx].items;
559                        let item_idx = batch_items
560                            .get_index_of(&PreBatchItem::ParallelModule(module))
561                            .context("could not find batch item index for parallel module")?;
562                        Ok((idx, item_idx))
563                    })
564                    .collect::<Result<Vec<_>>>()?;
565                let mut selected_items = 1;
566                fn get_item_at(
567                    pre_batches: &PreBatches,
568                    batch_idx: PreBatchIndex,
569                    item_idx: usize,
570                ) -> Option<&PreBatchItem> {
571                    pre_batches.batches[batch_idx].items.get_index(item_idx)
572                }
573                // Select more matching items that are equal in all batches that contain the shared
574                // module(s)
575                loop {
576                    if let Some(PreBatchItem::ParallelModule(next_module)) = get_item_at(
577                        &pre_batches,
578                        batches_with_item_index[0].0,
579                        batches_with_item_index[0].1 + selected_items,
580                    ) && parallel_module_to_pre_batch
581                        .get(next_module)
582                        .context("could not find pre batch for parallel module")?
583                        .len()
584                        == batches.len()
585                        && batches_with_item_index[1..]
586                            .iter()
587                            .all(|&(batch_idx, item_idx)| {
588                                get_item_at(&pre_batches, batch_idx, item_idx + selected_items)
589                                    == Some(&PreBatchItem::ParallelModule(*next_module))
590                            })
591                    {
592                        selected_items += 1;
593                        continue;
594                    }
595                    break;
596                }
597                extracted_shared_items += selected_items;
598
599                // Check if a batch is completely selected. In that case we can replace all other
600                // occurrences with a reference to that batch
601                let exact_match = batches_with_item_index
602                    .iter()
603                    .find(|&&(batch_idx, item_idx)| {
604                        item_idx == 0
605                            && pre_batches.batches[batch_idx].items.len() == selected_items
606                    });
607                if let Some(&(exact_match, _)) = exact_match {
608                    // Replace all other occurrences with a reference to the exact match
609                    for &(batch_index, item_start) in batches_with_item_index.iter() {
610                        if batch_index != exact_match {
611                            pre_batches.batches[batch_index].items.splice(
612                                item_start..item_start + selected_items,
613                                std::iter::once(PreBatchItem::ParallelReference(exact_match)),
614                            );
615                        }
616                    }
617                    for item in pre_batches.batches[exact_match].items.iter() {
618                        if let PreBatchItem::ParallelModule(module) = item {
619                            parallel_module_to_pre_batch
620                                .get_mut(module)
621                                .context("could not find pre batch for parallel module")?
622                                .clear();
623                        }
624                    }
625                } else {
626                    // Create a new batch of the shared part and replace all occurrences with a
627                    // reference to that batch
628                    let first_batch_index = batches_with_item_index[0].0;
629                    let first_batch_item_index = batches_with_item_index[0].1;
630                    let new_batch_index = pre_batches.batches.len();
631                    let mut new_batch =
632                        PreBatch::new(pre_batches.batches[first_batch_index].chunk_groups.clone());
633                    new_batch
634                        .items
635                        .extend(pre_batches.batches[first_batch_index].items.splice(
636                            first_batch_item_index..first_batch_item_index + selected_items,
637                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
638                        ));
639                    for item in new_batch.items.iter() {
640                        if let PreBatchItem::ParallelModule(module) = item {
641                            parallel_module_to_pre_batch
642                                .get_mut(module)
643                                .context("could not find pre batch for parallel module")?
644                                .clear();
645                        }
646                    }
647                    pre_batches.batches.push(new_batch);
648                    for &(batch_index, item_start) in batches_with_item_index[1..].iter() {
649                        pre_batches.batches[batch_index].items.splice(
650                            item_start..item_start + selected_items,
651                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
652                        );
653                    }
654                }
655            }
656        }
657        span.record("extracted_shared_items", extracted_shared_items);
658
659        // Now every module is only in one batch
660
661        let mut edges_count = 0;
662
663        // Since batches can only have references followed by a list of parallel chunkable modules,
664        // we need to split batches that have modules before references.
665        for i in 0..pre_batches.batches.len() {
666            let items = take(&mut pre_batches.batches[i].items);
667            let mut new_items =
668                FxIndexSet::with_capacity_and_hasher(items.len(), Default::default());
669            enum Mode {
670                ParallelChunkableModule,
671                Other,
672            }
673            let mut mode = Mode::Other;
674            for item in items {
675                let chunkable_module = if let PreBatchItem::ParallelModule(module) = &item {
676                    ResolvedVc::try_downcast::<Box<dyn ChunkableModule>>(*module)
677                } else {
678                    None
679                };
680                let item = if let PreBatchItem::ParallelModule(module) = item {
681                    if chunkable_module.is_some() {
682                        PreBatchItem::ParallelModule(module)
683                    } else {
684                        pre_batches.single_module_entries.insert(module);
685                        PreBatchItem::NonParallelEdge(
686                            ChunkingType::Parallel {
687                                inherit_async: false,
688                                hoisted: false,
689                            },
690                            module,
691                        )
692                    }
693                } else {
694                    item
695                };
696                match (&mode, chunkable_module) {
697                    (_, Some(_)) => {
698                        mode = Mode::ParallelChunkableModule;
699                        new_items.insert(item);
700                    }
701                    (Mode::Other, _) => {
702                        edges_count += 1;
703                        new_items.insert(item);
704                    }
705                    (Mode::ParallelChunkableModule, _) => {
706                        // Split the batch
707                        let idx = pre_batches.batches.len();
708                        let mut new_batch =
709                            PreBatch::new(pre_batches.batches[i].chunk_groups.clone());
710                        new_batch.items.extend(new_items.drain(..));
711                        pre_batches.batches.push(new_batch);
712                        edges_count += 1;
713                        new_items.insert(PreBatchItem::ParallelReference(idx));
714                        if chunkable_module.is_some() {
715                            new_items.insert(item);
716                        } else {
717                            edges_count += 1;
718                            mode = Mode::Other;
719                            new_items.insert(item);
720                        }
721                    }
722                }
723            }
724            pre_batches.batches[i].items = new_items;
725        }
726        span.record("pre_batches", pre_batches.batches.len());
727
728        // Now batches are in the correct shape. We can create the real batches and the graph.
729
730        // Create the graph
731        let mut graph: DiGraph<ModuleOrBatch, ModuleBatchesGraphEdge, u32> =
732            petgraph::graph::DiGraph::with_capacity(
733                pre_batches.batches.len() + pre_batches.single_module_entries.len(),
734                edges_count,
735            );
736
737        // Create the Vc<ModuleBatch> instances
738        let batches = pre_batches
739            .batches
740            .iter_mut()
741            .enumerate()
742            .map(async |(i, pre_batch)| {
743                let mut modules = pre_batch.items.iter().filter_map(|item| {
744                    if let PreBatchItem::ParallelModule(module) = item {
745                        ResolvedVc::try_downcast(*module)
746                    } else {
747                        None
748                    }
749                });
750                let Some(first) = modules.next() else {
751                    return Ok(ModuleOrBatch::None(i));
752                };
753                if let Some(second) = modules.next() {
754                    let batch = ModuleBatch::new(
755                        [first, second]
756                            .into_iter()
757                            .chain(modules)
758                            .map(|m| *m)
759                            .collect::<Vec<_>>(),
760                        Some(pre_batch.chunk_groups.clone()),
761                    );
762                    Ok(ModuleOrBatch::Batch(batch.to_resolved().await?))
763                } else {
764                    Ok(ModuleOrBatch::Module(ResolvedVc::upcast(first)))
765                }
766            })
767            .try_join()
768            .await?;
769
770        // Create the batch groups by grouping batches with the same chunk groups
771        let mut batch_groups: FxHashMap<_, Vec<_>> = FxHashMap::default();
772        for (i, pre_batch) in pre_batches.batches.iter().enumerate() {
773            let key = BuildHasherDefault::<FxHasher>::default().prehash(&pre_batch.chunk_groups);
774            let batch = batches[i];
775            batch_groups.entry(key).or_default().push(batch);
776        }
777        for &module in &pre_batches.single_module_entries {
778            let chunk_groups = module_chunk_groups
779                .get(&module)
780                .context("all modules need to have chunk group info")?;
781            let key = BuildHasherDefault::<FxHasher>::default().prehash(chunk_groups);
782            batch_groups
783                .entry(key)
784                .or_default()
785                .push(ModuleOrBatch::Module(module));
786        }
787
788        // Create the batch group instances
789        let batch_groups = batch_groups
790            .into_iter()
791            .map(async |(key, items)| {
792                if items.len() == 1 {
793                    Ok(Either::Left(std::iter::empty()))
794                } else {
795                    let batch_group = ModuleBatchGroup::new(items.clone(), (*key).clone())
796                        .to_resolved()
797                        .await?;
798                    Ok(Either::Right(
799                        items.into_iter().map(move |item| (item, batch_group)),
800                    ))
801                }
802            })
803            .try_join()
804            .await?
805            .into_iter()
806            .flatten()
807            .collect::<FxHashMap<_, _>>();
808
809        // Insert batches into the graph and store the NodeIndices
810        let mut batches_count = 0;
811        let mut modules_count = 0;
812        let batch_indices = batches
813            .into_iter()
814            .map(|batch| {
815                match &batch {
816                    ModuleOrBatch::Batch(_) => batches_count += 1,
817                    ModuleOrBatch::Module(_) => modules_count += 1,
818                    ModuleOrBatch::None(_) => {}
819                }
820                graph.add_node(batch)
821            })
822            .collect::<Vec<_>>();
823
824        // Also insert single modules into the graph and store the NodeIndices
825        let single_module_indices = pre_batches
826            .single_module_entries
827            .iter()
828            .map(|module| graph.add_node(ModuleOrBatch::Module(*module)))
829            .collect::<Vec<_>>();
830
831        span.record("batches", batches_count);
832        modules_count += pre_batches.single_module_entries.len();
833        span.record("modules", modules_count);
834        span.record("edges", edges_count);
835
836        // Add all the edges to the graph
837        for (i, pre_batch) in pre_batches.batches.into_iter().enumerate() {
838            let index = batch_indices[i];
839            let items = pre_batch.items;
840            for item in items {
841                match item {
842                    PreBatchItem::ParallelReference(idx) => {
843                        graph.add_edge(
844                            index,
845                            batch_indices[idx],
846                            ModuleBatchesGraphEdge {
847                                ty: ChunkingType::Parallel {
848                                    inherit_async: false,
849                                    hoisted: false,
850                                },
851                                module: None,
852                            },
853                        );
854                    }
855                    PreBatchItem::NonParallelEdge(ty, module) => {
856                        if let Some(batch) = pre_batches.entries.get(&module).copied() {
857                            graph.add_edge(
858                                index,
859                                batch_indices[batch],
860                                ModuleBatchesGraphEdge {
861                                    ty,
862                                    module: Some(module),
863                                },
864                            );
865                            continue;
866                        }
867                        let idx = pre_batches
868                            .single_module_entries
869                            .get_index_of(&module)
870                            .context("could not find single module entry index")?;
871                        let idx = single_module_indices[idx];
872                        graph.add_edge(
873                            index,
874                            idx,
875                            ModuleBatchesGraphEdge {
876                                ty,
877                                module: Some(module),
878                            },
879                        );
880                    }
881                    PreBatchItem::ParallelModule(_) => {}
882                }
883            }
884        }
885
886        debug_assert_eq!(graph.capacity().0, graph.node_count());
887        debug_assert_eq!(graph.capacity().1, graph.edge_count());
888
889        // Find the NodeIndices for our entries of the graph
890        let mut entries = FxHashMap::default();
891        for chunk_group in &chunk_group_info.chunk_groups {
892            for module in chunk_group.entries() {
893                if let Some(batch) = pre_batches.entries.get(&module).copied() {
894                    entries.insert(module, batch_indices[batch]);
895                    continue;
896                }
897                let idx = pre_batches
898                    .single_module_entries
899                    .get_index_of(&module)
900                    .context("could not find single module entry index")?;
901                let idx = single_module_indices[idx];
902                entries.insert(module, idx);
903            }
904        }
905
906        Ok(ModuleBatchesGraph {
907            graph: TracedDiGraph(graph),
908            entries,
909            batch_groups,
910            ordered_entries,
911        }
912        .cell())
913    }
914    .instrument(outer_span)
915    .await
916}