1#![feature(arbitrary_self_types_pointers)]
2
3use std::vec;
4
5use anyhow::{Result, bail};
6
7#[derive(Debug, Clone)]
11#[turbo_tasks::value(eq = "manual", shared)]
12#[serde(into = "RegexForm", try_from = "RegexForm")]
13pub struct EsRegex {
14    #[turbo_tasks(trace_ignore)]
15    delegate: EsRegexImpl,
16    pub pattern: String,
19    pub flags: String,
20}
21
22#[derive(Debug, Clone)]
23enum EsRegexImpl {
24    Regex(regex::Regex),
25    Regress(regress::Regex),
26}
27
28impl PartialEq for EsRegex {
33    fn eq(&self, other: &Self) -> bool {
34        self.pattern == other.pattern && self.flags == other.flags
35    }
36}
37impl Eq for EsRegex {}
38
39impl TryFrom<RegexForm> for EsRegex {
40    type Error = anyhow::Error;
41
42    fn try_from(value: RegexForm) -> std::result::Result<Self, Self::Error> {
43        EsRegex::new(&value.pattern, &value.flags)
44    }
45}
46
47#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
49struct RegexForm {
50    pattern: String,
51    flags: String,
52}
53
54impl From<EsRegex> for RegexForm {
55    fn from(value: EsRegex) -> Self {
56        Self {
57            pattern: value.pattern,
58            flags: value.flags,
59        }
60    }
61}
62
63impl EsRegex {
64    pub fn new(pattern: &str, flags: &str) -> Result<Self> {
67        let pattern = pattern.replace("\\/", "/");
69
70        let mut applied_flags = String::new();
71        for flag in flags.chars() {
72            match flag {
73                'd' => {}
75                'g' => {}
77                'i' => applied_flags.push('i'),
79                'm' => applied_flags.push('m'),
81                's' => applied_flags.push('s'),
83                'u' => applied_flags.push('u'),
85                'y' => {}
87                _ => bail!("unsupported flag `{flag}` in regex: `{pattern}` with flags: `{flags}`"),
88            }
89        }
90
91        let regex = if !applied_flags.is_empty() {
92            regex::Regex::new(&format!("(?{applied_flags}){pattern}"))
93        } else {
94            regex::Regex::new(&pattern)
95        };
96
97        let delegate = match regex {
98            Ok(reg) => Ok(EsRegexImpl::Regex(reg)),
99            Err(_e) => {
100                match regress::Regex::with_flags(&pattern, regress::Flags::from(flags)) {
103                    Ok(reg) => Ok(EsRegexImpl::Regress(reg)),
104                    Err(e) => Err(e),
106                }
107            }
108        }?;
109        Ok(Self {
110            delegate,
111            pattern,
112            flags: flags.to_string(),
113        })
114    }
115
116    pub fn is_match(&self, haystack: &str) -> bool {
118        match &self.delegate {
119            EsRegexImpl::Regex(r) => r.is_match(haystack),
120            EsRegexImpl::Regress(r) => r.find(haystack).is_some(),
121        }
122    }
123
124    pub fn captures<'h>(&self, haystack: &'h str) -> Option<Captures<'h>> {
135        let delegate = match &self.delegate {
136            EsRegexImpl::Regex(r) => CapturesImpl::Regex {
137                captures: r.captures(haystack)?,
138                idx: 0,
139            },
140            EsRegexImpl::Regress(r) => {
141                let re_match = r.find(haystack)?;
142                CapturesImpl::Regress {
143                    captures_iter: re_match.captures.into_iter(),
144                    haystack,
145                    match_range: Some(re_match.range),
146                }
147            }
148        };
149        Some(Captures { delegate })
150    }
151}
152
153pub struct Captures<'h> {
154    delegate: CapturesImpl<'h>,
155}
156
157enum CapturesImpl<'h> {
158    Regex {
165        captures: regex::Captures<'h>,
166        idx: usize,
167    },
168    Regress {
170        captures_iter: vec::IntoIter<Option<regress::Range>>,
171        haystack: &'h str,
172        match_range: Option<regress::Range>,
173    },
174}
175
176impl<'h> Iterator for Captures<'h> {
177    type Item = Option<&'h str>;
178
179    fn next(&mut self) -> Option<Self::Item> {
180        match &mut self.delegate {
181            CapturesImpl::Regex { captures, idx } => {
182                if *idx >= captures.len() {
183                    None
184                } else {
185                    let capture = Some(captures.get(*idx).map(|sub_match| sub_match.as_str()));
186                    *idx += 1;
187                    capture
188                }
189            }
190            CapturesImpl::Regress {
191                captures_iter,
192                haystack,
193                match_range,
194            } => {
195                if let Some(range) = match_range.take() {
196                    Some(Some(&haystack[range]))
198                } else {
199                    Some(captures_iter.next()?.map(|range| &haystack[range]))
200                }
201            }
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::{EsRegex, EsRegexImpl};
209
210    #[test]
211    fn round_trip_serialize() {
212        let regex = EsRegex::new("[a-z]", "i").unwrap();
213        let serialized = serde_json::to_string(®ex).unwrap();
214        let parsed = serde_json::from_str::<EsRegex>(&serialized).unwrap();
215        assert_eq!(regex, parsed);
216    }
217
218    #[test]
219    fn es_regex_matches_simple() {
220        let regex = EsRegex::new("a", "").unwrap();
221        assert!(matches!(regex.delegate, EsRegexImpl::Regex { .. }));
222        assert!(regex.is_match("a"));
223    }
224
225    #[test]
226    fn es_regex_matches_negative_lookahead() {
227        let regex = EsRegex::new("a(?!b)", "").unwrap();
229        assert!(matches!(regex.delegate, EsRegexImpl::Regress { .. }));
230        assert!(!regex.is_match("ab"));
231        assert!(regex.is_match("ac"));
232    }
233
234    #[test]
235    fn invalid_regex() {
236        assert!(matches!(EsRegex::new("*", ""), Err { .. }))
240    }
241
242    #[test]
243    fn captures_with_regex() {
244        let regex = EsRegex::new(r"(notmatched)|(\d{4})-(\d{2})-(\d{2})", "").unwrap();
245        assert!(matches!(regex.delegate, EsRegexImpl::Regex { .. }));
246
247        let captures = regex.captures("Today is 2024-01-15");
248        assert!(captures.is_some());
249        let caps: Vec<_> = captures.unwrap().collect();
250        assert_eq!(caps.len(), 5); assert_eq!(caps[0], Some("2024-01-15")); assert_eq!(caps[1], None); assert_eq!(caps[2], Some("2024")); assert_eq!(caps[3], Some("01")); assert_eq!(caps[4], Some("15")); }
257
258    #[test]
259    fn captures_with_regress() {
260        let regex = EsRegex::new(r"(\w+)(?=baz)", "").unwrap();
261        assert!(matches!(regex.delegate, EsRegexImpl::Regress { .. }));
262
263        let captures = regex.captures("foobar");
264        assert!(captures.is_none());
265
266        let captures = regex.captures("foobaz");
267        assert!(captures.is_some());
268        let caps: Vec<_> = captures.unwrap().collect();
269        assert_eq!(caps.len(), 2); assert_eq!(caps[0], Some("foo")); assert_eq!(caps[1], Some("foo")); }
273}