Skip to main content

turbo_tasks/graph/
graph_traversal.rs

1use std::future::Future;
2
3use anyhow::Result;
4use futures::{StreamExt, stream::FuturesUnordered};
5use rustc_hash::FxHashSet;
6
7use super::{Visit, VisitControlFlow, graph_store::GraphStore, with_future::With};
8
9/// A list of modules that were already visited and should be skipped (including their subgraphs).
10#[derive(Clone, Default, Debug)]
11pub struct VisitedNodes<T>(pub FxHashSet<T>);
12
13/// [`GraphTraversal`] is a utility type that can be used to traverse a graph of
14/// nodes, where each node can have a variable number of outgoing edges.
15///
16/// The traversal is done in parallel, and the order of the nodes in the traversal
17/// result is determined by the [`GraphStore`] parameter.
18pub trait GraphTraversal: GraphStore + Sized {
19    fn visit<VisitImpl, Impl>(
20        self,
21        root_nodes: impl IntoIterator<Item = Self::Node>,
22        visit: VisitImpl,
23    ) -> impl Future<Output = GraphTraversalResult<Result<Self>>> + Send
24    where
25        VisitImpl: Visit<Self::Node, Self::Edge, Impl> + Send,
26        Impl: Send;
27}
28
29impl<Store> GraphTraversal for Store
30where
31    Store: GraphStore,
32{
33    /// Visits the graph starting from the given `roots`, and returns a future
34    /// that will resolve to the traversal result.
35    fn visit<VisitImpl, Impl>(
36        mut self,
37        root_nodes: impl IntoIterator<Item = Self::Node>,
38        mut visit: VisitImpl,
39    ) -> impl Future<Output = GraphTraversalResult<Result<Self>>> + Send
40    where
41        VisitImpl: Visit<Self::Node, Self::Edge, Impl> + Send,
42        Impl: Send,
43    {
44        let mut futures = FuturesUnordered::new();
45
46        // Populate `futures` with all the roots, `root_nodes` isn't required to be `Send`, so this
47        // has to happen outside of the future. We could require `root_nodes` to be `Send` in the
48        // future.
49        for node in root_nodes {
50            match visit.visit(&node, None) {
51                VisitControlFlow::Continue => {
52                    if let Some(handle) = self.try_enter(&node) {
53                        let span = visit.span(&node, None);
54                        futures.push(With::new(visit.edges(&node), span, handle));
55                    }
56                    self.insert(None, node);
57                }
58                VisitControlFlow::Skip => {
59                    self.insert(None, node);
60                }
61                VisitControlFlow::Exclude => {
62                    // do nothing
63                }
64            }
65        }
66
67        async move {
68            let mut result = Ok(());
69            loop {
70                match futures.next().await {
71                    Some((parent_node, span, Ok(edges))) => {
72                        let _guard = span.enter();
73                        for (node, edge) in edges {
74                            match visit.visit(&node, Some(&edge)) {
75                                VisitControlFlow::Continue => {
76                                    if let Some(handle) = self.try_enter(&node) {
77                                        let span = visit.span(&node, Some(&edge));
78                                        let edges_future = visit.edges(&node);
79                                        futures.push(With::new(edges_future, span, handle));
80                                    }
81                                    self.insert(Some((&parent_node, edge)), node);
82                                }
83                                VisitControlFlow::Skip => {
84                                    self.insert(Some((&parent_node, edge)), node);
85                                }
86                                VisitControlFlow::Exclude => {
87                                    // do nothing
88                                }
89                            }
90                        }
91                    }
92                    Some((_, _, Err(err))) => {
93                        result = Err(err);
94                    }
95                    None => {
96                        return GraphTraversalResult::Completed(result.map(|()| self));
97                    }
98                }
99            }
100        }
101    }
102}
103
104pub enum GraphTraversalResult<Completed> {
105    Completed(Completed),
106}
107
108impl<Completed> GraphTraversalResult<Completed> {
109    pub fn completed(self) -> Completed {
110        match self {
111            GraphTraversalResult::Completed(completed) => completed,
112        }
113    }
114}