use std::collections::VecDeque;
use ide_db::{
assists::GroupLabel,
famous_defs::FamousDefs,
syntax_helpers::node_ext::{for_each_tail_expr, is_pattern_cond, walk_expr},
};
use syntax::{
NodeOrToken, SyntaxKind, T,
ast::{
self, AstNode,
Expr::BinExpr,
HasArgList,
prec::{ExprPrecedence, precedence},
syntax_factory::SyntaxFactory,
},
syntax_editor::{Position, SyntaxEditor},
};
use crate::{AssistContext, AssistId, Assists, utils::invert_boolean_expression};
// Assist: apply_demorgan
//
// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws).
// This transforms expressions of the form `!l || !r` into `!(l && r)`.
// This also works with `&&`. This assist can only be applied with the cursor
// on either `||` or `&&`.
//
// ```
// fn main() {
// if x != 4 ||$0 y < 3.14 {}
// }
// ```
// ->
// ```
// fn main() {
// if !(x == 4 && y >= 3.14) {}
// }
// ```
pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let mut bin_expr = if let Some(not) = ctx.find_token_syntax_at_offset(T![!])
&& let Some(NodeOrToken::Node(next)) = not.next_sibling_or_token()
&& let Some(paren) = ast::ParenExpr::cast(next)
&& let Some(ast::Expr::BinExpr(bin_expr)) = paren.expr()
{
bin_expr
} else {
let bin_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
let op_range = bin_expr.op_token()?.text_range();
// Is the cursor on the expression's logical operator?
if !op_range.contains_range(ctx.selection_trimmed()) {
return None;
}
bin_expr
};
let op = bin_expr.op_kind()?;
let op_range = bin_expr.op_token()?.text_range();
// Walk up the tree while we have the same binary operator
while let Some(parent_expr) = bin_expr.syntax().parent().and_then(ast::BinExpr::cast) {
match parent_expr.op_kind() {
Some(parent_op) if parent_op == op => {
bin_expr = parent_expr;
}
_ => break,
}
}
if is_pattern_cond(bin_expr.clone().into()) {
return None;
}
let op = bin_expr.op_kind()?;
let (inv_token, prec) = match op {
ast::BinaryOp::LogicOp(ast::LogicOp::And) => (SyntaxKind::PIPE2, ExprPrecedence::LOr),
ast::BinaryOp::LogicOp(ast::LogicOp::Or) => (SyntaxKind::AMP2, ExprPrecedence::LAnd),
_ => return None,
};
let make = SyntaxFactory::with_mappings();
let demorganed = bin_expr.clone_subtree();
let mut editor = SyntaxEditor::new(demorganed.syntax().clone());
editor.replace(demorganed.op_token()?, make.token(inv_token));
let mut exprs = VecDeque::from([
(bin_expr.lhs()?, demorganed.lhs()?, prec),
(bin_expr.rhs()?, demorganed.rhs()?, prec),
]);
while let Some((expr, demorganed, prec)) = exprs.pop_front() {
if let BinExpr(bin_expr) = &expr {
if let BinExpr(cbin_expr) = &demorganed {
if op == bin_expr.op_kind()? {
editor.replace(cbin_expr.op_token()?, make.token(inv_token));
exprs.push_back((bin_expr.lhs()?, cbin_expr.lhs()?, prec));
exprs.push_back((bin_expr.rhs()?, cbin_expr.rhs()?, prec));
} else {
let mut inv = invert_boolean_expression(&make, expr);
if precedence(&inv).needs_parentheses_in(prec) {
inv = make.expr_paren(inv).into();
}
editor.replace(demorganed.syntax(), inv.syntax());
}
} else {
return None;
}
} else {
let mut inv = invert_boolean_expression(&make, demorganed.clone());
if precedence(&inv).needs_parentheses_in(prec) {
inv = make.expr_paren(inv).into();
}
editor.replace(demorganed.syntax(), inv.syntax());
}
}
editor.add_mappings(make.finish_with_mappings());
let edit = editor.finish();
let demorganed = ast::Expr::cast(edit.new_root().clone())?;
acc.add_group(
&GroupLabel("Apply De Morgan's law".to_owned()),
AssistId::refactor_rewrite("apply_demorgan"),
"Apply De Morgan's law",
op_range,
|builder| {
let make = SyntaxFactory::with_mappings();
let (target_node, result_expr) = if let Some(neg_expr) = bin_expr
.syntax()
.parent()
.and_then(ast::ParenExpr::cast)
.and_then(|paren_expr| paren_expr.syntax().parent())
.and_then(ast::PrefixExpr::cast)
.filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
{
cov_mark::hit!(demorgan_double_negation);
(ast::Expr::from(neg_expr).syntax().clone(), demorganed)
} else if let Some(paren_expr) =
bin_expr.syntax().parent().and_then(ast::ParenExpr::cast)
{
cov_mark::hit!(demorgan_double_parens);
(paren_expr.syntax().clone(), add_bang_paren(&make, demorganed))
} else {
(bin_expr.syntax().clone(), add_bang_paren(&make, demorganed))
};
let final_expr = if target_node
.parent()
.is_some_and(|p| result_expr.needs_parens_in_place_of(&p, &target_node))
{
cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
make.expr_paren(result_expr).into()
} else {
result_expr
};
let mut editor = builder.make_editor(&target_node);
editor.replace(&target_node, final_expr.syntax());
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.vfs_file_id(), editor);
},
)
}
// Assist: apply_demorgan_iterator
//
// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to
// `Iterator::all` and `Iterator::any`.
//
// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
// `Iterator::all` into `Iterator::any`.
//
// ```
// # //- minicore: iterator
// fn main() {
// let arr = [1, 2, 3];
// if !arr.into_iter().$0any(|num| num == 4) {
// println!("foo");
// }
// }
// ```
// ->
// ```
// fn main() {
// let arr = [1, 2, 3];
// if arr.into_iter().all(|num| num != 4) {
// println!("foo");
// }
// }
// ```
pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
let closure_body = closure_expr.body()?.clone_for_update();
let op_range = method_call.syntax().text_range();
let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
acc.add_group(
&GroupLabel("Apply De Morgan's law".to_owned()),
AssistId::refactor_rewrite("apply_demorgan_iterator"),
label,
op_range,
|builder| {
let make = SyntaxFactory::with_mappings();
let mut editor = builder.make_editor(method_call.syntax());
// replace the method name
let new_name = match name.text().as_str() {
"all" => make.name_ref("any"),
"any" => make.name_ref("all"),
_ => unreachable!(),
};
editor.replace(name.syntax(), new_name.syntax());
// negate all tail expressions in the closure body
let tail_cb = &mut |e: &_| tail_cb_impl(&mut editor, &make, e);
walk_expr(&closure_body, &mut |expr| {
if let ast::Expr::ReturnExpr(ret_expr) = expr
&& let Some(ret_expr_arg) = &ret_expr.expr()
{
for_each_tail_expr(ret_expr_arg, tail_cb);
}
});
for_each_tail_expr(&closure_body, tail_cb);
// negate the whole method call
if let Some(prefix_expr) = method_call
.syntax()
.parent()
.and_then(ast::PrefixExpr::cast)
.filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
{
editor.delete(
prefix_expr.op_token().expect("prefix expression always has an operator"),
);
} else {
editor.insert(Position::before(method_call.syntax()), make.token(SyntaxKind::BANG));
}
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.vfs_file_id(), editor);
},
)
}
/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
fn validate_method_call_expr(
ctx: &AssistContext<'_>,
method_call: &ast::MethodCallExpr,
) -> Option<(ast::NameRef, ast::Expr)> {
let name_ref = method_call.name_ref()?;
if name_ref.text() != "all" && name_ref.text() != "any" {
return None;
}
let arg_expr = method_call.arg_list()?.args().next()?;
let sema = &ctx.sema;
let receiver = method_call.receiver()?;
let it_type = sema.type_of_expr(&receiver)?.adjusted();
let module = sema.scope(receiver.syntax())?.module();
let krate = module.krate(ctx.db());
let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
}
fn tail_cb_impl(editor: &mut SyntaxEditor, make: &SyntaxFactory, e: &ast::Expr) {
match e {
ast::Expr::BreakExpr(break_expr) => {
if let Some(break_expr_arg) = break_expr.expr() {
for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, make, e))
}
}
ast::Expr::ReturnExpr(_) => {
// all return expressions have already been handled by the walk loop
}
e => {
let inverted_body = invert_boolean_expression(make, e.clone());
editor.replace(e.syntax(), inverted_body.syntax());
}
}
}
/// Add bang and parentheses to the expression.
fn add_bang_paren(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr {
make.expr_prefix(T![!], make.expr_paren(expr).into()).into()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{check_assist, check_assist_not_applicable};
#[test]
fn demorgan_handles_leq() {
check_assist(
apply_demorgan,
r#"
struct S;
fn f() { S < S &&$0 S <= S }
"#,
r#"
struct S;
fn f() { !(S >= S || S > S) }
"#,
);
}
#[test]
fn demorgan_handles_geq() {
check_assist(
apply_demorgan,
r#"
struct S;
fn f() { S > S &&$0 S >= S }
"#,
r#"
struct S;
fn f() { !(S <= S || S < S) }
"#,
);
}
#[test]
fn demorgan_turns_and_into_or() {
check_assist(apply_demorgan, "fn f() { !x &&$0 !x }", "fn f() { !(x || x) }")
}
#[test]
fn demorgan_turns_or_into_and() {
check_assist(apply_demorgan, "fn f() { !x ||$0 !x }", "fn f() { !(x && x) }")
}
#[test]
fn demorgan_removes_inequality() {
check_assist(apply_demorgan, "fn f() { x != x ||$0 !x }", "fn f() { !(x == x && x) }")
}
#[test]
fn demorgan_general_case() {
check_assist(apply_demorgan, "fn f() { x ||$0 x }", "fn f() { !(!x && !x) }")
}
#[test]
fn demorgan_multiple_terms() {
check_assist(apply_demorgan, "fn f() { x ||$0 y || z }", "fn f() { !(!x && !y && !z) }");
check_assist(apply_demorgan, "fn f() { x || y ||$0 z }", "fn f() { !(!x && !y && !z) }");
}
#[test]
fn demorgan_doesnt_apply_with_cursor_not_on_op() {
check_assist_not_applicable(apply_demorgan, "fn f() { $0 !x || !x }")
}
#[test]
fn demorgan_doesnt_double_negation() {
cov_mark::check!(demorgan_double_negation);
check_assist(apply_demorgan, "fn f() { !(x ||$0 x) }", "fn f() { !x && !x }")
}
#[test]
fn demorgan_doesnt_double_parens() {
cov_mark::check!(demorgan_double_parens);
check_assist(apply_demorgan, "fn f() { (x ||$0 x) }", "fn f() { !(!x && !x) }")
}
#[test]
fn demorgan_doesnt_hang() {
check_assist(
apply_demorgan,
"fn f() { 1 || 3 &&$0 4 || 5 }",
"fn f() { 1 || !(!3 || !4) || 5 }",
)
}
#[test]
fn demorgan_doesnt_handles_pattern() {
check_assist_not_applicable(
apply_demorgan,
r#"
fn f() { if let 1 = 1 &&$0 true { } }
"#,
);
}
#[test]
fn demorgan_on_not() {
check_assist(
apply_demorgan,
"fn f() { $0!(1 || 3 && 4 || 5) }",
"fn f() { !1 && !(3 && 4) && !5 }",
)
}
#[test]
fn demorgan_keep_pars_for_op_precedence() {
check_assist(
apply_demorgan,
"fn main() {
let _ = !(!a ||$0 !(b || c));
}
",
"fn main() {
let _ = a && (b || c);
}
",
);
}
#[test]
fn demorgan_keep_pars_for_op_precedence2() {
cov_mark::check!(demorgan_keep_parens_for_op_precedence2);
check_assist(
apply_demorgan,
"fn f() { (a && !(b &&$0 c); }",
"fn f() { (a && (!b || !c); }",
);
}
#[test]
fn demorgan_keep_pars_for_op_precedence3() {
check_assist(
apply_demorgan,
"fn f() { (a || !(b &&$0 c); }",
"fn f() { (a || (!b || !c); }",
);
}
#[test]
fn demorgan_keeps_pars_in_eq_precedence() {
check_assist(
apply_demorgan,
"fn() { let x = a && !(!b |$0| !c); }",
"fn() { let x = a && (b && c); }",
)
}
#[test]
fn demorgan_removes_pars_for_op_precedence2() {
check_assist(apply_demorgan, "fn f() { (a || !(b ||$0 c); }", "fn f() { (a || !b && !c; }");
}
#[test]
fn demorgan_iterator_any_all_reverse() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| num $0!= 4) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().any(|num| num == 4) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_all_any() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0all(|num| num > 3) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().any(|num| num <= 3) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_multiple_terms() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_double_negation() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0all(|num| !(num > 3)) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().any(|num| num > 3) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_double_parens() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_multiline() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if arr
.into_iter()
.all$0(|num| !num.is_negative())
{
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [1, 2, 3];
if !arr
.into_iter()
.any(|num| num.is_negative())
{
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_block_closure() {
check_assist(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [-1, 1, 2, 3];
if arr.into_iter().all(|num: i32| {
$0if num.is_positive() {
num <= 3
} else {
num >= -1
}
}) {
println!("foo");
}
}
"#,
r#"
fn main() {
let arr = [-1, 1, 2, 3];
if !arr.into_iter().any(|num: i32| {
if num.is_positive() {
num > 3
} else {
num < -1
}
}) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_iterator_wrong_method() {
check_assist_not_applicable(
apply_demorgan_iterator,
r#"
//- minicore: iterator
fn main() {
let arr = [1, 2, 3];
if !arr.into_iter().$0map(|num| num > 3) {
println!("foo");
}
}
"#,
);
}
#[test]
fn demorgan_method_call_receiver() {
check_assist(
apply_demorgan,
"fn f() { (x ||$0 !y).then_some(42) }",
"fn f() { (!(!x && y)).then_some(42) }",
);
}
#[test]
fn demorgan_method_call_receiver_complex() {
check_assist(
apply_demorgan,
"fn f() { (a && b ||$0 c && d).then_some(42) }",
"fn f() { (!(!(a && b) && !(c && d))).then_some(42) }",
);
}
#[test]
fn demorgan_method_call_receiver_chained() {
check_assist(
apply_demorgan,
"fn f() { (a ||$0 b).then_some(42).or(Some(0)) }",
"fn f() { (!(!a && !b)).then_some(42).or(Some(0)) }",
);
}
}