Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/syntax/src/syntax_editor.rs')
| -rw-r--r-- | crates/syntax/src/syntax_editor.rs | 171 |
1 files changed, 157 insertions, 14 deletions
diff --git a/crates/syntax/src/syntax_editor.rs b/crates/syntax/src/syntax_editor.rs index e6937e4d0f..8e4dc75d22 100644 --- a/crates/syntax/src/syntax_editor.rs +++ b/crates/syntax/src/syntax_editor.rs @@ -14,7 +14,10 @@ use std::{ use rowan::TextRange; use rustc_hash::FxHashMap; -use crate::{SyntaxElement, SyntaxNode, SyntaxToken}; +use crate::{ + AstNode, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T, + ast::{self, edit::IndentLevel, syntax_factory::SyntaxFactory}, +}; mod edit_algo; mod edits; @@ -32,9 +35,37 @@ pub struct SyntaxEditor { } impl SyntaxEditor { - /// Creates a syntax editor to start editing from `root` - pub fn new(root: SyntaxNode) -> Self { - Self { root, changes: vec![], mappings: SyntaxMapping::default(), annotations: vec![] } + /// Creates a syntax editor from `root`. + /// + /// The returned `root` is guaranteed to be a detached, immutable node. + /// If the provided node is not a root (i.e., has a parent) or is already + /// mutable, it is cloned into a fresh subtree to satisfy syntax editor + /// invariants. + pub fn new(root: SyntaxNode) -> (Self, SyntaxNode) { + let mut root = root; + + if root.parent().is_some() || root.is_mutable() { + root = root.clone_subtree() + }; + + let editor = Self { + root: root.clone(), + changes: Vec::new(), + mappings: SyntaxMapping::default(), + annotations: Vec::new(), + }; + + (editor, root) + } + + /// Typed-node variant of [`SyntaxEditor::new`]. + pub fn with_ast_node<T>(root: &T) -> (Self, T) + where + T: AstNode, + { + let (editor, root) = Self::new(root.syntax().clone()); + + (editor, T::cast(root).unwrap()) } pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) { @@ -73,6 +104,34 @@ impl SyntaxEditor { self.changes.push(Change::InsertAll(position, elements)) } + pub fn insert_with_whitespace( + &mut self, + position: Position, + element: impl Element, + factory: &SyntaxFactory, + ) { + self.insert_all_with_whitespace(position, vec![element.syntax_element()], factory) + } + + pub fn insert_all_with_whitespace( + &mut self, + position: Position, + mut elements: Vec<SyntaxElement>, + factory: &SyntaxFactory, + ) { + if let Some(first) = elements.first() + && let Some(ws) = ws_before(&position, first, factory) + { + elements.insert(0, ws.into()); + } + if let Some(last) = elements.last() + && let Some(ws) = ws_after(&position, last, factory) + { + elements.push(ws.into()); + } + self.insert_all(position, elements) + } + pub fn delete(&mut self, element: impl Element) { let element = element.syntax_element(); debug_assert!(is_ancestor_or_self_of_element(&element, &self.root)); @@ -384,6 +443,86 @@ impl Element for SyntaxToken { } } +fn ws_before( + position: &Position, + new: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option<SyntaxToken> { + let prev = match &position.repr { + PositionRepr::FirstChild(_) => return None, + PositionRepr::After(it) => it, + }; + + if prev.kind() == T!['{'] + && new.kind() == SyntaxKind::USE + && let Some(item_list) = prev.parent().and_then(ast::ItemList::cast) + { + let mut indent = IndentLevel::from_element(&item_list.syntax().clone().into()); + indent.0 += 1; + return Some(factory.whitespace(&format!("\n{indent}"))); + } + + if prev.kind() == T!['{'] + && ast::Stmt::can_cast(new.kind()) + && let Some(stmt_list) = prev.parent().and_then(ast::StmtList::cast) + { + let mut indent = IndentLevel::from_element(&stmt_list.syntax().clone().into()); + indent.0 += 1; + return Some(factory.whitespace(&format!("\n{indent}"))); + } + + ws_between(prev, new, factory) +} + +fn ws_after( + position: &Position, + new: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option<SyntaxToken> { + let next = match &position.repr { + PositionRepr::FirstChild(parent) => parent.first_child_or_token()?, + PositionRepr::After(sibling) => sibling.next_sibling_or_token()?, + }; + ws_between(new, &next, factory) +} + +fn ws_between( + left: &SyntaxElement, + right: &SyntaxElement, + factory: &SyntaxFactory, +) -> Option<SyntaxToken> { + if left.kind() == SyntaxKind::WHITESPACE || right.kind() == SyntaxKind::WHITESPACE { + return None; + } + if right.kind() == T![;] || right.kind() == T![,] { + return None; + } + if left.kind() == T![<] || right.kind() == T![>] { + return None; + } + if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME { + return None; + } + if right.kind() == SyntaxKind::GENERIC_ARG_LIST { + return None; + } + if right.kind() == SyntaxKind::USE { + let mut indent = IndentLevel::from_element(left); + if left.kind() == SyntaxKind::USE { + indent.0 = IndentLevel::from_element(right).0.max(indent.0); + } + return Some(factory.whitespace(&format!("\n{indent}"))); + } + if left.kind() == SyntaxKind::ATTR { + let mut indent = IndentLevel::from_element(right); + if right.kind() == SyntaxKind::ATTR { + indent.0 = IndentLevel::from_element(left).0.max(indent.0); + } + return Some(factory.whitespace(&format!("\n{indent}"))); + } + Some(factory.whitespace(" ")) +} + fn is_ancestor_or_self(node: &SyntaxNode, ancestor: &SyntaxNode) -> bool { node == ancestor || node.ancestors().any(|it| &it == ancestor) } @@ -420,10 +559,11 @@ mod tests { .into(), ); + let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap(); let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap(); - let mut editor = SyntaxEditor::new(root.syntax().clone()); let make = SyntaxFactory::with_mappings(); let name = make::name("var_name"); @@ -478,9 +618,8 @@ mod tests { None, ); + let (mut editor, root) = SyntaxEditor::with_ast_node(&root); let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap(); - - let mut editor = SyntaxEditor::new(root.syntax().clone()); let make = SyntaxFactory::without_mappings(); editor.insert( @@ -530,11 +669,12 @@ mod tests { ), ); + let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let inner_block = root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap(); let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap(); - let mut editor = SyntaxEditor::new(root.syntax().clone()); let make = SyntaxFactory::with_mappings(); let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone()))); @@ -584,9 +724,9 @@ mod tests { None, ); - let inner_block = root.clone(); + let (mut editor, root) = SyntaxEditor::with_ast_node(&root); - let mut editor = SyntaxEditor::new(root.syntax().clone()); + let inner_block = root; let make = SyntaxFactory::with_mappings(); let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone()))); @@ -632,7 +772,7 @@ mod tests { false, ); - let mut editor = SyntaxEditor::new(parent_fn.syntax().clone()); + let (mut editor, parent_fn) = SyntaxEditor::with_ast_node(&parent_fn); if let Some(ret_ty) = parent_fn.ret_type() { editor.delete(ret_ty.syntax().clone()); @@ -659,7 +799,8 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let mut editor = SyntaxEditor::new(arg_list.syntax().clone()); + let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let target_expr = make::token(parser::SyntaxKind::UNDERSCORE); for arg in arg_list.args() { @@ -677,7 +818,8 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let mut editor = SyntaxEditor::new(arg_list.syntax().clone()); + let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let target_expr = make::expr_literal("3").clone_for_update(); for arg in arg_list.args() { @@ -695,7 +837,8 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let mut editor = SyntaxEditor::new(arg_list.syntax().clone()); + let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let target_expr = make::ext::expr_unit().clone_for_update(); for arg in arg_list.args() { |