Unnamed repository; edit this file 'description' to name the repository.
Fix overlap edit on record to tuple assist uses self
Example --- ```rust struct $0Foo { field1: u32 } impl Foo { fn clone(&self) -> Self { Self { field1: self.field1 } } } ``` **Before this PR** Panic **After this PR** ```rust struct Foo(u32); impl Foo { fn clone(&self) -> Self { Self(self.0) } } ```
A4-Tacks 2 months ago
parent 9d2bd63 · commit 84e8972
-rw-r--r--crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs210
-rw-r--r--crates/syntax/src/algo.rs16
2 files changed, 199 insertions, 27 deletions
diff --git a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
index 706cef4120..8dfd4598df 100644
--- a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
+++ b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs
@@ -3,12 +3,13 @@ use std::ops::RangeInclusive;
use either::Either;
use ide_db::{defs::Definition, search::FileReference};
use syntax::{
- SyntaxElement, SyntaxKind, SyntaxNode, T, TextRange,
+ NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, T, TextRange,
+ algo::next_non_trivia_token,
ast::{
self, AstNode, HasAttrs, HasGenericParams, HasVisibility, syntax_factory::SyntaxFactory,
},
match_ast,
- syntax_editor::{Position, SyntaxEditor},
+ syntax_editor::{Element, Position, SyntaxEditor},
};
use crate::{AssistContext, AssistId, Assists, assist_context::SourceChangeBuilder};
@@ -194,39 +195,32 @@ fn process_struct_name_reference(
// struct we want to edit.
return None;
}
- let make = SyntaxFactory::without_mappings();
// FIXME: Processing RecordPat and RecordExpr for unordered fields, and insert RestPat
let parent = full_path.syntax().parent()?;
match_ast! {
match parent {
ast::RecordPat(record_struct_pat) => {
- // When we failed to get the original range for the whole struct expression node,
+ // When we failed to get the original range for the whole struct pattern node,
// we can't provide any reasonable edit. Leave it untouched.
- let file_range = ctx.sema.original_range_opt(record_struct_pat.syntax())?;
- let new = make.tuple_struct_pat(
- record_struct_pat.path()?,
- record_struct_pat
- .record_pat_field_list()?
- .fields()
- .filter_map(|pat| pat.pat())
- .chain(record_struct_pat.record_pat_field_list()?
- .rest_pat()
- .map(Into::into))
+ record_to_tuple_struct_like(
+ ctx,
+ source,
+ edit,
+ record_struct_pat.record_pat_field_list()?,
+ |it| it.fields().filter_map(|it| it.name_ref()),
);
- edit.replace_all(cover_range(source, file_range.range), vec![new.syntax().clone().into()]);
},
ast::RecordExpr(record_expr) => {
- // When we failed to get the original range for the whole struct pattern node,
+ // When we failed to get the original range for the whole struct expression node,
// we can't provide any reasonable edit. Leave it untouched.
- let file_range = ctx.sema.original_range_opt(record_expr.syntax())?;
- let path = record_expr.path()?;
- let args = record_expr
- .record_expr_field_list()?
- .fields()
- .filter_map(|f| f.expr());
- let new = make.expr_call(make.expr_path(path), make.arg_list(args));
- edit.replace_all(cover_range(source, file_range.range), vec![new.syntax().clone().into()]);
+ record_to_tuple_struct_like(
+ ctx,
+ source,
+ edit,
+ record_expr.record_expr_field_list()?,
+ |it| it.fields().filter_map(|it| it.name_ref()),
+ );
},
_ => {}
}
@@ -235,6 +229,61 @@ fn process_struct_name_reference(
Some(())
}
+fn record_to_tuple_struct_like<T, I>(
+ ctx: &AssistContext<'_>,
+ source: &ast::SourceFile,
+ edit: &mut SyntaxEditor,
+ field_list: T,
+ fields: impl FnOnce(&T) -> I,
+) -> Option<()>
+where
+ T: AstNode,
+ I: IntoIterator<Item = ast::NameRef>,
+{
+ let make = SyntaxFactory::without_mappings();
+ let orig = ctx.sema.original_range_opt(field_list.syntax())?;
+ let list_range = cover_range(source, orig.range);
+
+ let l_curly = match list_range.start() {
+ NodeOrToken::Node(node) => node.first_token()?,
+ NodeOrToken::Token(t) => t.clone(),
+ };
+ let r_curly = match list_range.end() {
+ NodeOrToken::Node(node) => node.last_token()?,
+ NodeOrToken::Token(t) => t.clone(),
+ };
+
+ if l_curly.kind() == T!['{'] {
+ delete_whitespace(edit, l_curly.prev_token());
+ delete_whitespace(edit, l_curly.next_token());
+ edit.replace(l_curly, make.token(T!['(']));
+ }
+ if r_curly.kind() == T!['}'] {
+ delete_whitespace(edit, r_curly.prev_token());
+ edit.replace(r_curly, make.token(T![')']));
+ }
+
+ for name_ref in fields(&field_list) {
+ let Some(orig) = ctx.sema.original_range_opt(name_ref.syntax()) else { continue };
+ let name_range = cover_range(source, orig.range);
+
+ if let Some(colon) = next_non_trivia_token(name_range.end().clone())
+ && colon.kind() == T![:]
+ {
+ edit.delete(&colon);
+ edit.delete_all(name_range);
+
+ if let Some(next) = next_non_trivia_token(colon.clone())
+ && next.kind() != T!['}']
+ {
+ // Avoid overlapping delete whitespace on `{ field: }`
+ delete_whitespace(edit, colon.next_token());
+ }
+ }
+ }
+ Some(())
+}
+
fn edit_field_references(
ctx: &AssistContext<'_>,
builder: &mut SourceChangeBuilder,
@@ -271,8 +320,8 @@ fn edit_field_references(
fn cover_range(source: &ast::SourceFile, range: TextRange) -> RangeInclusive<SyntaxElement> {
let node = match source.syntax().covering_element(range) {
- syntax::NodeOrToken::Node(node) => node,
- syntax::NodeOrToken::Token(t) => t.parent().unwrap(),
+ NodeOrToken::Node(node) => node,
+ NodeOrToken::Token(t) => t.parent().unwrap(),
};
let mut iter = node.children_with_tokens().filter(|it| range.contains_range(it.text_range()));
let first = iter.next().unwrap_or(node.into());
@@ -280,6 +329,15 @@ fn cover_range(source: &ast::SourceFile, range: TextRange) -> RangeInclusive<Syn
first..=last
}
+fn delete_whitespace(edit: &mut SyntaxEditor, whitespace: Option<impl Element>) {
+ let Some(whitespace) = whitespace else { return };
+ let NodeOrToken::Token(token) = whitespace.syntax_element() else { return };
+
+ if token.kind() == SyntaxKind::WHITESPACE && !token.text().contains('\n') {
+ edit.delete(token);
+ }
+}
+
fn remove_trailing_comma(w: ast::WhereClause) -> SyntaxNode {
let w = w.syntax().clone_subtree();
let mut editor = SyntaxEditor::new(w.clone());
@@ -713,6 +771,102 @@ where
}
#[test]
+ fn convert_constructor_expr_uses_self() {
+ // regression test for #21595
+ check_assist(
+ convert_named_struct_to_tuple_struct,
+ r#"
+struct $0Foo { field1: u32 }
+impl Foo {
+ fn clone(&self) -> Self {
+ Self { field1: self.field1 }
+ }
+}"#,
+ r#"
+struct Foo(u32);
+impl Foo {
+ fn clone(&self) -> Self {
+ Self(self.0)
+ }
+}"#,
+ );
+
+ check_assist(
+ convert_named_struct_to_tuple_struct,
+ r#"
+macro_rules! id {
+ ($($t:tt)*) => { $($t)* }
+}
+struct $0Foo { field1: u32 }
+impl Foo {
+ fn clone(&self) -> Self {
+ id!(Self { field1: self.field1 })
+ }
+}"#,
+ r#"
+macro_rules! id {
+ ($($t:tt)*) => { $($t)* }
+}
+struct Foo(u32);
+impl Foo {
+ fn clone(&self) -> Self {
+ id!(Self(self.0))
+ }
+}"#,
+ );
+ }
+
+ #[test]
+ fn convert_pat_uses_self() {
+ // regression test for #21595
+ check_assist(
+ convert_named_struct_to_tuple_struct,
+ r#"
+enum Foo {
+ $0Value { field: &'static Foo },
+ Nil,
+}
+fn foo(foo: &Foo) {
+ if let Foo::Value { field: Foo::Value { field } } = foo {}
+}"#,
+ r#"
+enum Foo {
+ Value(&'static Foo),
+ Nil,
+}
+fn foo(foo: &Foo) {
+ if let Foo::Value(Foo::Value(field)) = foo {}
+}"#,
+ );
+
+ check_assist(
+ convert_named_struct_to_tuple_struct,
+ r#"
+macro_rules! id {
+ ($($t:tt)*) => { $($t)* }
+}
+enum Foo {
+ $0Value { field: &'static Foo },
+ Nil,
+}
+fn foo(foo: &Foo) {
+ if let id!(Foo::Value { field: Foo::Value { field } }) = foo {}
+}"#,
+ r#"
+macro_rules! id {
+ ($($t:tt)*) => { $($t)* }
+}
+enum Foo {
+ Value(&'static Foo),
+ Nil,
+}
+fn foo(foo: &Foo) {
+ if let id!(Foo::Value(Foo::Value(field))) = foo {}
+}"#,
+ );
+ }
+
+ #[test]
fn not_applicable_other_than_record_variant() {
check_assist_not_applicable(
convert_named_struct_to_tuple_struct,
@@ -1077,7 +1231,9 @@ struct Struct(i32);
fn test() {
id! {
- let s = Struct(42);
+ let s = Struct(
+ 42,
+ );
let Struct(value) = s;
let Struct(inner) = s;
}
diff --git a/crates/syntax/src/algo.rs b/crates/syntax/src/algo.rs
index 3ab9c90262..c679921b3f 100644
--- a/crates/syntax/src/algo.rs
+++ b/crates/syntax/src/algo.rs
@@ -132,3 +132,19 @@ pub fn previous_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxTo
}
None
}
+
+pub fn next_non_trivia_token(e: impl Into<SyntaxElement>) -> Option<SyntaxToken> {
+ let mut token = match e.into() {
+ SyntaxElement::Node(n) => n.last_token()?,
+ SyntaxElement::Token(t) => t,
+ }
+ .next_token();
+ while let Some(inner) = token {
+ if !inner.kind().is_trivia() {
+ return Some(inner);
+ } else {
+ token = inner.next_token();
+ }
+ }
+ None
+}