Unnamed repository; edit this file 'description' to name the repository.
Feat: extracted method from trait impl is placed in existing impl
Previously, when triggering a method extraction from within a trait impl block, then this would always create a new impl block for the struct, even if there already is one. Now, it'll put the extracted method in the matching existing block if it exists.
Tiddo Langerak 2022-08-10
parent d186986 · commit 7c56896
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs272
1 files changed, 268 insertions, 4 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 52a55ead3a..40d0327ef7 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -109,8 +109,6 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let params =
body.extracted_function_params(ctx, &container_info, locals_used.iter().copied());
- let extracted_from_trait_impl = body.extracted_from_trait_impl();
-
let name = make_function_name(&semantics_scope);
let fun = Function {
@@ -129,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
builder.replace(target_range, make_call(ctx, &fun, old_indent));
+ let has_impl_wrapper = insert_after
+ .ancestors()
+ .find(|a| a.kind() == SyntaxKind::IMPL && a != &insert_after)
+ .is_some();
+
let fn_def = match fun.self_param_adt(ctx) {
- Some(adt) if extracted_from_trait_impl => {
+ Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
}
@@ -271,7 +274,7 @@ enum FunType {
}
/// Where to put extracted function definition
-#[derive(Debug)]
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
enum Anchor {
/// Extract free function and put right after current top-level function
Freestanding,
@@ -1244,6 +1247,15 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
while let Some(next_ancestor) = ancestors.next() {
match next_ancestor.kind() {
SyntaxKind::SOURCE_FILE => break,
+ SyntaxKind::IMPL => {
+ if body.extracted_from_trait_impl() && matches!(anchor, Anchor::Method) {
+ let impl_node = find_non_trait_impl(&next_ancestor);
+ let target_node = impl_node.as_ref().and_then(last_impl_member);
+ if target_node.is_some() {
+ return target_node;
+ }
+ }
+ }
SyntaxKind::ITEM_LIST if !matches!(anchor, Anchor::Freestanding) => continue,
SyntaxKind::ITEM_LIST => {
if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) {
@@ -1264,6 +1276,28 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
last_ancestor
}
+fn find_non_trait_impl(trait_impl: &SyntaxNode) -> Option<SyntaxNode> {
+ let impl_type = Some(impl_type_name(trait_impl)?);
+
+ let mut sibblings = trait_impl.parent()?.children();
+ sibblings.find(|s| impl_type_name(s) == impl_type && !is_trait_impl(s))
+}
+
+fn last_impl_member(impl_node: &SyntaxNode) -> Option<SyntaxNode> {
+ impl_node.children().find(|c| c.kind() == SyntaxKind::ASSOC_ITEM_LIST)?.last_child()
+}
+
+fn is_trait_impl(node: &SyntaxNode) -> bool {
+ match ast::Impl::cast(node.clone()) {
+ Some(c) => c.trait_().is_some(),
+ None => false,
+ }
+}
+
+fn impl_type_name(impl_node: &SyntaxNode) -> Option<String> {
+ Some(ast::Impl::cast(impl_node.clone())?.self_ty()?.to_string())
+}
+
fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String {
let ret_ty = fun.return_type(ctx);
@@ -5059,6 +5093,236 @@ impl Struct {
}
#[test]
+ fn extract_method_from_trait_with_existing_non_empty_impl_block() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_function_from_trait_with_existing_non_empty_impl_block() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ let three_squared = $03 * 3$0;
+ self.0 + three_squared
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ let three_squared = fun_name();
+ self.0 + three_squared
+ }
+}
+
+fn $0fun_name() -> i32 {
+ 3 * 3
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_method_from_trait_with_multiple_existing_impl_blocks() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+ fn foo(){}
+}
+
+impl Struct {
+ fn foo(){}
+}
+
+impl StructAfter {
+ fn foo(){}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+ fn foo(){}
+}
+
+impl Struct {
+ fn foo(){}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl StructAfter {
+ fn foo(){}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_method_from_trait_with_multiple_existing_trait_impl_blocks() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+trait TraitBefore {
+ fn before(&self) -> i32;
+}
+trait TraitAfter {
+ fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+ fn before(&self) -> i32 {
+ 42
+ }
+}
+
+impl Struct {
+ fn foo(){}
+}
+
+impl TraitAfter for Struct {
+ fn after(&self) -> i32 {
+ 42
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+trait TraitBefore {
+ fn before(&self) -> i32;
+}
+trait TraitAfter {
+ fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+ fn before(&self) -> i32 {
+ 42
+ }
+}
+
+impl Struct {
+ fn foo(){}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl TraitAfter for Struct {
+ fn after(&self) -> i32 {
+ 42
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn closure_arguments() {
check_assist(
extract_function,