Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/inline_local_variable.rs')
-rw-r--r--crates/ide-assists/src/handlers/inline_local_variable.rs118
1 files changed, 85 insertions, 33 deletions
diff --git a/crates/ide-assists/src/handlers/inline_local_variable.rs b/crates/ide-assists/src/handlers/inline_local_variable.rs
index cc7bea5152..36eed290dc 100644
--- a/crates/ide-assists/src/handlers/inline_local_variable.rs
+++ b/crates/ide-assists/src/handlers/inline_local_variable.rs
@@ -5,7 +5,7 @@ use ide_db::{
EditionedFileId, RootDatabase,
};
use syntax::{
- ast::{self, AstNode, AstToken, HasName},
+ ast::{self, syntax_factory::SyntaxFactory, AstNode, AstToken, HasName},
SyntaxElement, TextRange,
};
@@ -43,22 +43,6 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
}?;
let initializer_expr = let_stmt.initializer()?;
- let delete_range = delete_let.then(|| {
- if let Some(whitespace) = let_stmt
- .syntax()
- .next_sibling_or_token()
- .and_then(SyntaxElement::into_token)
- .and_then(ast::Whitespace::cast)
- {
- TextRange::new(
- let_stmt.syntax().text_range().start(),
- whitespace.syntax().text_range().end(),
- )
- } else {
- let_stmt.syntax().text_range()
- }
- });
-
let wrap_in_parens = references
.into_iter()
.filter_map(|FileReference { range, name, .. }| match name {
@@ -73,40 +57,60 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>)
}
let usage_node =
name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind()));
- let usage_parent_option = usage_node.and_then(|it| it.parent());
+ let usage_parent_option = usage_node.as_ref().and_then(|it| it.parent());
let usage_parent = match usage_parent_option {
Some(u) => u,
- None => return Some((range, name_ref, false)),
+ None => return Some((name_ref, false)),
};
- Some((range, name_ref, initializer_expr.needs_parens_in(&usage_parent)))
+ let should_wrap = initializer_expr
+ .needs_parens_in_place_of(&usage_parent, usage_node.as_ref().unwrap());
+ Some((name_ref, should_wrap))
})
.collect::<Option<Vec<_>>>()?;
- let init_str = initializer_expr.syntax().text().to_string();
- let init_in_paren = format!("({init_str})");
-
let target = match target {
- ast::NameOrNameRef::Name(it) => it.syntax().text_range(),
- ast::NameOrNameRef::NameRef(it) => it.syntax().text_range(),
+ ast::NameOrNameRef::Name(it) => it.syntax().clone(),
+ ast::NameOrNameRef::NameRef(it) => it.syntax().clone(),
};
acc.add(
AssistId("inline_local_variable", AssistKind::RefactorInline),
"Inline variable",
- target,
+ target.text_range(),
move |builder| {
- if let Some(range) = delete_range {
- builder.delete(range);
+ let mut editor = builder.make_editor(&target);
+ if delete_let {
+ editor.delete(let_stmt.syntax());
+ if let Some(whitespace) = let_stmt
+ .syntax()
+ .next_sibling_or_token()
+ .and_then(SyntaxElement::into_token)
+ .and_then(ast::Whitespace::cast)
+ {
+ editor.delete(whitespace.syntax());
+ }
}
- for (range, name, should_wrap) in wrap_in_parens {
- let replacement = if should_wrap { &init_in_paren } else { &init_str };
- if ast::RecordExprField::for_field_name(&name).is_some() {
+
+ let make = SyntaxFactory::new();
+
+ for (name, should_wrap) in wrap_in_parens {
+ let replacement = if should_wrap {
+ make.expr_paren(initializer_expr.clone()).into()
+ } else {
+ initializer_expr.clone()
+ };
+
+ if let Some(record_field) = ast::RecordExprField::for_field_name(&name) {
cov_mark::hit!(inline_field_shorthand);
- builder.insert(range.end(), format!(": {replacement}"));
+ let replacement = make.record_expr_field(name, Some(replacement));
+ editor.replace(record_field.syntax(), replacement.syntax());
} else {
- builder.replace(range, replacement.clone())
+ editor.replace(name.syntax(), replacement.syntax());
}
}
+
+ editor.add_mappings(make.finish_with_mappings());
+ builder.add_file_edits(ctx.file_id(), editor);
},
)
}
@@ -942,4 +946,52 @@ fn main() {
"#,
);
}
+
+ #[test]
+ fn test_wrap_in_parens() {
+ check_assist(
+ inline_local_variable,
+ r#"
+fn main() {
+ let $0a = 123 < 456;
+ let b = !a;
+}
+"#,
+ r#"
+fn main() {
+ let b = !(123 < 456);
+}
+"#,
+ );
+ check_assist(
+ inline_local_variable,
+ r#"
+trait Foo {
+ fn foo(&self);
+}
+
+impl Foo for bool {
+ fn foo(&self) {}
+}
+
+fn main() {
+ let $0a = 123 < 456;
+ let b = a.foo();
+}
+"#,
+ r#"
+trait Foo {
+ fn foo(&self);
+}
+
+impl Foo for bool {
+ fn foo(&self) {}
+}
+
+fn main() {
+ let b = (123 < 456).foo();
+}
+"#,
+ );
+ }
}