Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs270
1 files changed, 266 insertions, 4 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 0605883584..c1e2f19ab1 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -109,8 +109,6 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let params =
body.extracted_function_params(ctx, &container_info, locals_used.iter().copied());
- let extracted_from_trait_impl = body.extracted_from_trait_impl();
-
let name = make_function_name(&semantics_scope);
let fun = Function {
@@ -129,8 +127,11 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
builder.replace(target_range, make_call(ctx, &fun, old_indent));
+ let has_impl_wrapper =
+ insert_after.ancestors().any(|a| a.kind() == SyntaxKind::IMPL && a != insert_after);
+
let fn_def = match fun.self_param_adt(ctx) {
- Some(adt) if extracted_from_trait_impl => {
+ Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
}
@@ -272,7 +273,7 @@ enum FunType {
}
/// Where to put extracted function definition
-#[derive(Debug)]
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
enum Anchor {
/// Extract free function and put right after current top-level function
Freestanding,
@@ -1245,6 +1246,14 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
while let Some(next_ancestor) = ancestors.next() {
match next_ancestor.kind() {
SyntaxKind::SOURCE_FILE => break,
+ SyntaxKind::IMPL => {
+ if body.extracted_from_trait_impl() && matches!(anchor, Anchor::Method) {
+ let impl_node = find_non_trait_impl(&next_ancestor);
+ if let target_node @ Some(_) = impl_node.as_ref().and_then(last_impl_member) {
+ return target_node;
+ }
+ }
+ }
SyntaxKind::ITEM_LIST if !matches!(anchor, Anchor::Freestanding) => continue,
SyntaxKind::ITEM_LIST => {
if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) {
@@ -1265,6 +1274,29 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
last_ancestor
}
+fn find_non_trait_impl(trait_impl: &SyntaxNode) -> Option<ast::Impl> {
+ let as_impl = ast::Impl::cast(trait_impl.clone())?;
+ let impl_type = Some(impl_type_name(&as_impl)?);
+
+ let sibblings = trait_impl.parent()?.children();
+ sibblings
+ .filter_map(ast::Impl::cast)
+ .find(|s| impl_type_name(s) == impl_type && !is_trait_impl(s))
+}
+
+fn last_impl_member(impl_node: &ast::Impl) -> Option<SyntaxNode> {
+ let last_child = impl_node.assoc_item_list()?.assoc_items().last()?;
+ Some(last_child.syntax().clone())
+}
+
+fn is_trait_impl(node: &ast::Impl) -> bool {
+ node.trait_().is_some()
+}
+
+fn impl_type_name(impl_node: &ast::Impl) -> Option<String> {
+ Some(impl_node.self_ty()?.to_string())
+}
+
fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String {
let ret_ty = fun.return_type(ctx);
@@ -5052,6 +5084,236 @@ impl Struct {
}
#[test]
+ fn extract_method_from_trait_with_existing_non_empty_impl_block() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_function_from_trait_with_existing_non_empty_impl_block() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ let three_squared = $03 * 3$0;
+ self.0 + three_squared
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl Struct {
+ fn foo() {}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ let three_squared = fun_name();
+ self.0 + three_squared
+ }
+}
+
+fn $0fun_name() -> i32 {
+ 3 * 3
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_method_from_trait_with_multiple_existing_impl_blocks() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+ fn foo(){}
+}
+
+impl Struct {
+ fn foo(){}
+}
+
+impl StructAfter {
+ fn foo(){}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+struct StructBefore(i32);
+struct StructAfter(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+
+impl StructBefore {
+ fn foo(){}
+}
+
+impl Struct {
+ fn foo(){}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl StructAfter {
+ fn foo(){}
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn extract_method_from_trait_with_multiple_existing_trait_impl_blocks() {
+ check_assist(
+ extract_function,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+trait TraitBefore {
+ fn before(&self) -> i32;
+}
+trait TraitAfter {
+ fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+ fn before(&self) -> i32 {
+ 42
+ }
+}
+
+impl Struct {
+ fn foo(){}
+}
+
+impl TraitAfter for Struct {
+ fn after(&self) -> i32 {
+ 42
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ $0self.0 + 2$0
+ }
+}
+"#,
+ r#"
+struct Struct(i32);
+trait Trait {
+ fn bar(&self) -> i32;
+}
+trait TraitBefore {
+ fn before(&self) -> i32;
+}
+trait TraitAfter {
+ fn after(&self) -> i32;
+}
+
+impl TraitBefore for Struct {
+ fn before(&self) -> i32 {
+ 42
+ }
+}
+
+impl Struct {
+ fn foo(){}
+
+ fn $0fun_name(&self) -> i32 {
+ self.0 + 2
+ }
+}
+
+impl TraitAfter for Struct {
+ fn after(&self) -> i32 {
+ 42
+ }
+}
+
+impl Trait for Struct {
+ fn bar(&self) -> i32 {
+ self.fun_name()
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn closure_arguments() {
check_assist(
extract_function,