Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/convert_bool_then.rs')
-rw-r--r--crates/ide-assists/src/handlers/convert_bool_then.rs575
1 files changed, 575 insertions, 0 deletions
diff --git a/crates/ide-assists/src/handlers/convert_bool_then.rs b/crates/ide-assists/src/handlers/convert_bool_then.rs
new file mode 100644
index 0000000000..f9ec9326b6
--- /dev/null
+++ b/crates/ide-assists/src/handlers/convert_bool_then.rs
@@ -0,0 +1,575 @@
+use hir::{known, AsAssocItem, Semantics};
+use ide_db::{
+ famous_defs::FamousDefs,
+ syntax_helpers::node_ext::{
+ block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
+ },
+ RootDatabase,
+};
+use itertools::Itertools;
+use syntax::{
+ ast::{self, edit::AstNodeEdit, make, HasArgList},
+ ted, AstNode, SyntaxNode,
+};
+
+use crate::{
+ utils::{invert_boolean_expression, unwrap_trivial_block},
+ AssistContext, AssistId, AssistKind, Assists,
+};
+
+// Assist: convert_if_to_bool_then
+//
+// Converts an if expression into a corresponding `bool::then` call.
+//
+// ```
+// # //- minicore: option
+// fn main() {
+// if$0 cond {
+// Some(val)
+// } else {
+// None
+// }
+// }
+// ```
+// ->
+// ```
+// fn main() {
+// cond.then(|| val)
+// }
+// ```
+pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
+ // FIXME applies to match as well
+ let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
+ if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
+ return None;
+ }
+
+ let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
+ let then = expr.then_branch()?;
+ let else_ = match expr.else_branch()? {
+ ast::ElseBranch::Block(b) => b,
+ ast::ElseBranch::IfExpr(_) => {
+ cov_mark::hit!(convert_if_to_bool_then_chain);
+ return None;
+ }
+ };
+
+ let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
+
+ let (invert_cond, closure_body) = match (
+ block_is_none_variant(&ctx.sema, &then, none_variant),
+ block_is_none_variant(&ctx.sema, &else_, none_variant),
+ ) {
+ (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
+ (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
+ _ => return None,
+ };
+
+ if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
+ cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
+ return None;
+ }
+
+ let target = expr.syntax().text_range();
+ acc.add(
+ AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
+ "Convert `if` expression to `bool::then` call",
+ target,
+ |builder| {
+ let closure_body = closure_body.clone_for_update();
+ // Rewrite all `Some(e)` in tail position to `e`
+ let mut replacements = Vec::new();
+ for_each_tail_expr(&closure_body, &mut |e| {
+ let e = match e {
+ ast::Expr::BreakExpr(e) => e.expr(),
+ e @ ast::Expr::CallExpr(_) => Some(e.clone()),
+ _ => None,
+ };
+ if let Some(ast::Expr::CallExpr(call)) = e {
+ if let Some(arg_list) = call.arg_list() {
+ if let Some(arg) = arg_list.args().next() {
+ replacements.push((call.syntax().clone(), arg.syntax().clone()));
+ }
+ }
+ }
+ });
+ replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
+ let closure_body = match closure_body {
+ ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
+ e => e,
+ };
+
+ let parenthesize = matches!(
+ cond,
+ ast::Expr::BinExpr(_)
+ | ast::Expr::BlockExpr(_)
+ | ast::Expr::BoxExpr(_)
+ | ast::Expr::BreakExpr(_)
+ | ast::Expr::CastExpr(_)
+ | ast::Expr::ClosureExpr(_)
+ | ast::Expr::ContinueExpr(_)
+ | ast::Expr::ForExpr(_)
+ | ast::Expr::IfExpr(_)
+ | ast::Expr::LoopExpr(_)
+ | ast::Expr::MacroExpr(_)
+ | ast::Expr::MatchExpr(_)
+ | ast::Expr::PrefixExpr(_)
+ | ast::Expr::RangeExpr(_)
+ | ast::Expr::RefExpr(_)
+ | ast::Expr::ReturnExpr(_)
+ | ast::Expr::WhileExpr(_)
+ | ast::Expr::YieldExpr(_)
+ );
+ let cond = if invert_cond { invert_boolean_expression(cond) } else { cond };
+ let cond = if parenthesize { make::expr_paren(cond) } else { cond };
+ let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
+ let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
+ builder.replace(target, mcall.to_string());
+ },
+ )
+}
+
+// Assist: convert_bool_then_to_if
+//
+// Converts a `bool::then` method call to an equivalent if expression.
+//
+// ```
+// # //- minicore: bool_impl
+// fn main() {
+// (0 == 0).then$0(|| val)
+// }
+// ```
+// ->
+// ```
+// fn main() {
+// if 0 == 0 {
+// Some(val)
+// } else {
+// None
+// }
+// }
+// ```
+pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
+ let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
+ let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
+ let receiver = mcall.receiver()?;
+ let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
+ let closure_body = match closure_body {
+ ast::Expr::ClosureExpr(expr) => expr.body()?,
+ _ => return None,
+ };
+ // Verify this is `bool::then` that is being called.
+ let func = ctx.sema.resolve_method_call(&mcall)?;
+ if func.name(ctx.sema.db).to_string() != "then" {
+ return None;
+ }
+ let assoc = func.as_assoc_item(ctx.sema.db)?;
+ match assoc.container(ctx.sema.db) {
+ hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {}
+ _ => return None,
+ }
+
+ let target = mcall.syntax().text_range();
+ acc.add(
+ AssistId("convert_bool_then_to_if", AssistKind::RefactorRewrite),
+ "Convert `bool::then` call to `if`",
+ target,
+ |builder| {
+ let closure_body = match closure_body {
+ ast::Expr::BlockExpr(block) => block,
+ e => make::block_expr(None, Some(e)),
+ };
+
+ let closure_body = closure_body.clone_for_update();
+ // Wrap all tails in `Some(...)`
+ let none_path = make::expr_path(make::ext::ident_path("None"));
+ let some_path = make::expr_path(make::ext::ident_path("Some"));
+ let mut replacements = Vec::new();
+ for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
+ let e = match e {
+ ast::Expr::BreakExpr(e) => e.expr(),
+ ast::Expr::ReturnExpr(e) => e.expr(),
+ _ => Some(e.clone()),
+ };
+ if let Some(expr) = e {
+ replacements.push((
+ expr.syntax().clone(),
+ make::expr_call(some_path.clone(), make::arg_list(Some(expr)))
+ .syntax()
+ .clone_for_update(),
+ ));
+ }
+ });
+ replacements.into_iter().for_each(|(old, new)| ted::replace(old, new));
+
+ let cond = match &receiver {
+ ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
+ _ => receiver,
+ };
+ let if_expr = make::expr_if(
+ cond,
+ closure_body.reset_indent(),
+ Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))),
+ )
+ .indent(mcall.indent_level());
+
+ builder.replace(target, if_expr.to_string());
+ },
+ )
+}
+
+fn option_variants(
+ sema: &Semantics<RootDatabase>,
+ expr: &SyntaxNode,
+) -> Option<(hir::Variant, hir::Variant)> {
+ let fam = FamousDefs(sema, sema.scope(expr)?.krate());
+ let option_variants = fam.core_option_Option()?.variants(sema.db);
+ match &*option_variants {
+ &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
+ (variant0, variant1)
+ } else {
+ (variant1, variant0)
+ }),
+ _ => None,
+ }
+}
+
+/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
+/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
+fn is_invalid_body(
+ sema: &Semantics<RootDatabase>,
+ some_variant: hir::Variant,
+ expr: &ast::Expr,
+) -> bool {
+ let mut invalid = false;
+ preorder_expr(expr, &mut |e| {
+ invalid |=
+ matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
+ invalid
+ });
+ if !invalid {
+ for_each_tail_expr(expr, &mut |e| {
+ if invalid {
+ return;
+ }
+ let e = match e {
+ ast::Expr::BreakExpr(e) => e.expr(),
+ e @ ast::Expr::CallExpr(_) => Some(e.clone()),
+ _ => None,
+ };
+ if let Some(ast::Expr::CallExpr(call)) = e {
+ if let Some(ast::Expr::PathExpr(p)) = call.expr() {
+ let res = p.path().and_then(|p| sema.resolve_path(&p));
+ if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
+ return invalid |= v != some_variant;
+ }
+ }
+ }
+ invalid = true
+ });
+ }
+ invalid
+}
+
+fn block_is_none_variant(
+ sema: &Semantics<RootDatabase>,
+ block: &ast::BlockExpr,
+ none_variant: hir::Variant,
+) -> bool {
+ block_as_lone_tail(block).and_then(|e| match e {
+ ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
+ hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
+ _ => None,
+ },
+ _ => None,
+ }) == Some(none_variant)
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::tests::{check_assist, check_assist_not_applicable};
+
+ use super::*;
+
+ #[test]
+ fn convert_if_to_bool_then_simple() {
+ check_assist(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ Some(15)
+ } else {
+ None
+ }
+}
+",
+ r"
+fn main() {
+ true.then(|| 15)
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_invert() {
+ check_assist(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ None
+ } else {
+ Some(15)
+ }
+}
+",
+ r"
+fn main() {
+ false.then(|| 15)
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_none_none() {
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ None
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_some_some() {
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ Some(15)
+ } else {
+ Some(15)
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_mixed() {
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ if true {
+ Some(15)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_chain() {
+ cov_mark::check!(convert_if_to_bool_then_chain);
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ Some(15)
+ } else if true {
+ None
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_pattern_cond() {
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 let true = true {
+ Some(15)
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_if_to_bool_then_pattern_invalid_body() {
+ cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn make_me_an_option() -> Option<i32> { None }
+fn main() {
+ if$0 true {
+ if true {
+ make_me_an_option()
+ } else {
+ Some(15)
+ }
+ } else {
+ None
+ }
+}
+",
+ );
+ check_assist_not_applicable(
+ convert_if_to_bool_then,
+ r"
+//- minicore:option
+fn main() {
+ if$0 true {
+ if true {
+ return;
+ }
+ Some(15)
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_bool_then_to_if_inapplicable() {
+ check_assist_not_applicable(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ 0.t$0hen(|| 15);
+}
+",
+ );
+ check_assist_not_applicable(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ true.t$0hen(15);
+}
+",
+ );
+ check_assist_not_applicable(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ true.t$0hen(|| 15, 15);
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_bool_then_to_if_simple() {
+ check_assist(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ true.t$0hen(|| 15)
+}
+",
+ r"
+fn main() {
+ if true {
+ Some(15)
+ } else {
+ None
+ }
+}
+",
+ );
+ check_assist(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ true.t$0hen(|| {
+ 15
+ })
+}
+",
+ r"
+fn main() {
+ if true {
+ Some(15)
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+
+ #[test]
+ fn convert_bool_then_to_if_tails() {
+ check_assist(
+ convert_bool_then_to_if,
+ r"
+//- minicore:bool_impl
+fn main() {
+ true.t$0hen(|| {
+ loop {
+ if false {
+ break 0;
+ }
+ break 15;
+ }
+ })
+}
+",
+ r"
+fn main() {
+ if true {
+ loop {
+ if false {
+ break Some(0);
+ }
+ break Some(15);
+ }
+ } else {
+ None
+ }
+}
+",
+ );
+ }
+}