Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query.rs')
| -rw-r--r-- | helix-syntax/src/tree_sitter/query.rs | 451 |
1 files changed, 451 insertions, 0 deletions
diff --git a/helix-syntax/src/tree_sitter/query.rs b/helix-syntax/src/tree_sitter/query.rs new file mode 100644 index 00000000..69a39417 --- /dev/null +++ b/helix-syntax/src/tree_sitter/query.rs @@ -0,0 +1,451 @@ +use std::fmt::{self, Display}; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::ptr::NonNull; +use std::{slice, str}; + +use crate::tree_sitter::query::predicate::{InvalidPredicateError, Predicate, TextPredicate}; +use crate::tree_sitter::Grammar; + +mod predicate; +mod property; + +pub enum UserPredicate<'a> { + IsPropertySet { + negate: bool, + key: &'a str, + val: Option<&'a str>, + }, + SetProperty { + key: &'a str, + val: Option<&'a str>, + }, + Other(Predicate<'a>), +} + +impl Display for UserPredicate<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + UserPredicate::IsPropertySet { negate, key, val } => { + let predicate = if negate { "is-not?" } else { "is?" }; + write!(f, " ({predicate} {key} {})", val.unwrap_or("")) + } + UserPredicate::SetProperty { key, val } => { + write!(f, "(set! {key} {})", val.unwrap_or("")) + } + UserPredicate::Other(ref predicate) => { + write!(f, "{}", predicate.name()) + } + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Pattern(pub(crate) u32); + +impl Pattern { + pub const SENTINEL: Pattern = Pattern(u32::MAX); + pub fn idx(&self) -> usize { + self.0 as usize + } +} + +pub enum QueryData {} + +#[derive(Debug)] +pub(super) struct PatternData { + text_predicates: Range<u32>, +} + +#[derive(Debug)] +pub struct Query { + pub(crate) raw: NonNull<QueryData>, + num_captures: u32, + num_strings: u32, + text_predicates: Vec<TextPredicate>, + patterns: Box<[PatternData]>, +} + +impl Query { + /// Create a new query from a string containing one or more S-expression + /// patterns. + /// + /// The query is associated with a particular grammar, and can only be run + /// on syntax nodes parsed with that grammar. References to Queries can be + /// shared between multiple threads. + pub fn new( + grammar: Grammar, + source: &str, + path: impl AsRef<Path>, + mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>, + ) -> Result<Self, ParseError> { + assert!( + source.len() <= i32::MAX as usize, + "TreeSitter queries must be smaller then 2 GiB (is {})", + source.len() as f64 / 1024.0 / 1024.0 / 1024.0 + ); + let mut error_offset = 0u32; + let mut error_kind = RawQueryError::None; + let bytes = source.as_bytes(); + + // Compile the query. + let ptr = unsafe { + ts_query_new( + grammar, + bytes.as_ptr(), + bytes.len() as u32, + &mut error_offset, + &mut error_kind, + ) + }; + + let Some(raw) = ptr else { + let offset = error_offset as usize; + let error_word = || { + source[offset..] + .chars() + .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-')) + .collect() + }; + let err = match error_kind { + RawQueryError::NodeType => { + let node: String = error_word(); + ParseError::InvalidNodeType { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + node.chars().count(), + ), + node, + } + } + RawQueryError::Field => { + let field = error_word(); + ParseError::InvalidFieldName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + field.chars().count(), + ), + field, + } + } + RawQueryError::Capture => { + let capture = error_word(); + ParseError::InvalidCaptureName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + capture.chars().count(), + ), + capture, + } + } + RawQueryError::Syntax => ParseError::SyntaxError(ParserErrorLocation::new( + source, + path.as_ref(), + offset, + 0, + )), + RawQueryError::Structure => ParseError::ImpossiblePattern( + ParserErrorLocation::new(source, path.as_ref(), offset, 0), + ), + RawQueryError::None => { + unreachable!("tree-sitter returned a null pointer but did not set an error") + } + RawQueryError::Language => unreachable!("should be handled at grammar load"), + }; + return Err(err); + }; + + // I am not going to bother with safety comments here, all of these are + // safe as long as TS is not buggy because raw is a properly constructed query + let num_captures = unsafe { ts_query_capture_count(raw) }; + let num_strings = unsafe { ts_query_string_count(raw) }; + let num_patterns = unsafe { ts_query_pattern_count(raw) }; + + let mut query = Query { + raw, + num_captures, + num_strings, + text_predicates: Vec::new(), + patterns: Box::default(), + }; + let patterns: Result<_, ParseError> = (0..num_patterns) + .map(|pattern| { + query + .parse_pattern_predicates(Pattern(pattern), &mut custom_predicate) + .map_err(|err| ParseError::InvalidPredicate { + message: err.msg.into(), + location: ParserErrorLocation::new( + source, + path.as_ref(), + unsafe { ts_query_start_byte_for_pattern(query.raw, pattern) as usize }, + 0, + ), + }) + }) + .collect(); + query.patterns = patterns?; + Ok(query) + } + + #[inline] + fn get_string(&self, str: QueryStr) -> &str { + let value_id = str.0; + // need an assertions because the ts c api does not do bounds check + assert!(value_id <= self.num_captures, "invalid value index"); + unsafe { + let mut len = 0; + let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len); + let data = slice::from_raw_parts(ptr, len as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(data) + } + } + + #[inline] + pub fn capture_name(&self, capture_idx: Capture) -> &str { + let capture_idx = capture_idx.0; + // need an assertions because the ts c api does not do bounds check + assert!(capture_idx <= self.num_captures, "invalid capture index"); + let mut length = 0; + unsafe { + let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length); + let name = slice::from_raw_parts(ptr, length as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(name) + } + } + + #[inline] + pub fn captures(&self) -> impl ExactSizeIterator<Item = (Capture, &str)> { + (0..self.num_captures).map(|cap| (Capture(cap), self.capture_name(Capture(cap)))) + } + + #[inline] + pub fn num_captures(&self) -> u32 { + self.num_captures + } + + #[inline] + pub fn get_capture(&self, capture_name: &str) -> Option<Capture> { + for capture in 0..self.num_captures { + if capture_name == self.capture_name(Capture(capture)) { + return Some(Capture(capture)); + } + } + None + } + + pub(crate) fn pattern_text_predicates(&self, pattern_idx: u16) -> &[TextPredicate] { + let range = self.patterns[pattern_idx as usize].text_predicates.clone(); + &self.text_predicates[range.start as usize..range.end as usize] + } + + /// Get the byte offset where the given pattern starts in the query's + /// source. + #[doc(alias = "ts_query_start_byte_for_pattern")] + #[must_use] + pub fn start_byte_for_pattern(&self, pattern: Pattern) -> usize { + assert!( + pattern.0 < self.text_predicates.len() as u32, + "Pattern index is {pattern_index} but the pattern count is {}", + self.text_predicates.len(), + ); + unsafe { ts_query_start_byte_for_pattern(self.raw, pattern.0) as usize } + } + + /// Get the number of patterns in the query. + #[must_use] + pub fn pattern_count(&self) -> usize { + unsafe { ts_query_pattern_count(self.raw) as usize } + } + /// Get the number of patterns in the query. + #[must_use] + pub fn patterns(&self) -> impl ExactSizeIterator<Item = Pattern> { + (0..self.pattern_count() as u32).map(Pattern) + } +} + +impl Drop for Query { + fn drop(&mut self) { + unsafe { ts_query_delete(self.raw) } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(transparent)] +pub struct Capture(u32); + +impl Capture { + pub fn name(self, query: &Query) -> &str { + query.capture_name(self) + } + pub fn idx(self) -> usize { + self.0 as usize + } +} + +/// A reference to a string stroed in a query +#[derive(Clone, Copy, Debug)] +pub struct QueryStr(u32); + +impl QueryStr { + pub fn get(self, query: &Query) -> &str { + query.get_string(self) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParserErrorLocation { + pub path: PathBuf, + /// at which line the error occured + pub line: usize, + /// at which codepoints/columns the errors starts in the line + pub column: usize, + /// how many codepoints/columns the error takes up + pub len: usize, + line_content: String, +} + +impl ParserErrorLocation { + pub fn new(source: &str, path: &Path, offset: usize, len: usize) -> ParserErrorLocation { + let (line, line_content) = source[..offset] + .split('\n') + .map(|line| line.strip_suffix('\r').unwrap_or(line)) + .enumerate() + .last() + .unwrap_or((0, "")); + let column = line_content.chars().count(); + ParserErrorLocation { + path: path.to_owned(), + line, + column, + len, + line_content: line_content.to_owned(), + } + } +} + +impl Display for ParserErrorLocation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + " --> {}:{}:{}", + self.path.display(), + self.line, + self.column + )?; + let line = self.line.to_string(); + let prefix = format!(" {:width$} |", "", width = line.len()); + writeln!(f, "{prefix}")?; + writeln!(f, " {line} | {}", self.line_content)?; + writeln!( + f, + "{prefix}{:width$}{:^<len$}", + "", + "^", + width = self.column, + len = self.len + )?; + writeln!(f, "{prefix}") + } +} + +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub enum ParseError { + #[error("unexpected EOF")] + UnexpectedEof, + #[error("invalid query syntax\n{0}")] + SyntaxError(ParserErrorLocation), + #[error("invalid node type {node:?}\n{location}")] + InvalidNodeType { + node: String, + location: ParserErrorLocation, + }, + #[error("invalid field name {field:?}\n{location}")] + InvalidFieldName { + field: String, + location: ParserErrorLocation, + }, + #[error("invalid capture name {capture:?}\n{location}")] + InvalidCaptureName { + capture: String, + location: ParserErrorLocation, + }, + #[error("{message}\n{location}")] + InvalidPredicate { + message: String, + location: ParserErrorLocation, + }, + #[error("invalid predicate\n{0}")] + ImpossiblePattern(ParserErrorLocation), +} + +#[repr(C)] +enum RawQueryError { + None = 0, + Syntax = 1, + NodeType = 2, + Field = 3, + Capture = 4, + Structure = 5, + Language = 6, +} + +extern "C" { + /// Create a new query from a string containing one or more S-expression + /// patterns. The query is associated with a particular language, and can + /// only be run on syntax nodes parsed with that language. If all of the + /// given patterns are valid, this returns a [`TSQuery`]. If a pattern is + /// invalid, this returns `NULL`, and provides two pieces of information + /// about the problem: 1. The byte offset of the error is written to + /// the `error_offset` parameter. 2. The type of error is written to the + /// `error_type` parameter. + fn ts_query_new( + grammar: Grammar, + source: *const u8, + source_len: u32, + error_offset: &mut u32, + error_type: &mut RawQueryError, + ) -> Option<NonNull<QueryData>>; + + /// Delete a query, freeing all of the memory that it used. + fn ts_query_delete(query: NonNull<QueryData>); + + /// Get the number of patterns, captures, or string literals in the query. + fn ts_query_pattern_count(query: NonNull<QueryData>) -> u32; + fn ts_query_capture_count(query: NonNull<QueryData>) -> u32; + fn ts_query_string_count(query: NonNull<QueryData>) -> u32; + + /// Get the byte offset where the given pattern starts in the query's + /// source. This can be useful when combining queries by concatenating their + /// source code strings. + fn ts_query_start_byte_for_pattern(query: NonNull<QueryData>, pattern_index: u32) -> u32; + + // fn ts_query_is_pattern_rooted(query: NonNull<QueryData>, pattern_index: u32) -> bool; + // fn ts_query_is_pattern_non_local(query: NonNull<QueryData>, pattern_index: u32) -> bool; + // fn ts_query_is_pattern_guaranteed_at_step(query: NonNull<QueryData>, byte_offset: u32) -> bool; + /// Get the name and length of one of the query's captures, or one of the + /// query's string literals. Each capture and string is associated with a + /// numeric id based on the order that it appeared in the query's source. + fn ts_query_capture_name_for_id( + query: NonNull<QueryData>, + index: u32, + length: &mut u32, + ) -> *const u8; + + fn ts_query_string_value_for_id( + self_: NonNull<QueryData>, + index: u32, + length: &mut u32, + ) -> *const u8; +} |