Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_variable.rs')
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_variable.rs | 66 |
1 files changed, 58 insertions, 8 deletions
diff --git a/crates/ide-assists/src/handlers/extract_variable.rs b/crates/ide-assists/src/handlers/extract_variable.rs index 7c60184142..e5ce02cf53 100644 --- a/crates/ide-assists/src/handlers/extract_variable.rs +++ b/crates/ide-assists/src/handlers/extract_variable.rs @@ -9,7 +9,6 @@ use syntax::{ ast::{ self, AstNode, edit::{AstNodeEdit, IndentLevel}, - make, syntax_factory::SyntaxFactory, }, syntax_editor::Position, @@ -75,7 +74,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op .next() .and_then(ast::Expr::cast) { - expr.syntax().ancestors().find_map(valid_target_expr)?.syntax().clone() + expr.syntax().ancestors().find_map(valid_target_expr(ctx))?.syntax().clone() } else { return None; } @@ -96,7 +95,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let to_extract = node .descendants() .take_while(|it| range.contains_range(it.text_range())) - .find_map(valid_target_expr)?; + .find_map(valid_target_expr(ctx))?; let ty = ctx.sema.type_of_expr(&to_extract).map(TypeInfo::adjusted); if matches!(&ty, Some(ty_info) if ty_info.is_unit()) { @@ -176,7 +175,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let mut editor = edit.make_editor(&expr_replace); let pat_name = make.name(&var_name); - let name_expr = make.expr_path(make::ext::ident_path(&var_name)); + let name_expr = make.expr_path(make.ident_path(&var_name)); if let Some(cap) = ctx.config.snippet_cap { let tabstop = edit.make_tabstop_before(cap); @@ -233,7 +232,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op Position::before(place), vec![ new_stmt.syntax().clone().into(), - make::tokens::whitespace(&trailing_ws).into(), + make.whitespace(&trailing_ws).into(), ], ); @@ -283,14 +282,19 @@ fn peel_parens(mut expr: ast::Expr) -> ast::Expr { /// Check whether the node is a valid expression which can be extracted to a variable. /// In general that's true for any expression, but in some cases that would produce invalid code. -fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> { - match node.kind() { - SyntaxKind::PATH_EXPR | SyntaxKind::LOOP_EXPR | SyntaxKind::LET_EXPR => None, +fn valid_target_expr(ctx: &AssistContext<'_>) -> impl Fn(SyntaxNode) -> Option<ast::Expr> { + |node| match node.kind() { + SyntaxKind::LOOP_EXPR | SyntaxKind::LET_EXPR => None, SyntaxKind::BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()), SyntaxKind::RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()), SyntaxKind::BLOCK_EXPR => { ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from) } + SyntaxKind::PATH_EXPR => { + let path_expr = ast::PathExpr::cast(node)?; + let path_resolution = ctx.sema.resolve_path(&path_expr.path()?)?; + like_const_value(ctx, path_resolution).then_some(path_expr.into()) + } _ => ast::Expr::cast(node), } } @@ -455,6 +459,31 @@ impl Anchor { } } +fn like_const_value(ctx: &AssistContext<'_>, path_resolution: hir::PathResolution) -> bool { + let db = ctx.db(); + let adt_like_const_value = |adt: Option<hir::Adt>| matches!(adt, Some(hir::Adt::Struct(s)) if s.kind(db) == hir::StructKind::Unit); + match path_resolution { + hir::PathResolution::Def(def) => match def { + hir::ModuleDef::Adt(adt) => adt_like_const_value(Some(adt)), + hir::ModuleDef::EnumVariant(variant) => variant.kind(db) == hir::StructKind::Unit, + hir::ModuleDef::TypeAlias(ty) => adt_like_const_value(ty.ty(db).as_adt()), + hir::ModuleDef::Const(_) | hir::ModuleDef::Static(_) => true, + hir::ModuleDef::Trait(_) + | hir::ModuleDef::BuiltinType(_) + | hir::ModuleDef::Macro(_) + | hir::ModuleDef::Module(_) => false, + hir::ModuleDef::Function(_) => false, // no extract named function + }, + hir::PathResolution::SelfType(ty) => adt_like_const_value(ty.self_ty(db).as_adt()), + hir::PathResolution::ConstParam(_) => true, + hir::PathResolution::Local(_) + | hir::PathResolution::TypeParam(_) + | hir::PathResolution::BuiltinAttr(_) + | hir::PathResolution::ToolModule(_) + | hir::PathResolution::DeriveHelper(_) => false, + } +} + #[cfg(test)] mod tests { // NOTE: We use check_assist_by_label, but not check_assist_not_applicable_by_label @@ -1748,6 +1777,27 @@ fn main() { } #[test] + fn extract_non_local_path_expr() { + check_assist_by_label( + extract_variable, + r#" +struct Foo; +fn foo() -> Foo { + $0Foo$0 +} +"#, + r#" +struct Foo; +fn foo() -> Foo { + let $0foo = Foo; + foo +} +"#, + "Extract into variable", + ); + } + + #[test] fn extract_var_for_return_not_applicable() { check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } "); } |