turbo_static/
visitor.rs

1//! A visitor that traverses the AST and collects all functions or methods that
2//! are annotated with `#[turbo_tasks::function]`.
3
4use std::{collections::VecDeque, ops::Add};
5
6use lsp_types::Range;
7use syn::{Expr, Meta, visit::Visit};
8
9use crate::identifier::Identifier;
10
11pub struct TaskVisitor {
12    /// the list of results as pairs of an identifier and its tags
13    pub results: Vec<(syn::Ident, Vec<String>)>,
14}
15
16impl TaskVisitor {
17    pub fn new() -> Self {
18        Self {
19            results: Default::default(),
20        }
21    }
22}
23
24impl Visit<'_> for TaskVisitor {
25    #[tracing::instrument(skip_all)]
26    fn visit_item_fn(&mut self, i: &syn::ItemFn) {
27        if let Some(tags) = extract_tags(i.attrs.iter()) {
28            tracing::trace!("L{}: {}", i.sig.ident.span().start().line, i.sig.ident,);
29            self.results.push((i.sig.ident.clone(), tags));
30        }
31    }
32
33    #[tracing::instrument(skip_all)]
34    fn visit_impl_item_fn(&mut self, i: &syn::ImplItemFn) {
35        if let Some(tags) = extract_tags(i.attrs.iter()) {
36            tracing::trace!("L{}: {}", i.sig.ident.span().start().line, i.sig.ident,);
37            self.results.push((i.sig.ident.clone(), tags));
38        }
39    }
40}
41
42fn extract_tags<'a>(mut meta: impl Iterator<Item = &'a syn::Attribute>) -> Option<Vec<String>> {
43    meta.find_map(|a| match &a.meta {
44        // path has two segments, turbo_tasks and function
45        Meta::Path(path) if path.segments.len() == 2 => {
46            let first = &path.segments[0];
47            let second = &path.segments[1];
48            (first.ident == "turbo_tasks" && second.ident == "function").then(std::vec::Vec::new)
49        }
50        Meta::List(list) if list.path.segments.len() == 2 => {
51            let first = &list.path.segments[0];
52            let second = &list.path.segments[1];
53            if first.ident != "turbo_tasks" || second.ident != "function" {
54                return None;
55            }
56
57            // collect ident tokens as args
58            let tags: Vec<_> = list
59                .tokens
60                .clone()
61                .into_iter()
62                .filter_map(|t| {
63                    if let proc_macro2::TokenTree::Ident(ident) = t {
64                        Some(ident.to_string())
65                    } else {
66                        None
67                    }
68                })
69                .collect();
70
71            Some(tags)
72        }
73        _ => {
74            tracing::trace!("skipping unknown annotation");
75            None
76        }
77    })
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
81pub enum CallingStyle {
82    Once = 0b0010,
83    ZeroOrOnce = 0b0011,
84    ZeroOrMore = 0b0111,
85    OneOrMore = 0b0110,
86}
87
88impl CallingStyle {
89    fn bitset(self) -> u8 {
90        self as u8
91    }
92}
93
94impl Add for CallingStyle {
95    type Output = Self;
96
97    /// Add two calling styles together to determine the calling style of the
98    /// target function within the source function.
99    ///
100    /// Consider it as a bitset over properties.
101    /// - 0b000: Nothing
102    /// - 0b001: Zero
103    /// - 0b010: Once
104    /// - 0b011: Zero Or Once
105    /// - 0b100: More Than Once
106    /// - 0b101: Zero Or More Than Once (?)
107    /// - 0b110: Once Or More
108    /// - 0b111: Zero Or More
109    ///
110    /// Note that zero is not a valid calling style.
111    fn add(self, rhs: Self) -> Self {
112        let left = self.bitset();
113        let right = rhs.bitset();
114
115        // we treat this as a bitset under addition
116        #[allow(clippy::suspicious_arithmetic_impl)]
117        match left | right {
118            0b0010 => CallingStyle::Once,
119            0b011 => CallingStyle::ZeroOrOnce,
120            0b0111 => CallingStyle::ZeroOrMore,
121            0b0110 => CallingStyle::OneOrMore,
122            // the remaining 4 (null, zero, more than once, zero or more than once)
123            // are unreachable because we don't detect 'zero' or 'more than once'
124            _ => unreachable!(),
125        }
126    }
127}
128
129pub struct CallingStyleVisitor {
130    pub reference: crate::IdentifierReference,
131    state: VecDeque<CallingStyleVisitorState>,
132    halt: bool,
133}
134
135impl CallingStyleVisitor {
136    /// Create a new visitor that will traverse the AST and determine the
137    /// calling style of the target function within the source function.
138    pub fn new(reference: crate::IdentifierReference) -> Self {
139        Self {
140            reference,
141            state: Default::default(),
142            halt: false,
143        }
144    }
145
146    pub fn result(self) -> Option<CallingStyle> {
147        self.state
148            .into_iter()
149            .map(|b| match b {
150                CallingStyleVisitorState::Block => CallingStyle::Once,
151                CallingStyleVisitorState::Loop => CallingStyle::ZeroOrMore,
152                CallingStyleVisitorState::If => CallingStyle::ZeroOrOnce,
153                CallingStyleVisitorState::Closure => CallingStyle::ZeroOrMore,
154            })
155            .reduce(|a, b| a + b)
156    }
157}
158
159#[derive(Debug, Clone, Copy)]
160enum CallingStyleVisitorState {
161    Block,
162    Loop,
163    If,
164    Closure,
165}
166
167impl Visit<'_> for CallingStyleVisitor {
168    fn visit_item_fn(&mut self, i: &'_ syn::ItemFn) {
169        self.state.push_back(CallingStyleVisitorState::Block);
170        syn::visit::visit_item_fn(self, i);
171        if !self.halt {
172            self.state.pop_back();
173        }
174    }
175
176    fn visit_impl_item_fn(&mut self, i: &'_ syn::ImplItemFn) {
177        self.state.push_back(CallingStyleVisitorState::Block);
178        syn::visit::visit_impl_item_fn(self, i);
179        if !self.halt {
180            self.state.pop_back();
181        }
182    }
183
184    fn visit_expr_loop(&mut self, i: &'_ syn::ExprLoop) {
185        self.state.push_back(CallingStyleVisitorState::Loop);
186        syn::visit::visit_expr_loop(self, i);
187        if !self.halt {
188            self.state.pop_back();
189        }
190    }
191
192    fn visit_expr_for_loop(&mut self, i: &'_ syn::ExprForLoop) {
193        self.state.push_back(CallingStyleVisitorState::Loop);
194        syn::visit::visit_expr_for_loop(self, i);
195        if !self.halt {
196            self.state.pop_back();
197        }
198    }
199
200    fn visit_expr_if(&mut self, i: &'_ syn::ExprIf) {
201        self.state.push_back(CallingStyleVisitorState::If);
202        syn::visit::visit_expr_if(self, i);
203        if !self.halt {
204            self.state.pop_back();
205        }
206    }
207
208    fn visit_expr_closure(&mut self, i: &'_ syn::ExprClosure) {
209        self.state.push_back(CallingStyleVisitorState::Closure);
210        syn::visit::visit_expr_closure(self, i);
211        if !self.halt {
212            self.state.pop_back();
213        }
214    }
215
216    fn visit_expr_call(&mut self, i: &'_ syn::ExprCall) {
217        syn::visit::visit_expr_call(self, i);
218        if let Expr::Path(p) = i.func.as_ref() {
219            if let Some(last) = p.path.segments.last() {
220                if is_match(
221                    &self.reference.identifier,
222                    &last.ident,
223                    &self.reference.references,
224                ) {
225                    self.halt = true;
226                }
227            }
228        }
229    }
230
231    // to validate this, we first check if the name is the same and then compare it
232    // against any of the references we are holding
233    fn visit_expr_method_call(&mut self, i: &'_ syn::ExprMethodCall) {
234        if is_match(
235            &self.reference.identifier,
236            &i.method,
237            &self.reference.references,
238        ) {
239            self.halt = true;
240        }
241
242        syn::visit::visit_expr_method_call(self, i);
243    }
244}
245
246/// Check if some ident referenced by `check` is calling the `target` by
247/// looking it up in the list of known `ranges`.
248fn is_match(target: &Identifier, check: &syn::Ident, ranges: &[Range]) -> bool {
249    if target.equals_ident(check, false) {
250        let span = check.span();
251        // syn is 1-indexed, range is not
252        for reference in ranges {
253            if reference.start.line != span.start().line as u32 - 1 {
254                continue;
255            }
256
257            if reference.start.character != span.start().column as u32 {
258                continue;
259            }
260
261            if reference.end.line != span.end().line as u32 - 1 {
262                continue;
263            }
264
265            if reference.end.character != span.end().column as u32 {
266                continue;
267            }
268
269            // match, just exit the visitor
270            return true;
271        }
272    }
273
274    false
275}