Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_function.rs | 111 |
1 files changed, 96 insertions, 15 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 124ef509fb..fa5bb39c54 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -92,11 +92,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding }; let insert_after = node_to_insert_after(&body, anchor)?; + let trait_name = ast::Trait::cast(insert_after.clone()).and_then(|trait_| trait_.name()); let semantics_scope = ctx.sema.scope(&insert_after)?; let module = semantics_scope.module(); let edition = semantics_scope.krate().edition(ctx.db()); - let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema, edition)?; + let (container_info, contains_tail_expr) = + body.analyze_container(&ctx.sema, edition, trait_name)?; let ret_ty = body.return_ty(ctx)?; let control_flow = body.external_control_flow(ctx, &container_info)?; @@ -181,6 +183,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op builder.add_tabstop_before(cap, name); } + // FIXME: wrap non-adt types let fn_def = match fun.self_param_adt(ctx) { Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => { fn_def.indent(1.into()); @@ -377,6 +380,7 @@ struct ControlFlow<'db> { struct ContainerInfo<'db> { is_const: bool, parent_loop: Option<SyntaxNode>, + trait_name: Option<ast::Type>, /// The function's return type, const's type etc. ret_type: Option<hir::Type<'db>>, generic_param_lists: Vec<ast::GenericParamList>, @@ -838,6 +842,7 @@ impl FunctionBody { &self, sema: &Semantics<'db, RootDatabase>, edition: Edition, + trait_name: Option<ast::Name>, ) -> Option<(ContainerInfo<'db>, bool)> { let mut ancestors = self.parent()?.ancestors(); let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted); @@ -924,6 +929,9 @@ impl FunctionBody { false }; + // FIXME: make trait arguments + let trait_name = trait_name.map(|name| make::ty_path(make::ext::ident_path(&name.text()))); + let parent = self.parent()?; let parents = generic_parents(&parent); let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); @@ -934,6 +942,7 @@ impl FunctionBody { ContainerInfo { is_const, parent_loop, + trait_name, ret_type: ty, generic_param_lists, where_clauses, @@ -1419,14 +1428,18 @@ fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) { fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) -> SyntaxNode { let ret_ty = fun.return_type(ctx); - let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition))); let name = fun.name.clone(); - let mut call_expr = if fun.self_param.is_some() { + let args = fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition)); + let mut call_expr = if fun.make_this_param().is_some() { + let self_arg = make::expr_path(make::ext::ident_path("self")); + let func = make::expr_path(make::path_unqualified(make::path_segment(name))); + make::expr_call(func, make::arg_list(Some(self_arg).into_iter().chain(args))).into() + } else if fun.self_param.is_some() { let self_arg = make::expr_path(make::ext::ident_path("self")); - make::expr_method_call(self_arg, name, args).into() + make::expr_method_call(self_arg, name, make::arg_list(args)).into() } else { let func = make::expr_path(make::path_unqualified(make::path_segment(name))); - make::expr_call(func, args).into() + make::expr_call(func, make::arg_list(args)).into() }; let handler = FlowHandler::from_ret_ty(fun, &ret_ty); @@ -1729,9 +1742,28 @@ impl<'db> Function<'db> { module: hir::Module, edition: Edition, ) -> ast::ParamList { - let self_param = self.self_param.clone(); + let this_param = self.make_this_param().map(|f| f()); + let self_param = self.self_param.clone().filter(|_| this_param.is_none()); let params = self.params.iter().map(|param| param.to_param(ctx, module, edition)); - make::param_list(self_param, params) + make::param_list(self_param, this_param.into_iter().chain(params)) + } + + fn make_this_param(&self) -> Option<impl FnOnce() -> ast::Param> { + if let Some(name) = self.mods.trait_name.clone() + && let Some(self_param) = &self.self_param + { + Some(|| { + let bounds = make::type_bound_list([make::type_bound(name)]); + let pat = make::path_pat(make::ext::ident_path("this")); + let mut ty = make::impl_trait_type(bounds.unwrap()).into(); + if self_param.amp_token().is_some() { + ty = make::ty_ref(ty, self_param.mut_token().is_some()); + } + make::param(pat, ty) + }) + } else { + None + } } fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> { @@ -1806,10 +1838,12 @@ fn make_body( ) -> ast::BlockExpr { let ret_ty = fun.return_type(ctx); let handler = FlowHandler::from_ret_ty(fun, &ret_ty); + let to_this_param = fun.self_param.clone().filter(|_| fun.make_this_param().is_some()); let block = match &fun.body { FunctionBody::Expr(expr) => { - let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax()); + let expr = + rewrite_body_segment(ctx, to_this_param, &fun.params, &handler, expr.syntax()); let expr = ast::Expr::cast(expr).expect("Body segment should be an expr"); match expr { ast::Expr::BlockExpr(block) => { @@ -1847,7 +1881,7 @@ fn make_body( .filter(|it| text_range.contains_range(it.text_range())) .map(|it| match &it { syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node( - rewrite_body_segment(ctx, &fun.params, &handler, n), + rewrite_body_segment(ctx, to_this_param.clone(), &fun.params, &handler, n), ), _ => it, }) @@ -1997,11 +2031,13 @@ fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> fn rewrite_body_segment( ctx: &AssistContext<'_>, + to_this_param: Option<ast::SelfParam>, params: &[Param<'_>], handler: &FlowHandler<'_>, syntax: &SyntaxNode, ) -> SyntaxNode { - let syntax = fix_param_usages(ctx, params, syntax); + let to_this_param = to_this_param.and_then(|it| ctx.sema.to_def(&it)); + let syntax = fix_param_usages(ctx, to_this_param, params, syntax); update_external_control_flow(handler, &syntax); syntax } @@ -2009,30 +2045,46 @@ fn rewrite_body_segment( /// change all usages to account for added `&`/`&mut` for some params fn fix_param_usages( ctx: &AssistContext<'_>, + to_this_param: Option<Local>, params: &[Param<'_>], syntax: &SyntaxNode, ) -> SyntaxNode { let mut usages_for_param: Vec<(&Param<'_>, Vec<ast::Expr>)> = Vec::new(); + let mut usages_for_self_param: Vec<ast::Expr> = Vec::new(); let tm = TreeMutator::new(syntax); + let reference_filter = |reference: &FileReference| { + syntax + .text_range() + .contains_range(reference.range) + .then_some(()) + .and_then(|_| path_element_of_reference(syntax, reference)) + .map(|expr| tm.make_mut(&expr)) + }; + if let Some(self_param) = to_this_param { + usages_for_self_param = LocalUsages::find_local_usages(ctx, self_param) + .iter() + .filter_map(reference_filter) + .collect(); + } for param in params { if !param.kind().is_ref() { continue; } let usages = LocalUsages::find_local_usages(ctx, param.var); - let usages = usages - .iter() - .filter(|reference| syntax.text_range().contains_range(reference.range)) - .filter_map(|reference| path_element_of_reference(syntax, reference)) - .map(|expr| tm.make_mut(&expr)); + let usages = usages.iter().filter_map(reference_filter); usages_for_param.push((param, usages.unique().collect())); } let res = tm.make_syntax_mut(syntax); + for self_usage in usages_for_self_param { + let this_expr = make::expr_path(make::ext::ident_path("this")).clone_for_update(); + ted::replace(self_usage.syntax(), this_expr.syntax()); + } for (param, usages) in usages_for_param { for usage in usages { match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { @@ -2940,6 +2992,35 @@ impl S { } #[test] + fn method_in_trait() { + check_assist( + extract_function, + r#" +trait Foo { + fn f(&self) -> i32; + + fn foo(&self) -> i32 { + $0self.f()+self.f()$0 + } +} +"#, + r#" +trait Foo { + fn f(&self) -> i32; + + fn foo(&self) -> i32 { + fun_name(self) + } +} + +fn $0fun_name(this: &impl Foo) -> i32 { + this.f()+this.f() +} +"#, + ); + } + + #[test] fn variable_defined_inside_and_used_after_no_ret() { check_assist( extract_function, |