Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query/predicate.rs')
| -rw-r--r-- | helix-syntax/src/tree_sitter/query/predicate.rs | 448 |
1 files changed, 448 insertions, 0 deletions
diff --git a/helix-syntax/src/tree_sitter/query/predicate.rs b/helix-syntax/src/tree_sitter/query/predicate.rs new file mode 100644 index 00000000..91f3dabe --- /dev/null +++ b/helix-syntax/src/tree_sitter/query/predicate.rs @@ -0,0 +1,448 @@ +use std::error::Error; +use std::iter::zip; +use std::ops::Range; +use std::ptr::NonNull; +use std::{fmt, slice}; + +use crate::tree_sitter::query::property::QueryProperty; +use crate::tree_sitter::query::{ + Capture, Pattern, PatternData, Query, QueryData, QueryStr, UserPredicate, +}; +use crate::tree_sitter::query_cursor::MatchedNode; +use crate::tree_sitter::TsInput; + +use regex_cursor::engines::meta::Regex; +use regex_cursor::Cursor; + +macro_rules! bail { + ($($args:tt)*) => {{ + return Err(InvalidPredicateError {msg: format!($($args)*).into() }) + }} +} + +macro_rules! ensure { + ($cond: expr, $($args:tt)*) => {{ + if !$cond { + return Err(InvalidPredicateError { msg: format!($($args)*).into() }) + } + }} +} + +#[derive(Debug)] +pub(super) enum TextPredicateKind { + EqString(QueryStr), + EqCapture(Capture), + MatchString(Regex), + AnyString(Box<[QueryStr]>), +} + +#[derive(Debug)] +pub(crate) struct TextPredicate { + capture: Capture, + kind: TextPredicateKind, + negated: bool, + match_all: bool, +} + +fn input_matches_str<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool { + if str.len() != range.len() { + return false; + } + let mut str = str.as_bytes(); + let cursor = input.cursor_at(range.start); + let start_in_chunk = range.start - cursor.offset(); + if range.end - cursor.offset() <= cursor.chunk().len() { + // hotpath + return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str; + } + if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] { + return false; + } + str = &str[..cursor.chunk().len() - start_in_chunk]; + while cursor.advance() { + if str.len() <= cursor.chunk().len() { + return &cursor.chunk()[..range.end - cursor.offset()] == str; + } + if &str[..cursor.chunk().len()] != cursor.chunk() { + return false; + } + str = &str[cursor.chunk().len()..] + } + // buggy cursor/invalid range + false +} + +fn inputs_match<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool { + if str.len() != range.len() { + return false; + } + let mut str = str.as_bytes(); + let cursor = input.cursor_at(range.start); + let start_in_chunk = range.start - cursor.offset(); + if range.end - cursor.offset() <= cursor.chunk().len() { + // hotpath + return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str; + } + if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] { + return false; + } + str = &str[..cursor.chunk().len() - start_in_chunk]; + while cursor.advance() { + if str.len() <= cursor.chunk().len() { + return &cursor.chunk()[..range.end - cursor.offset()] == str; + } + if &str[..cursor.chunk().len()] != cursor.chunk() { + return false; + } + str = &str[cursor.chunk().len()..] + } + // buggy cursor/invalid range + false +} + +impl TextPredicate { + /// handlers match_all and negated + fn satisfied_helper(&self, mut nodes: impl Iterator<Item = bool>) -> bool { + if self.match_all { + nodes.all(|matched| matched != self.negated) + } else { + nodes.any(|matched| matched != self.negated) + } + } + + pub fn satsified<I: TsInput>( + &self, + input: &mut I, + matched_nodes: &[MatchedNode], + query: &Query, + ) -> bool { + let mut capture_nodes = matched_nodes + .iter() + .filter(|matched_node| matched_node.capture == self.capture); + match self.kind { + TextPredicateKind::EqString(str) => self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + input_matches_str(query.get_string(str), range.clone(), input) + })), + TextPredicateKind::EqCapture(other_capture) => { + let mut other_nodes = matched_nodes + .iter() + .filter(|matched_node| matched_node.capture == other_capture); + + let res = self.satisfied_helper(zip(&mut capture_nodes, &mut other_nodes).map( + |(node1, node2)| { + let range1 = node1.syntax_node.byte_range(); + let range2 = node2.syntax_node.byte_range(); + input.eq(range1, range2) + }, + )); + let consumed_all = capture_nodes.next().is_none() && other_nodes.next().is_none(); + res && (!self.match_all || consumed_all) + } + TextPredicateKind::MatchString(ref regex) => { + self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + let input = regex_cursor::Input::new(input.cursor_at(range.start)).range(range); + regex.is_match(input) + })) + } + TextPredicateKind::AnyString(ref strings) => { + let strings = strings.iter().map(|&str| query.get_string(str)); + self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + strings + .clone() + .filter(|str| str.len() == range.len()) + .any(|str| input_matches_str(str, range.clone(), input)) + })) + } + } + } +} + +impl Query { + pub(super) fn parse_pattern_predicates( + &mut self, + pattern: Pattern, + mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>, + ) -> Result<PatternData, InvalidPredicateError> { + let text_predicate_start = self.text_predicates.len() as u32; + + let predicate_steps = unsafe { + let mut len = 0u32; + let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern.0, &mut len); + (len != 0) + .then(|| slice::from_raw_parts(raw_predicates, len as usize)) + .unwrap_or_default() + }; + let predicates = predicate_steps + .split(|step| step.kind == PredicateStepKind::Done) + .filter(|predicate| !predicate.is_empty()); + + for predicate in predicates { + let predicate = unsafe { Predicate::new(self, predicate)? }; + + match predicate.name() { + "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.capture_arg(0)?; + let arg2 = predicate.arg(1); + + let negated = matches!(predicate.name(), "not-eq?" | "not-any-eq?"); + let match_all = matches!(predicate.name(), "eq?" | "not-eq?"); + let kind = match arg2 { + PredicateArg::Capture(capture) => TextPredicateKind::EqCapture(capture), + PredicateArg::String(str) => TextPredicateKind::EqString(str), + }; + self.text_predicates.push(TextPredicate { + capture: capture_idx, + kind, + negated, + match_all, + }); + } + + "match?" | "not-match?" | "any-match?" | "any-not-match?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.capture_arg(0)?; + let regex = predicate.query_str_arg(1)?.get(self); + + let negated = matches!(predicate.name(), "not-match?" | "any-not-match?"); + let match_all = matches!(predicate.name(), "match?" | "not-match?"); + let regex = match Regex::new(regex) { + Ok(regex) => regex, + Err(err) => bail!("invalid regex '{regex}', {err}"), + }; + self.text_predicates.push(TextPredicate { + capture: capture_idx, + kind: TextPredicateKind::MatchString(regex), + negated, + match_all, + }); + } + + "set!" => { + let property = QueryProperty::parse(&predicate)?; + custom_predicate( + pattern, + UserPredicate::SetProperty { + key: property.key.get(&self), + val: property.val.map(|val| val.get(&self)), + }, + )? + } + "is-not?" | "is?" => { + let property = QueryProperty::parse(&predicate)?; + custom_predicate( + pattern, + UserPredicate::IsPropertySet { + negate: predicate.name() == "is-not?", + key: property.key.get(&self), + val: property.val.map(|val| val.get(&self)), + }, + )? + } + + "any-of?" | "not-any-of?" => { + predicate.check_min_arg_count(1)?; + let capture = predicate.capture_arg(0)?; + let negated = predicate.name() == "not-any-of?"; + let values: Result<_, InvalidPredicateError> = (1..predicate.num_args()) + .map(|i| predicate.query_str_arg(i)) + .collect(); + self.text_predicates.push(TextPredicate { + capture, + kind: TextPredicateKind::AnyString(values?), + negated, + match_all: false, + }); + } + + // is and is-not are better handeled as custom predicates since interpreting is context dependent + // "is?" => property_predicates.push((QueryProperty::parse(&predicate), false)), + // "is-not?" => property_predicates.push((QueryProperty::parse(&predicate), true)), + _ => custom_predicate(pattern, UserPredicate::Other(predicate))?, + } + } + Ok(PatternData { + text_predicates: text_predicate_start..self.text_predicates.len() as u32, + }) + } +} + +pub enum PredicateArg { + Capture(Capture), + String(QueryStr), +} + +pub struct Predicate<'a> { + pub name: QueryStr, + args: &'a [PredicateStep], + query: &'a Query, +} + +impl<'a> Predicate<'a> { + unsafe fn new( + query: &'a Query, + predicate: &'a [PredicateStep], + ) -> Result<Predicate<'a>, InvalidPredicateError> { + ensure!( + predicate[0].kind == PredicateStepKind::String, + "expected predicate to start with a function name. Got @{}.", + Capture(predicate[0].value_id).name(query) + ); + let operator_name = QueryStr(predicate[0].value_id); + Ok(Predicate { + name: operator_name, + args: &predicate[1..], + query, + }) + } + + pub fn name(&self) -> &str { + self.name.get(self.query) + } + + pub fn check_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> { + ensure!( + self.args.len() == n, + "expected {n} arguments for #{}, got {}", + self.name(), + self.args.len() + ); + Ok(()) + } + + pub fn check_min_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> { + ensure!( + n <= self.args.len(), + "expected at least {n} arguments for #{}, got {}", + self.name(), + self.args.len() + ); + Ok(()) + } + + pub fn check_max_arg_count(&self, n: usize) -> Result<(), InvalidPredicateError> { + ensure!( + self.args.len() <= n, + "expected at most {n} arguments for #{}, got {}", + self.name(), + self.args.len() + ); + Ok(()) + } + + pub fn query_str_arg(&self, i: usize) -> Result<QueryStr, InvalidPredicateError> { + match self.arg(i) { + PredicateArg::String(str) => Ok(str), + PredicateArg::Capture(capture) => bail!( + "{i}. argument to #{} must be a literal, got capture @{:?}", + self.name(), + capture.name(self.query) + ), + } + } + + pub fn str_arg(&self, i: usize) -> Result<&str, InvalidPredicateError> { + Ok(self.query_str_arg(i)?.get(self.query)) + } + + pub fn num_args(&self) -> usize { + self.args.len() + } + + pub fn capture_arg(&self, i: usize) -> Result<Capture, InvalidPredicateError> { + match self.arg(i) { + PredicateArg::Capture(capture) => Ok(capture), + PredicateArg::String(str) => bail!( + "{i}. argument to #{} expected a capture, got literal {:?}", + self.name(), + str.get(self.query) + ), + } + } + + pub fn arg(&self, i: usize) -> PredicateArg { + self.args[i].try_into().unwrap() + } + + pub fn args(&self) -> impl Iterator<Item = PredicateArg> + '_ { + self.args.iter().map(|&arg| arg.try_into().unwrap()) + } +} + +#[derive(Debug)] +pub struct InvalidPredicateError { + pub(super) msg: Box<str>, +} + +impl From<String> for InvalidPredicateError { + fn from(value: String) -> Self { + InvalidPredicateError { + msg: value.into_boxed_str(), + } + } +} + +impl<'a> From<&'a str> for InvalidPredicateError { + fn from(value: &'a str) -> Self { + InvalidPredicateError { msg: value.into() } + } +} + +impl fmt::Display for InvalidPredicateError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.msg) + } +} + +impl Error for InvalidPredicateError {} + +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredicateStepKind { + Done = 0, + Capture = 1, + String = 2, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +struct PredicateStep { + kind: PredicateStepKind, + value_id: u32, +} + +impl TryFrom<PredicateStep> for PredicateArg { + type Error = (); + + fn try_from(step: PredicateStep) -> Result<Self, Self::Error> { + match step.kind { + PredicateStepKind::String => Ok(PredicateArg::String(QueryStr(step.value_id))), + PredicateStepKind::Capture => Ok(PredicateArg::Capture(Capture(step.value_id))), + PredicateStepKind::Done => Err(()), + } + } +} + +extern "C" { + /// Get all of the predicates for the given pattern in the query. The + /// predicates are represented as a single array of steps. There are three + /// types of steps in this array, which correspond to the three legal values + /// for the `type` field: - `TSQueryPredicateStepTypeCapture` - Steps with + /// this type represent names of captures. Their `value_id` can be used + /// with the [`ts_query_capture_name_for_id`] function to obtain the name + /// of the capture. - `TSQueryPredicateStepTypeString` - Steps with this + /// type represent literal strings. Their `value_id` can be used with the + /// [`ts_query_string_value_for_id`] function to obtain their string value. + /// - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels* + /// that represent the end of an individual predicate. If a pattern has two + /// predicates, then there will be two steps with this `type` in the array. + fn ts_query_predicates_for_pattern( + query: NonNull<QueryData>, + pattern_index: u32, + step_count: &mut u32, + ) -> *const PredicateStep; + +} |