use std::fmt::{self, Display}; use std::ops::Range; use std::path::{Path, PathBuf}; use std::ptr::NonNull; use std::{slice, str}; use crate::tree_sitter::query::predicate::{InvalidPredicateError, Predicate, TextPredicate}; use crate::tree_sitter::Grammar; mod predicate; mod property; pub enum UserPredicate<'a> { IsPropertySet { negate: bool, key: &'a str, val: Option<&'a str>, }, SetProperty { key: &'a str, val: Option<&'a str>, }, Other(Predicate<'a>), } impl Display for UserPredicate<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { UserPredicate::IsPropertySet { negate, key, val } => { let predicate = if negate { "is-not?" } else { "is?" }; write!(f, " ({predicate} {key} {})", val.unwrap_or("")) } UserPredicate::SetProperty { key, val } => { write!(f, "(set! {key} {})", val.unwrap_or("")) } UserPredicate::Other(ref predicate) => { write!(f, "{}", predicate.name()) } } } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct Pattern(pub(crate) u32); impl Pattern { pub const SENTINEL: Pattern = Pattern(u32::MAX); pub fn idx(&self) -> usize { self.0 as usize } } pub enum QueryData {} #[derive(Debug)] pub(super) struct PatternData { text_predicates: Range, } #[derive(Debug)] pub struct Query { pub(crate) raw: NonNull, num_captures: u32, num_strings: u32, text_predicates: Vec, patterns: Box<[PatternData]>, } impl Query { /// Create a new query from a string containing one or more S-expression /// patterns. /// /// The query is associated with a particular grammar, and can only be run /// on syntax nodes parsed with that grammar. References to Queries can be /// shared between multiple threads. pub fn new( grammar: Grammar, source: &str, path: impl AsRef, mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>, ) -> Result { assert!( source.len() <= i32::MAX as usize, "TreeSitter queries must be smaller then 2 GiB (is {})", source.len() as f64 / 1024.0 / 1024.0 / 1024.0 ); let mut error_offset = 0u32; let mut error_kind = RawQueryError::None; let bytes = source.as_bytes(); // Compile the query. let ptr = unsafe { ts_query_new( grammar, bytes.as_ptr(), bytes.len() as u32, &mut error_offset, &mut error_kind, ) }; let Some(raw) = ptr else { let offset = error_offset as usize; let error_word = || { source[offset..] .chars() .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-')) .collect() }; let err = match error_kind { RawQueryError::NodeType => { let node: String = error_word(); ParseError::InvalidNodeType { location: ParserErrorLocation::new( source, path.as_ref(), offset, node.chars().count(), ), node, } } RawQueryError::Field => { let field = error_word(); ParseError::InvalidFieldName { location: ParserErrorLocation::new( source, path.as_ref(), offset, field.chars().count(), ), field, } } RawQueryError::Capture => { let capture = error_word(); ParseError::InvalidCaptureName { location: ParserErrorLocation::new( source, path.as_ref(), offset, capture.chars().count(), ), capture, } } RawQueryError::Syntax => ParseError::SyntaxError(ParserErrorLocation::new( source, path.as_ref(), offset, 0, )), RawQueryError::Structure => ParseError::ImpossiblePattern( ParserErrorLocation::new(source, path.as_ref(), offset, 0), ), RawQueryError::None => { unreachable!("tree-sitter returned a null pointer but did not set an error") } RawQueryError::Language => unreachable!("should be handled at grammar load"), }; return Err(err); }; // I am not going to bother with safety comments here, all of these are // safe as long as TS is not buggy because raw is a properly constructed query let num_captures = unsafe { ts_query_capture_count(raw) }; let num_strings = unsafe { ts_query_string_count(raw) }; let num_patterns = unsafe { ts_query_pattern_count(raw) }; let mut query = Query { raw, num_captures, num_strings, text_predicates: Vec::new(), patterns: Box::default(), }; let patterns: Result<_, ParseError> = (0..num_patterns) .map(|pattern| { query .parse_pattern_predicates(Pattern(pattern), &mut custom_predicate) .map_err(|err| ParseError::InvalidPredicate { message: err.msg.into(), location: ParserErrorLocation::new( source, path.as_ref(), unsafe { ts_query_start_byte_for_pattern(query.raw, pattern) as usize }, 0, ), }) }) .collect(); query.patterns = patterns?; Ok(query) } #[inline] fn get_string(&self, str: QueryStr) -> &str { let value_id = str.0; // need an assertions because the ts c api does not do bounds check assert!(value_id <= self.num_captures, "invalid value index"); unsafe { let mut len = 0; let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len); let data = slice::from_raw_parts(ptr, len as usize); // safety: we only allow passing valid str(ings) as arguments to query::new // name is always a substring of that. Treesitter does proper utf8 segmentation // so any substrings it produces are codepoint aligned and therefore valid utf8 str::from_utf8_unchecked(data) } } #[inline] pub fn capture_name(&self, capture_idx: Capture) -> &str { let capture_idx = capture_idx.0; // need an assertions because the ts c api does not do bounds check assert!(capture_idx <= self.num_captures, "invalid capture index"); let mut length = 0; unsafe { let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length); let name = slice::from_raw_parts(ptr, length as usize); // safety: we only allow passing valid str(ings) as arguments to query::new // name is always a substring of that. Treesitter does proper utf8 segmentation // so any substrings it produces are codepoint aligned and therefore valid utf8 str::from_utf8_unchecked(name) } } #[inline] pub fn captures(&self) -> impl ExactSizeIterator { (0..self.num_captures).map(|cap| (Capture(cap), self.capture_name(Capture(cap)))) } #[inline] pub fn num_captures(&self) -> u32 { self.num_captures } #[inline] pub fn get_capture(&self, capture_name: &str) -> Option { for capture in 0..self.num_captures { if capture_name == self.capture_name(Capture(capture)) { return Some(Capture(capture)); } } None } pub(crate) fn pattern_text_predicates(&self, pattern_idx: u16) -> &[TextPredicate] { let range = self.patterns[pattern_idx as usize].text_predicates.clone(); &self.text_predicates[range.start as usize..range.end as usize] } /// Get the byte offset where the given pattern starts in the query's /// source. #[doc(alias = "ts_query_start_byte_for_pattern")] #[must_use] pub fn start_byte_for_pattern(&self, pattern: Pattern) -> usize { assert!( pattern.0 < self.text_predicates.len() as u32, "Pattern index is {pattern_index} but the pattern count is {}", self.text_predicates.len(), ); unsafe { ts_query_start_byte_for_pattern(self.raw, pattern.0) as usize } } /// Get the number of patterns in the query. #[must_use] pub fn pattern_count(&self) -> usize { unsafe { ts_query_pattern_count(self.raw) as usize } } /// Get the number of patterns in the query. #[must_use] pub fn patterns(&self) -> impl ExactSizeIterator { (0..self.pattern_count() as u32).map(Pattern) } } impl Drop for Query { fn drop(&mut self) { unsafe { ts_query_delete(self.raw) } } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(transparent)] pub struct Capture(u32); impl Capture { pub fn name(self, query: &Query) -> &str { query.capture_name(self) } pub fn idx(self) -> usize { self.0 as usize } } /// A reference to a string stroed in a query #[derive(Clone, Copy, Debug)] pub struct QueryStr(u32); impl QueryStr { pub fn get(self, query: &Query) -> &str { query.get_string(self) } } #[derive(Debug, PartialEq, Eq)] pub struct ParserErrorLocation { pub path: PathBuf, /// at which line the error occured pub line: usize, /// at which codepoints/columns the errors starts in the line pub column: usize, /// how many codepoints/columns the error takes up pub len: usize, line_content: String, } impl ParserErrorLocation { pub fn new(source: &str, path: &Path, offset: usize, len: usize) -> ParserErrorLocation { let (line, line_content) = source[..offset] .split('\n') .map(|line| line.strip_suffix('\r').unwrap_or(line)) .enumerate() .last() .unwrap_or((0, "")); let column = line_content.chars().count(); ParserErrorLocation { path: path.to_owned(), line, column, len, line_content: line_content.to_owned(), } } } impl Display for ParserErrorLocation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, " --> {}:{}:{}", self.path.display(), self.line, self.column )?; let line = self.line.to_string(); let prefix = format!(" {:width$} |", "", width = line.len()); writeln!(f, "{prefix}")?; writeln!(f, " {line} | {}", self.line_content)?; writeln!( f, "{prefix}{:width$}{:^ Option>; /// Delete a query, freeing all of the memory that it used. fn ts_query_delete(query: NonNull); /// Get the number of patterns, captures, or string literals in the query. fn ts_query_pattern_count(query: NonNull) -> u32; fn ts_query_capture_count(query: NonNull) -> u32; fn ts_query_string_count(query: NonNull) -> u32; /// Get the byte offset where the given pattern starts in the query's /// source. This can be useful when combining queries by concatenating their /// source code strings. fn ts_query_start_byte_for_pattern(query: NonNull, pattern_index: u32) -> u32; // fn ts_query_is_pattern_rooted(query: NonNull, pattern_index: u32) -> bool; // fn ts_query_is_pattern_non_local(query: NonNull, pattern_index: u32) -> bool; // fn ts_query_is_pattern_guaranteed_at_step(query: NonNull, byte_offset: u32) -> bool; /// Get the name and length of one of the query's captures, or one of the /// query's string literals. Each capture and string is associated with a /// numeric id based on the order that it appeared in the query's source. fn ts_query_capture_name_for_id( query: NonNull, index: u32, length: &mut u32, ) -> *const u8; fn ts_query_string_value_for_id( self_: NonNull, index: u32, length: &mut u32, ) -> *const u8; }