Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs111
1 files changed, 96 insertions, 15 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 124ef509fb..fa5bb39c54 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -92,11 +92,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
let insert_after = node_to_insert_after(&body, anchor)?;
+ let trait_name = ast::Trait::cast(insert_after.clone()).and_then(|trait_| trait_.name());
let semantics_scope = ctx.sema.scope(&insert_after)?;
let module = semantics_scope.module();
let edition = semantics_scope.krate().edition(ctx.db());
- let (container_info, contains_tail_expr) = body.analyze_container(&ctx.sema, edition)?;
+ let (container_info, contains_tail_expr) =
+ body.analyze_container(&ctx.sema, edition, trait_name)?;
let ret_ty = body.return_ty(ctx)?;
let control_flow = body.external_control_flow(ctx, &container_info)?;
@@ -181,6 +183,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
builder.add_tabstop_before(cap, name);
}
+ // FIXME: wrap non-adt types
let fn_def = match fun.self_param_adt(ctx) {
Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
fn_def.indent(1.into());
@@ -377,6 +380,7 @@ struct ControlFlow<'db> {
struct ContainerInfo<'db> {
is_const: bool,
parent_loop: Option<SyntaxNode>,
+ trait_name: Option<ast::Type>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type<'db>>,
generic_param_lists: Vec<ast::GenericParamList>,
@@ -838,6 +842,7 @@ impl FunctionBody {
&self,
sema: &Semantics<'db, RootDatabase>,
edition: Edition,
+ trait_name: Option<ast::Name>,
) -> Option<(ContainerInfo<'db>, bool)> {
let mut ancestors = self.parent()?.ancestors();
let infer_expr_opt = |expr| sema.type_of_expr(&expr?).map(TypeInfo::adjusted);
@@ -924,6 +929,9 @@ impl FunctionBody {
false
};
+ // FIXME: make trait arguments
+ let trait_name = trait_name.map(|name| make::ty_path(make::ext::ident_path(&name.text())));
+
let parent = self.parent()?;
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
@@ -934,6 +942,7 @@ impl FunctionBody {
ContainerInfo {
is_const,
parent_loop,
+ trait_name,
ret_type: ty,
generic_param_lists,
where_clauses,
@@ -1419,14 +1428,18 @@ fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) {
fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) -> SyntaxNode {
let ret_ty = fun.return_type(ctx);
- let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition)));
let name = fun.name.clone();
- let mut call_expr = if fun.self_param.is_some() {
+ let args = fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition));
+ let mut call_expr = if fun.make_this_param().is_some() {
+ let self_arg = make::expr_path(make::ext::ident_path("self"));
+ let func = make::expr_path(make::path_unqualified(make::path_segment(name)));
+ make::expr_call(func, make::arg_list(Some(self_arg).into_iter().chain(args))).into()
+ } else if fun.self_param.is_some() {
let self_arg = make::expr_path(make::ext::ident_path("self"));
- make::expr_method_call(self_arg, name, args).into()
+ make::expr_method_call(self_arg, name, make::arg_list(args)).into()
} else {
let func = make::expr_path(make::path_unqualified(make::path_segment(name)));
- make::expr_call(func, args).into()
+ make::expr_call(func, make::arg_list(args)).into()
};
let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
@@ -1729,9 +1742,28 @@ impl<'db> Function<'db> {
module: hir::Module,
edition: Edition,
) -> ast::ParamList {
- let self_param = self.self_param.clone();
+ let this_param = self.make_this_param().map(|f| f());
+ let self_param = self.self_param.clone().filter(|_| this_param.is_none());
let params = self.params.iter().map(|param| param.to_param(ctx, module, edition));
- make::param_list(self_param, params)
+ make::param_list(self_param, this_param.into_iter().chain(params))
+ }
+
+ fn make_this_param(&self) -> Option<impl FnOnce() -> ast::Param> {
+ if let Some(name) = self.mods.trait_name.clone()
+ && let Some(self_param) = &self.self_param
+ {
+ Some(|| {
+ let bounds = make::type_bound_list([make::type_bound(name)]);
+ let pat = make::path_pat(make::ext::ident_path("this"));
+ let mut ty = make::impl_trait_type(bounds.unwrap()).into();
+ if self_param.amp_token().is_some() {
+ ty = make::ty_ref(ty, self_param.mut_token().is_some());
+ }
+ make::param(pat, ty)
+ })
+ } else {
+ None
+ }
}
fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> {
@@ -1806,10 +1838,12 @@ fn make_body(
) -> ast::BlockExpr {
let ret_ty = fun.return_type(ctx);
let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
+ let to_this_param = fun.self_param.clone().filter(|_| fun.make_this_param().is_some());
let block = match &fun.body {
FunctionBody::Expr(expr) => {
- let expr = rewrite_body_segment(ctx, &fun.params, &handler, expr.syntax());
+ let expr =
+ rewrite_body_segment(ctx, to_this_param, &fun.params, &handler, expr.syntax());
let expr = ast::Expr::cast(expr).expect("Body segment should be an expr");
match expr {
ast::Expr::BlockExpr(block) => {
@@ -1847,7 +1881,7 @@ fn make_body(
.filter(|it| text_range.contains_range(it.text_range()))
.map(|it| match &it {
syntax::NodeOrToken::Node(n) => syntax::NodeOrToken::Node(
- rewrite_body_segment(ctx, &fun.params, &handler, n),
+ rewrite_body_segment(ctx, to_this_param.clone(), &fun.params, &handler, n),
),
_ => it,
})
@@ -1997,11 +2031,13 @@ fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) ->
fn rewrite_body_segment(
ctx: &AssistContext<'_>,
+ to_this_param: Option<ast::SelfParam>,
params: &[Param<'_>],
handler: &FlowHandler<'_>,
syntax: &SyntaxNode,
) -> SyntaxNode {
- let syntax = fix_param_usages(ctx, params, syntax);
+ let to_this_param = to_this_param.and_then(|it| ctx.sema.to_def(&it));
+ let syntax = fix_param_usages(ctx, to_this_param, params, syntax);
update_external_control_flow(handler, &syntax);
syntax
}
@@ -2009,30 +2045,46 @@ fn rewrite_body_segment(
/// change all usages to account for added `&`/`&mut` for some params
fn fix_param_usages(
ctx: &AssistContext<'_>,
+ to_this_param: Option<Local>,
params: &[Param<'_>],
syntax: &SyntaxNode,
) -> SyntaxNode {
let mut usages_for_param: Vec<(&Param<'_>, Vec<ast::Expr>)> = Vec::new();
+ let mut usages_for_self_param: Vec<ast::Expr> = Vec::new();
let tm = TreeMutator::new(syntax);
+ let reference_filter = |reference: &FileReference| {
+ syntax
+ .text_range()
+ .contains_range(reference.range)
+ .then_some(())
+ .and_then(|_| path_element_of_reference(syntax, reference))
+ .map(|expr| tm.make_mut(&expr))
+ };
+ if let Some(self_param) = to_this_param {
+ usages_for_self_param = LocalUsages::find_local_usages(ctx, self_param)
+ .iter()
+ .filter_map(reference_filter)
+ .collect();
+ }
for param in params {
if !param.kind().is_ref() {
continue;
}
let usages = LocalUsages::find_local_usages(ctx, param.var);
- let usages = usages
- .iter()
- .filter(|reference| syntax.text_range().contains_range(reference.range))
- .filter_map(|reference| path_element_of_reference(syntax, reference))
- .map(|expr| tm.make_mut(&expr));
+ let usages = usages.iter().filter_map(reference_filter);
usages_for_param.push((param, usages.unique().collect()));
}
let res = tm.make_syntax_mut(syntax);
+ for self_usage in usages_for_self_param {
+ let this_expr = make::expr_path(make::ext::ident_path("this")).clone_for_update();
+ ted::replace(self_usage.syntax(), this_expr.syntax());
+ }
for (param, usages) in usages_for_param {
for usage in usages {
match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) {
@@ -2940,6 +2992,35 @@ impl S {
}
#[test]
+ fn method_in_trait() {
+ check_assist(
+ extract_function,
+ r#"
+trait Foo {
+ fn f(&self) -> i32;
+
+ fn foo(&self) -> i32 {
+ $0self.f()+self.f()$0
+ }
+}
+"#,
+ r#"
+trait Foo {
+ fn f(&self) -> i32;
+
+ fn foo(&self) -> i32 {
+ fun_name(self)
+ }
+}
+
+fn $0fun_name(this: &impl Foo) -> i32 {
+ this.f()+this.f()
+}
+"#,
+ );
+ }
+
+ #[test]
fn variable_defined_inside_and_used_after_no_ret() {
check_assist(
extract_function,