Unnamed repository; edit this file 'description' to name the repository.
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_function.rs | 617 |
1 files changed, 333 insertions, 284 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 4219e6845f..99dd2ea237 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -1,6 +1,5 @@ use std::{iter, ops::RangeInclusive}; -use ast::make; use either::Either; use hir::{ HasSource, HirDisplay, InFile, Local, LocalSource, ModuleDef, PathResolution, Semantics, @@ -11,10 +10,9 @@ use ide_db::{ assists::GroupLabel, defs::Definition, famous_defs::FamousDefs, - helpers::mod_path_to_ast, - imports::insert_use::{ImportScope, insert_use}, + helpers::mod_path_to_ast_with_factory, + imports::insert_use::{ImportScope, insert_use_with_editor}, search::{FileReference, ReferenceCategory, SearchScope}, - source_change::SourceChangeBuilder, syntax_helpers::node_ext::{ for_each_tail_expr, preorder_expr, walk_pat, walk_patterns_in_expr, }, @@ -25,15 +23,17 @@ use syntax::{ SyntaxKind::{self, COMMENT}, SyntaxNode, SyntaxToken, T, TextRange, TextSize, TokenAtOffset, WalkEvent, ast::{ - self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, edit::IndentLevel, - edit_in_place::Indent, + self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, + edit::{AstNodeEdit, IndentLevel}, + syntax_factory::SyntaxFactory, }, - match_ast, ted, + match_ast, + syntax_editor::{Position, SyntaxEditor}, }; use crate::{ AssistId, - assist_context::{AssistContext, Assists, TreeMutator}, + assist_context::{AssistContext, Assists}, utils::generate_impl, }; @@ -97,8 +97,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let module = semantics_scope.module(); let edition = semantics_scope.krate().edition(ctx.db()); + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); let (container_info, contains_tail_expr) = - body.analyze_container(&ctx.sema, edition, trait_name)?; + body.analyze_container(editor.make(), &ctx.sema, edition, trait_name)?; let ret_ty = body.return_ty(ctx)?; let control_flow = body.external_control_flow(ctx, &container_info)?; @@ -114,6 +115,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op "Extract into function", target_range, move |builder| { + let make = editor.make(); let outliving_locals: Vec<_> = ret_values.collect(); if stdx::never!(!outliving_locals.is_empty() && !ret_ty.is_unit()) { // We should not have variables that outlive body if we have expression block @@ -122,7 +124,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let params = body.extracted_function_params(ctx, &container_info, locals_used); - let name = make_function_name(&semantics_scope, &body); + let name = make_function_name(make, &semantics_scope, &body); let fun = Function { name, @@ -139,67 +141,47 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let new_indent = IndentLevel::from_node(&insert_after); let old_indent = fun.body.indent_level(); - let insert_after = builder.make_syntax_mut(insert_after); + let call_expr = make_call(make, ctx, &fun, old_indent); - let call_expr = make_call(ctx, &fun, old_indent); - - // Map the element range to replace into the mutable version let elements = match &fun.body { FunctionBody::Expr(expr) => { - // expr itself becomes the replacement target - let expr = &builder.make_mut(expr.clone()); let node = SyntaxElement::Node(expr.syntax().clone()); - node.clone()..=node } - FunctionBody::Span { parent, elements, .. } => { - // Map the element range into the mutable versions - let parent = builder.make_mut(parent.clone()); - - let start = parent - .syntax() - .children_with_tokens() - .nth(elements.start().index()) - .expect("should be able to find mutable start element"); - - let end = parent - .syntax() - .children_with_tokens() - .nth(elements.end().index()) - .expect("should be able to find mutable end element"); - - start..=end - } + FunctionBody::Span { elements, .. } => elements.clone(), }; let has_impl_wrapper = insert_after.ancestors().any(|a| a.kind() == SyntaxKind::IMPL && a != insert_after); - let fn_def = format_function(ctx, module, &fun, old_indent).clone_for_update(); - - if let Some(cap) = ctx.config.snippet_cap - && let Some(name) = fn_def.name() - { - builder.add_tabstop_before(cap, name); - } + let fn_def = format_function(ctx, module, &fun, old_indent, make); // FIXME: wrap non-adt types let fn_def = match fun.self_param_adt(ctx) { Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => { - fn_def.indent(1.into()); + let fn_def = fn_def.indent_with_mapping(1.into(), make); - let impl_ = generate_impl(&adt); - impl_.indent(new_indent); + let impl_ = generate_impl(make, &adt).indent(new_indent); impl_.get_or_create_assoc_item_list().add_item(fn_def.into()); impl_.syntax().clone() } - _ => { - fn_def.indent(new_indent); - - fn_def.syntax().clone() - } + _ => fn_def.indent_with_mapping(new_indent, make).syntax().clone(), }; + if let Some(cap) = ctx.config.snippet_cap { + let extracted_fn = fn_def.descendants().find_map(ast::Fn::cast); + if let Some(fn_) = extracted_fn { + if let Some(ws) = fn_ + .fn_token() + .and_then(|tok| tok.next_token()) + .filter(|tok| tok.kind() == SyntaxKind::WHITESPACE) + { + editor.add_annotation(ws, builder.make_tabstop_after(cap)); + } else if let Some(name) = fn_.name() { + editor.add_annotation(name.syntax(), builder.make_tabstop_before(cap)); + } + } + } // There are external control flows if fun @@ -207,7 +189,6 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op .kind .is_some_and(|kind| matches!(kind, FlowKind::Break(_, _) | FlowKind::Continue(_))) { - let scope = builder.make_import_scope_mut(scope); let control_flow_enum = FamousDefs(&ctx.sema, module.krate(ctx.db())).core_ops_ControlFlow(); @@ -222,29 +203,45 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ); if let Some(mod_path) = mod_path { - insert_use( + insert_use_with_editor( &scope, - mod_path_to_ast(&mod_path, edition), + mod_path_to_ast_with_factory(make, &mod_path, edition), &ctx.config.insert_use, + &editor, ); } } } // Replace the call site with the call to the new function - fixup_call_site(builder, &fun.body); - ted::replace_all(elements, vec![call_expr.into()]); + let needs_match_arm_comma = fun + .body + .parent() + .and_then(ast::MatchArm::cast) + .is_some_and(|arm| arm.comma_token().is_none()); + match &fun.body { + FunctionBody::Expr(expr) => { + let mut replacement = vec![call_expr.clone().into()]; + if needs_match_arm_comma { + replacement.push(make.token(T![,]).into()); + } + editor.replace_with_many(expr.syntax(), replacement); + } + FunctionBody::Span { .. } => editor.replace_all(elements, vec![call_expr.into()]), + } // Insert the newly extracted function (or impl) - ted::insert_all_raw( - ted::Position::after(insert_after), - vec![make::tokens::whitespace(&format!("\n\n{new_indent}")).into(), fn_def.into()], + editor.insert_all( + Position::after(insert_after), + vec![make.whitespace(&format!("\n\n{new_indent}")).into(), fn_def.into()], ); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } fn make_function_name( + make: &SyntaxFactory, semantics_scope: &hir::SemanticsScope<'_>, body: &FunctionBody, ) -> ast::NameRef { @@ -267,7 +264,7 @@ fn make_function_name( counter += 1; name = format!("{default_name}{counter}") } - make::name_ref(&name) + make.name_ref(&name) } /// Try to guess what user wants to extract @@ -510,38 +507,44 @@ impl<'db> Param<'db> { } } - fn to_arg(&self, ctx: &AssistContext<'db>, edition: Edition) -> ast::Expr { - let var = path_expr_from_local(ctx, self.var, edition); + fn to_arg( + &self, + make: &SyntaxFactory, + ctx: &AssistContext<'db>, + edition: Edition, + ) -> ast::Expr { + let var = path_expr_from_local(make, ctx, self.var, edition); match self.kind() { ParamKind::Value | ParamKind::MutValue => var, - ParamKind::SharedRef => make::expr_ref(var, false), - ParamKind::MutRef => make::expr_ref(var, true), + ParamKind::SharedRef => make.expr_ref(var, false), + ParamKind::MutRef => make.expr_ref(var, true), } } fn to_param( &self, + make: &SyntaxFactory, ctx: &AssistContext<'_>, module: hir::Module, edition: Edition, ) -> ast::Param { let var = self.var.name(ctx.db()).display(ctx.db(), edition).to_string(); - let var_name = make::name(&var); + let var_name = make.name(&var); let pat = match self.kind() { - ParamKind::MutValue => make::ident_pat(false, true, var_name), + ParamKind::MutValue => make.ident_pat(false, true, var_name), ParamKind::Value | ParamKind::SharedRef | ParamKind::MutRef => { - make::ext::simple_ident_pat(var_name) + make.simple_ident_pat(var_name) } }; - let ty = make_ty(&self.ty, ctx, module); + let ty = make_ty(make, &self.ty, ctx, module); let ty = match self.kind() { ParamKind::Value | ParamKind::MutValue => ty, - ParamKind::SharedRef => make::ty_ref(ty, false), - ParamKind::MutRef => make::ty_ref(ty, true), + ParamKind::SharedRef => make.ty_ref(ty, false), + ParamKind::MutRef => make.ty_ref(ty, true), }; - make::param(pat.into(), ty) + make.param(pat.into(), ty) } } @@ -569,17 +572,17 @@ impl<'db> TryKind<'db> { } impl<'db> FlowKind<'db> { - fn make_result_handler(&self, expr: Option<ast::Expr>) -> ast::Expr { + fn make_result_handler(&self, make: &SyntaxFactory, expr: Option<ast::Expr>) -> ast::Expr { match self { - FlowKind::Return(_) => make::expr_return(expr), - FlowKind::Break(label, _) => make::expr_break(label.clone(), expr), + FlowKind::Return(_) => make.expr_return(expr).into(), + FlowKind::Break(label, _) => make.expr_break(label.clone(), expr).into(), FlowKind::Try { .. } => { stdx::never!("cannot have result handler with try"); - expr.unwrap_or_else(|| make::expr_return(None)) + expr.unwrap_or_else(|| make.expr_return(None).into()) } FlowKind::Continue(label) => { stdx::always!(expr.is_none(), "continue with value is not possible"); - make::expr_continue(label.clone()) + make.expr_continue(label.clone()).into() } } } @@ -840,6 +843,7 @@ impl FunctionBody { fn analyze_container<'db>( &self, + make: &SyntaxFactory, sema: &Semantics<'db, RootDatabase>, edition: Edition, trait_name: Option<ast::Name>, @@ -930,7 +934,7 @@ impl FunctionBody { }; // FIXME: make trait arguments - let trait_name = trait_name.map(|name| make::ty_path(make::ext::ident_path(&name.text()))); + let trait_name = trait_name.map(|name| make.ty_path(make.ident_path(&name.text())).into()); let parent = self.parent()?; let parents = generic_parents(&parent); @@ -1191,7 +1195,7 @@ fn reference_is_exclusive( } // we take `&mut` reference to variable: `&mut v` - let path = match path_element_of_reference(node, reference) { + let path = match path_element_of_reference(node, reference.range) { Some(path) => path, None => return false, }; @@ -1282,10 +1286,10 @@ impl HasTokenAtOffset for FunctionBody { /// `node` must cover `reference`, that is `node.text_range().contains_range(reference.range)` fn path_element_of_reference( node: &dyn HasTokenAtOffset, - reference: &FileReference, + reference_range: TextRange, ) -> Option<ast::Expr> { - let token = node.token_at_offset(reference.range.start()).right_biased().or_else(|| { - stdx::never!(false, "cannot find token at variable usage: {:?}", reference); + let token = node.token_at_offset(reference_range.start()).right_biased().or_else(|| { + stdx::never!(false, "cannot find token at variable usage: {:?}", reference_range); None })?; let path = token.parent_ancestors().find_map(ast::Expr::cast).or_else(|| { @@ -1413,58 +1417,50 @@ fn impl_type_name(impl_node: &ast::Impl) -> Option<String> { Some(impl_node.self_ty()?.to_string()) } -/// Fixes up the call site before the target expressions are replaced with the call expression -fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) { - let parent_match_arm = body.parent().and_then(ast::MatchArm::cast); - - if let Some(parent_match_arm) = parent_match_arm - && parent_match_arm.comma_token().is_none() - { - let parent_match_arm = builder.make_mut(parent_match_arm); - ted::append_child_raw(parent_match_arm.syntax(), make::token(T![,])); - } -} - -fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) -> SyntaxNode { +fn make_call( + make: &SyntaxFactory, + ctx: &AssistContext<'_>, + fun: &Function<'_>, + indent: IndentLevel, +) -> SyntaxNode { let ret_ty = fun.return_type(ctx); let name = fun.name.clone(); - let args = fun.params.iter().map(|param| param.to_arg(ctx, fun.mods.edition)); + let args = fun.params.iter().map(|param| param.to_arg(make, ctx, fun.mods.edition)); let mut call_expr = if fun.make_this_param().is_some() { - let self_arg = make::expr_path(make::ext::ident_path("self")); - let func = make::expr_path(make::path_unqualified(make::path_segment(name))); - make::expr_call(func, make::arg_list(Some(self_arg).into_iter().chain(args))).into() + let self_arg = make.expr_path(make.ident_path("self")); + let func = make.expr_path(make.path_unqualified(make.path_segment(name))); + make.expr_call(func, make.arg_list(Some(self_arg).into_iter().chain(args))).into() } else if fun.self_param.is_some() { - let self_arg = make::expr_path(make::ext::ident_path("self")); - make::expr_method_call(self_arg, name, make::arg_list(args)).into() + let self_arg = make.expr_path(make.ident_path("self")); + make.expr_method_call(self_arg, name, make.arg_list(args)).into() } else { - let func = make::expr_path(make::path_unqualified(make::path_segment(name))); - make::expr_call(func, make::arg_list(args)).into() + let func = make.expr_path(make.path_unqualified(make.path_segment(name))); + make.expr_call(func, make.arg_list(args)).into() }; let handler = FlowHandler::from_ret_ty(fun, &ret_ty); if fun.control_flow.is_async { - call_expr = make::expr_await(call_expr); + call_expr = make.expr_await(call_expr).into(); } - let expr = handler.make_call_expr(call_expr).clone_for_update(); - expr.indent(indent); + let expr = handler.make_call_expr(make, call_expr).indent_with_mapping(indent, make); let outliving_bindings = match fun.outliving_locals.as_slice() { [] => None, [var] => { let name = var.local.name(ctx.db()); - let name = make::name(&name.display(ctx.db(), fun.mods.edition).to_string()); - Some(ast::Pat::IdentPat(make::ident_pat(false, var.mut_usage_outside_body, name))) + let name = make.name(&name.display(ctx.db(), fun.mods.edition).to_string()); + Some(ast::Pat::IdentPat(make.ident_pat(false, var.mut_usage_outside_body, name))) } vars => { let binding_pats = vars.iter().map(|var| { let name = var.local.name(ctx.db()); - let name = make::name(&name.display(ctx.db(), fun.mods.edition).to_string()); - make::ident_pat(false, var.mut_usage_outside_body, name).into() + let name = make.name(&name.display(ctx.db(), fun.mods.edition).to_string()); + make.ident_pat(false, var.mut_usage_outside_body, name).into() }); - Some(ast::Pat::TuplePat(make::tuple_pat(binding_pats))) + Some(ast::Pat::TuplePat(make.tuple_pat(binding_pats))) } }; @@ -1472,7 +1468,7 @@ fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) - if let Some(bindings) = outliving_bindings { // with bindings that outlive it - make::let_stmt(bindings, None, Some(expr)).syntax().clone_for_update() + make.let_stmt(bindings, None, Some(expr)).syntax().clone() } else if parent_match_arm.as_ref().is_some() { // as a tail expr for a match arm expr.syntax().clone() @@ -1481,7 +1477,7 @@ fn make_call(ctx: &AssistContext<'_>, fun: &Function<'_>, indent: IndentLevel) - && (!fun.outliving_locals.is_empty() || !expr.is_block_like()) { // as an expr stmt - make::expr_stmt(expr).syntax().clone_for_update() + make.expr_stmt(expr).syntax().clone() } else { // as a tail expr, or a block expr.syntax().clone() @@ -1527,82 +1523,87 @@ impl<'db> FlowHandler<'db> { } } - fn make_call_expr(&self, call_expr: ast::Expr) -> ast::Expr { + fn make_call_expr(&self, make: &SyntaxFactory, call_expr: ast::Expr) -> ast::Expr { match self { FlowHandler::None => call_expr, - FlowHandler::Try { kind: _ } => make::expr_try(call_expr), + FlowHandler::Try { kind: _ } => make.expr_try(call_expr), FlowHandler::If { action } => { - let action = action.make_result_handler(None); - let stmt = make::expr_stmt(action); - let block = make::block_expr(iter::once(stmt.into()), None); - let controlflow_break_path = make::path_from_text("ControlFlow::Break"); - let condition = make::expr_let( - make::tuple_struct_pat( + let action = action.make_result_handler(make, None); + let stmt = make.expr_stmt(action); + let block = make.block_expr(iter::once(stmt.into()), None); + let controlflow_break_path = make.path_from_text("ControlFlow::Break"); + let condition = make.expr_let( + make.tuple_struct_pat( controlflow_break_path, - iter::once(make::wildcard_pat().into()), + iter::once(make.wildcard_pat().into()), ) .into(), call_expr, ); - make::expr_if(condition.into(), block, None).into() + make.expr_if(condition.into(), block, None).into() } FlowHandler::IfOption { action } => { - let path = make::ext::ident_path("Some"); - let value_pat = make::ext::simple_ident_pat(make::name("value")); - let pattern = make::tuple_struct_pat(path, iter::once(value_pat.into())); - let cond = make::expr_let(pattern.into(), call_expr); - let value = make::expr_path(make::ext::ident_path("value")); - let action_expr = action.make_result_handler(Some(value)); - let action_stmt = make::expr_stmt(action_expr); - let then = make::block_expr(iter::once(action_stmt.into()), None); - make::expr_if(cond.into(), then, None).into() + let path = make.ident_path("Some"); + let value_pat = make.simple_ident_pat(make.name("value")); + let pattern = make.tuple_struct_pat(path, iter::once(value_pat.into())); + let cond = make.expr_let(pattern.into(), call_expr); + let value = make.expr_path(make.ident_path("value")); + let action_expr = action.make_result_handler(make, Some(value)); + let action_stmt = make.expr_stmt(action_expr); + let then = make.block_expr(iter::once(action_stmt.into()), None); + make.expr_if(cond.into(), then, None).into() } FlowHandler::MatchOption { none } => { let some_name = "value"; let some_arm = { - let path = make::ext::ident_path("Some"); - let value_pat = make::ext::simple_ident_pat(make::name(some_name)); - let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); - let value = make::expr_path(make::ext::ident_path(some_name)); - make::match_arm(pat.into(), None, value) + let path = make.ident_path("Some"); + let value_pat = make.simple_ident_pat(make.name(some_name)); + let pat = make.tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make.expr_path(make.ident_path(some_name)); + make.match_arm(pat.into(), None, value) }; let none_arm = { - let path = make::ext::ident_path("None"); - let pat = make::path_pat(path); - make::match_arm(pat, None, none.make_result_handler(None)) + let path = make.ident_path("None"); + let pat = make.path_pat(path); + make.match_arm(pat, None, none.make_result_handler(make, None)) }; - let arms = make::match_arm_list(vec![some_arm, none_arm]); - make::expr_match(call_expr, arms).into() + let arms = make.match_arm_list(vec![some_arm, none_arm]); + make.expr_match(call_expr, arms).into() } FlowHandler::MatchResult { err } => { let ok_name = "value"; let err_name = "value"; let ok_arm = { - let path = make::ext::ident_path("Ok"); - let value_pat = make::ext::simple_ident_pat(make::name(ok_name)); - let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); - let value = make::expr_path(make::ext::ident_path(ok_name)); - make::match_arm(pat.into(), None, value) + let path = make.ident_path("Ok"); + let value_pat = make.simple_ident_pat(make.name(ok_name)); + let pat = make.tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make.expr_path(make.ident_path(ok_name)); + make.match_arm(pat.into(), None, value) }; let err_arm = { - let path = make::ext::ident_path("Err"); - let value_pat = make::ext::simple_ident_pat(make::name(err_name)); - let pat = make::tuple_struct_pat(path, iter::once(value_pat.into())); - let value = make::expr_path(make::ext::ident_path(err_name)); - make::match_arm(pat.into(), None, err.make_result_handler(Some(value))) + let path = make.ident_path("Err"); + let value_pat = make.simple_ident_pat(make.name(err_name)); + let pat = make.tuple_struct_pat(path, iter::once(value_pat.into())); + let value = make.expr_path(make.ident_path(err_name)); + make.match_arm(pat.into(), None, err.make_result_handler(make, Some(value))) }; - let arms = make::match_arm_list(vec![ok_arm, err_arm]); - make::expr_match(call_expr, arms).into() + let arms = make.match_arm_list(vec![ok_arm, err_arm]); + make.expr_match(call_expr, arms).into() } } } } -fn path_expr_from_local(ctx: &AssistContext<'_>, var: Local, edition: Edition) -> ast::Expr { +fn path_expr_from_local( + make: &SyntaxFactory, + ctx: &AssistContext<'_>, + var: Local, + edition: Edition, +) -> ast::Expr { let name = var.name(ctx.db()).display(ctx.db(), edition).to_string(); - make::expr_path(make::ext::ident_path(&name)) + make.expr_path(make.ident_path(&name)) } fn format_function( @@ -1610,17 +1611,18 @@ fn format_function( module: hir::Module, fun: &Function<'_>, old_indent: IndentLevel, + make: &SyntaxFactory, ) -> ast::Fn { - let fun_name = make::name(&fun.name.text()); - let params = fun.make_param_list(ctx, module, fun.mods.edition); - let ret_ty = fun.make_ret_ty(ctx, module); - let body = make_body(ctx, old_indent, fun); - let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun); + let fun_name = make.name(&fun.name.text()); + let params = fun.make_param_list(make, ctx, module, fun.mods.edition); + let ret_ty = fun.make_ret_ty(make, ctx, module); + let body = make_body(make, ctx, old_indent, fun); + let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, make, fun); - make::fn_( + make.fn_( fun.mods.attrs.clone(), None, - fun_name, + fun_name.clone(), generic_params, where_clause, params, @@ -1635,18 +1637,20 @@ fn format_function( fn make_generic_params_and_where_clause( ctx: &AssistContext<'_>, + make: &SyntaxFactory, fun: &Function<'_>, ) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) { let used_type_params = fun.type_params(ctx); - let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params); - let where_clause = make_where_clause(ctx, fun, &used_type_params); + let generic_param_list = make_generic_param_list(ctx, make, fun, &used_type_params); + let where_clause = make_where_clause(ctx, make, fun, &used_type_params); (generic_param_list, where_clause) } fn make_generic_param_list( ctx: &AssistContext<'_>, + make: &SyntaxFactory, fun: &Function<'_>, used_type_params: &[TypeParam], ) -> Option<ast::GenericParamList> { @@ -1662,7 +1666,7 @@ fn make_generic_param_list( .peekable(); if generic_params.peek().is_some() { - Some(make::generic_param_list(generic_params)) + Some(make.generic_param_list(generic_params)) } else { None } @@ -1684,6 +1688,7 @@ fn param_is_required( fn make_where_clause( ctx: &AssistContext<'_>, + make: &SyntaxFactory, fun: &Function<'_>, used_type_params: &[TypeParam], ) -> Option<ast::WhereClause> { @@ -1698,7 +1703,7 @@ fn make_where_clause( }) .peekable(); - if predicates.peek().is_some() { Some(make::where_clause(predicates)) } else { None } + if predicates.peek().is_some() { Some(make.where_clause(predicates)) } else { None } } fn pred_is_required( @@ -1738,35 +1743,41 @@ impl<'db> Function<'db> { fn make_param_list( &self, + make: &SyntaxFactory, ctx: &AssistContext<'_>, module: hir::Module, edition: Edition, ) -> ast::ParamList { - let this_param = self.make_this_param().map(|f| f()); + let this_param = self.make_this_param().map(|f| f(make)); let self_param = self.self_param.clone().filter(|_| this_param.is_none()); - let params = self.params.iter().map(|param| param.to_param(ctx, module, edition)); - make::param_list(self_param, this_param.into_iter().chain(params)) + let params = self.params.iter().map(|param| param.to_param(make, ctx, module, edition)); + make.param_list(self_param, this_param.into_iter().chain(params)) } - fn make_this_param(&self) -> Option<impl FnOnce() -> ast::Param> { + fn make_this_param(&self) -> Option<impl FnOnce(&SyntaxFactory) -> ast::Param> { if let Some(name) = self.mods.trait_name.clone() && let Some(self_param) = &self.self_param { - Some(|| { - let bounds = make::type_bound_list([make::type_bound(name)]); - let pat = make::path_pat(make::ext::ident_path("this")); - let mut ty = make::impl_trait_type(bounds.unwrap()).into(); + Some(move |make: &SyntaxFactory| { + let bounds = make.type_bound_list([make.type_bound(name)]); + let pat = make.path_pat(make.ident_path("this")); + let mut ty = make.impl_trait_type(bounds.unwrap()).into(); if self_param.amp_token().is_some() { - ty = make::ty_ref(ty, self_param.mut_token().is_some()); + ty = make.ty_ref(ty, self_param.mut_token().is_some()); } - make::param(pat, ty) + make.param(pat, ty) }) } else { None } } - fn make_ret_ty(&self, ctx: &AssistContext<'_>, module: hir::Module) -> Option<ast::RetType> { + fn make_ret_ty( + &self, + make: &SyntaxFactory, + ctx: &AssistContext<'_>, + module: hir::Module, + ) -> Option<ast::RetType> { let fun_ty = self.return_type(ctx); let handler = FlowHandler::from_ret_ty(self, &fun_ty); let ret_ty = match &handler { @@ -1774,57 +1785,64 @@ impl<'db> Function<'db> { if matches!(fun_ty, FunType::Unit) { return None; } - fun_ty.make_ty(ctx, module) + fun_ty.make_ty(make, ctx, module) } FlowHandler::Try { kind: TryKind::Option } => { - make::ext::ty_option(fun_ty.make_ty(ctx, module)) + make.ty_option(fun_ty.make_ty(make, ctx, module)).into() } FlowHandler::Try { kind: TryKind::Result { ty: parent_ret_ty } } => { let handler_ty = parent_ret_ty .type_arguments() .nth(1) - .map(|ty| make_ty(&ty, ctx, module)) - .unwrap_or_else(make::ty_placeholder); - make::ext::ty_result(fun_ty.make_ty(ctx, module), handler_ty) + .map(|ty| make_ty(make, &ty, ctx, module)) + .unwrap_or_else(|| make.ty_placeholder()); + make.ty_result(fun_ty.make_ty(make, ctx, module), handler_ty).into() } - FlowHandler::If { .. } => make::ty("ControlFlow<()>"), + FlowHandler::If { .. } => make.ty("ControlFlow<()>"), FlowHandler::IfOption { action } => { let handler_ty = action .expr_ty(ctx) - .map(|ty| make_ty(&ty, ctx, module)) - .unwrap_or_else(make::ty_placeholder); - make::ext::ty_option(handler_ty) + .map(|ty| make_ty(make, &ty, ctx, module)) + .unwrap_or_else(|| make.ty_placeholder()); + make.ty_option(handler_ty).into() + } + FlowHandler::MatchOption { .. } => { + make.ty_option(fun_ty.make_ty(make, ctx, module)).into() } - FlowHandler::MatchOption { .. } => make::ext::ty_option(fun_ty.make_ty(ctx, module)), FlowHandler::MatchResult { err } => { let handler_ty = err .expr_ty(ctx) - .map(|ty| make_ty(&ty, ctx, module)) - .unwrap_or_else(make::ty_placeholder); - make::ext::ty_result(fun_ty.make_ty(ctx, module), handler_ty) + .map(|ty| make_ty(make, &ty, ctx, module)) + .unwrap_or_else(|| make.ty_placeholder()); + make.ty_result(fun_ty.make_ty(make, ctx, module), handler_ty).into() } }; - Some(make::ret_type(ret_ty)) + Some(make.ret_type(ret_ty)) } } impl<'db> FunType<'db> { - fn make_ty(&self, ctx: &AssistContext<'db>, module: hir::Module) -> ast::Type { + fn make_ty( + &self, + make: &SyntaxFactory, + ctx: &AssistContext<'db>, + module: hir::Module, + ) -> ast::Type { match self { - FunType::Unit => make::ty_unit(), - FunType::Single(ty) => make_ty(ty, ctx, module), + FunType::Unit => make.ty_unit(), + FunType::Single(ty) => make_ty(make, ty, ctx, module), FunType::Tuple(types) => match types.as_slice() { [] => { stdx::never!("tuple type with 0 elements"); - make::ty_unit() + make.ty_unit() } [ty] => { stdx::never!("tuple type with 1 element"); - make_ty(ty, ctx, module) + make_ty(make, ty, ctx, module) } types => { - let types = types.iter().map(|ty| make_ty(ty, ctx, module)); - make::ty_tuple(types) + let types = types.iter().map(|ty| make_ty(make, ty, ctx, module)); + make.ty_tuple(types) } }, } @@ -1832,6 +1850,7 @@ impl<'db> FunType<'db> { } fn make_body( + make: &SyntaxFactory, ctx: &AssistContext<'_>, old_indent: IndentLevel, fun: &Function<'_>, @@ -1848,7 +1867,7 @@ fn make_body( match expr { ast::Expr::BlockExpr(block) => { // If the extracted expression is itself a block, there is no need to wrap it inside another block. - block.dedent(old_indent); + let block = block.dedent(old_indent); let elements = block.stmt_list().map_or_else( || Either::Left(iter::empty()), |stmt_list| { @@ -1865,12 +1884,12 @@ fn make_body( Either::Right(elements) }, ); - make::hacky_block_expr(elements, block.tail_expr()) + make.hacky_block_expr(elements, block.tail_expr()) } _ => { - expr.reindent_to(1.into()); + let expr = expr.dedent(old_indent).indent(1.into()); - make::block_expr(Vec::new(), Some(expr)) + make.block_expr(Vec::new(), Some(expr)) } } } @@ -1901,13 +1920,14 @@ fn make_body( None => match fun.outliving_locals.as_slice() { [] => {} [var] => { - tail_expr = Some(path_expr_from_local(ctx, var.local, fun.mods.edition)); + tail_expr = + Some(path_expr_from_local(make, ctx, var.local, fun.mods.edition)); } vars => { - let exprs = vars - .iter() - .map(|var| path_expr_from_local(ctx, var.local, fun.mods.edition)); - let expr = make::expr_tuple(exprs); + let exprs = vars.iter().map(|var| { + path_expr_from_local(make, ctx, var.local, fun.mods.edition) + }); + let expr = make.expr_tuple(exprs); tail_expr = Some(expr.into()); } }, @@ -1919,80 +1939,90 @@ fn make_body( .map(|node_or_token| match &node_or_token { syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) { Some(stmt) => { - stmt.reindent_to(body_indent); - let ast_node = stmt.syntax().clone_subtree(); - syntax::NodeOrToken::Node(ast_node) + let stmt = stmt.dedent(old_indent).indent(body_indent); + syntax::NodeOrToken::Node(stmt.syntax().clone()) } _ => node_or_token, }, _ => node_or_token, }) .collect::<Vec<SyntaxElement>>(); - if let Some(tail_expr) = &mut tail_expr { - tail_expr.reindent_to(body_indent); - } + tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent)); - make::hacky_block_expr(elements, tail_expr) + make.hacky_block_expr(elements, tail_expr) } }; match &handler { FlowHandler::None => block, FlowHandler::Try { kind } => { - let block = with_default_tail_expr(block, make::ext::expr_unit()); - map_tail_expr(block, |tail_expr| { + let block = with_default_tail_expr(make, block, make.expr_unit()); + map_tail_expr(make, block, |tail_expr| { let constructor = match kind { TryKind::Option => "Some", TryKind::Result { .. } => "Ok", }; - let func = make::expr_path(make::ext::ident_path(constructor)); - let args = make::arg_list(iter::once(tail_expr)); - make::expr_call(func, args).into() + let func = make.expr_path(make.ident_path(constructor)); + let args = make.arg_list(iter::once(tail_expr)); + make.expr_call(func, args).into() }) } FlowHandler::If { .. } => { - let controlflow_continue = make::expr_call( - make::expr_path(make::path_from_text("ControlFlow::Continue")), - make::arg_list([make::ext::expr_unit()]), - ) - .into(); - with_tail_expr(block, controlflow_continue) + let controlflow_continue = make + .expr_call( + make.expr_path(make.path_from_text("ControlFlow::Continue")), + make.arg_list([make.expr_unit()]), + ) + .into(); + with_tail_expr(make, block, controlflow_continue) } FlowHandler::IfOption { .. } => { - let none = make::expr_path(make::ext::ident_path("None")); - with_tail_expr(block, none) + let none = make.expr_path(make.ident_path("None")); + with_tail_expr(make, block, none) } - FlowHandler::MatchOption { .. } => map_tail_expr(block, |tail_expr| { - let some = make::expr_path(make::ext::ident_path("Some")); - let args = make::arg_list(iter::once(tail_expr)); - make::expr_call(some, args).into() + FlowHandler::MatchOption { .. } => map_tail_expr(make, block, |tail_expr| { + let some = make.expr_path(make.ident_path("Some")); + let args = make.arg_list(iter::once(tail_expr)); + make.expr_call(some, args).into() }), - FlowHandler::MatchResult { .. } => map_tail_expr(block, |tail_expr| { - let ok = make::expr_path(make::ext::ident_path("Ok")); - let args = make::arg_list(iter::once(tail_expr)); - make::expr_call(ok, args).into() + FlowHandler::MatchResult { .. } => map_tail_expr(make, block, |tail_expr| { + let ok = make.expr_path(make.ident_path("Ok")); + let args = make.arg_list(iter::once(tail_expr)); + make.expr_call(ok, args).into() }), } } -fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr { +fn map_tail_expr( + make: &SyntaxFactory, + block: ast::BlockExpr, + f: impl FnOnce(ast::Expr) -> ast::Expr, +) -> ast::BlockExpr { let tail_expr = match block.tail_expr() { Some(tail_expr) => tail_expr, None => return block, }; - make::block_expr(block.statements(), Some(f(tail_expr))) + make.block_expr(block.statements(), Some(f(tail_expr))) } -fn with_default_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { +fn with_default_tail_expr( + make: &SyntaxFactory, + block: ast::BlockExpr, + tail_expr: ast::Expr, +) -> ast::BlockExpr { match block.tail_expr() { Some(_) => block, - None => make::block_expr(block.statements(), Some(tail_expr)), + None => make.block_expr(block.statements(), Some(tail_expr)), } } -fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr { +fn with_tail_expr( + make: &SyntaxFactory, + block: ast::BlockExpr, + tail_expr: ast::Expr, +) -> ast::BlockExpr { let stmt_tail_opt: Option<ast::Stmt> = - block.tail_expr().map(|expr| make::expr_stmt(expr).into()); + block.tail_expr().map(|expr| make.expr_stmt(expr).into()); let mut elements: Vec<SyntaxElement> = vec![]; @@ -2012,7 +2042,7 @@ fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr elements.push(syntax::NodeOrToken::Node(stmt_tail.syntax().clone())); } - make::hacky_block_expr(elements, Some(tail_expr)) + make.hacky_block_expr(elements, Some(tail_expr)) } fn format_type(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> String { @@ -2024,9 +2054,14 @@ fn is_inherit_attr(attr: &ast::Attr) -> bool { matches!(name.as_str(), "track_caller" | "cfg") } -fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type { +fn make_ty( + make: &SyntaxFactory, + ty: &hir::Type<'_>, + ctx: &AssistContext<'_>, + module: hir::Module, +) -> ast::Type { let ty_str = format_type(ty, ctx, module); - make::ty(&ty_str) + make.ty(&ty_str) } fn rewrite_body_segment( @@ -2037,29 +2072,33 @@ fn rewrite_body_segment( syntax: &SyntaxNode, ) -> SyntaxNode { let to_this_param = to_this_param.and_then(|it| ctx.sema.to_def(&it)); - let syntax = fix_param_usages(ctx, to_this_param, params, syntax); - update_external_control_flow(handler, &syntax); - syntax + let (param_editor, param_root) = SyntaxEditor::new(syntax.clone()); + fix_param_usages(¶m_editor, syntax, ¶m_root, ctx, to_this_param, params); + let syntax = param_editor.finish().new_root().clone(); + + let (flow_editor, flow_root) = SyntaxEditor::new(syntax); + update_external_control_flow(&flow_editor, &flow_root, handler); + flow_editor.finish().new_root().clone() } /// change all usages to account for added `&`/`&mut` for some params fn fix_param_usages( + editor: &SyntaxEditor, + source_syntax: &SyntaxNode, + syntax: &SyntaxNode, ctx: &AssistContext<'_>, to_this_param: Option<Local>, params: &[Param<'_>], - syntax: &SyntaxNode, -) -> SyntaxNode { +) { let mut usages_for_param: Vec<(&Param<'_>, Vec<ast::Expr>)> = Vec::new(); let mut usages_for_self_param: Vec<ast::Expr> = Vec::new(); + let source_range = source_syntax.text_range(); + let source_start = source_range.start(); - let tm = TreeMutator::new(syntax); let reference_filter = |reference: &FileReference| { - syntax - .text_range() - .contains_range(reference.range) - .then_some(()) - .and_then(|_| path_element_of_reference(syntax, reference)) - .map(|expr| tm.make_mut(&expr)) + source_range.contains_range(reference.range).then_some(())?; + let local_range = reference.range - source_start; + path_element_of_reference(syntax, local_range) }; if let Some(self_param) = to_this_param { @@ -2079,11 +2118,11 @@ fn fix_param_usages( usages_for_param.push((param, usages.unique().collect())); } - let res = tm.make_syntax_mut(syntax); + let make = editor.make(); for self_usage in usages_for_self_param { - let this_expr = make::expr_path(make::ext::ident_path("this")).clone_for_update(); - ted::replace(self_usage.syntax(), this_expr.syntax()); + let this_expr = make.expr_path(make.ident_path("this")); + editor.replace(self_usage.syntax(), this_expr.syntax()); } for (param, usages) in usages_for_param { for usage in usages { @@ -2098,7 +2137,7 @@ fn fix_param_usages( Some(ast::Expr::RefExpr(node)) if param.kind() == ParamKind::MutRef && node.mut_token().is_some() => { - ted::replace( + editor.replace( node.syntax(), node.expr().expect("RefExpr::expr() cannot be None").syntax(), ); @@ -2106,23 +2145,25 @@ fn fix_param_usages( Some(ast::Expr::RefExpr(node)) if param.kind() == ParamKind::SharedRef && node.mut_token().is_none() => { - ted::replace( + editor.replace( node.syntax(), node.expr().expect("RefExpr::expr() cannot be None").syntax(), ); } Some(_) | None => { - let p = &make::expr_prefix(T![*], usage.clone()).clone_for_update(); - ted::replace(usage.syntax(), p.syntax()) + let p = make.expr_prefix(T![*], usage.clone()); + editor.replace(usage.syntax(), p.syntax()) } } } } - - res } -fn update_external_control_flow(handler: &FlowHandler<'_>, syntax: &SyntaxNode) { +fn update_external_control_flow( + editor: &SyntaxEditor, + syntax: &SyntaxNode, + handler: &FlowHandler<'_>, +) { let mut nested_loop = None; let mut nested_scope = None; for event in syntax.preorder() { @@ -2151,19 +2192,25 @@ fn update_external_control_flow(handler: &FlowHandler<'_>, syntax: &SyntaxNode) match expr { ast::Expr::ReturnExpr(return_expr) => { let expr = return_expr.expr(); - if let Some(replacement) = make_rewritten_flow(handler, expr) { - ted::replace(return_expr.syntax(), replacement.syntax()) + if let Some(replacement) = + make_rewritten_flow(handler, expr, editor.make()) + { + editor.replace(return_expr.syntax(), replacement.syntax()) } } ast::Expr::BreakExpr(break_expr) if nested_loop.is_none() => { let expr = break_expr.expr(); - if let Some(replacement) = make_rewritten_flow(handler, expr) { - ted::replace(break_expr.syntax(), replacement.syntax()) + if let Some(replacement) = + make_rewritten_flow(handler, expr, editor.make()) + { + editor.replace(break_expr.syntax(), replacement.syntax()) } } ast::Expr::ContinueExpr(continue_expr) if nested_loop.is_none() => { - if let Some(replacement) = make_rewritten_flow(handler, None) { - ted::replace(continue_expr.syntax(), replacement.syntax()) + if let Some(replacement) = + make_rewritten_flow(handler, None, editor.make()) + { + editor.replace(continue_expr.syntax(), replacement.syntax()) } } _ => { @@ -2186,27 +2233,29 @@ fn update_external_control_flow(handler: &FlowHandler<'_>, syntax: &SyntaxNode) fn make_rewritten_flow( handler: &FlowHandler<'_>, arg_expr: Option<ast::Expr>, + make: &SyntaxFactory, ) -> Option<ast::Expr> { let value = match handler { FlowHandler::None | FlowHandler::Try { .. } => return None, - FlowHandler::If { .. } => make::expr_call( - make::expr_path(make::path_from_text("ControlFlow::Break")), - make::arg_list([make::ext::expr_unit()]), - ) - .into(), + FlowHandler::If { .. } => make + .expr_call( + make.expr_path(make.path_from_text("ControlFlow::Break")), + make.arg_list([make.expr_unit()]), + ) + .into(), FlowHandler::IfOption { .. } => { - let expr = arg_expr.unwrap_or_else(make::ext::expr_unit); - let args = make::arg_list([expr]); - make::expr_call(make::expr_path(make::ext::ident_path("Some")), args).into() + let expr = arg_expr.unwrap_or_else(|| make.expr_unit()); + let args = make.arg_list([expr]); + make.expr_call(make.expr_path(make.ident_path("Some")), args).into() } - FlowHandler::MatchOption { .. } => make::expr_path(make::ext::ident_path("None")), + FlowHandler::MatchOption { .. } => make.expr_path(make.ident_path("None")), FlowHandler::MatchResult { .. } => { - let expr = arg_expr.unwrap_or_else(make::ext::expr_unit); - let args = make::arg_list([expr]); - make::expr_call(make::expr_path(make::ext::ident_path("Err")), args).into() + let expr = arg_expr.unwrap_or_else(|| make.expr_unit()); + let args = make.arg_list([expr]); + make.expr_call(make.expr_path(make.ident_path("Err")), args).into() } }; - Some(make::expr_return(Some(value)).clone_for_update()) + Some(make.expr_return(Some(value)).into()) } #[cfg(test)] |