Unnamed repository; edit this file 'description' to name the repository.
fix: add checks for overwriting incorrect ancestor
Ryan Mehri 2023-09-10
parent 2e13aed · commit 7ba2e13
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs166
1 files changed, 165 insertions, 1 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index f59b052813..784a0d3559 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -263,6 +263,11 @@ fn replace_usages(
fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?;
+ if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) {
+ cov_mark::hit!(dont_assign_incorrect_ref);
+ return None;
+ }
+
if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
bin_expr.rhs()
} else {
@@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> {
let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
+ if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
+ cov_mark::hit!(dont_overwrite_expression_inside_negation);
+ return None;
+ }
+
if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
let inner_expr = prefix_expr.expr()?;
Some((prefix_expr, inner_expr))
@@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie
let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?;
let initializer = record_field.expr()?;
- Some((record_field, initializer))
+ if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) {
+ Some((record_field, initializer))
+ } else {
+ cov_mark::hit!(dont_overwrite_wrong_record_field);
+ None
+ }
}
/// Adds the definition of the new enum before the target node.
@@ -562,6 +577,37 @@ fn main() {
}
#[test]
+ fn local_variable_nested_in_negation() {
+ cov_mark::check!(dont_overwrite_expression_inside_negation);
+ check_assist(
+ bool_to_enum,
+ r#"
+fn main() {
+ if !"foo".chars().any(|c| {
+ let $0foo = true;
+ foo
+ }) {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+fn main() {
+ if !"foo".chars().any(|c| {
+ #[derive(PartialEq, Eq)]
+ enum Bool { True, False }
+
+ let foo = Bool::True;
+ foo == Bool::True
+ }) {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn local_variable_non_bool() {
cov_mark::check!(not_applicable_non_bool_local);
check_assist_not_applicable(
@@ -639,6 +685,42 @@ fn main() {
}
#[test]
+ fn field_negated() {
+ check_assist(
+ bool_to_enum,
+ r#"
+struct Foo {
+ $0bar: bool,
+}
+
+fn main() {
+ let foo = Foo { bar: false };
+
+ if !foo.bar {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ bar: Bool,
+}
+
+fn main() {
+ let foo = Foo { bar: Bool::False };
+
+ if foo.bar == Bool::False {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn field_in_mod_properly_indented() {
check_assist(
bool_to_enum,
@@ -715,6 +797,88 @@ fn main() {
}
#[test]
+ fn field_assigned_to_another() {
+ cov_mark::check!(dont_assign_incorrect_ref);
+ check_assist(
+ bool_to_enum,
+ r#"
+struct Foo {
+ $0foo: bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+ let mut bar = Bar { bar: true };
+
+ bar.bar = foo.foo;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ foo: Bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+ let mut bar = Bar { bar: true };
+
+ bar.bar = foo.foo == Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_initialized_with_other() {
+ cov_mark::check!(dont_overwrite_wrong_record_field);
+ check_assist(
+ bool_to_enum,
+ r#"
+struct Foo {
+ $0foo: bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+ let bar = Bar { bar: foo.foo };
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ foo: Bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+ let bar = Bar { bar: foo.foo == Bool::True };
+}
+"#,
+ )
+ }
+
+ #[test]
fn field_non_bool() {
cov_mark::check!(not_applicable_non_bool_field);
check_assist_not_applicable(