turbopack_core/module_graph/
module_batches.rs

1use std::{
2    collections::{VecDeque, hash_map::Entry},
3    hash::BuildHasherDefault,
4    mem::take,
5};
6
7use anyhow::{Context, Result, bail};
8use either::Either;
9use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
10use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
11use serde::{Deserialize, Serialize};
12use tracing::Instrument;
13use turbo_prehash::BuildHasherExt;
14use turbo_tasks::{
15    FxIndexMap, FxIndexSet, NonLocalValue, ResolvedVc, TaskInput, TryJoinIterExt, ValueToString,
16    Vc, trace::TraceRawVcs,
17};
18
19use crate::{
20    chunk::{ChunkableModule, ChunkingType},
21    module::Module,
22    module_graph::{
23        GraphTraversalAction, ModuleGraph,
24        chunk_group_info::{ChunkGroupInfo, ChunkGroupKey, RoaringBitmapWrapper},
25        module_batch::{ModuleBatch, ModuleBatchGroup, ModuleOrBatch},
26        traced_di_graph::{TracedDiGraph, iter_neighbors_rev},
27    },
28};
29#[turbo_tasks::value]
30#[derive(Debug, Clone, Default, TaskInput, Hash)]
31pub struct BatchingConfig {
32    /// Use a heuristic based on the module path to create batches. It aims for batches of a good
33    /// size.
34    pub use_heuristic: bool,
35}
36
37#[turbo_tasks::value_impl]
38impl BatchingConfig {
39    #[turbo_tasks::function]
40    pub fn new(config: BatchingConfig) -> Vc<Self> {
41        config.cell()
42    }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, TraceRawVcs, NonLocalValue)]
46pub struct ModuleBatchesGraphEdge {
47    pub ty: ChunkingType,
48    pub module: Option<ResolvedVc<Box<dyn Module>>>,
49}
50
51type EntriesList = FxIndexSet<ResolvedVc<Box<dyn Module>>>;
52
53#[turbo_tasks::value(cell = "new", eq = "manual", into = "new")]
54pub struct ModuleBatchesGraph {
55    graph: TracedDiGraph<ModuleOrBatch, ModuleBatchesGraphEdge>,
56
57    // NodeIndex isn't necessarily stable (because of swap_remove), but we never remove nodes.
58    //
59    // HashMaps have nondeterministic order, but this map is only used for lookups and not
60    // iteration.
61    //
62    // This contains Vcs, but they are already contained in the graph, so no need to trace this.
63    #[turbo_tasks(trace_ignore)]
64    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, NodeIndex>,
65    batch_groups: FxHashMap<ModuleOrBatch, ResolvedVc<ModuleBatchGroup>>,
66
67    /// For chunk groups where the postorder of entries is different than the order of the
68    /// `ChunkGroup::entries()` this contains Some with the postorder list of entries of that chunk
69    /// group. The index in this list corresponds to the index in the
70    /// chunk_group_info.chunk_groups.
71    ordered_entries: Vec<Option<EntriesList>>,
72}
73
74impl ModuleBatchesGraph {
75    pub async fn get_entry_index(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<NodeIndex> {
76        let Some(entry) = self.entries.get(&entry) else {
77            bail!(
78                "Entry {} is not in graph (possible entries: {:#?})",
79                entry.ident().to_string().await?,
80                self.entries
81                    .keys()
82                    .map(|e| e.ident().to_string())
83                    .try_join()
84                    .await?
85            );
86        };
87        Ok(*entry)
88    }
89
90    pub fn get_ordered_entries<'l>(
91        &'l self,
92        chunk_group_info: &'l ChunkGroupInfo,
93        idx: usize,
94    ) -> impl Iterator<Item = ResolvedVc<Box<dyn Module>>> + 'l {
95        if let Some(ordered_entries) = self
96            .ordered_entries
97            .get(idx)
98            .as_ref()
99            .and_then(|o| o.as_ref())
100        {
101            if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
102                debug_assert_eq!(ordered_entries.len(), chunk_group.entries_count());
103            }
104            Either::Left(Either::Left(ordered_entries.iter().copied()))
105        } else if let Some(chunk_group) = chunk_group_info.chunk_groups.get_index(idx) {
106            Either::Right(chunk_group.entries())
107        } else {
108            Either::Left(Either::Right(std::iter::empty()))
109        }
110    }
111
112    pub fn get_batch_group(
113        &self,
114        module_or_batch: &ModuleOrBatch,
115    ) -> Option<ResolvedVc<ModuleBatchGroup>> {
116        self.batch_groups.get(module_or_batch).copied()
117    }
118
119    pub async fn get_entry(&self, entry: ResolvedVc<Box<dyn Module>>) -> Result<ModuleOrBatch> {
120        let entry = self.get_entry_index(entry).await?;
121        Ok(*self.graph.node_weight(entry).unwrap())
122    }
123
124    // Clippy complains but there's a type error without the bound
125    #[allow(clippy::implied_bounds_in_impls)]
126    /// Traverses all reachable edges in topological order. The preorder visitor can be used to
127    /// forward state down the graph, and to skip subgraphs
128    ///
129    /// Use this to collect batches/modules in evaluation order.
130    ///
131    /// Target nodes can be revisited (once per incoming edge).
132    /// Edges are traversed in normal order, so should correspond to reference order.
133    ///
134    /// * `entry` - The entry module to start the traversal from
135    /// * `state` - The state to be passed to the visitors
136    /// * `visit_preorder` - Called before visiting the children of a node.
137    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
138    ///      &ModuleBatchesGraphNode, state &S
139    ///    - Can return [GraphTraversalAction]s to control the traversal
140    /// * `visit_postorder` - Called after visiting the children of a node. Return
141    ///    - Receives: (originating &ModuleBatchesGraphNode, edge &ChunkingType), target
142    ///      &ModuleBatchesGraphNode, state &S
143    pub fn traverse_edges_from_entries_topological<'a, S>(
144        &'a self,
145        entries: impl IntoIterator<
146            Item = NodeIndex,
147            IntoIter = impl Iterator<Item = NodeIndex> + DoubleEndedIterator,
148        >,
149        state: &mut S,
150        mut visit_preorder: impl FnMut(
151            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
152            &'a ModuleOrBatch,
153            &mut S,
154        ) -> Result<GraphTraversalAction>,
155        mut visit_postorder: impl FnMut(
156            Option<(&'a ModuleOrBatch, &'a ModuleBatchesGraphEdge)>,
157            &'a ModuleOrBatch,
158            &mut S,
159        ),
160    ) -> Result<()> {
161        let graph = &self.graph;
162
163        enum ReverseTopologicalPass {
164            Visit,
165            ExpandAndVisit,
166        }
167
168        let entries = entries.into_iter();
169        #[allow(clippy::type_complexity)] // This is a temporary internal structure
170        let mut stack: Vec<(
171            ReverseTopologicalPass,
172            Option<(NodeIndex, EdgeIndex)>,
173            NodeIndex,
174        )> = entries
175            .rev()
176            .map(|e| (ReverseTopologicalPass::ExpandAndVisit, None, e))
177            .collect();
178        let mut expanded = FxHashSet::default();
179        while let Some((pass, parent, current)) = stack.pop() {
180            let parent_arg = parent.map(|(node, edge)| {
181                (
182                    graph.node_weight(node).unwrap(),
183                    graph.edge_weight(edge).unwrap(),
184                )
185            });
186            match pass {
187                ReverseTopologicalPass::Visit => {
188                    let current_node = graph.node_weight(current).unwrap();
189                    visit_postorder(parent_arg, current_node, state);
190                }
191                ReverseTopologicalPass::ExpandAndVisit => {
192                    let current_node = graph.node_weight(current).unwrap();
193                    let action = visit_preorder(parent_arg, current_node, state)?;
194                    if action == GraphTraversalAction::Exclude {
195                        continue;
196                    }
197                    stack.push((ReverseTopologicalPass::Visit, parent, current));
198                    if action == GraphTraversalAction::Continue && expanded.insert(current) {
199                        stack.extend(iter_neighbors_rev(graph, current).map(|(edge, child)| {
200                            (
201                                ReverseTopologicalPass::ExpandAndVisit,
202                                Some((current, edge)),
203                                child,
204                            )
205                        }));
206                    }
207                }
208            }
209        }
210
211        Ok(())
212    }
213}
214
215type PreBatchIndex = usize;
216
217#[derive(Hash, PartialEq, Eq, Clone, Debug)]
218enum PreBatchItem {
219    ParallelModule(ResolvedVc<Box<dyn Module>>),
220    ParallelReference(PreBatchIndex),
221    NonParallelEdge(ChunkingType, ResolvedVc<Box<dyn Module>>),
222}
223
224struct PreBatch {
225    items: FxIndexSet<PreBatchItem>,
226    chunk_groups: RoaringBitmapWrapper,
227}
228
229impl PreBatch {
230    fn new(chunk_groups: RoaringBitmapWrapper) -> Self {
231        Self {
232            items: FxIndexSet::default(),
233            chunk_groups,
234        }
235    }
236}
237
238struct TraversalState<'l> {
239    items: Vec<PreBatchItem>,
240    this: &'l mut PreBatches,
241}
242
243struct PreBatches {
244    boundary_modules: FxHashSet<ResolvedVc<Box<dyn Module>>>,
245    batches: Vec<PreBatch>,
246    entries: FxHashMap<ResolvedVc<Box<dyn Module>>, PreBatchIndex>,
247    single_module_entries: FxIndexSet<ResolvedVc<Box<dyn Module>>>,
248}
249
250impl PreBatches {
251    fn new() -> Self {
252        Self {
253            boundary_modules: FxHashSet::default(),
254            batches: Vec::new(),
255            entries: FxHashMap::default(),
256            single_module_entries: FxIndexSet::default(),
257        }
258    }
259
260    fn ensure_pre_batch_for_module(
261        &mut self,
262        module: ResolvedVc<Box<dyn Module>>,
263        chunk_group_info: &ChunkGroupInfo,
264        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
265    ) -> Result<PreBatchIndex> {
266        Ok(match self.entries.entry(module) {
267            Entry::Vacant(e) => {
268                let index = self.batches.len();
269                queue.push_back((module, index));
270                let chunk_groups = chunk_group_info
271                    .module_chunk_groups
272                    .get(&module)
273                    .context("all modules need to have chunk group info")?;
274                let batch = PreBatch::new(chunk_groups.clone());
275                self.batches.push(batch);
276                e.insert(index);
277                index
278            }
279            Entry::Occupied(e) => *e.get(),
280        })
281    }
282
283    async fn get_pre_batch_items(
284        &mut self,
285        entry: ResolvedVc<Box<dyn Module>>,
286        chunk_group_info: &ChunkGroupInfo,
287        module_graph: &ModuleGraph,
288        queue: &mut VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)>,
289    ) -> Result<Vec<PreBatchItem>> {
290        let mut state = TraversalState {
291            items: Vec::new(),
292            this: self,
293        };
294        let mut visited = FxHashSet::default();
295        module_graph
296            .traverse_edges_from_entries_topological(
297                std::iter::once(ResolvedVc::upcast(entry)),
298                &mut state,
299                |parent_info, node, state| {
300                    let ty = parent_info.map_or(
301                        &ChunkingType::Parallel {
302                            inherit_async: false,
303                            hoisted: false,
304                        },
305                        |(_, ty)| &ty.chunking_type,
306                    );
307                    let module = node.module;
308                    if !ty.is_parallel() {
309                        state.items.push(PreBatchItem::NonParallelEdge(
310                            ty.without_inherit_async(),
311                            module,
312                        ));
313                        return Ok(GraphTraversalAction::Exclude);
314                    }
315                    if visited.insert(module) {
316                        if parent_info.is_some() && state.this.boundary_modules.contains(&module) {
317                            let idx = state.this.ensure_pre_batch_for_module(
318                                module,
319                                chunk_group_info,
320                                queue,
321                            )?;
322                            state.items.push(PreBatchItem::ParallelReference(idx));
323                            return Ok(GraphTraversalAction::Exclude);
324                        }
325                        Ok(GraphTraversalAction::Continue)
326                    } else {
327                        Ok(GraphTraversalAction::Exclude)
328                    }
329                },
330                |_, node, state| {
331                    let item = PreBatchItem::ParallelModule(node.module);
332                    state.items.push(item);
333                    Ok(())
334                },
335            )
336            .await?;
337        Ok(state.items)
338    }
339}
340
341pub async fn compute_module_batches(
342    module_graph: Vc<ModuleGraph>,
343    _config: &BatchingConfig,
344) -> Result<Vc<ModuleBatchesGraph>> {
345    let outer_span = tracing::info_span!(
346        "compute module batches",
347        initial_pre_batch_items = tracing::field::Empty,
348        initial_pre_batches = tracing::field::Empty,
349        extracted_shared_items = tracing::field::Empty,
350        batches = tracing::field::Empty,
351        modules = tracing::field::Empty,
352        edges = tracing::field::Empty
353    );
354    let span = outer_span.clone();
355    async move {
356        let chunk_group_info = module_graph.chunk_group_info().await?;
357        let module_graph = module_graph.await?;
358
359        let mut pre_batches = PreBatches::new();
360
361        // Walk the module graph and mark all modules that are boundary modules (referenced from a
362        // different chunk group bitmap)
363        module_graph
364            .traverse_all_edges_unordered(|(parent, ty), node| {
365                let std::collections::hash_set::Entry::Vacant(entry) =
366                    pre_batches.boundary_modules.entry(node.module)
367                else {
368                    // Already a boundary module, can skip check
369                    return Ok(());
370                };
371                if ty.chunking_type.is_parallel() {
372                    let parent_chunk_groups = chunk_group_info
373                        .module_chunk_groups
374                        .get(&parent.module)
375                        .context("all modules need to have chunk group info")?;
376                    let chunk_groups = chunk_group_info
377                        .module_chunk_groups
378                        .get(&node.module)
379                        .context("all modules need to have chunk group info")?;
380                    if parent_chunk_groups != chunk_groups {
381                        // This is a boundary module
382                        entry.insert();
383                    }
384                } else {
385                    entry.insert();
386                }
387                Ok(())
388            })
389            .await?;
390
391        // All entries are boundary modules too
392        for chunk_group in &chunk_group_info.chunk_groups {
393            for entry in chunk_group.entries() {
394                pre_batches.boundary_modules.insert(entry);
395            }
396        }
397
398        // Pre batches would be incorrect with cycles, so we need to opt-out of pre batches for
399        // cycles that include boundary modules
400        module_graph
401            .traverse_cycles(
402                |ref_data| ref_data.chunking_type.is_parallel(),
403                |cycle| {
404                    if cycle
405                        .iter()
406                        .any(|node| pre_batches.boundary_modules.contains(&node.module))
407                    {
408                        pre_batches
409                            .boundary_modules
410                            .extend(cycle.iter().map(|node| node.module));
411                    }
412                },
413            )
414            .await?;
415
416        let mut queue: VecDeque<(ResolvedVc<Box<dyn Module>>, PreBatchIndex)> = VecDeque::new();
417
418        let mut chunk_group_indices_with_merged_children = FxHashSet::default();
419
420        // Start with the entries
421        for chunk_group in &chunk_group_info.chunk_groups {
422            for entry in chunk_group.entries() {
423                if let Some(chunkable_module) = ResolvedVc::try_downcast(entry) {
424                    pre_batches.ensure_pre_batch_for_module(
425                        chunkable_module,
426                        &chunk_group_info,
427                        &mut queue,
428                    )?;
429                } else {
430                    pre_batches.single_module_entries.insert(entry);
431                }
432            }
433            if let Some(parent) = chunk_group.get_merged_parent() {
434                chunk_group_indices_with_merged_children.insert(parent);
435            }
436        }
437
438        let mut initial_pre_batch_items = 0;
439        // Fill all pre batches
440        while let Some((chunkable_module, idx)) = queue.pop_front() {
441            let items = pre_batches
442                .get_pre_batch_items(
443                    chunkable_module,
444                    &chunk_group_info,
445                    &module_graph,
446                    &mut queue,
447                )
448                .await?;
449            initial_pre_batch_items += items.len();
450            let batch = &mut pre_batches.batches[idx];
451            batch.items.extend(items);
452        }
453        span.record("initial_pre_batch_items", initial_pre_batch_items);
454        span.record("initial_pre_batches", pre_batches.batches.len());
455
456        // Figure out the order of all merged groups
457        let mut ordered_entries: Vec<Option<EntriesList>> =
458            vec![None; chunk_group_info.chunk_groups.len()];
459        for (i, chunk_group) in chunk_group_info.chunk_groups.iter().enumerate() {
460            if !chunk_group_indices_with_merged_children.contains(&i) {
461                continue;
462            }
463            let mut merged_modules: FxHashMap<ChunkingType, FxIndexSet<_>> = FxHashMap::default();
464            let mut stack = ordered_entries[i]
465                .as_ref()
466                .map_or_else(
467                    || Either::Left(chunk_group.entries()),
468                    |v| Either::Right(v.iter().copied()),
469                )
470                .filter_map(|module| {
471                    if let Some(chunkable_module) = ResolvedVc::try_downcast(module) {
472                        let idx = *pre_batches.entries.get(&chunkable_module).unwrap();
473                        Some((idx, 0))
474                    } else {
475                        None
476                    }
477                })
478                .collect::<Vec<_>>();
479            stack.reverse();
480            let mut visited = FxHashSet::default();
481            while let Some((idx, mut pos)) = stack.pop() {
482                let batch = &pre_batches.batches[idx];
483                while let Some(item) = batch.items.get_index(pos) {
484                    match item {
485                        PreBatchItem::ParallelModule(_) => {}
486                        PreBatchItem::ParallelReference(other_idx) => {
487                            if visited.insert(*other_idx) {
488                                stack.push((idx, pos + 1));
489                                stack.push((*other_idx, 0));
490                                break;
491                            }
492                        }
493                        PreBatchItem::NonParallelEdge(chunking_type, module) => {
494                            if chunking_type.is_merged() {
495                                merged_modules
496                                    .entry(chunking_type.clone())
497                                    .or_default()
498                                    .insert(*module);
499                            }
500                        }
501                    }
502                    pos += 1;
503                }
504            }
505            if !merged_modules.is_empty() {
506                for (ty, merged_modules) in merged_modules {
507                    let chunk_group_key = match ty {
508                        ChunkingType::Isolated {
509                            merge_tag: Some(merge_tag),
510                            ..
511                        } => ChunkGroupKey::IsolatedMerged {
512                            parent: i.into(),
513                            merge_tag: merge_tag.clone(),
514                        },
515                        ChunkingType::Shared {
516                            merge_tag: Some(merge_tag),
517                            ..
518                        } => ChunkGroupKey::SharedMerged {
519                            parent: i.into(),
520                            merge_tag: merge_tag.clone(),
521                        },
522                        _ => unreachable!(),
523                    };
524                    let idx = chunk_group_info
525                        .chunk_group_keys
526                        .get_index_of(&chunk_group_key)
527                        .unwrap();
528                    ordered_entries[idx] = Some(merged_modules);
529                }
530            }
531        }
532
533        // Create a map of parallel module to the batches they are contained in.
534        let mut parallel_module_to_pre_batch: FxIndexMap<_, Vec<PreBatchIndex>> =
535            FxIndexMap::default();
536
537        // Fill the map and also fill up the single_module_entries
538        for (idx, pre_batch) in pre_batches.batches.iter().enumerate() {
539            for item in &pre_batch.items {
540                match item {
541                    PreBatchItem::ParallelModule(module) => {
542                        parallel_module_to_pre_batch
543                            .entry(*module)
544                            .or_default()
545                            .push(idx);
546                    }
547                    PreBatchItem::NonParallelEdge(_, module) => {
548                        if let Some(chunkable_module) = ResolvedVc::try_downcast(*module) {
549                            if !pre_batches.entries.contains_key(&chunkable_module) {
550                                pre_batches.single_module_entries.insert(*module);
551                            }
552                        } else {
553                            pre_batches.single_module_entries.insert(*module);
554                        }
555                    }
556                    PreBatchItem::ParallelReference(_) => {}
557                }
558            }
559        }
560
561        // We never want a module to occur in multiple batches.
562
563        let mut extracted_shared_items = 0;
564        // Extract shared modules into separate batches
565        for i in 0..parallel_module_to_pre_batch.len() {
566            let (&module, batches) = parallel_module_to_pre_batch.get_index(i).unwrap();
567            if batches.len() > 1 {
568                // Create a new batch for the shared modules
569                let batches_with_item_index = batches
570                    .iter()
571                    .map(|&idx| {
572                        let batch_items = &pre_batches.batches[idx].items;
573                        let item_idx = batch_items
574                            .get_index_of(&PreBatchItem::ParallelModule(module))
575                            .unwrap();
576                        (idx, item_idx)
577                    })
578                    .collect::<Vec<_>>();
579                let mut selected_items = 1;
580                fn get_item_at(
581                    pre_batches: &PreBatches,
582                    batch_idx: PreBatchIndex,
583                    item_idx: usize,
584                ) -> Option<&PreBatchItem> {
585                    pre_batches.batches[batch_idx].items.get_index(item_idx)
586                }
587                // Select more matching items that are equal in all batches that contain the shared
588                // module(s)
589                loop {
590                    if let Some(PreBatchItem::ParallelModule(next_module)) = get_item_at(
591                        &pre_batches,
592                        batches_with_item_index[0].0,
593                        batches_with_item_index[0].1 + selected_items,
594                    ) && parallel_module_to_pre_batch.get(next_module).unwrap().len()
595                        == batches.len()
596                        && batches_with_item_index[1..]
597                            .iter()
598                            .all(|&(batch_idx, item_idx)| {
599                                get_item_at(&pre_batches, batch_idx, item_idx + selected_items)
600                                    == Some(&PreBatchItem::ParallelModule(*next_module))
601                            })
602                    {
603                        selected_items += 1;
604                        continue;
605                    }
606                    break;
607                }
608                extracted_shared_items += selected_items;
609
610                // Check if a batch is completely selected. In that case we can replace all other
611                // occurrences with a reference to that batch
612                let exact_match = batches_with_item_index
613                    .iter()
614                    .find(|&&(batch_idx, item_idx)| {
615                        item_idx == 0
616                            && pre_batches.batches[batch_idx].items.len() == selected_items
617                    });
618                if let Some(&(exact_match, _)) = exact_match {
619                    // Replace all other occurrences with a reference to the exact match
620                    for &(batch_index, item_start) in batches_with_item_index.iter() {
621                        if batch_index != exact_match {
622                            pre_batches.batches[batch_index].items.splice(
623                                item_start..item_start + selected_items,
624                                std::iter::once(PreBatchItem::ParallelReference(exact_match)),
625                            );
626                        }
627                    }
628                    for item in pre_batches.batches[exact_match].items.iter() {
629                        if let PreBatchItem::ParallelModule(module) = item {
630                            parallel_module_to_pre_batch
631                                .get_mut(module)
632                                .unwrap()
633                                .clear();
634                        }
635                    }
636                } else {
637                    // Create a new batch of the shared part and replace all occurrences with a
638                    // reference to that batch
639                    let first_batch_index = batches_with_item_index[0].0;
640                    let first_batch_item_index = batches_with_item_index[0].1;
641                    let new_batch_index = pre_batches.batches.len();
642                    let mut new_batch =
643                        PreBatch::new(pre_batches.batches[first_batch_index].chunk_groups.clone());
644                    new_batch
645                        .items
646                        .extend(pre_batches.batches[first_batch_index].items.splice(
647                            first_batch_item_index..first_batch_item_index + selected_items,
648                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
649                        ));
650                    for item in new_batch.items.iter() {
651                        if let PreBatchItem::ParallelModule(module) = item {
652                            parallel_module_to_pre_batch
653                                .get_mut(module)
654                                .unwrap()
655                                .clear();
656                        }
657                    }
658                    pre_batches.batches.push(new_batch);
659                    for &(batch_index, item_start) in batches_with_item_index[1..].iter() {
660                        pre_batches.batches[batch_index].items.splice(
661                            item_start..item_start + selected_items,
662                            std::iter::once(PreBatchItem::ParallelReference(new_batch_index)),
663                        );
664                    }
665                }
666            }
667        }
668        span.record("extracted_shared_items", extracted_shared_items);
669
670        // Now every module is only in one batch
671
672        let mut edges_count = 0;
673
674        // Since batches can only have references followed by a list of parallel chunkable modules,
675        // we need to split batches that have modules before references.
676        for i in 0..pre_batches.batches.len() {
677            let items = take(&mut pre_batches.batches[i].items);
678            let mut new_items =
679                FxIndexSet::with_capacity_and_hasher(items.len(), Default::default());
680            enum Mode {
681                ParallelChunkableModule,
682                Other,
683            }
684            let mut mode = Mode::Other;
685            for item in items {
686                let chunkable_module = if let PreBatchItem::ParallelModule(module) = &item {
687                    ResolvedVc::try_downcast::<Box<dyn ChunkableModule>>(*module)
688                } else {
689                    None
690                };
691                let item = if let PreBatchItem::ParallelModule(module) = item {
692                    if chunkable_module.is_some() {
693                        PreBatchItem::ParallelModule(module)
694                    } else {
695                        pre_batches.single_module_entries.insert(module);
696                        PreBatchItem::NonParallelEdge(
697                            ChunkingType::Parallel {
698                                inherit_async: false,
699                                hoisted: false,
700                            },
701                            module,
702                        )
703                    }
704                } else {
705                    item
706                };
707                match (&mode, chunkable_module) {
708                    (_, Some(_)) => {
709                        mode = Mode::ParallelChunkableModule;
710                        new_items.insert(item);
711                    }
712                    (Mode::Other, _) => {
713                        edges_count += 1;
714                        new_items.insert(item);
715                    }
716                    (Mode::ParallelChunkableModule, _) => {
717                        // Split the batch
718                        let idx = pre_batches.batches.len();
719                        let mut new_batch =
720                            PreBatch::new(pre_batches.batches[i].chunk_groups.clone());
721                        new_batch.items.extend(new_items.drain(..));
722                        pre_batches.batches.push(new_batch);
723                        edges_count += 1;
724                        new_items.insert(PreBatchItem::ParallelReference(idx));
725                        if chunkable_module.is_some() {
726                            new_items.insert(item);
727                        } else {
728                            edges_count += 1;
729                            mode = Mode::Other;
730                            new_items.insert(item);
731                        }
732                    }
733                }
734            }
735            pre_batches.batches[i].items = new_items;
736        }
737        span.record("pre_batches", pre_batches.batches.len());
738
739        // Now batches are in the correct shape. We can create the real batches and the graph.
740
741        // Create the graph
742        let mut graph: DiGraph<ModuleOrBatch, ModuleBatchesGraphEdge, u32> =
743            petgraph::graph::DiGraph::with_capacity(
744                pre_batches.batches.len() + pre_batches.single_module_entries.len(),
745                edges_count,
746            );
747
748        // Create the Vc<ModuleBatch> instances
749        let batches = pre_batches
750            .batches
751            .iter_mut()
752            .enumerate()
753            .map(async |(i, pre_batch)| {
754                let mut modules = pre_batch.items.iter().filter_map(|item| {
755                    if let PreBatchItem::ParallelModule(module) = item {
756                        ResolvedVc::try_downcast(*module)
757                    } else {
758                        None
759                    }
760                });
761                let Some(first) = modules.next() else {
762                    return Ok(ModuleOrBatch::None(i));
763                };
764                if let Some(second) = modules.next() {
765                    let batch = ModuleBatch::new(
766                        [first, second]
767                            .into_iter()
768                            .chain(modules)
769                            .map(|m| *m)
770                            .collect::<Vec<_>>(),
771                        Some(pre_batch.chunk_groups.clone()),
772                    );
773                    Ok(ModuleOrBatch::Batch(batch.to_resolved().await?))
774                } else {
775                    Ok(ModuleOrBatch::Module(ResolvedVc::upcast(first)))
776                }
777            })
778            .try_join()
779            .await?;
780
781        // Create the batch groups by grouping batches with the same chunk groups
782        let mut batch_groups: FxHashMap<_, Vec<_>> = FxHashMap::default();
783        for (i, pre_batch) in pre_batches.batches.iter().enumerate() {
784            let key =
785                BuildHasherDefault::<FxHasher>::default().prehash(pre_batch.chunk_groups.clone());
786            let batch = batches[i];
787            batch_groups.entry(key).or_default().push(batch);
788        }
789        for &module in &pre_batches.single_module_entries {
790            let chunk_groups = chunk_group_info
791                .module_chunk_groups
792                .get(&module)
793                .context("all modules need to have chunk group info")?;
794            let key = BuildHasherDefault::<FxHasher>::default().prehash(chunk_groups.clone());
795            batch_groups
796                .entry(key)
797                .or_default()
798                .push(ModuleOrBatch::Module(module));
799        }
800
801        // Create the batch group instances
802        let batch_groups = batch_groups
803            .into_iter()
804            .map(async |(key, items)| {
805                if items.len() == 1 {
806                    Ok(Either::Left(std::iter::empty()))
807                } else {
808                    let batch_group = ModuleBatchGroup::new(items.clone(), (*key).clone())
809                        .to_resolved()
810                        .await?;
811                    Ok(Either::Right(
812                        items.into_iter().map(move |item| (item, batch_group)),
813                    ))
814                }
815            })
816            .try_join()
817            .await?
818            .into_iter()
819            .flatten()
820            .collect::<FxHashMap<_, _>>();
821
822        // Insert batches into the graph and store the NodeIndices
823        let mut batches_count = 0;
824        let mut modules_count = 0;
825        let batch_indices = batches
826            .into_iter()
827            .map(|batch| {
828                match &batch {
829                    ModuleOrBatch::Batch(_) => batches_count += 1,
830                    ModuleOrBatch::Module(_) => modules_count += 1,
831                    ModuleOrBatch::None(_) => {}
832                }
833                graph.add_node(batch)
834            })
835            .collect::<Vec<_>>();
836
837        // Also insert single modules into the graph and store the NodeIndices
838        let single_module_indices = pre_batches
839            .single_module_entries
840            .iter()
841            .map(|module| graph.add_node(ModuleOrBatch::Module(*module)))
842            .collect::<Vec<_>>();
843
844        span.record("batches", batches_count);
845        modules_count += pre_batches.single_module_entries.len();
846        span.record("modules", modules_count);
847        span.record("edges", edges_count);
848
849        // Add all the edges to the graph
850        for (i, pre_batch) in pre_batches.batches.into_iter().enumerate() {
851            let index = batch_indices[i];
852            let items = pre_batch.items;
853            for item in items {
854                match item {
855                    PreBatchItem::ParallelReference(idx) => {
856                        graph.add_edge(
857                            index,
858                            batch_indices[idx],
859                            ModuleBatchesGraphEdge {
860                                ty: ChunkingType::Parallel {
861                                    inherit_async: false,
862                                    hoisted: false,
863                                },
864                                module: None,
865                            },
866                        );
867                    }
868                    PreBatchItem::NonParallelEdge(ty, module) => {
869                        if let Some(chunkable_module) = ResolvedVc::try_downcast(module)
870                            && let Some(batch) = pre_batches.entries.get(&chunkable_module).copied()
871                        {
872                            graph.add_edge(
873                                index,
874                                batch_indices[batch],
875                                ModuleBatchesGraphEdge {
876                                    ty,
877                                    module: Some(module),
878                                },
879                            );
880                            continue;
881                        }
882                        let idx = pre_batches
883                            .single_module_entries
884                            .get_index_of(&module)
885                            .unwrap();
886                        let idx = single_module_indices[idx];
887                        graph.add_edge(
888                            index,
889                            idx,
890                            ModuleBatchesGraphEdge {
891                                ty,
892                                module: Some(module),
893                            },
894                        );
895                    }
896                    PreBatchItem::ParallelModule(_) => {}
897                }
898            }
899        }
900
901        debug_assert_eq!(graph.capacity().0, graph.node_count());
902        debug_assert_eq!(graph.capacity().1, graph.edge_count());
903
904        // Find the NodeIndices for our entries of the graph
905        let mut entries = FxHashMap::default();
906        for chunk_group in &chunk_group_info.chunk_groups {
907            for module in chunk_group.entries() {
908                if let Some(chunkable_module) = ResolvedVc::try_downcast(module)
909                    && let Some(batch) = pre_batches.entries.get(&chunkable_module).copied()
910                {
911                    entries.insert(module, batch_indices[batch]);
912                    continue;
913                }
914                let idx = pre_batches
915                    .single_module_entries
916                    .get_index_of(&module)
917                    .unwrap();
918                let idx = single_module_indices[idx];
919                entries.insert(module, idx);
920            }
921        }
922
923        Ok(ModuleBatchesGraph {
924            graph: TracedDiGraph(graph),
925            entries,
926            batch_groups,
927            ordered_entries,
928        }
929        .cell())
930    }
931    .instrument(outer_span)
932    .await
933}