Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query/predicate.rs')
-rw-r--r--helix-syntax/src/tree_sitter/query/predicate.rs137
1 files changed, 129 insertions, 8 deletions
diff --git a/helix-syntax/src/tree_sitter/query/predicate.rs b/helix-syntax/src/tree_sitter/query/predicate.rs
index 8fac6cf7..7a2f858e 100644
--- a/helix-syntax/src/tree_sitter/query/predicate.rs
+++ b/helix-syntax/src/tree_sitter/query/predicate.rs
@@ -1,11 +1,16 @@
use std::error::Error;
+use std::iter::zip;
+use std::ops::Range;
use std::ptr::NonNull;
use std::{fmt, slice};
use crate::tree_sitter::query::property::QueryProperty;
-use crate::tree_sitter::query::{Capture, Pattern, Query, QueryData, QueryStr};
+use crate::tree_sitter::query::{Capture, Pattern, PatternData, Query, QueryData, QueryStr};
+use crate::tree_sitter::query_cursor::MatchedNode;
+use crate::tree_sitter::TsInput;
use regex_cursor::engines::meta::Regex;
+use regex_cursor::Cursor;
macro_rules! bail {
($($args:tt)*) => {{
@@ -29,25 +34,141 @@ pub(super) enum TextPredicateKind {
AnyString(Box<[QueryStr]>),
}
-pub(super) struct TextPredicate {
+pub(crate) struct TextPredicate {
capture: Capture,
kind: TextPredicateKind,
negated: bool,
match_all: bool,
}
+fn input_matches_str<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool {
+ if str.len() != range.len() {
+ return false;
+ }
+ let mut str = str.as_bytes();
+ let cursor = input.cursor_at(range.start);
+ let start_in_chunk = range.start - cursor.offset();
+ if range.end - cursor.offset() <= cursor.chunk().len() {
+ // hotpath
+ return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str;
+ }
+ if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] {
+ return false;
+ }
+ str = &str[..cursor.chunk().len() - start_in_chunk];
+ while cursor.advance() {
+ if str.len() <= cursor.chunk().len() {
+ return &cursor.chunk()[..range.end - cursor.offset()] == str;
+ }
+ if &str[..cursor.chunk().len()] != cursor.chunk() {
+ return false;
+ }
+ str = &str[cursor.chunk().len()..]
+ }
+ // buggy cursor/invalid range
+ false
+}
+
+fn inputs_match<I: TsInput>(str: &str, range: Range<usize>, input: &mut I) -> bool {
+ if str.len() != range.len() {
+ return false;
+ }
+ let mut str = str.as_bytes();
+ let cursor = input.cursor_at(range.start);
+ let start_in_chunk = range.start - cursor.offset();
+ if range.end - cursor.offset() <= cursor.chunk().len() {
+ // hotpath
+ return &cursor.chunk()[start_in_chunk..range.end - cursor.offset()] == str;
+ }
+ if cursor.chunk()[start_in_chunk..] != str[..cursor.chunk().len() - start_in_chunk] {
+ return false;
+ }
+ str = &str[..cursor.chunk().len() - start_in_chunk];
+ while cursor.advance() {
+ if str.len() <= cursor.chunk().len() {
+ return &cursor.chunk()[..range.end - cursor.offset()] == str;
+ }
+ if &str[..cursor.chunk().len()] != cursor.chunk() {
+ return false;
+ }
+ str = &str[cursor.chunk().len()..]
+ }
+ // buggy cursor/invalid range
+ false
+}
+
+impl TextPredicate {
+ /// handlers match_all and negated
+ fn satisfied_helper(&self, mut nodes: impl Iterator<Item = bool>) -> bool {
+ if self.match_all {
+ nodes.all(|matched| matched != self.negated)
+ } else {
+ nodes.any(|matched| matched != self.negated)
+ }
+ }
+
+ pub fn satsified<I: TsInput>(
+ &self,
+ input: &mut I,
+ matched_nodes: &[MatchedNode],
+ query: &Query,
+ ) -> bool {
+ let mut capture_nodes = matched_nodes
+ .iter()
+ .filter(|matched_node| matched_node.capture == self.capture);
+ match self.kind {
+ TextPredicateKind::EqString(str) => self.satisfied_helper(capture_nodes.map(|node| {
+ let range = node.syntax_node.byte_range();
+ input_matches_str(query.get_string(str), range.clone(), input)
+ })),
+ TextPredicateKind::EqCapture(other_capture) => {
+ let mut other_nodes = matched_nodes
+ .iter()
+ .filter(|matched_node| matched_node.capture == other_capture);
+
+ let res = self.satisfied_helper(zip(&mut capture_nodes, &mut other_nodes).map(
+ |(node1, node2)| {
+ let range1 = node1.syntax_node.byte_range();
+ let range2 = node2.syntax_node.byte_range();
+ input.eq(range1, range2)
+ },
+ ));
+ let consumed_all = capture_nodes.next().is_none() && other_nodes.next().is_none();
+ res && (!self.match_all || consumed_all)
+ }
+ TextPredicateKind::MatchString(ref regex) => {
+ self.satisfied_helper(capture_nodes.map(|node| {
+ let range = node.syntax_node.byte_range();
+ let input = regex_cursor::Input::new(input.cursor_at(range.start)).range(range);
+ regex.is_match(input)
+ }))
+ }
+ TextPredicateKind::AnyString(ref strings) => {
+ let strings = strings.iter().map(|&str| query.get_string(str));
+ self.satisfied_helper(capture_nodes.map(|node| {
+ let range = node.syntax_node.byte_range();
+ strings
+ .clone()
+ .filter(|str| str.len() == range.len())
+ .any(|str| input_matches_str(str, range.clone(), input))
+ }))
+ }
+ }
+ }
+}
+
impl Query {
pub(super) fn parse_pattern_predicates(
&mut self,
- pattern_index: u32,
- mut custom_predicate: impl FnMut(Predicate) -> Result<(), InvalidPredicateError>,
- ) -> Result<Pattern, InvalidPredicateError> {
+ pattern: Pattern,
+ mut custom_predicate: impl FnMut(Pattern, Predicate) -> Result<(), InvalidPredicateError>,
+ ) -> Result<PatternData, InvalidPredicateError> {
let text_predicate_start = self.text_predicates.len() as u32;
let property_start = self.properties.len() as u32;
let predicate_steps = unsafe {
let mut len = 0u32;
- let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern_index, &mut len);
+ let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern.0, &mut len);
(len != 0)
.then(|| slice::from_raw_parts(raw_predicates, len as usize))
.unwrap_or_default()
@@ -118,10 +239,10 @@ impl Query {
// is and is-not are better handeled as custom predicates since interpreting is context dependent
// "is?" => property_predicates.push((QueryProperty::parse(&predicate), false)),
// "is-not?" => property_predicates.push((QueryProperty::parse(&predicate), true)),
- _ => custom_predicate(predicate)?,
+ _ => custom_predicate(pattern, predicate)?,
}
}
- Ok(Pattern {
+ Ok(PatternData {
text_predicates: text_predicate_start..self.text_predicates.len() as u32,
properties: property_start..self.properties.len() as u32,
})