Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/syntax/src/syntax_editor.rs')
-rw-r--r--crates/syntax/src/syntax_editor.rs171
1 files changed, 157 insertions, 14 deletions
diff --git a/crates/syntax/src/syntax_editor.rs b/crates/syntax/src/syntax_editor.rs
index e6937e4d0f..8e4dc75d22 100644
--- a/crates/syntax/src/syntax_editor.rs
+++ b/crates/syntax/src/syntax_editor.rs
@@ -14,7 +14,10 @@ use std::{
use rowan::TextRange;
use rustc_hash::FxHashMap;
-use crate::{SyntaxElement, SyntaxNode, SyntaxToken};
+use crate::{
+ AstNode, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T,
+ ast::{self, edit::IndentLevel, syntax_factory::SyntaxFactory},
+};
mod edit_algo;
mod edits;
@@ -32,9 +35,37 @@ pub struct SyntaxEditor {
}
impl SyntaxEditor {
- /// Creates a syntax editor to start editing from `root`
- pub fn new(root: SyntaxNode) -> Self {
- Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] }
+ /// Creates a syntax editor from `root`.
+ ///
+ /// The returned `root` is guaranteed to be a detached, immutable node.
+ /// If the provided node is not a root (i.e., has a parent) or is already
+ /// mutable, it is cloned into a fresh subtree to satisfy syntax editor
+ /// invariants.
+ pub fn new(root: SyntaxNode) -> (Self, SyntaxNode) {
+ let mut root = root;
+
+ if root.parent().is_some() || root.is_mutable() {
+ root = root.clone_subtree()
+ };
+
+ let editor = Self {
+ root: root.clone(),
+ changes: Vec::new(),
+ mappings: SyntaxMapping::default(),
+ annotations: Vec::new(),
+ };
+
+ (editor, root)
+ }
+
+ /// Typed-node variant of [`SyntaxEditor::new`].
+ pub fn with_ast_node<T>(root: &T) -> (Self, T)
+ where
+ T: AstNode,
+ {
+ let (editor, root) = Self::new(root.syntax().clone());
+
+ (editor, T::cast(root).unwrap())
}
pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) {
@@ -73,6 +104,34 @@ impl SyntaxEditor {
self.changes.push(Change::InsertAll(position, elements))
}
+ pub fn insert_with_whitespace(
+ &mut self,
+ position: Position,
+ element: impl Element,
+ factory: &SyntaxFactory,
+ ) {
+ self.insert_all_with_whitespace(position, vec![element.syntax_element()], factory)
+ }
+
+ pub fn insert_all_with_whitespace(
+ &mut self,
+ position: Position,
+ mut elements: Vec<SyntaxElement>,
+ factory: &SyntaxFactory,
+ ) {
+ if let Some(first) = elements.first()
+ && let Some(ws) = ws_before(&position, first, factory)
+ {
+ elements.insert(0, ws.into());
+ }
+ if let Some(last) = elements.last()
+ && let Some(ws) = ws_after(&position, last, factory)
+ {
+ elements.push(ws.into());
+ }
+ self.insert_all(position, elements)
+ }
+
pub fn delete(&mut self, element: impl Element) {
let element = element.syntax_element();
debug_assert!(is_ancestor_or_self_of_element(&element, &self.root));
@@ -384,6 +443,86 @@ impl Element for SyntaxToken {
}
}
+fn ws_before(
+ position: &Position,
+ new: &SyntaxElement,
+ factory: &SyntaxFactory,
+) -> Option<SyntaxToken> {
+ let prev = match &position.repr {
+ PositionRepr::FirstChild(_) => return None,
+ PositionRepr::After(it) => it,
+ };
+
+ if prev.kind() == T!['{']
+ && new.kind() == SyntaxKind::USE
+ && let Some(item_list) = prev.parent().and_then(ast::ItemList::cast)
+ {
+ let mut indent = IndentLevel::from_element(&item_list.syntax().clone().into());
+ indent.0 += 1;
+ return Some(factory.whitespace(&format!("\n{indent}")));
+ }
+
+ if prev.kind() == T!['{']
+ && ast::Stmt::can_cast(new.kind())
+ && let Some(stmt_list) = prev.parent().and_then(ast::StmtList::cast)
+ {
+ let mut indent = IndentLevel::from_element(&stmt_list.syntax().clone().into());
+ indent.0 += 1;
+ return Some(factory.whitespace(&format!("\n{indent}")));
+ }
+
+ ws_between(prev, new, factory)
+}
+
+fn ws_after(
+ position: &Position,
+ new: &SyntaxElement,
+ factory: &SyntaxFactory,
+) -> Option<SyntaxToken> {
+ let next = match &position.repr {
+ PositionRepr::FirstChild(parent) => parent.first_child_or_token()?,
+ PositionRepr::After(sibling) => sibling.next_sibling_or_token()?,
+ };
+ ws_between(new, &next, factory)
+}
+
+fn ws_between(
+ left: &SyntaxElement,
+ right: &SyntaxElement,
+ factory: &SyntaxFactory,
+) -> Option<SyntaxToken> {
+ if left.kind() == SyntaxKind::WHITESPACE || right.kind() == SyntaxKind::WHITESPACE {
+ return None;
+ }
+ if right.kind() == T![;] || right.kind() == T![,] {
+ return None;
+ }
+ if left.kind() == T![<] || right.kind() == T![>] {
+ return None;
+ }
+ if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
+ return None;
+ }
+ if right.kind() == SyntaxKind::GENERIC_ARG_LIST {
+ return None;
+ }
+ if right.kind() == SyntaxKind::USE {
+ let mut indent = IndentLevel::from_element(left);
+ if left.kind() == SyntaxKind::USE {
+ indent.0 = IndentLevel::from_element(right).0.max(indent.0);
+ }
+ return Some(factory.whitespace(&format!("\n{indent}")));
+ }
+ if left.kind() == SyntaxKind::ATTR {
+ let mut indent = IndentLevel::from_element(right);
+ if right.kind() == SyntaxKind::ATTR {
+ indent.0 = IndentLevel::from_element(left).0.max(indent.0);
+ }
+ return Some(factory.whitespace(&format!("\n{indent}")));
+ }
+ Some(factory.whitespace(" "))
+}
+
fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool {
node == ancestor || node.ancestors().any(|it| &it == ancestor)
}
@@ -420,10 +559,11 @@ mod tests {
.into(),
);
+ let (mut editor, root) = SyntaxEditor::with_ast_node(&root);
+
let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap();
let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap();
- let mut editor = SyntaxEditor::new(root.syntax().clone());
let make = SyntaxFactory::with_mappings();
let name = make::name("var_name");
@@ -478,9 +618,8 @@ mod tests {
None,
);
+ let (mut editor, root) = SyntaxEditor::with_ast_node(&root);
let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
-
- let mut editor = SyntaxEditor::new(root.syntax().clone());
let make = SyntaxFactory::without_mappings();
editor.insert(
@@ -530,11 +669,12 @@ mod tests {
),
);
+ let (mut editor, root) = SyntaxEditor::with_ast_node(&root);
+
let inner_block =
root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap();
let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap();
- let mut editor = SyntaxEditor::new(root.syntax().clone());
let make = SyntaxFactory::with_mappings();
let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
@@ -584,9 +724,9 @@ mod tests {
None,
);
- let inner_block = root.clone();
+ let (mut editor, root) = SyntaxEditor::with_ast_node(&root);
- let mut editor = SyntaxEditor::new(root.syntax().clone());
+ let inner_block = root;
let make = SyntaxFactory::with_mappings();
let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone())));
@@ -632,7 +772,7 @@ mod tests {
false,
);
- let mut editor = SyntaxEditor::new(parent_fn.syntax().clone());
+ let (mut editor, parent_fn) = SyntaxEditor::with_ast_node(&parent_fn);
if let Some(ret_ty) = parent_fn.ret_type() {
editor.delete(ret_ty.syntax().clone());
@@ -659,7 +799,8 @@ mod tests {
let arg_list =
make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]);
- let mut editor = SyntaxEditor::new(arg_list.syntax().clone());
+ let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list);
+
let target_expr = make::token(parser::SyntaxKind::UNDERSCORE);
for arg in arg_list.args() {
@@ -677,7 +818,8 @@ mod tests {
let arg_list =
make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]);
- let mut editor = SyntaxEditor::new(arg_list.syntax().clone());
+ let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list);
+
let target_expr = make::expr_literal("3").clone_for_update();
for arg in arg_list.args() {
@@ -695,7 +837,8 @@ mod tests {
let arg_list =
make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]);
- let mut editor = SyntaxEditor::new(arg_list.syntax().clone());
+ let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list);
+
let target_expr = make::ext::expr_unit().clone_for_update();
for arg in arg_list.args() {