turbo_static/
call_resolver.rs

1use std::{fs::OpenOptions, path::PathBuf};
2
3use rustc_hash::FxHashMap;
4
5use crate::{Identifier, IdentifierReference, lsp_client::RAClient};
6
7/// A wrapper around a rust-analyzer client that can resolve call references.
8/// This is quite expensive so we cache the results in an on-disk key-value
9/// store.
10pub struct CallResolver<'a> {
11    client: &'a mut RAClient,
12    state: FxHashMap<Identifier, Vec<IdentifierReference>>,
13    path: Option<PathBuf>,
14}
15
16/// On drop, serialize the state to disk
17impl Drop for CallResolver<'_> {
18    fn drop(&mut self) {
19        let file = OpenOptions::new()
20            .create(true)
21            .truncate(false)
22            .write(true)
23            .open(self.path.as_ref().unwrap())
24            .unwrap();
25        bincode::serialize_into(file, &self.state).unwrap();
26    }
27}
28
29impl<'a> CallResolver<'a> {
30    pub fn new(client: &'a mut RAClient, path: Option<PathBuf>) -> Self {
31        // load bincode-encoded FxHashMap from path
32        let state = path
33            .as_ref()
34            .and_then(|path| {
35                let file = OpenOptions::new()
36                    .create(true)
37                    .truncate(false)
38                    .read(true)
39                    .write(true)
40                    .open(path)
41                    .unwrap();
42                let reader = std::io::BufReader::new(file);
43                bincode::deserialize_from::<_, FxHashMap<Identifier, Vec<IdentifierReference>>>(
44                    reader,
45                )
46                .inspect_err(|_| {
47                    tracing::warn!("failed to load existing cache, restarting");
48                })
49                .ok()
50            })
51            .unwrap_or_default();
52        Self {
53            client,
54            state,
55            path,
56        }
57    }
58
59    pub fn cached_count(&self) -> usize {
60        self.state.len()
61    }
62
63    pub fn cleared(mut self) -> Self {
64        // delete file if exists and clear state
65        self.state = Default::default();
66        if let Some(path) = self.path.as_ref() {
67            std::fs::remove_file(path).unwrap();
68        }
69        self
70    }
71
72    pub fn resolve(&mut self, ident: &Identifier) -> Vec<IdentifierReference> {
73        if let Some(data) = self.state.get(ident) {
74            tracing::info!("skipping {}", ident);
75            return data.to_owned();
76        };
77
78        tracing::info!("checking {}", ident);
79
80        let mut count = 0;
81        let _response = loop {
82            let Some(response) = self.client.request(lsp_server::Request {
83                id: 1.into(),
84                method: "textDocument/prepareCallHierarchy".to_string(),
85                params: serde_json::to_value(&lsp_types::CallHierarchyPrepareParams {
86                    text_document_position_params: lsp_types::TextDocumentPositionParams {
87                        position: ident.range.start,
88                        text_document: lsp_types::TextDocumentIdentifier {
89                            uri: lsp_types::Url::from_file_path(&ident.path).unwrap(),
90                        },
91                    },
92                    work_done_progress_params: lsp_types::WorkDoneProgressParams {
93                        work_done_token: Some(lsp_types::ProgressToken::String(
94                            "prepare".to_string(),
95                        )),
96                    },
97                })
98                .unwrap(),
99            }) else {
100                tracing::warn!("RA server shut down");
101                return vec![];
102            };
103
104            if let Some(Some(value)) = response.result.as_ref().map(|r| r.as_array()) {
105                if !value.is_empty() {
106                    break value.to_owned();
107                }
108                count += 1;
109            }
110
111            // textDocument/prepareCallHierarchy will sometimes return an empty array so try
112            // at most 5 times
113            if count > 5 {
114                tracing::warn!("discovered isolated task {}", ident);
115                break vec![];
116            }
117
118            std::thread::sleep(std::time::Duration::from_secs(1));
119        };
120
121        // callHierarchy/incomingCalls
122        let Some(response) = self.client.request(lsp_server::Request {
123            id: 1.into(),
124            method: "callHierarchy/incomingCalls".to_string(),
125            params: serde_json::to_value(lsp_types::CallHierarchyIncomingCallsParams {
126                partial_result_params: lsp_types::PartialResultParams::default(),
127                item: lsp_types::CallHierarchyItem {
128                    name: ident.name.to_owned(),
129                    kind: lsp_types::SymbolKind::FUNCTION,
130                    data: None,
131                    tags: None,
132                    detail: None,
133                    uri: lsp_types::Url::from_file_path(&ident.path).unwrap(),
134                    range: ident.range,
135                    selection_range: ident.range,
136                },
137                work_done_progress_params: lsp_types::WorkDoneProgressParams {
138                    work_done_token: Some(lsp_types::ProgressToken::String("prepare".to_string())),
139                },
140            })
141            .unwrap(),
142        }) else {
143            tracing::warn!("RA server shut down");
144            return vec![];
145        };
146
147        let links = if let Some(e) = response.error {
148            tracing::warn!("unable to resolve {}: {:?}", ident, e);
149            vec![]
150        } else {
151            let response: Result<Vec<lsp_types::CallHierarchyIncomingCall>, _> =
152                serde_path_to_error::deserialize(response.result.unwrap());
153
154            response
155                .unwrap()
156                .into_iter()
157                .map(|i| i.into())
158                .collect::<Vec<IdentifierReference>>()
159        };
160
161        tracing::debug!("links: {:?}", links);
162
163        self.state.insert(ident.to_owned(), links.clone());
164        links
165    }
166}