Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-syntax/src/tree_sitter/query.rs')
-rw-r--r--helix-syntax/src/tree_sitter/query.rs451
1 files changed, 451 insertions, 0 deletions
diff --git a/helix-syntax/src/tree_sitter/query.rs b/helix-syntax/src/tree_sitter/query.rs
new file mode 100644
index 00000000..69a39417
--- /dev/null
+++ b/helix-syntax/src/tree_sitter/query.rs
@@ -0,0 +1,451 @@
+use std::fmt::{self, Display};
+use std::ops::Range;
+use std::path::{Path, PathBuf};
+use std::ptr::NonNull;
+use std::{slice, str};
+
+use crate::tree_sitter::query::predicate::{InvalidPredicateError, Predicate, TextPredicate};
+use crate::tree_sitter::Grammar;
+
+mod predicate;
+mod property;
+
+pub enum UserPredicate<'a> {
+ IsPropertySet {
+ negate: bool,
+ key: &'a str,
+ val: Option<&'a str>,
+ },
+ SetProperty {
+ key: &'a str,
+ val: Option<&'a str>,
+ },
+ Other(Predicate<'a>),
+}
+
+impl Display for UserPredicate<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ UserPredicate::IsPropertySet { negate, key, val } => {
+ let predicate = if negate { "is-not?" } else { "is?" };
+ write!(f, " ({predicate} {key} {})", val.unwrap_or(""))
+ }
+ UserPredicate::SetProperty { key, val } => {
+ write!(f, "(set! {key} {})", val.unwrap_or(""))
+ }
+ UserPredicate::Other(ref predicate) => {
+ write!(f, "{}", predicate.name())
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub struct Pattern(pub(crate) u32);
+
+impl Pattern {
+ pub const SENTINEL: Pattern = Pattern(u32::MAX);
+ pub fn idx(&self) -> usize {
+ self.0 as usize
+ }
+}
+
+pub enum QueryData {}
+
+#[derive(Debug)]
+pub(super) struct PatternData {
+ text_predicates: Range<u32>,
+}
+
+#[derive(Debug)]
+pub struct Query {
+ pub(crate) raw: NonNull<QueryData>,
+ num_captures: u32,
+ num_strings: u32,
+ text_predicates: Vec<TextPredicate>,
+ patterns: Box<[PatternData]>,
+}
+
+impl Query {
+ /// Create a new query from a string containing one or more S-expression
+ /// patterns.
+ ///
+ /// The query is associated with a particular grammar, and can only be run
+ /// on syntax nodes parsed with that grammar. References to Queries can be
+ /// shared between multiple threads.
+ pub fn new(
+ grammar: Grammar,
+ source: &str,
+ path: impl AsRef<Path>,
+ mut custom_predicate: impl FnMut(Pattern, UserPredicate) -> Result<(), InvalidPredicateError>,
+ ) -> Result<Self, ParseError> {
+ assert!(
+ source.len() <= i32::MAX as usize,
+ "TreeSitter queries must be smaller then 2 GiB (is {})",
+ source.len() as f64 / 1024.0 / 1024.0 / 1024.0
+ );
+ let mut error_offset = 0u32;
+ let mut error_kind = RawQueryError::None;
+ let bytes = source.as_bytes();
+
+ // Compile the query.
+ let ptr = unsafe {
+ ts_query_new(
+ grammar,
+ bytes.as_ptr(),
+ bytes.len() as u32,
+ &mut error_offset,
+ &mut error_kind,
+ )
+ };
+
+ let Some(raw) = ptr else {
+ let offset = error_offset as usize;
+ let error_word = || {
+ source[offset..]
+ .chars()
+ .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-'))
+ .collect()
+ };
+ let err = match error_kind {
+ RawQueryError::NodeType => {
+ let node: String = error_word();
+ ParseError::InvalidNodeType {
+ location: ParserErrorLocation::new(
+ source,
+ path.as_ref(),
+ offset,
+ node.chars().count(),
+ ),
+ node,
+ }
+ }
+ RawQueryError::Field => {
+ let field = error_word();
+ ParseError::InvalidFieldName {
+ location: ParserErrorLocation::new(
+ source,
+ path.as_ref(),
+ offset,
+ field.chars().count(),
+ ),
+ field,
+ }
+ }
+ RawQueryError::Capture => {
+ let capture = error_word();
+ ParseError::InvalidCaptureName {
+ location: ParserErrorLocation::new(
+ source,
+ path.as_ref(),
+ offset,
+ capture.chars().count(),
+ ),
+ capture,
+ }
+ }
+ RawQueryError::Syntax => ParseError::SyntaxError(ParserErrorLocation::new(
+ source,
+ path.as_ref(),
+ offset,
+ 0,
+ )),
+ RawQueryError::Structure => ParseError::ImpossiblePattern(
+ ParserErrorLocation::new(source, path.as_ref(), offset, 0),
+ ),
+ RawQueryError::None => {
+ unreachable!("tree-sitter returned a null pointer but did not set an error")
+ }
+ RawQueryError::Language => unreachable!("should be handled at grammar load"),
+ };
+ return Err(err);
+ };
+
+ // I am not going to bother with safety comments here, all of these are
+ // safe as long as TS is not buggy because raw is a properly constructed query
+ let num_captures = unsafe { ts_query_capture_count(raw) };
+ let num_strings = unsafe { ts_query_string_count(raw) };
+ let num_patterns = unsafe { ts_query_pattern_count(raw) };
+
+ let mut query = Query {
+ raw,
+ num_captures,
+ num_strings,
+ text_predicates: Vec::new(),
+ patterns: Box::default(),
+ };
+ let patterns: Result<_, ParseError> = (0..num_patterns)
+ .map(|pattern| {
+ query
+ .parse_pattern_predicates(Pattern(pattern), &mut custom_predicate)
+ .map_err(|err| ParseError::InvalidPredicate {
+ message: err.msg.into(),
+ location: ParserErrorLocation::new(
+ source,
+ path.as_ref(),
+ unsafe { ts_query_start_byte_for_pattern(query.raw, pattern) as usize },
+ 0,
+ ),
+ })
+ })
+ .collect();
+ query.patterns = patterns?;
+ Ok(query)
+ }
+
+ #[inline]
+ fn get_string(&self, str: QueryStr) -> &str {
+ let value_id = str.0;
+ // need an assertions because the ts c api does not do bounds check
+ assert!(value_id <= self.num_captures, "invalid value index");
+ unsafe {
+ let mut len = 0;
+ let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len);
+ let data = slice::from_raw_parts(ptr, len as usize);
+ // safety: we only allow passing valid str(ings) as arguments to query::new
+ // name is always a substring of that. Treesitter does proper utf8 segmentation
+ // so any substrings it produces are codepoint aligned and therefore valid utf8
+ str::from_utf8_unchecked(data)
+ }
+ }
+
+ #[inline]
+ pub fn capture_name(&self, capture_idx: Capture) -> &str {
+ let capture_idx = capture_idx.0;
+ // need an assertions because the ts c api does not do bounds check
+ assert!(capture_idx <= self.num_captures, "invalid capture index");
+ let mut length = 0;
+ unsafe {
+ let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length);
+ let name = slice::from_raw_parts(ptr, length as usize);
+ // safety: we only allow passing valid str(ings) as arguments to query::new
+ // name is always a substring of that. Treesitter does proper utf8 segmentation
+ // so any substrings it produces are codepoint aligned and therefore valid utf8
+ str::from_utf8_unchecked(name)
+ }
+ }
+
+ #[inline]
+ pub fn captures(&self) -> impl ExactSizeIterator<Item = (Capture, &str)> {
+ (0..self.num_captures).map(|cap| (Capture(cap), self.capture_name(Capture(cap))))
+ }
+
+ #[inline]
+ pub fn num_captures(&self) -> u32 {
+ self.num_captures
+ }
+
+ #[inline]
+ pub fn get_capture(&self, capture_name: &str) -> Option<Capture> {
+ for capture in 0..self.num_captures {
+ if capture_name == self.capture_name(Capture(capture)) {
+ return Some(Capture(capture));
+ }
+ }
+ None
+ }
+
+ pub(crate) fn pattern_text_predicates(&self, pattern_idx: u16) -> &[TextPredicate] {
+ let range = self.patterns[pattern_idx as usize].text_predicates.clone();
+ &self.text_predicates[range.start as usize..range.end as usize]
+ }
+
+ /// Get the byte offset where the given pattern starts in the query's
+ /// source.
+ #[doc(alias = "ts_query_start_byte_for_pattern")]
+ #[must_use]
+ pub fn start_byte_for_pattern(&self, pattern: Pattern) -> usize {
+ assert!(
+ pattern.0 < self.text_predicates.len() as u32,
+ "Pattern index is {pattern_index} but the pattern count is {}",
+ self.text_predicates.len(),
+ );
+ unsafe { ts_query_start_byte_for_pattern(self.raw, pattern.0) as usize }
+ }
+
+ /// Get the number of patterns in the query.
+ #[must_use]
+ pub fn pattern_count(&self) -> usize {
+ unsafe { ts_query_pattern_count(self.raw) as usize }
+ }
+ /// Get the number of patterns in the query.
+ #[must_use]
+ pub fn patterns(&self) -> impl ExactSizeIterator<Item = Pattern> {
+ (0..self.pattern_count() as u32).map(Pattern)
+ }
+}
+
+impl Drop for Query {
+ fn drop(&mut self) {
+ unsafe { ts_query_delete(self.raw) }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[repr(transparent)]
+pub struct Capture(u32);
+
+impl Capture {
+ pub fn name(self, query: &Query) -> &str {
+ query.capture_name(self)
+ }
+ pub fn idx(self) -> usize {
+ self.0 as usize
+ }
+}
+
+/// A reference to a string stroed in a query
+#[derive(Clone, Copy, Debug)]
+pub struct QueryStr(u32);
+
+impl QueryStr {
+ pub fn get(self, query: &Query) -> &str {
+ query.get_string(self)
+ }
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct ParserErrorLocation {
+ pub path: PathBuf,
+ /// at which line the error occured
+ pub line: usize,
+ /// at which codepoints/columns the errors starts in the line
+ pub column: usize,
+ /// how many codepoints/columns the error takes up
+ pub len: usize,
+ line_content: String,
+}
+
+impl ParserErrorLocation {
+ pub fn new(source: &str, path: &Path, offset: usize, len: usize) -> ParserErrorLocation {
+ let (line, line_content) = source[..offset]
+ .split('\n')
+ .map(|line| line.strip_suffix('\r').unwrap_or(line))
+ .enumerate()
+ .last()
+ .unwrap_or((0, ""));
+ let column = line_content.chars().count();
+ ParserErrorLocation {
+ path: path.to_owned(),
+ line,
+ column,
+ len,
+ line_content: line_content.to_owned(),
+ }
+ }
+}
+
+impl Display for ParserErrorLocation {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ writeln!(
+ f,
+ " --> {}:{}:{}",
+ self.path.display(),
+ self.line,
+ self.column
+ )?;
+ let line = self.line.to_string();
+ let prefix = format!(" {:width$} |", "", width = line.len());
+ writeln!(f, "{prefix}")?;
+ writeln!(f, " {line} | {}", self.line_content)?;
+ writeln!(
+ f,
+ "{prefix}{:width$}{:^<len$}",
+ "",
+ "^",
+ width = self.column,
+ len = self.len
+ )?;
+ writeln!(f, "{prefix}")
+ }
+}
+
+#[derive(thiserror::Error, Debug, PartialEq, Eq)]
+pub enum ParseError {
+ #[error("unexpected EOF")]
+ UnexpectedEof,
+ #[error("invalid query syntax\n{0}")]
+ SyntaxError(ParserErrorLocation),
+ #[error("invalid node type {node:?}\n{location}")]
+ InvalidNodeType {
+ node: String,
+ location: ParserErrorLocation,
+ },
+ #[error("invalid field name {field:?}\n{location}")]
+ InvalidFieldName {
+ field: String,
+ location: ParserErrorLocation,
+ },
+ #[error("invalid capture name {capture:?}\n{location}")]
+ InvalidCaptureName {
+ capture: String,
+ location: ParserErrorLocation,
+ },
+ #[error("{message}\n{location}")]
+ InvalidPredicate {
+ message: String,
+ location: ParserErrorLocation,
+ },
+ #[error("invalid predicate\n{0}")]
+ ImpossiblePattern(ParserErrorLocation),
+}
+
+#[repr(C)]
+enum RawQueryError {
+ None = 0,
+ Syntax = 1,
+ NodeType = 2,
+ Field = 3,
+ Capture = 4,
+ Structure = 5,
+ Language = 6,
+}
+
+extern "C" {
+ /// Create a new query from a string containing one or more S-expression
+ /// patterns. The query is associated with a particular language, and can
+ /// only be run on syntax nodes parsed with that language. If all of the
+ /// given patterns are valid, this returns a [`TSQuery`]. If a pattern is
+ /// invalid, this returns `NULL`, and provides two pieces of information
+ /// about the problem: 1. The byte offset of the error is written to
+ /// the `error_offset` parameter. 2. The type of error is written to the
+ /// `error_type` parameter.
+ fn ts_query_new(
+ grammar: Grammar,
+ source: *const u8,
+ source_len: u32,
+ error_offset: &mut u32,
+ error_type: &mut RawQueryError,
+ ) -> Option<NonNull<QueryData>>;
+
+ /// Delete a query, freeing all of the memory that it used.
+ fn ts_query_delete(query: NonNull<QueryData>);
+
+ /// Get the number of patterns, captures, or string literals in the query.
+ fn ts_query_pattern_count(query: NonNull<QueryData>) -> u32;
+ fn ts_query_capture_count(query: NonNull<QueryData>) -> u32;
+ fn ts_query_string_count(query: NonNull<QueryData>) -> u32;
+
+ /// Get the byte offset where the given pattern starts in the query's
+ /// source. This can be useful when combining queries by concatenating their
+ /// source code strings.
+ fn ts_query_start_byte_for_pattern(query: NonNull<QueryData>, pattern_index: u32) -> u32;
+
+ // fn ts_query_is_pattern_rooted(query: NonNull<QueryData>, pattern_index: u32) -> bool;
+ // fn ts_query_is_pattern_non_local(query: NonNull<QueryData>, pattern_index: u32) -> bool;
+ // fn ts_query_is_pattern_guaranteed_at_step(query: NonNull<QueryData>, byte_offset: u32) -> bool;
+ /// Get the name and length of one of the query's captures, or one of the
+ /// query's string literals. Each capture and string is associated with a
+ /// numeric id based on the order that it appeared in the query's source.
+ fn ts_query_capture_name_for_id(
+ query: NonNull<QueryData>,
+ index: u32,
+ length: &mut u32,
+ ) -> *const u8;
+
+ fn ts_query_string_value_for_id(
+ self_: NonNull<QueryData>,
+ index: u32,
+ length: &mut u32,
+ ) -> *const u8;
+}