Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/convert_to_guarded_return.rs282
-rw-r--r--crates/ide-assists/src/tests.rs2
2 files changed, 232 insertions, 52 deletions
diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index db45916792..e59527b0e0 100644
--- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,17 +1,19 @@
use std::iter::once;
use either::Either;
-use hir::{Semantics, TypeInfo};
+use hir::Semantics;
use ide_db::{RootDatabase, ty_filter::TryEnum};
use syntax::{
AstNode,
- SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
+ SyntaxKind::WHITESPACE,
SyntaxNode, T,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
syntax_factory::SyntaxFactory,
},
+ match_ast,
+ syntax_editor::SyntaxEditor,
};
use crate::{
@@ -71,9 +73,7 @@ fn if_expr_to_guarded_return(
) -> Option<()> {
let make = SyntaxFactory::without_mappings();
let else_block = match if_expr.else_branch() {
- Some(ast::ElseBranch::Block(block_expr)) if is_never_block(&ctx.sema, &block_expr) => {
- Some(block_expr)
- }
+ Some(ast::ElseBranch::Block(block_expr)) => Some(block_expr),
Some(_) => return None,
_ => None,
};
@@ -96,25 +96,20 @@ fn if_expr_to_guarded_return(
let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
- if parent_block.tail_expr() != Some(if_expr.clone().into())
- && !(else_block.is_some() && ast::ExprStmt::can_cast(if_expr.syntax().parent()?.kind()))
- {
- return None;
- }
-
// check for early return and continue
if is_early_block(&then_block) || is_never_block(&ctx.sema, &then_branch) {
return None;
}
let parent_container = parent_block.syntax().parent()?;
+ let else_block = ElseBlock::new(&ctx.sema, else_block, &parent_container)?;
- let early_expression = else_block
- .or_else(|| {
- early_expression(parent_container, &ctx.sema, &make)
- .map(ast::make::tail_only_block_expr)
- })?
- .reset_indent();
+ if parent_block.tail_expr() != Some(if_expr.clone().into())
+ && !(else_block.is_never_block
+ && ast::ExprStmt::can_cast(if_expr.syntax().parent()?.kind()))
+ {
+ return None;
+ }
then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
@@ -137,6 +132,7 @@ fn if_expr_to_guarded_return(
|edit| {
let make = SyntaxFactory::without_mappings();
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
+ let early_expression = else_block.make_early_block(&ctx.sema, &make);
let replacement = let_chains.into_iter().map(|expr| {
if let ast::Expr::LetExpr(let_expr) = &expr
&& let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
@@ -204,14 +200,9 @@ fn let_stmt_to_guarded_return(
let happy_pattern = try_enum.happy_pattern(pat);
let target = let_stmt.syntax().text_range();
- let make = SyntaxFactory::without_mappings();
- let early_expression: ast::Expr = {
- let parent_block =
- let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
- let parent_container = parent_block.syntax().parent()?;
-
- early_expression(parent_container, &ctx.sema, &make)?
- };
+ let parent_block = let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
+ let parent_container = parent_block.syntax().parent()?;
+ let else_block = ElseBlock::new(&ctx.sema, None, &parent_container)?;
acc.add(
AssistId::refactor_rewrite("convert_to_guarded_return"),
@@ -226,7 +217,7 @@ fn let_stmt_to_guarded_return(
happy_pattern,
let_stmt.ty(),
expr.reset_indent(),
- ast::make::tail_only_block_expr(early_expression),
+ else_block.make_early_block(&ctx.sema, &make),
);
let let_else_stmt = let_else_stmt.indent(let_indent_level);
let_else_stmt.syntax().clone()
@@ -239,33 +230,119 @@ fn let_stmt_to_guarded_return(
)
}
-fn early_expression(
- parent_container: SyntaxNode,
- sema: &Semantics<'_, RootDatabase>,
- make: &SyntaxFactory,
-) -> Option<ast::Expr> {
- let return_none_expr = || {
- let none_expr = make.expr_path(make.ident_path("None"));
- make.expr_return(Some(none_expr))
- };
- if let Some(fn_) = ast::Fn::cast(parent_container.clone())
- && let Some(fn_def) = sema.to_def(&fn_)
- && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
- {
- return Some(return_none_expr().into());
+struct ElseBlock<'db> {
+ exist_else_block: Option<ast::BlockExpr>,
+ is_never_block: bool,
+ kind: EarlyKind<'db>,
+}
+
+impl<'db> ElseBlock<'db> {
+ fn new(
+ sema: &Semantics<'db, RootDatabase>,
+ exist_else_block: Option<ast::BlockExpr>,
+ parent_container: &SyntaxNode,
+ ) -> Option<Self> {
+ let is_never_block = exist_else_block.as_ref().is_some_and(|it| is_never_block(sema, it));
+ let kind = EarlyKind::from_node(parent_container, sema)?;
+
+ Some(Self { exist_else_block, is_never_block, kind })
}
- if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
- && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
- && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
- {
- return Some(return_none_expr().into());
+
+ fn make_early_block(
+ self,
+ sema: &Semantics<'_, RootDatabase>,
+ make: &SyntaxFactory,
+ ) -> ast::BlockExpr {
+ let Some(block_expr) = self.exist_else_block else {
+ return make.tail_only_block_expr(self.kind.make_early_expr(sema, make, None));
+ };
+
+ if self.is_never_block {
+ return block_expr.reset_indent();
+ }
+
+ let block_expr = block_expr.reset_indent().clone_subtree();
+ let last_stmt = block_expr.statements().last().map(|it| it.syntax().clone());
+ let tail_expr = block_expr.tail_expr().map(|it| it.syntax().clone());
+ let Some(last_element) = tail_expr.clone().or(last_stmt.clone()) else {
+ return make.tail_only_block_expr(self.kind.make_early_expr(sema, make, None));
+ };
+ let whitespace = last_element.prev_sibling_or_token().filter(|it| it.kind() == WHITESPACE);
+
+ let make = SyntaxFactory::without_mappings();
+ let mut edit = SyntaxEditor::new(block_expr.syntax().clone());
+
+ if let Some(tail_expr) = block_expr.tail_expr()
+ && !self.kind.is_unit()
+ {
+ let early_expr = self.kind.make_early_expr(sema, &make, Some(tail_expr.clone()));
+ edit.replace(tail_expr.syntax(), early_expr.syntax());
+ } else {
+ let last_stmt = match block_expr.tail_expr() {
+ Some(expr) => make.expr_stmt(expr).syntax().clone(),
+ None => last_element.clone_for_update(),
+ };
+ let whitespace =
+ make.whitespace(&whitespace.map_or(String::new(), |it| it.to_string()));
+ let early_expr = self.kind.make_early_expr(sema, &make, None).syntax().clone().into();
+ edit.replace_with_many(
+ last_element,
+ vec![last_stmt.into(), whitespace.into(), early_expr],
+ );
+ }
+
+ ast::BlockExpr::cast(edit.finish().new_root().clone()).unwrap()
}
+}
- Some(match parent_container.kind() {
- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make.expr_continue(None).into(),
- FN | CLOSURE_EXPR => make.expr_return(None).into(),
- _ => return None,
- })
+enum EarlyKind<'db> {
+ Continue,
+ Return(hir::Type<'db>),
+}
+
+impl<'db> EarlyKind<'db> {
+ fn from_node(
+ parent_container: &SyntaxNode,
+ sema: &Semantics<'db, RootDatabase>,
+ ) -> Option<Self> {
+ match_ast! {
+ match parent_container {
+ ast::Fn(it) => Some(Self::Return(sema.to_def(&it)?.ret_type(sema.db))),
+ ast::ClosureExpr(it) => Some(Self::Return(sema.type_of_expr(&it.body()?)?.original)),
+ ast::WhileExpr(_) => Some(Self::Continue),
+ ast::LoopExpr(_) => Some(Self::Continue),
+ ast::ForExpr(_) => Some(Self::Continue),
+ _ => None
+ }
+ }
+ }
+
+ fn make_early_expr(
+ &self,
+ sema: &Semantics<'_, RootDatabase>,
+ make: &SyntaxFactory,
+ ret: Option<ast::Expr>,
+ ) -> ast::Expr {
+ match self {
+ EarlyKind::Continue => make.expr_continue(None).into(),
+ EarlyKind::Return(ty) => {
+ let expr = match TryEnum::from_ty(sema, ty) {
+ Some(TryEnum::Option) => {
+ ret.or_else(|| Some(make.expr_path(make.ident_path("None"))))
+ }
+ _ => ret,
+ };
+ make.expr_return(expr).into()
+ }
+ }
+ }
+
+ fn is_unit(&self) -> bool {
+ match self {
+ EarlyKind::Continue => true,
+ EarlyKind::Return(ty) => ty.is_unit(),
+ }
+ }
}
fn flat_let_chain(mut expr: ast::Expr, make: &SyntaxFactory) -> Vec<ast::Expr> {
@@ -465,6 +542,74 @@ fn main() {
}
#[test]
+ fn convert_if_let_has_else_block() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() -> i32 {
+ if$0 true {
+ foo();
+ } else {
+ bar()
+ }
+}
+"#,
+ r#"
+fn main() -> i32 {
+ if false {
+ return bar();
+ }
+ foo();
+}
+"#,
+ );
+
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ if$0 true {
+ foo();
+ } else {
+ bar()
+ }
+}
+"#,
+ r#"
+fn main() {
+ if false {
+ bar();
+ return
+ }
+ foo();
+}
+"#,
+ );
+
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ if$0 true {
+ foo();
+ } else {
+ bar();
+ }
+}
+"#,
+ r#"
+fn main() {
+ if false {
+ bar();
+ return
+ }
+ foo();
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_if_let_has_never_type_else_block() {
check_assist(
convert_to_guarded_return,
@@ -512,7 +657,7 @@ fn main() {
}
#[test]
- fn convert_if_let_has_else_block_in_statement() {
+ fn convert_if_let_has_never_type_else_block_in_statement() {
check_assist(
convert_to_guarded_return,
r#"
@@ -923,6 +1068,37 @@ fn main() {
}
#[test]
+ fn convert_let_inside_for_with_else() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ for n in ns {
+ if$0 let Some(n) = n {
+ foo(n);
+ bar();
+ } else {
+ baz()
+ }
+ }
+}
+"#,
+ r#"
+fn main() {
+ for n in ns {
+ let Some(n) = n else {
+ baz();
+ continue
+ };
+ foo(n);
+ bar();
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_let_stmt_inside_fn() {
check_assist(
convert_to_guarded_return,
@@ -1186,16 +1362,18 @@ fn main() {
}
#[test]
- fn ignore_else_branch() {
+ fn ignore_else_branch_has_non_never_types_in_statement() {
check_assist_not_applicable(
convert_to_guarded_return,
r#"
fn main() {
+ some_statements();
if$0 true {
foo();
} else {
bar()
}
+ some_statements();
}
"#,
);
diff --git a/crates/ide-assists/src/tests.rs b/crates/ide-assists/src/tests.rs
index a52bd74d14..1c90c95fe1 100644
--- a/crates/ide-assists/src/tests.rs
+++ b/crates/ide-assists/src/tests.rs
@@ -479,6 +479,7 @@ pub fn test_some_range(a: int) -> bool {
expect![[r#"
Extract into...
Replace if let with match
+ Convert to guarded return
"#]]
.assert_eq(&expected);
}
@@ -511,6 +512,7 @@ pub fn test_some_range(a: int) -> bool {
expect![[r#"
Extract into...
Replace if let with match
+ Convert to guarded return
"#]]
.assert_eq(&expected);
}