Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/syntax/src/ast/edit_in_place.rs')
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs177
1 files changed, 172 insertions, 5 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index a85e1d1d9d..37d8212042 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -3,18 +3,17 @@
use std::iter::{empty, successors};
use parser::{SyntaxKind, T};
-use rowan::SyntaxElement;
use crate::{
algo::{self, neighbor},
ast::{self, edit::IndentLevel, make, HasGenericParams},
ted::{self, Position},
- AstNode, AstToken, Direction,
+ AstNode, AstToken, Direction, SyntaxElement,
SyntaxKind::{ATTR, COMMENT, WHITESPACE},
SyntaxNode, SyntaxToken,
};
-use super::HasName;
+use super::{HasArgList, HasName};
pub trait GenericParamsOwnerEdit: ast::HasGenericParams {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
@@ -362,6 +361,24 @@ impl ast::PathSegment {
}
}
+impl ast::MethodCallExpr {
+ pub fn get_or_create_generic_arg_list(&self) -> ast::GenericArgList {
+ if self.generic_arg_list().is_none() {
+ let generic_arg_list = make::turbofish_generic_arg_list(empty()).clone_for_update();
+
+ if let Some(arg_list) = self.arg_list() {
+ ted::insert_raw(
+ ted::Position::before(arg_list.syntax()),
+ generic_arg_list.syntax(),
+ );
+ } else {
+ ted::append_child(self.syntax(), generic_arg_list.syntax());
+ }
+ }
+ self.generic_arg_list().unwrap()
+ }
+}
+
impl Removable for ast::UseTree {
fn remove(&self) {
for dir in [Direction::Next, Direction::Prev] {
@@ -559,7 +576,7 @@ impl ast::AssocItemList {
None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
},
};
- let elements: Vec<SyntaxElement<_>> = vec![
+ let elements: Vec<SyntaxElement> = vec![
make::tokens::whitespace(&format!("{whitespace}{indent}")).into(),
item.syntax().clone().into(),
];
@@ -629,6 +646,50 @@ impl ast::MatchArmList {
}
}
+impl ast::LetStmt {
+ pub fn set_ty(&self, ty: Option<ast::Type>) {
+ match ty {
+ None => {
+ if let Some(colon_token) = self.colon_token() {
+ ted::remove(colon_token);
+ }
+
+ if let Some(existing_ty) = self.ty() {
+ if let Some(sibling) = existing_ty.syntax().prev_sibling_or_token() {
+ if sibling.kind() == SyntaxKind::WHITESPACE {
+ ted::remove(sibling);
+ }
+ }
+
+ ted::remove(existing_ty.syntax());
+ }
+
+ // Remove any trailing ws
+ if let Some(last) = self.syntax().last_token().filter(|it| it.kind() == WHITESPACE)
+ {
+ last.detach();
+ }
+ }
+ Some(new_ty) => {
+ if self.colon_token().is_none() {
+ ted::insert_raw(
+ Position::after(
+ self.pat().expect("let stmt should have a pattern").syntax(),
+ ),
+ make::token(T![:]),
+ );
+ }
+
+ if let Some(old_ty) = self.ty() {
+ ted::replace(old_ty.syntax(), new_ty.syntax());
+ } else {
+ ted::insert(Position::after(self.colon_token().unwrap()), new_ty.syntax());
+ }
+ }
+ }
+ }
+}
+
impl ast::RecordExprFieldList {
pub fn add_field(&self, field: ast::RecordExprField) {
let is_multiline = self.syntax().text().contains_char('\n');
@@ -753,7 +814,7 @@ impl ast::VariantList {
None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
},
};
- let elements: Vec<SyntaxElement<_>> = vec![
+ let elements: Vec<SyntaxElement> = vec![
make::tokens::whitespace(&format!("{}{indent}", "\n")).into(),
variant.syntax().clone().into(),
ast::make::token(T![,]).into(),
@@ -788,6 +849,53 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
Some(())
}
+impl ast::IdentPat {
+ pub fn set_pat(&self, pat: Option<ast::Pat>) {
+ match pat {
+ None => {
+ if let Some(at_token) = self.at_token() {
+ // Remove `@ Pat`
+ let start = at_token.clone().into();
+ let end = self
+ .pat()
+ .map(|it| it.syntax().clone().into())
+ .unwrap_or_else(|| at_token.into());
+
+ ted::remove_all(start..=end);
+
+ // Remove any trailing ws
+ if let Some(last) =
+ self.syntax().last_token().filter(|it| it.kind() == WHITESPACE)
+ {
+ last.detach();
+ }
+ }
+ }
+ Some(pat) => {
+ if let Some(old_pat) = self.pat() {
+ // Replace existing pattern
+ ted::replace(old_pat.syntax(), pat.syntax())
+ } else if let Some(at_token) = self.at_token() {
+ // Have an `@` token but not a pattern yet
+ ted::insert(ted::Position::after(at_token), pat.syntax());
+ } else {
+ // Don't have an `@`, should have a name
+ let name = self.name().unwrap();
+
+ ted::insert_all(
+ ted::Position::after(name.syntax()),
+ vec![
+ make::token(T![@]).into(),
+ make::tokens::single_space().into(),
+ pat.syntax().clone().into(),
+ ],
+ )
+ }
+ }
+ }
+ }
+}
+
pub trait HasVisibilityEdit: ast::HasVisibility {
fn set_visibility(&self, visbility: ast::Visibility) {
match self.visibility() {
@@ -890,6 +998,65 @@ mod tests {
}
#[test]
+ fn test_ident_pat_set_pat() {
+ #[track_caller]
+ fn check(before: &str, expected: &str, pat: Option<ast::Pat>) {
+ let pat = pat.map(|it| it.clone_for_update());
+
+ let ident_pat = ast_mut_from_text::<ast::IdentPat>(&format!("fn f() {{ {before} }}"));
+ ident_pat.set_pat(pat);
+
+ let after = ast_mut_from_text::<ast::IdentPat>(&format!("fn f() {{ {expected} }}"));
+ assert_eq!(ident_pat.to_string(), after.to_string());
+ }
+
+ // replacing
+ check("let a @ _;", "let a @ ();", Some(make::tuple_pat([]).into()));
+
+ // note: no trailing semicolon is added for the below tests since it
+ // seems to be picked up by the ident pat during error recovery?
+
+ // adding
+ check("let a ", "let a @ ()", Some(make::tuple_pat([]).into()));
+ check("let a @ ", "let a @ ()", Some(make::tuple_pat([]).into()));
+
+ // removing
+ check("let a @ ()", "let a", None);
+ check("let a @ ", "let a", None);
+ }
+
+ #[test]
+ fn test_let_stmt_set_ty() {
+ #[track_caller]
+ fn check(before: &str, expected: &str, ty: Option<ast::Type>) {
+ let ty = ty.map(|it| it.clone_for_update());
+
+ let let_stmt = ast_mut_from_text::<ast::LetStmt>(&format!("fn f() {{ {before} }}"));
+ let_stmt.set_ty(ty);
+
+ let after = ast_mut_from_text::<ast::LetStmt>(&format!("fn f() {{ {expected} }}"));
+ assert_eq!(let_stmt.to_string(), after.to_string(), "{let_stmt:#?}\n!=\n{after:#?}");
+ }
+
+ // adding
+ check("let a;", "let a: ();", Some(make::ty_tuple([])));
+ // no semicolon due to it being eaten during error recovery
+ check("let a:", "let a: ()", Some(make::ty_tuple([])));
+
+ // replacing
+ check("let a: u8;", "let a: ();", Some(make::ty_tuple([])));
+ check("let a: u8 = 3;", "let a: () = 3;", Some(make::ty_tuple([])));
+ check("let a: = 3;", "let a: () = 3;", Some(make::ty_tuple([])));
+
+ // removing
+ check("let a: u8;", "let a;", None);
+ check("let a:;", "let a;", None);
+
+ check("let a: u8 = 3;", "let a = 3;", None);
+ check("let a: = 3;", "let a = 3;", None);
+ }
+
+ #[test]
fn add_variant_to_empty_enum() {
let variant = make::variant(make::name("Bar"), None).clone_for_update();