Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #21912 from A4-Tacks/conv-guarded-tail-expr-else
feat: offer on tail-expr with else-branch for if_let_to_guarded assist
| -rw-r--r-- | crates/ide-assists/src/handlers/convert_to_guarded_return.rs | 282 | ||||
| -rw-r--r-- | crates/ide-assists/src/tests.rs | 2 |
2 files changed, 232 insertions, 52 deletions
diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs index db45916792..e59527b0e0 100644 --- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs +++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs @@ -1,17 +1,19 @@ use std::iter::once; use either::Either; -use hir::{Semantics, TypeInfo}; +use hir::Semantics; use ide_db::{RootDatabase, ty_filter::TryEnum}; use syntax::{ AstNode, - SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE}, + SyntaxKind::WHITESPACE, SyntaxNode, T, ast::{ self, edit::{AstNodeEdit, IndentLevel}, syntax_factory::SyntaxFactory, }, + match_ast, + syntax_editor::SyntaxEditor, }; use crate::{ @@ -71,9 +73,7 @@ fn if_expr_to_guarded_return( ) -> Option<()> { let make = SyntaxFactory::without_mappings(); let else_block = match if_expr.else_branch() { - Some(ast::ElseBranch::Block(block_expr)) if is_never_block(&ctx.sema, &block_expr) => { - Some(block_expr) - } + Some(ast::ElseBranch::Block(block_expr)) => Some(block_expr), Some(_) => return None, _ => None, }; @@ -96,25 +96,20 @@ fn if_expr_to_guarded_return( let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; - if parent_block.tail_expr() != Some(if_expr.clone().into()) - && !(else_block.is_some() && ast::ExprStmt::can_cast(if_expr.syntax().parent()?.kind())) - { - return None; - } - // check for early return and continue if is_early_block(&then_block) || is_never_block(&ctx.sema, &then_branch) { return None; } let parent_container = parent_block.syntax().parent()?; + let else_block = ElseBlock::new(&ctx.sema, else_block, &parent_container)?; - let early_expression = else_block - .or_else(|| { - early_expression(parent_container, &ctx.sema, &make) - .map(ast::make::tail_only_block_expr) - })? - .reset_indent(); + if parent_block.tail_expr() != Some(if_expr.clone().into()) + && !(else_block.is_never_block + && ast::ExprStmt::can_cast(if_expr.syntax().parent()?.kind())) + { + return None; + } then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?; @@ -137,6 +132,7 @@ fn if_expr_to_guarded_return( |edit| { let make = SyntaxFactory::without_mappings(); let if_indent_level = IndentLevel::from_node(if_expr.syntax()); + let early_expression = else_block.make_early_block(&ctx.sema, &make); let replacement = let_chains.into_iter().map(|expr| { if let ast::Expr::LetExpr(let_expr) = &expr && let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr()) @@ -204,14 +200,9 @@ fn let_stmt_to_guarded_return( let happy_pattern = try_enum.happy_pattern(pat); let target = let_stmt.syntax().text_range(); - let make = SyntaxFactory::without_mappings(); - let early_expression: ast::Expr = { - let parent_block = - let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; - let parent_container = parent_block.syntax().parent()?; - - early_expression(parent_container, &ctx.sema, &make)? - }; + let parent_block = let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; + let parent_container = parent_block.syntax().parent()?; + let else_block = ElseBlock::new(&ctx.sema, None, &parent_container)?; acc.add( AssistId::refactor_rewrite("convert_to_guarded_return"), @@ -226,7 +217,7 @@ fn let_stmt_to_guarded_return( happy_pattern, let_stmt.ty(), expr.reset_indent(), - ast::make::tail_only_block_expr(early_expression), + else_block.make_early_block(&ctx.sema, &make), ); let let_else_stmt = let_else_stmt.indent(let_indent_level); let_else_stmt.syntax().clone() @@ -239,33 +230,119 @@ fn let_stmt_to_guarded_return( ) } -fn early_expression( - parent_container: SyntaxNode, - sema: &Semantics<'_, RootDatabase>, - make: &SyntaxFactory, -) -> Option<ast::Expr> { - let return_none_expr = || { - let none_expr = make.expr_path(make.ident_path("None")); - make.expr_return(Some(none_expr)) - }; - if let Some(fn_) = ast::Fn::cast(parent_container.clone()) - && let Some(fn_def) = sema.to_def(&fn_) - && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db)) - { - return Some(return_none_expr().into()); +struct ElseBlock<'db> { + exist_else_block: Option<ast::BlockExpr>, + is_never_block: bool, + kind: EarlyKind<'db>, +} + +impl<'db> ElseBlock<'db> { + fn new( + sema: &Semantics<'db, RootDatabase>, + exist_else_block: Option<ast::BlockExpr>, + parent_container: &SyntaxNode, + ) -> Option<Self> { + let is_never_block = exist_else_block.as_ref().is_some_and(|it| is_never_block(sema, it)); + let kind = EarlyKind::from_node(parent_container, sema)?; + + Some(Self { exist_else_block, is_never_block, kind }) } - if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body()) - && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original) - && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty) - { - return Some(return_none_expr().into()); + + fn make_early_block( + self, + sema: &Semantics<'_, RootDatabase>, + make: &SyntaxFactory, + ) -> ast::BlockExpr { + let Some(block_expr) = self.exist_else_block else { + return make.tail_only_block_expr(self.kind.make_early_expr(sema, make, None)); + }; + + if self.is_never_block { + return block_expr.reset_indent(); + } + + let block_expr = block_expr.reset_indent().clone_subtree(); + let last_stmt = block_expr.statements().last().map(|it| it.syntax().clone()); + let tail_expr = block_expr.tail_expr().map(|it| it.syntax().clone()); + let Some(last_element) = tail_expr.clone().or(last_stmt.clone()) else { + return make.tail_only_block_expr(self.kind.make_early_expr(sema, make, None)); + }; + let whitespace = last_element.prev_sibling_or_token().filter(|it| it.kind() == WHITESPACE); + + let make = SyntaxFactory::without_mappings(); + let mut edit = SyntaxEditor::new(block_expr.syntax().clone()); + + if let Some(tail_expr) = block_expr.tail_expr() + && !self.kind.is_unit() + { + let early_expr = self.kind.make_early_expr(sema, &make, Some(tail_expr.clone())); + edit.replace(tail_expr.syntax(), early_expr.syntax()); + } else { + let last_stmt = match block_expr.tail_expr() { + Some(expr) => make.expr_stmt(expr).syntax().clone(), + None => last_element.clone_for_update(), + }; + let whitespace = + make.whitespace(&whitespace.map_or(String::new(), |it| it.to_string())); + let early_expr = self.kind.make_early_expr(sema, &make, None).syntax().clone().into(); + edit.replace_with_many( + last_element, + vec![last_stmt.into(), whitespace.into(), early_expr], + ); + } + + ast::BlockExpr::cast(edit.finish().new_root().clone()).unwrap() } +} - Some(match parent_container.kind() { - WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make.expr_continue(None).into(), - FN | CLOSURE_EXPR => make.expr_return(None).into(), - _ => return None, - }) +enum EarlyKind<'db> { + Continue, + Return(hir::Type<'db>), +} + +impl<'db> EarlyKind<'db> { + fn from_node( + parent_container: &SyntaxNode, + sema: &Semantics<'db, RootDatabase>, + ) -> Option<Self> { + match_ast! { + match parent_container { + ast::Fn(it) => Some(Self::Return(sema.to_def(&it)?.ret_type(sema.db))), + ast::ClosureExpr(it) => Some(Self::Return(sema.type_of_expr(&it.body()?)?.original)), + ast::WhileExpr(_) => Some(Self::Continue), + ast::LoopExpr(_) => Some(Self::Continue), + ast::ForExpr(_) => Some(Self::Continue), + _ => None + } + } + } + + fn make_early_expr( + &self, + sema: &Semantics<'_, RootDatabase>, + make: &SyntaxFactory, + ret: Option<ast::Expr>, + ) -> ast::Expr { + match self { + EarlyKind::Continue => make.expr_continue(None).into(), + EarlyKind::Return(ty) => { + let expr = match TryEnum::from_ty(sema, ty) { + Some(TryEnum::Option) => { + ret.or_else(|| Some(make.expr_path(make.ident_path("None")))) + } + _ => ret, + }; + make.expr_return(expr).into() + } + } + } + + fn is_unit(&self) -> bool { + match self { + EarlyKind::Continue => true, + EarlyKind::Return(ty) => ty.is_unit(), + } + } } fn flat_let_chain(mut expr: ast::Expr, make: &SyntaxFactory) -> Vec<ast::Expr> { @@ -465,6 +542,74 @@ fn main() { } #[test] + fn convert_if_let_has_else_block() { + check_assist( + convert_to_guarded_return, + r#" +fn main() -> i32 { + if$0 true { + foo(); + } else { + bar() + } +} +"#, + r#" +fn main() -> i32 { + if false { + return bar(); + } + foo(); +} +"#, + ); + + check_assist( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + foo(); + } else { + bar() + } +} +"#, + r#" +fn main() { + if false { + bar(); + return + } + foo(); +} +"#, + ); + + check_assist( + convert_to_guarded_return, + r#" +fn main() { + if$0 true { + foo(); + } else { + bar(); + } +} +"#, + r#" +fn main() { + if false { + bar(); + return + } + foo(); +} +"#, + ); + } + + #[test] fn convert_if_let_has_never_type_else_block() { check_assist( convert_to_guarded_return, @@ -512,7 +657,7 @@ fn main() { } #[test] - fn convert_if_let_has_else_block_in_statement() { + fn convert_if_let_has_never_type_else_block_in_statement() { check_assist( convert_to_guarded_return, r#" @@ -923,6 +1068,37 @@ fn main() { } #[test] + fn convert_let_inside_for_with_else() { + check_assist( + convert_to_guarded_return, + r#" +fn main() { + for n in ns { + if$0 let Some(n) = n { + foo(n); + bar(); + } else { + baz() + } + } +} +"#, + r#" +fn main() { + for n in ns { + let Some(n) = n else { + baz(); + continue + }; + foo(n); + bar(); + } +} +"#, + ); + } + + #[test] fn convert_let_stmt_inside_fn() { check_assist( convert_to_guarded_return, @@ -1186,16 +1362,18 @@ fn main() { } #[test] - fn ignore_else_branch() { + fn ignore_else_branch_has_non_never_types_in_statement() { check_assist_not_applicable( convert_to_guarded_return, r#" fn main() { + some_statements(); if$0 true { foo(); } else { bar() } + some_statements(); } "#, ); diff --git a/crates/ide-assists/src/tests.rs b/crates/ide-assists/src/tests.rs index a52bd74d14..1c90c95fe1 100644 --- a/crates/ide-assists/src/tests.rs +++ b/crates/ide-assists/src/tests.rs @@ -479,6 +479,7 @@ pub fn test_some_range(a: int) -> bool { expect![[r#" Extract into... Replace if let with match + Convert to guarded return "#]] .assert_eq(&expected); } @@ -511,6 +512,7 @@ pub fn test_some_range(a: int) -> bool { expect![[r#" Extract into... Replace if let with match + Convert to guarded return "#]] .assert_eq(&expected); } |