Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query.rs')
| -rw-r--r-- | helix-syntax/src/tree_sitter/query.rs | 574 |
1 files changed, 574 insertions, 0 deletions
diff --git a/helix-syntax/src/tree_sitter/query.rs b/helix-syntax/src/tree_sitter/query.rs new file mode 100644 index 00000000..44a7fa3c --- /dev/null +++ b/helix-syntax/src/tree_sitter/query.rs @@ -0,0 +1,574 @@ +use std::fmt::Display; +use std::iter::zip; +use std::path::{Path, PathBuf}; +use std::ptr::NonNull; +use std::{slice, str}; + +use regex_cursor::engines::meta::Regex; + +use crate::tree_sitter::Grammar; + +macro_rules! bail { + ($($args:tt)*) => {{ + return Err(format!($($args)*)) + }} +} + +macro_rules! ensure { + ($cond: expr, $($args:tt)*) => {{ + if !$cond { + return Err(format!($($args)*)) + } + }} +} + +#[derive(Debug)] +enum TextPredicateCaptureKind { + EqString(u32), + EqCapture(u32), + MatchString(Regex), + AnyString(Box<[Box<str>]>), +} + +struct TextPredicateCapture { + capture_idx: u32, + kind: TextPredicateCaptureKind, + negated: bool, + match_all: bool, +} + +pub enum QueryData {} +pub struct Query { + raw: NonNull<QueryData>, + num_captures: u32, +} + +impl Query { + /// Create a new query from a string containing one or more S-expression + /// patterns. + /// + /// The query is associated with a particular grammar, and can only be run + /// on syntax nodes parsed with that grammar. References to Queries can be + /// shared between multiple threads. + pub fn new(grammar: Grammar, source: &str, path: impl AsRef<Path>) -> Result<Self, ParseError> { + assert!( + source.len() <= i32::MAX as usize, + "TreeSitter queries must be smaller then 2 GiB (is {})", + source.len() as f64 / 1024.0 / 1024.0 / 1024.0 + ); + let mut error_offset = 0u32; + let mut error_kind = RawQueryError::None; + let bytes = source.as_bytes(); + + // Compile the query. + let ptr = unsafe { + ts_query_new( + grammar, + bytes.as_ptr(), + bytes.len() as u32, + &mut error_offset, + &mut error_kind, + ) + }; + + let Some(raw) = ptr else { + let offset = error_offset as usize; + let error_word = || { + source[offset..] + .chars() + .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-')) + .collect() + }; + let err = match error_kind { + RawQueryError::NodeType => { + let node: String = error_word(); + ParseError::InvalidNodeType { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + node.chars().count(), + ), + node, + } + } + RawQueryError::Field => { + let field = error_word(); + ParseError::InvalidFieldName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + field.chars().count(), + ), + field, + } + } + RawQueryError::Capture => { + let capture = error_word(); + ParseError::InvalidCaptureName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + capture.chars().count(), + ), + capture, + } + } + RawQueryError::Syntax => ParseError::SyntaxError(ParserErrorLocation::new( + source, + path.as_ref(), + offset, + 0, + )), + RawQueryError::Structure => ParseError::ImpossiblePattern( + ParserErrorLocation::new(source, path.as_ref(), offset, 0), + ), + RawQueryError::None => { + unreachable!("tree-sitter returned a null pointer but did not set an error") + } + RawQueryError::Language => unreachable!("should be handled at grammar load"), + }; + return Err(err) + }; + + // I am not going to bother with safety comments here, all of these are + // safe as long as TS is not buggy because raw is a properly constructed query + let num_captures = unsafe { ts_query_capture_count(raw) }; + + Ok(Query { raw, num_captures }) + } + + fn parse_predicates(&mut self) { + let pattern_count = unsafe { ts_query_pattern_count(self.raw) }; + + let mut text_predicates = Vec::with_capacity(pattern_count as usize); + let mut property_predicates = Vec::with_capacity(pattern_count as usize); + let mut property_settings = Vec::with_capacity(pattern_count as usize); + let mut general_predicates = Vec::with_capacity(pattern_count as usize); + + for i in 0..pattern_count {} + } + + fn parse_predicate(&self, pattern_index: u32) -> Result<(), String> { + let mut text_predicates = Vec::new(); + let mut property_predicates = Vec::new(); + let mut property_settings = Vec::new(); + let mut general_predicates = Vec::new(); + for predicate in self.predicates(pattern_index) { + let predicate = unsafe { Predicate::new(self, predicate)? }; + + // Build a predicate for each of the known predicate function names. + match predicate.operator_name { + "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.get_arg(0, PredicateArg::Capture)?; + let (arg2, arg2_kind) = predicate.get_any_arg(1); + + let negated = matches!(predicate.operator_name, "not-eq?" | "not-any-eq?"); + let match_all = matches!(predicate.operator_name, "eq?" | "not-eq?"); + let kind = match arg2_kind { + PredicateArg::Capture => TextPredicateCaptureKind::EqCapture(arg2), + PredicateArg::String => TextPredicateCaptureKind::EqString(arg2), + }; + text_predicates.push(TextPredicateCapture { + capture_idx, + kind, + negated, + match_all, + }); + } + + "match?" | "not-match?" | "any-match?" | "any-not-match?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.get_arg(0, PredicateArg::Capture)?; + let regex = predicate.get_str_arg(1)?; + + let negated = + matches!(predicate.operator_name, "not-match?" | "any-not-match?"); + let match_all = matches!(predicate.operator_name, "match?" | "not-match?"); + let regex = match Regex::new(regex) { + Ok(regex) => regex, + Err(err) => bail!("invalid regex '{regex}', {err}"), + }; + text_predicates.push(TextPredicateCapture { + capture_idx, + kind: TextPredicateCaptureKind::MatchString(regex), + negated, + match_all, + }); + } + + "set!" => property_settings.push(Self::parse_property( + row, + operator_name, + &capture_names, + &string_values, + &p[1..], + )?), + + "is?" | "is-not?" => property_predicates.push(( + Self::parse_property( + row, + operator_name, + &capture_names, + &string_values, + &p[1..], + )?, + operator_name == "is?", + )), + + "any-of?" | "not-any-of?" => { + if p.len() < 2 { + return Err(predicate_error(row, format!( + "Wrong number of arguments to #any-of? predicate. Expected at least 1, got {}.", + p.len() - 1 + ))); + } + if p[1].type_ != TYPE_CAPTURE { + return Err(predicate_error(row, format!( + "First argument to #any-of? predicate must be a capture name. Got literal \"{}\".", + string_values[p[1].value_id as usize], + ))); + } + + let is_positive = operator_name == "any-of?"; + let mut values = Vec::new(); + for arg in &p[2..] { + if arg.type_ == TYPE_CAPTURE { + return Err(predicate_error(row, format!( + "Arguments to #any-of? predicate must be literals. Got capture @{}.", + capture_names[arg.value_id as usize], + ))); + } + values.push(string_values[arg.value_id as usize]); + } + text_predicates.push(TextPredicateCapture::AnyString( + p[1].value_id, + values + .iter() + .map(|x| (*x).to_string().into()) + .collect::<Vec<_>>() + .into(), + is_positive, + )); + } + + _ => general_predicates.push(QueryPredicate { + operator: operator_name.to_string().into(), + args: p[1..] + .iter() + .map(|a| { + if a.type_ == TYPE_CAPTURE { + QueryPredicateArg::Capture(a.value_id) + } else { + QueryPredicateArg::String( + string_values[a.value_id as usize].to_string().into(), + ) + } + }) + .collect(), + }), + } + } + + text_predicates_vec.push(text_predicates.into()); + property_predicates_vec.push(property_predicates.into()); + property_settings_vec.push(property_settings.into()); + general_predicates_vec.push(general_predicates.into()); + } + + fn predicates<'a>( + &'a self, + pattern_index: u32, + ) -> impl Iterator<Item = &'a [PredicateStep]> + 'a { + let predicate_steps = unsafe { + let mut len = 0u32; + let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern_index, &mut len); + (len != 0) + .then(|| slice::from_raw_parts(raw_predicates, len as usize)) + .unwrap_or_default() + }; + predicate_steps + .split(|step| step.kind == PredicateStepKind::Done) + .filter(|predicate| !predicate.is_empty()) + } + + /// Safety: value_idx must be a valid string id (in bounds) for this query and pattern_index + unsafe fn get_pattern_string(&self, value_id: u32) -> &str { + unsafe { + let mut len = 0; + let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len); + let data = slice::from_raw_parts(ptr, len as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(data) + } + } + + #[inline] + pub fn capture_name(&self, capture_idx: u32) -> &str { + // this one needs an assertions because the ts c api is inconsisent + // and unsafe, other functions do have checks and would return null + assert!(capture_idx <= self.num_captures, "invalid capture index"); + let mut length = 0; + unsafe { + let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length); + let name = slice::from_raw_parts(ptr, length as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(name) + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParserErrorLocation { + pub path: PathBuf, + /// at which line the error occured + pub line: usize, + /// at which codepoints/columns the errors starts in the line + pub column: usize, + /// how many codepoints/columns the error takes up + pub len: usize, + line_content: String, +} + +impl ParserErrorLocation { + pub fn new(source: &str, path: &Path, offset: usize, len: usize) -> ParserErrorLocation { + let (line, line_content) = source[..offset] + .split('\n') + .map(|line| line.strip_suffix('\r').unwrap_or(line)) + .enumerate() + .last() + .unwrap_or((0, "")); + let column = line_content.chars().count(); + ParserErrorLocation { + path: path.to_owned(), + line, + column, + len, + line_content: line_content.to_owned(), + } + } +} + +impl Display for ParserErrorLocation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + " --> {}:{}:{}", + self.path.display(), + self.line, + self.column + )?; + let line = self.line.to_string(); + let prefix = format_args!(" {:width$} |", "", width = line.len()); + writeln!(f, "{prefix}"); + writeln!(f, " {line} | {}", self.line_content)?; + writeln!( + f, + "{prefix}{:width$}{:^<len$}", + "", + "^", + width = self.column, + len = self.len + )?; + writeln!(f, "{prefix}") + } +} + +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub enum ParseError { + #[error("unexpected EOF")] + UnexpectedEof, + #[error("invalid query syntax\n{0}")] + SyntaxError(ParserErrorLocation), + #[error("invalid node type {node:?}\n{location}")] + InvalidNodeType { + node: String, + location: ParserErrorLocation, + }, + #[error("invalid field name {field:?}\n{location}")] + InvalidFieldName { + field: String, + location: ParserErrorLocation, + }, + #[error("invalid capture name {capture:?}\n{location}")] + InvalidCaptureName { + capture: String, + location: ParserErrorLocation, + }, + #[error("{message}\n{location}")] + InvalidPredicate { + message: String, + location: ParserErrorLocation, + }, + #[error("invalid predicate\n{0}")] + ImpossiblePattern(ParserErrorLocation), +} + +#[repr(C)] +enum RawQueryError { + None = 0, + Syntax = 1, + NodeType = 2, + Field = 3, + Capture = 4, + Structure = 5, + Language = 6, +} + +#[repr(C)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredicateStepKind { + Done = 0, + Capture = 1, + String = 2, +} + +#[repr(C)] +struct PredicateStep { + kind: PredicateStepKind, + value_id: u32, +} + +struct Predicate<'a> { + operator_name: &'a str, + args: &'a [PredicateStep], + query: &'a Query, +} + +impl<'a> Predicate<'a> { + unsafe fn new( + query: &'a Query, + predicate: &'a [PredicateStep], + ) -> Result<Predicate<'a>, String> { + ensure!( + predicate[0].kind == PredicateStepKind::String, + "expected predicate to start with a function name. Got @{}.", + query.capture_name(predicate[0].value_id) + ); + let operator_name = query.get_pattern_string(predicate[0].value_id); + Ok(Predicate { + operator_name, + args: &predicate[1..], + query, + }) + } + pub fn check_arg_count(&self, n: usize) -> Result<(), String> { + ensure!( + self.args.len() == n, + "expected {n} arguments for #{}, got {}", + self.operator_name, + self.args.len() + ); + Ok(()) + } + + pub fn get_arg(&self, i: usize, expect: PredicateArg) -> Result<u32, String> { + let (val, actual) = self.get_any_arg(i); + match (actual, expect) { + (PredicateArg::Capture, PredicateArg::String) => bail!( + "{i}. argument to #{} expected a capture, got literal {val:?}", + self.operator_name + ), + (PredicateArg::String, PredicateArg::Capture) => bail!( + "{i}. argument to #{} must be a literal, got capture @{val:?}", + self.operator_name + ), + _ => (), + }; + Ok(val) + } + pub fn get_str_arg(&self, i: usize) -> Result<&'a str, String> { + let arg = self.get_arg(i, PredicateArg::String)?; + unsafe { Ok(self.query.get_pattern_string(arg)) } + } + + pub fn get_any_arg(&self, i: usize) -> (u32, PredicateArg) { + match self.args[i].kind { + PredicateStepKind::String => unsafe { (self.args[i].value_id, PredicateArg::String) }, + PredicateStepKind::Capture => (self.args[i].value_id, PredicateArg::Capture), + PredicateStepKind::Done => unreachable!(), + } + } +} + +enum PredicateArg { + Capture, + String, +} + +extern "C" { + /// Create a new query from a string containing one or more S-expression + /// patterns. The query is associated with a particular language, and can + /// only be run on syntax nodes parsed with that language. If all of the + /// given patterns are valid, this returns a [`TSQuery`]. If a pattern is + /// invalid, this returns `NULL`, and provides two pieces of information + /// about the problem: 1. The byte offset of the error is written to + /// the `error_offset` parameter. 2. The type of error is written to the + /// `error_type` parameter. + pub fn ts_query_new( + grammar: Grammar, + source: *const u8, + source_len: u32, + error_offset: &mut u32, + error_type: &mut RawQueryError, + ) -> Option<NonNull<QueryData>>; + + /// Delete a query, freeing all of the memory that it used. + pub fn ts_query_delete(query: NonNull<QueryData>); + + /// Get the number of patterns, captures, or string literals in the query. + pub fn ts_query_pattern_count(query: NonNull<QueryData>) -> u32; + pub fn ts_query_capture_count(query: NonNull<QueryData>) -> u32; + pub fn ts_query_string_count(query: NonNull<QueryData>) -> u32; + + /// Get the byte offset where the given pattern starts in the query's + /// source. This can be useful when combining queries by concatenating their + /// source code strings. + pub fn ts_query_start_byte_for_pattern(query: NonNull<QueryData>, pattern_index: u32) -> u32; + + /// Get all of the predicates for the given pattern in the query. The + /// predicates are represented as a single array of steps. There are three + /// types of steps in this array, which correspond to the three legal values + /// for the `type` field: - `TSQueryPredicateStepTypeCapture` - Steps with + /// this type represent names of captures. Their `value_id` can be used + /// with the [`ts_query_capture_name_for_id`] function to obtain the name + /// of the capture. - `TSQueryPredicateStepTypeString` - Steps with this + /// type represent literal strings. Their `value_id` can be used with the + /// [`ts_query_string_value_for_id`] function to obtain their string value. + /// - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels* + /// that represent the end of an individual predicate. If a pattern has two + /// predicates, then there will be two steps with this `type` in the array. + pub fn ts_query_predicates_for_pattern( + query: NonNull<QueryData>, + pattern_index: u32, + step_count: &mut u32, + ) -> *const PredicateStep; + + pub fn ts_query_is_pattern_rooted(query: NonNull<QueryData>, pattern_index: u32) -> bool; + pub fn ts_query_is_pattern_non_local(query: NonNull<QueryData>, pattern_index: u32) -> bool; + pub fn ts_query_is_pattern_guaranteed_at_step( + query: NonNull<QueryData>, + byte_offset: u32, + ) -> bool; + /// Get the name and length of one of the query's captures, or one of the + /// query's string literals. Each capture and string is associated with a + /// numeric id based on the order that it appeared in the query's source. + pub fn ts_query_capture_name_for_id( + query: NonNull<QueryData>, + index: u32, + length: &mut u32, + ) -> *const u8; + + pub fn ts_query_string_value_for_id( + self_: NonNull<QueryData>, + index: u32, + length: &mut u32, + ) -> *const u8; +} |