Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/convert_to_guarded_return.rs297
1 files changed, 252 insertions, 45 deletions
diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index 2ea032fb62..82213ae321 100644
--- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,13 +1,12 @@
use std::iter::once;
-use ide_db::{
- syntax_helpers::node_ext::{is_pattern_cond, single_let},
- ty_filter::TryEnum,
-};
+use either::Either;
+use hir::{Semantics, TypeInfo};
+use ide_db::{RootDatabase, ty_filter::TryEnum};
use syntax::{
AstNode,
- SyntaxKind::{FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
- T,
+ SyntaxKind::{CLOSURE_EXPR, FN, FOR_EXPR, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
+ SyntaxNode, T,
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
@@ -44,12 +43,9 @@ use crate::{
// }
// ```
pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
- if let Some(let_stmt) = ctx.find_node_at_offset() {
- let_stmt_to_guarded_return(let_stmt, acc, ctx)
- } else if let Some(if_expr) = ctx.find_node_at_offset() {
- if_expr_to_guarded_return(if_expr, acc, ctx)
- } else {
- None
+ match ctx.find_node_at_offset::<Either<ast::LetStmt, ast::IfExpr>>()? {
+ Either::Left(let_stmt) => let_stmt_to_guarded_return(let_stmt, acc, ctx),
+ Either::Right(if_expr) => if_expr_to_guarded_return(if_expr, acc, ctx),
}
}
@@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
return None;
}
- // Check if there is an IfLet that we can handle.
- let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
- let let_ = single_let(cond)?;
- (Some(let_.pat()?), let_.expr()?)
- } else {
- (None, cond)
- };
+ let let_chains = flat_let_chain(cond);
let then_block = if_expr.then_branch()?;
let then_block = then_block.stmt_list()?;
@@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
let parent_container = parent_block.syntax().parent()?;
- let early_expression: ast::Expr = match parent_container.kind() {
- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
- FN => make::expr_return(None),
- _ => return None,
- };
+ let early_expression: ast::Expr = early_expression(parent_container, &ctx.sema)?;
then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{'])?;
@@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
target,
|edit| {
let if_indent_level = IndentLevel::from_node(if_expr.syntax());
- let replacement = match if_let_pat {
- None => {
- // If.
- let new_expr = {
- let then_branch =
- make::block_expr(once(make::expr_stmt(early_expression).into()), None);
- let cond = invert_boolean_expression_legacy(cond_expr);
- make::expr_if(cond, then_branch, None).indent(if_indent_level)
- };
- new_expr.syntax().clone()
- }
- Some(pat) => {
+ let replacement = let_chains.into_iter().map(|expr| {
+ if let ast::Expr::LetExpr(let_expr) = &expr
+ && let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr())
+ {
// If-let.
let let_else_stmt = make::let_else_stmt(
pat,
None,
- cond_expr,
- ast::make::tail_only_block_expr(early_expression),
+ expr,
+ ast::make::tail_only_block_expr(early_expression.clone()),
);
let let_else_stmt = let_else_stmt.indent(if_indent_level);
let_else_stmt.syntax().clone()
+ } else {
+ // If.
+ let new_expr = {
+ let then_branch = make::block_expr(
+ once(make::expr_stmt(early_expression.clone()).into()),
+ None,
+ );
+ let cond = invert_boolean_expression_legacy(expr);
+ make::expr_if(cond, then_branch, None).indent(if_indent_level)
+ };
+ new_expr.syntax().clone()
}
- };
+ });
+ let newline = &format!("\n{if_indent_level}");
let then_statements = replacement
- .children_with_tokens()
+ .enumerate()
+ .flat_map(|(i, node)| {
+ (i != 0)
+ .then(|| make::tokens::whitespace(newline).into())
+ .into_iter()
+ .chain(node.children_with_tokens())
+ })
.chain(
then_block_items
.syntax()
@@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
let parent_container = parent_block.syntax().parent()?;
- match parent_container.kind() {
- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
- FN => make::expr_return(None),
- _ => return None,
- }
+ early_expression(parent_container, &ctx.sema)?
};
acc.add(
@@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
)
}
+fn early_expression(
+ parent_container: SyntaxNode,
+ sema: &Semantics<'_, RootDatabase>,
+) -> Option<ast::Expr> {
+ let return_none_expr = || {
+ let none_expr = make::expr_path(make::ext::ident_path("None"));
+ make::expr_return(Some(none_expr))
+ };
+ if let Some(fn_) = ast::Fn::cast(parent_container.clone())
+ && let Some(fn_def) = sema.to_def(&fn_)
+ && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &fn_def.ret_type(sema.db))
+ {
+ return Some(return_none_expr());
+ }
+ if let Some(body) = ast::ClosureExpr::cast(parent_container.clone()).and_then(|it| it.body())
+ && let Some(ret_ty) = sema.type_of_expr(&body).map(TypeInfo::original)
+ && let Some(TryEnum::Option) = TryEnum::from_ty(sema, &ret_ty)
+ {
+ return Some(return_none_expr());
+ }
+
+ Some(match parent_container.kind() {
+ WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
+ FN | CLOSURE_EXPR => make::expr_return(None),
+ _ => return None,
+ })
+}
+
+fn flat_let_chain(mut expr: ast::Expr) -> Vec<ast::Expr> {
+ let mut chains = vec![];
+
+ while let ast::Expr::BinExpr(bin_expr) = &expr
+ && bin_expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And))
+ && let (Some(lhs), Some(rhs)) = (bin_expr.lhs(), bin_expr.rhs())
+ {
+ if let Some(last) = chains.pop_if(|last| !matches!(last, ast::Expr::LetExpr(_))) {
+ chains.push(make::expr_bin_op(rhs, ast::BinaryOp::LogicOp(ast::LogicOp::And), last));
+ } else {
+ chains.push(rhs);
+ }
+ expr = lhs;
+ }
+
+ chains.push(expr);
+ chains.reverse();
+ chains
+}
+
#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
@@ -269,6 +309,71 @@ fn main() {
}
#[test]
+ fn convert_inside_fn_return_option() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+//- minicore: option
+fn ret_option() -> Option<()> {
+ bar();
+ if$0 true {
+ foo();
+
+ // comment
+ bar();
+ }
+}
+"#,
+ r#"
+fn ret_option() -> Option<()> {
+ bar();
+ if false {
+ return None;
+ }
+ foo();
+
+ // comment
+ bar();
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn convert_inside_closure() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ let _f = || {
+ bar();
+ if$0 true {
+ foo();
+
+ // comment
+ bar();
+ }
+ }
+}
+"#,
+ r#"
+fn main() {
+ let _f = || {
+ bar();
+ if false {
+ return;
+ }
+ foo();
+
+ // comment
+ bar();
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_let_inside_fn() {
check_assist(
convert_to_guarded_return,
@@ -317,6 +422,82 @@ fn main() {
}
#[test]
+ fn convert_if_let_result_inside_let() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ let _x = loop {
+ if$0 let Ok(x) = Err(92) {
+ foo(x);
+ }
+ };
+}
+"#,
+ r#"
+fn main() {
+ let _x = loop {
+ let Ok(x) = Err(92) else { continue };
+ foo(x);
+ };
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn convert_if_let_chain_result() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ if$0 let Ok(x) = Err(92)
+ && x < 30
+ && let Some(y) = Some(8)
+ {
+ foo(x, y);
+ }
+}
+"#,
+ r#"
+fn main() {
+ let Ok(x) = Err(92) else { return };
+ if x >= 30 {
+ return;
+ }
+ let Some(y) = Some(8) else { return };
+ foo(x, y);
+}
+"#,
+ );
+
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ if$0 let Ok(x) = Err(92)
+ && x < 30
+ && y < 20
+ && let Some(y) = Some(8)
+ {
+ foo(x, y);
+ }
+}
+"#,
+ r#"
+fn main() {
+ let Ok(x) = Err(92) else { return };
+ if !(x < 30 && y < 20) {
+ return;
+ }
+ let Some(y) = Some(8) else { return };
+ foo(x, y);
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_let_ok_inside_fn() {
check_assist(
convert_to_guarded_return,
@@ -561,6 +742,32 @@ fn main() {
}
#[test]
+ fn convert_let_stmt_inside_fn_return_option() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+ None
+}
+
+fn ret_option() -> Option<i32> {
+ let x$0 = foo();
+}
+"#,
+ r#"
+fn foo() -> Option<i32> {
+ None
+}
+
+fn ret_option() -> Option<i32> {
+ let Some(x) = foo() else { return None };
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_let_stmt_inside_loop() {
check_assist(
convert_to_guarded_return,