turbo_tasks/graph/
graph_traversal.rs

1use std::future::Future;
2
3use anyhow::Result;
4use futures::{StreamExt, stream::FuturesUnordered};
5use rustc_hash::FxHashSet;
6
7use super::{
8    SkipDuplicates, Visit, VisitControlFlow,
9    graph_store::{GraphNode, GraphStore, SkipDuplicatesWithKey},
10    with_future::With,
11};
12
13/// A list of modules that were already visited and should be skipped (including their subgraphs).
14#[derive(Clone, Default, Debug)]
15pub struct VisitedNodes<T>(pub FxHashSet<T>);
16
17/// [`GraphTraversal`] is a utility type that can be used to traverse a graph of
18/// nodes, where each node can have a variable number of outgoing edges.
19///
20/// The traversal is done in parallel, and the order of the nodes in the traversal
21/// result is determined by the [`GraphStore`] parameter.
22pub trait GraphTraversal: GraphStore + Sized {
23    fn visit<VisitImpl, Abort, Impl>(
24        self,
25        root_edges: impl IntoIterator<Item = VisitImpl::Edge>,
26        visit: VisitImpl,
27    ) -> impl Future<Output = GraphTraversalResult<Result<Self>, Abort>> + Send
28    where
29        VisitImpl: Visit<Self::Node, Abort, Impl> + Send,
30        Abort: Send,
31        Impl: Send;
32
33    fn skip_duplicates(self) -> SkipDuplicates<Self>;
34    fn skip_duplicates_with_visited_nodes(
35        self,
36        visited: VisitedNodes<Self::Node>,
37    ) -> SkipDuplicates<Self>;
38
39    fn skip_duplicates_with_key<
40        Key: Send + Eq + std::hash::Hash + Clone,
41        KeyExtractor: Send + Fn(&Self::Node) -> &Key,
42    >(
43        self,
44        key_extractor: KeyExtractor,
45    ) -> SkipDuplicatesWithKey<Self, Key, KeyExtractor>;
46}
47
48impl<Store> GraphTraversal for Store
49where
50    Store: GraphStore,
51{
52    /// Visits the graph starting from the given `roots`, and returns a future
53    /// that will resolve to the traversal result.
54    fn visit<VisitImpl, Abort, Impl>(
55        mut self,
56        root_edges: impl IntoIterator<Item = VisitImpl::Edge>,
57        mut visit: VisitImpl,
58    ) -> impl Future<Output = GraphTraversalResult<Result<Self>, Abort>> + Send
59    where
60        VisitImpl: Visit<Self::Node, Abort, Impl> + Send,
61        Abort: Send,
62        Impl: Send,
63    {
64        let mut futures = FuturesUnordered::new();
65        let mut root_abort = None;
66
67        // Populate `futures` with all the roots, `root_edges` isn't required to be `Send`, so this
68        // has to happen outside of the future. We could require `root_edges` to be `Send` in the
69        // future.
70        for edge in root_edges {
71            match visit.visit(edge) {
72                VisitControlFlow::Continue(node) => {
73                    if let Some((parent_handle, node_ref)) = self.insert(None, GraphNode(node)) {
74                        let span = visit.span(node_ref);
75                        futures.push(With::new(visit.edges(node_ref), span, parent_handle));
76                    }
77                }
78                VisitControlFlow::Skip(node) => {
79                    self.insert(None, GraphNode(node));
80                }
81                VisitControlFlow::Abort(abort) => {
82                    // this must be returned inside the `async` block below so that it's part of the
83                    // returned future
84                    root_abort = Some(abort)
85                }
86            }
87        }
88
89        async move {
90            if let Some(abort) = root_abort {
91                return GraphTraversalResult::Aborted(abort);
92            }
93            loop {
94                match futures.next().await {
95                    Some((parent_handle, span, Ok(edges))) => {
96                        let _guard = span.enter();
97                        for edge in edges {
98                            match visit.visit(edge) {
99                                VisitControlFlow::Continue(node) => {
100                                    if let Some((node_handle, node_ref)) =
101                                        self.insert(Some(parent_handle.clone()), GraphNode(node))
102                                    {
103                                        let span = visit.span(node_ref);
104                                        futures.push(With::new(
105                                            visit.edges(node_ref),
106                                            span,
107                                            node_handle,
108                                        ));
109                                    }
110                                }
111                                VisitControlFlow::Skip(node) => {
112                                    self.insert(Some(parent_handle.clone()), GraphNode(node));
113                                }
114                                VisitControlFlow::Abort(abort) => {
115                                    return GraphTraversalResult::Aborted(abort);
116                                }
117                            }
118                        }
119                    }
120                    Some((_, _, Err(err))) => {
121                        return GraphTraversalResult::Completed(Err(err));
122                    }
123                    None => {
124                        return GraphTraversalResult::Completed(Ok(self));
125                    }
126                }
127            }
128        }
129    }
130
131    fn skip_duplicates(self) -> SkipDuplicates<Self> {
132        SkipDuplicates::new(self)
133    }
134
135    fn skip_duplicates_with_visited_nodes(
136        self,
137        visited: VisitedNodes<Store::Node>,
138    ) -> SkipDuplicates<Self> {
139        SkipDuplicates::new_with_visited_nodes(self, visited.0)
140    }
141
142    fn skip_duplicates_with_key<
143        Key: Send + Eq + std::hash::Hash + Clone,
144        KeyExtractor: Send + Fn(&Self::Node) -> &Key,
145    >(
146        self,
147        key_extractor: KeyExtractor,
148    ) -> SkipDuplicatesWithKey<Self, Key, KeyExtractor> {
149        SkipDuplicatesWithKey::new(self, key_extractor)
150    }
151}
152
153pub enum GraphTraversalResult<Completed, Aborted> {
154    Completed(Completed),
155    Aborted(Aborted),
156}
157
158impl<Completed> GraphTraversalResult<Completed, !> {
159    pub fn completed(self) -> Completed {
160        match self {
161            GraphTraversalResult::Completed(completed) => completed,
162            GraphTraversalResult::Aborted(_) => unreachable!("the type parameter `Aborted` is `!`"),
163        }
164    }
165}