turbo_static/
main.rs

1use std::{
2    error::Error,
3    fs,
4    path::PathBuf,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9};
10
11use call_resolver::CallResolver;
12use clap::Parser;
13use identifier::{Identifier, IdentifierReference};
14use itertools::Itertools;
15use rustc_hash::{FxHashMap, FxHashSet};
16use syn::visit::Visit;
17use visitor::CallingStyleVisitor;
18
19use crate::visitor::CallingStyle;
20
21mod call_resolver;
22mod identifier;
23mod lsp_client;
24mod visitor;
25
26#[derive(Parser)]
27struct Opt {
28    #[clap(required = true)]
29    paths: Vec<PathBuf>,
30
31    /// reparse all files
32    #[clap(long)]
33    reparse: bool,
34
35    /// reindex all files
36    #[clap(long)]
37    reindex: bool,
38}
39
40fn main() -> Result<(), Box<dyn Error>> {
41    tracing_subscriber::fmt::init();
42    let opt = Opt::parse();
43
44    let mut connection = lsp_client::RAClient::new();
45    connection.start(&opt.paths);
46
47    let call_resolver = CallResolver::new(&mut connection, Some("call_resolver.bincode".into()));
48    let mut call_resolver = if opt.reindex {
49        call_resolver.cleared()
50    } else {
51        call_resolver
52    };
53
54    let halt = Arc::new(AtomicBool::new(false));
55    let halt_clone = halt.clone();
56    ctrlc::set_handler({
57        move || {
58            halt_clone.store(true, Ordering::SeqCst);
59        }
60    })?;
61
62    tracing::info!("getting tasks");
63    let mut tasks = get_all_tasks(&opt.paths);
64    let dep_tree = resolve_tasks(&mut tasks, &mut call_resolver, halt.clone());
65    let concurrency = resolve_concurrency(&tasks, &dep_tree, halt.clone());
66
67    write_dep_tree(&tasks, concurrency, std::path::Path::new("graph.cypherl"));
68
69    if halt.load(Ordering::Relaxed) {
70        tracing::info!("ctrl-c detected, exiting");
71    }
72
73    Ok(())
74}
75
76/// search the given folders recursively and attempt to find all tasks inside
77#[tracing::instrument(skip_all)]
78fn get_all_tasks(folders: &[PathBuf]) -> FxHashMap<Identifier, Vec<String>> {
79    let mut out = FxHashMap::default();
80
81    for folder in folders {
82        let walker = ignore::Walk::new(folder);
83        for entry in walker {
84            let entry = entry.unwrap();
85            let rs_file = if let Some(true) = entry.file_type().map(|t| t.is_file()) {
86                let path = entry.path();
87                let ext = path.extension().unwrap_or_default();
88                if ext == "rs" {
89                    std::fs::canonicalize(path).unwrap()
90                } else {
91                    continue;
92                }
93            } else {
94                continue;
95            };
96
97            let file = fs::read_to_string(&rs_file).unwrap();
98            let lines = file.lines();
99            let mut occurences = vec![];
100
101            tracing::debug!("processing {}", rs_file.display());
102
103            for ((_, line), (line_no, _)) in lines.enumerate().tuple_windows() {
104                if line.contains("turbo_tasks::function") {
105                    tracing::debug!("found at {:?}:L{}", rs_file, line_no);
106                    occurences.push(line_no + 1);
107                }
108            }
109
110            if occurences.is_empty() {
111                continue;
112            }
113
114            // parse the file using syn and get the span of the functions
115            let file = syn::parse_file(&file).unwrap();
116            let occurences_count = occurences.len();
117            let mut visitor = visitor::TaskVisitor::new();
118            syn::visit::visit_file(&mut visitor, &file);
119            if visitor.results.len() != occurences_count {
120                tracing::warn!(
121                    "file {:?} passed the heuristic with {:?} but the visitor found {:?}",
122                    rs_file,
123                    occurences_count,
124                    visitor.results.len()
125                );
126            }
127
128            out.extend(
129                visitor
130                    .results
131                    .into_iter()
132                    .map(move |(ident, tags)| ((rs_file.clone(), ident).into(), tags)),
133            )
134        }
135    }
136
137    out
138}
139
140/// Given a list of tasks, get all the tasks that call that one
141fn resolve_tasks(
142    tasks: &mut FxHashMap<Identifier, Vec<String>>,
143    client: &mut CallResolver,
144    halt: Arc<AtomicBool>,
145) -> FxHashMap<Identifier, Vec<IdentifierReference>> {
146    tracing::info!(
147        "found {} tasks, of which {} cached",
148        tasks.len(),
149        client.cached_count()
150    );
151
152    let mut unresolved = tasks.keys().cloned().collect::<FxHashSet<_>>();
153    let mut resolved = FxHashMap::default();
154
155    while let Some(top) = unresolved.iter().next().cloned() {
156        unresolved.remove(&top);
157
158        let callers = client.resolve(&top);
159
160        // add all non-task callers to the unresolved list if they are not in the
161        // resolved list
162        for caller in callers.iter() {
163            if !resolved.contains_key(&caller.identifier)
164                && !unresolved.contains(&caller.identifier)
165            {
166                tracing::debug!("adding {} to unresolved", caller.identifier);
167                unresolved.insert(caller.identifier.to_owned());
168            }
169        }
170        resolved.insert(top.to_owned(), callers);
171
172        if halt.load(Ordering::Relaxed) {
173            break;
174        }
175    }
176
177    resolved
178}
179
180/// given a map of tasks and functions that call it, produce a map of tasks and
181/// those tasks that it calls
182///
183/// returns a list of pairs with a task, the task that calls it, and the calling
184/// style
185fn resolve_concurrency(
186    task_list: &FxHashMap<Identifier, Vec<String>>,
187    dep_tree: &FxHashMap<Identifier, Vec<IdentifierReference>>, // pairs of tasks and call trees
188    halt: Arc<AtomicBool>,
189) -> Vec<(Identifier, Identifier, CallingStyle)> {
190    // println!("{:?}", dep_tree);
191    // println!("{:#?}", task_list);
192
193    let mut edges = vec![];
194
195    for (ident, references) in dep_tree {
196        for reference in references {
197            #[allow(clippy::map_entry)] // This doesn't insert into dep_tree, so entry isn't useful
198            if !dep_tree.contains_key(&reference.identifier) {
199                // this is a task that is not in the task list
200                // so we can't resolve it
201                tracing::error!("missing task for {}: {}", ident, reference.identifier);
202                for task in task_list.keys() {
203                    if task.name == reference.identifier.name {
204                        // we found a task that is not in the task list
205                        // so we can't resolve it
206                        tracing::trace!("- found {}", task);
207                        continue;
208                    }
209                }
210                continue;
211            } else {
212                // load the source file and get the calling style
213                let target = IdentifierReference {
214                    identifier: ident.clone(),
215                    references: reference.references.clone(),
216                };
217                let mut visitor = CallingStyleVisitor::new(target);
218                tracing::info!("looking for {} from {}", ident, reference.identifier);
219                let file =
220                    syn::parse_file(&fs::read_to_string(&reference.identifier.path).unwrap())
221                        .unwrap();
222                visitor.visit_file(&file);
223
224                edges.push((
225                    ident.clone(),
226                    reference.identifier.clone(),
227                    visitor.result().unwrap_or(CallingStyle::Once),
228                ));
229            }
230
231            if halt.load(Ordering::Relaxed) {
232                break;
233            }
234        }
235    }
236
237    // parse each fn between parent and child and get the max calling style
238
239    edges
240}
241
242/// Write the dep tree into the given file using cypher syntax
243fn write_dep_tree(
244    task_list: &FxHashMap<Identifier, Vec<String>>,
245    dep_tree: Vec<(Identifier, Identifier, CallingStyle)>,
246    out: &std::path::Path,
247) {
248    use std::io::Write;
249
250    let mut node_ids = FxHashMap::default();
251    let mut counter = 0;
252
253    let mut file = std::fs::File::create(out).unwrap();
254
255    let empty = vec![];
256
257    // collect all tasks as well as all intermediate nodes
258    // tasks come last to ensure the tags are preserved
259    let node_list = dep_tree
260        .iter()
261        .flat_map(|(dest, src, _)| [(src, &empty), (dest, &empty)])
262        .chain(task_list)
263        .collect::<FxHashMap<_, _>>();
264
265    for (ident, tags) in node_list {
266        counter += 1;
267
268        let label = if !task_list.contains_key(ident) {
269            "Function"
270        } else if tags.contains(&"fs".to_string()) || tags.contains(&"network".to_string()) {
271            "ImpureTask"
272        } else {
273            "Task"
274        };
275
276        _ = writeln!(
277            file,
278            "CREATE (n_{}:{} {{name: '{}', file: '{}', line: {}, tags: [{}]}})",
279            counter,
280            label,
281            ident.name,
282            ident.path,
283            ident.range.start.line,
284            tags.iter().map(|t| format!("\"{t}\"")).join(",")
285        );
286        node_ids.insert(ident, counter);
287    }
288
289    for (dest, src, style) in &dep_tree {
290        let style = match style {
291            CallingStyle::Once => "ONCE",
292            CallingStyle::ZeroOrOnce => "ZERO_OR_ONCE",
293            CallingStyle::ZeroOrMore => "ZERO_OR_MORE",
294            CallingStyle::OneOrMore => "ONE_OR_MORE",
295        };
296
297        let src_id = *node_ids.get(src).unwrap();
298        let dst_id = *node_ids.get(dest).unwrap();
299
300        _ = writeln!(file, "CREATE (n_{src_id})-[:{style}]->(n_{dst_id})",);
301    }
302}