Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/convert_to_guarded_return.rs175
1 files changed, 173 insertions, 2 deletions
diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
index 6f30ffa622..e1966d476c 100644
--- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
+++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
@@ -1,6 +1,9 @@
use std::iter::once;
-use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let};
+use ide_db::{
+ syntax_helpers::node_ext::{is_pattern_cond, single_let},
+ ty_filter::TryEnum,
+};
use syntax::{
ast::{
self,
@@ -41,13 +44,35 @@ use crate::{
// }
// ```
pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
- let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
+ if let Some(let_stmt) = ctx.find_node_at_offset() {
+ let_stmt_to_guarded_return(let_stmt, acc, ctx)
+ } else if let Some(if_expr) = ctx.find_node_at_offset() {
+ if_expr_to_guarded_return(if_expr, acc, ctx)
+ } else {
+ None
+ }
+}
+
+fn if_expr_to_guarded_return(
+ if_expr: ast::IfExpr,
+ acc: &mut Assists,
+ ctx: &AssistContext<'_>,
+) -> Option<()> {
if if_expr.else_branch().is_some() {
return None;
}
let cond = if_expr.condition()?;
+ let if_token_range = if_expr.if_token()?.text_range();
+ let if_cond_range = cond.syntax().text_range();
+
+ let cursor_in_range =
+ if_token_range.cover(if_cond_range).contains_range(ctx.selection_trimmed());
+ if !cursor_in_range {
+ return None;
+ }
+
// Check if there is an IfLet that we can handle.
let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
let let_ = single_let(cond)?;
@@ -148,6 +173,65 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'
)
}
+fn let_stmt_to_guarded_return(
+ let_stmt: ast::LetStmt,
+ acc: &mut Assists,
+ ctx: &AssistContext<'_>,
+) -> Option<()> {
+ let pat = let_stmt.pat()?;
+ let expr = let_stmt.initializer()?;
+
+ let let_token_range = let_stmt.let_token()?.text_range();
+ let let_pattern_range = pat.syntax().text_range();
+ let cursor_in_range =
+ let_token_range.cover(let_pattern_range).contains_range(ctx.selection_trimmed());
+
+ if !cursor_in_range {
+ return None;
+ }
+
+ let try_enum =
+ ctx.sema.type_of_expr(&expr).and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))?;
+
+ let happy_pattern = try_enum.happy_pattern(pat);
+ let target = let_stmt.syntax().text_range();
+
+ let early_expression: ast::Expr = {
+ let parent_block =
+ let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
+ let parent_container = parent_block.syntax().parent()?;
+
+ match parent_container.kind() {
+ WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None),
+ FN => make::expr_return(None),
+ _ => return None,
+ }
+ };
+
+ acc.add(
+ AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
+ "Convert to guarded return",
+ target,
+ |edit| {
+ let let_stmt = edit.make_mut(let_stmt);
+ let let_indent_level = IndentLevel::from_node(let_stmt.syntax());
+
+ let replacement = {
+ let let_else_stmt = make::let_else_stmt(
+ happy_pattern,
+ let_stmt.ty(),
+ expr,
+ ast::make::tail_only_block_expr(early_expression),
+ );
+ let let_else_stmt = let_else_stmt.indent(let_indent_level);
+ let_else_stmt.syntax().clone_for_update()
+ };
+
+ ted::replace(let_stmt.syntax(), replacement)
+ },
+ )
+}
+
#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
@@ -451,6 +535,62 @@ fn main() {
}
#[test]
+ fn convert_let_stmt_inside_fn() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+ None
+}
+
+fn main() {
+ let x$0 = foo();
+}
+"#,
+ r#"
+fn foo() -> Option<i32> {
+ None
+}
+
+fn main() {
+ let Some(x) = foo() else { return };
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn convert_let_stmt_inside_loop() {
+ check_assist(
+ convert_to_guarded_return,
+ r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+ None
+}
+
+fn main() {
+ loop {
+ let x$0 = foo();
+ }
+}
+"#,
+ r#"
+fn foo() -> Option<i32> {
+ None
+}
+
+fn main() {
+ loop {
+ let Some(x) = foo() else { continue };
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
fn convert_arbitrary_if_let_patterns() {
check_assist(
convert_to_guarded_return,
@@ -594,4 +734,35 @@ fn main() {
"#,
);
}
+
+ #[test]
+ fn ignore_inside_if_stmt() {
+ check_assist_not_applicable(
+ convert_to_guarded_return,
+ r#"
+fn main() {
+ if false {
+ foo()$0;
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn ignore_inside_let_initializer() {
+ check_assist_not_applicable(
+ convert_to_guarded_return,
+ r#"
+//- minicore: option
+fn foo() -> Option<i32> {
+ None
+}
+
+fn main() {
+ let x = foo()$0;
+}
+"#,
+ );
+ }
}