Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query/predicate.rs')
| -rw-r--r-- | helix-syntax/src/tree_sitter/query/predicate.rs | 137 |
1 files changed, 129 insertions, 8 deletions
diff --git a/helix-syntax/src/tree_sitter/query/predicate.rs b/helix-syntax/src/tree_sitter/query/predicate.rs index 8fac6cf7..7a2f858e 100644 --- a/helix-syntax/src/tree_sitter/query/predicate.rs +++ b/helix-syntax/src/tree_sitter/query/predicate.rs @@ -1,11 +1,16 @@ use std::error::Error; +use std::iter::zip; +use std::ops::Range; use std::ptr::NonNull; use std::{fmt, slice}; use crate::tree_sitter::query::property::QueryProperty; -use crate::tree_sitter::query::{Capture, Pattern, Query, QueryData, QueryStr}; +use crate::tree_sitter::query::{Capture, Pattern, PatternData, Query, QueryData, QueryStr}; +use crate::tree_sitter::query_cursor::MatchedNode; +use crate::tree_sitter::TsInput; use regex_cursor::engines::meta::Regex; +use regex_cursor::Cursor; macro_rules! bail { ($($args:tt)*) => {{ @@ -29,25 +34,141 @@ pub(super) enum TextPredicateKind { AnyString(Box<[QueryStr]>), } -pub(super) struct TextPredicate { +pub(crate) struct TextPredicate { capture: Capture, kind: TextPredicateKind, negated: bool, match_all: bool, } +fn input_matches_str<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool { + if str.len() != range.len() { + return false; + } + let mut str = str.as_bytes(); + let cursor = input.cursor_at(range.start); + let start_in_chunk = range.start - cursor.offset(); + if range.end - cursor.offset() <= cursor.chunk().len() { + // hotpath + return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str; + } + if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] { + return false; + } + str = &str[..cursor.chunk().len() - start_in_chunk]; + while cursor.advance() { + if str.len() <= cursor.chunk().len() { + return &cursor.chunk()[..range.end - cursor.offset()] == str; + } + if &str[..cursor.chunk().len()] != cursor.chunk() { + return false; + } + str = &str[cursor.chunk().len()..] + } + // buggy cursor/invalid range + false +} + +fn inputs_match<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool { + if str.len() != range.len() { + return false; + } + let mut str = str.as_bytes(); + let cursor = input.cursor_at(range.start); + let start_in_chunk = range.start - cursor.offset(); + if range.end - cursor.offset() <= cursor.chunk().len() { + // hotpath + return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str; + } + if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] { + return false; + } + str = &str[..cursor.chunk().len() - start_in_chunk]; + while cursor.advance() { + if str.len() <= cursor.chunk().len() { + return &cursor.chunk()[..range.end - cursor.offset()] == str; + } + if &str[..cursor.chunk().len()] != cursor.chunk() { + return false; + } + str = &str[cursor.chunk().len()..] + } + // buggy cursor/invalid range + false +} + +impl TextPredicate { + /// handlers match_all and negated + fn satisfied_helper(&self, mut nodes: impl Iterator<Item = bool>) -> bool { + if self.match_all { + nodes.all(|matched| matched != self.negated) + } else { + nodes.any(|matched| matched != self.negated) + } + } + + pub fn satsified<I: TsInput>( + &self, + input: &mut I, + matched_nodes: &[MatchedNode], + query: &Query, + ) -> bool { + let mut capture_nodes = matched_nodes + .iter() + .filter(|matched_node| matched_node.capture == self.capture); + match self.kind { + TextPredicateKind::EqString(str) => self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + input_matches_str(query.get_string(str), range.clone(), input) + })), + TextPredicateKind::EqCapture(other_capture) => { + let mut other_nodes = matched_nodes + .iter() + .filter(|matched_node| matched_node.capture == other_capture); + + let res = self.satisfied_helper(zip(&mut capture_nodes, &mut other_nodes).map( + |(node1, node2)| { + let range1 = node1.syntax_node.byte_range(); + let range2 = node2.syntax_node.byte_range(); + input.eq(range1, range2) + }, + )); + let consumed_all = capture_nodes.next().is_none() && other_nodes.next().is_none(); + res && (!self.match_all || consumed_all) + } + TextPredicateKind::MatchString(ref regex) => { + self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + let input = regex_cursor::Input::new(input.cursor_at(range.start)).range(range); + regex.is_match(input) + })) + } + TextPredicateKind::AnyString(ref strings) => { + let strings = strings.iter().map(|&str| query.get_string(str)); + self.satisfied_helper(capture_nodes.map(|node| { + let range = node.syntax_node.byte_range(); + strings + .clone() + .filter(|str| str.len() == range.len()) + .any(|str| input_matches_str(str, range.clone(), input)) + })) + } + } + } +} + impl Query { pub(super) fn parse_pattern_predicates( &mut self, - pattern_index: u32, - mut custom_predicate: impl FnMut(Predicate) -> Result<(), InvalidPredicateError>, - ) -> Result<Pattern, InvalidPredicateError> { + pattern: Pattern, + mut custom_predicate: impl FnMut(Pattern, Predicate) -> Result<(), InvalidPredicateError>, + ) -> Result<PatternData, InvalidPredicateError> { let text_predicate_start = self.text_predicates.len() as u32; let property_start = self.properties.len() as u32; let predicate_steps = unsafe { let mut len = 0u32; - let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern_index, &mut len); + let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern.0, &mut len); (len != 0) .then(|| slice::from_raw_parts(raw_predicates, len as usize)) .unwrap_or_default() @@ -118,10 +239,10 @@ impl Query { // is and is-not are better handeled as custom predicates since interpreting is context dependent // "is?" => property_predicates.push((QueryProperty::parse(&predicate), false)), // "is-not?" => property_predicates.push((QueryProperty::parse(&predicate), true)), - _ => custom_predicate(predicate)?, + _ => custom_predicate(pattern, predicate)?, } } - Ok(Pattern { + Ok(PatternData { text_predicates: text_predicate_start..self.text_predicates.len() as u32, properties: property_start..self.properties.len() as u32, }) |