Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/convert_bool_to_enum.rs')
-rw-r--r--crates/ide-assists/src/handlers/convert_bool_to_enum.rs1930
1 files changed, 1930 insertions, 0 deletions
diff --git a/crates/ide-assists/src/handlers/convert_bool_to_enum.rs b/crates/ide-assists/src/handlers/convert_bool_to_enum.rs
new file mode 100644
index 0000000000..7716e99e60
--- /dev/null
+++ b/crates/ide-assists/src/handlers/convert_bool_to_enum.rs
@@ -0,0 +1,1930 @@
+use either::Either;
+use hir::ModuleDef;
+use ide_db::text_edit::TextRange;
+use ide_db::{
+ assists::{AssistId, AssistKind},
+ defs::Definition,
+ helpers::mod_path_to_ast,
+ imports::insert_use::{insert_use, ImportScope},
+ search::{FileReference, UsageSearchResult},
+ source_change::SourceChangeBuilder,
+ FxHashSet,
+};
+use itertools::Itertools;
+use syntax::{
+ ast::{
+ self,
+ edit::IndentLevel,
+ edit_in_place::{AttrsOwnerEdit, Indent},
+ make, HasName,
+ },
+ AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
+};
+
+use crate::{
+ assist_context::{AssistContext, Assists},
+ utils,
+};
+
+// Assist: convert_bool_to_enum
+//
+// This converts boolean local variables, fields, constants, and statics into a new
+// enum with two variants `Bool::True` and `Bool::False`, as well as replacing
+// all assignments with the variants and replacing all usages with `== Bool::True` or
+// `== Bool::False`.
+//
+// ```
+// fn main() {
+// let $0bool = true;
+//
+// if bool {
+// println!("foo");
+// }
+// }
+// ```
+// ->
+// ```
+// #[derive(PartialEq, Eq)]
+// enum Bool { True, False }
+//
+// fn main() {
+// let bool = Bool::True;
+//
+// if bool == Bool::True {
+// println!("foo");
+// }
+// }
+// ```
+pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
+ let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
+ find_bool_node(ctx)?;
+ let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
+
+ let target = name.syntax().text_range();
+ acc.add(
+ AssistId("convert_bool_to_enum", AssistKind::RefactorRewrite),
+ "Convert boolean to enum",
+ target,
+ |edit| {
+ if let Some(ty) = &ty_annotation {
+ cov_mark::hit!(replaces_ty_annotation);
+ edit.replace(ty.syntax().text_range(), "Bool");
+ }
+
+ if let Some(initializer) = initializer {
+ replace_bool_expr(edit, initializer);
+ }
+
+ let usages = definition.usages(&ctx.sema).all();
+ add_enum_def(edit, ctx, &usages, target_node, &target_module);
+ let mut delayed_mutations = Vec::new();
+ replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations);
+ for (scope, path) in delayed_mutations {
+ insert_use(&scope, path, &ctx.config.insert_use);
+ }
+ },
+ )
+}
+
+struct BoolNodeData {
+ target_node: SyntaxNode,
+ name: ast::Name,
+ ty_annotation: Option<ast::Type>,
+ initializer: Option<ast::Expr>,
+ definition: Definition,
+}
+
+/// Attempts to find an appropriate node to apply the action to.
+fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
+ let name = ctx.find_node_at_offset::<ast::Name>()?;
+
+ if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
+ let def = ctx.sema.to_def(&ident_pat)?;
+ if !def.ty(ctx.db()).is_bool() {
+ cov_mark::hit!(not_applicable_non_bool_local);
+ return None;
+ }
+
+ let local_definition = Definition::Local(def);
+ match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
+ Either::Left(param) => Some(BoolNodeData {
+ target_node: param.syntax().clone(),
+ name,
+ ty_annotation: param.ty(),
+ initializer: None,
+ definition: local_definition,
+ }),
+ Either::Right(let_stmt) => Some(BoolNodeData {
+ target_node: let_stmt.syntax().clone(),
+ name,
+ ty_annotation: let_stmt.ty(),
+ initializer: let_stmt.initializer(),
+ definition: local_definition,
+ }),
+ }
+ } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
+ let def = ctx.sema.to_def(&const_)?;
+ if !def.ty(ctx.db()).is_bool() {
+ cov_mark::hit!(not_applicable_non_bool_const);
+ return None;
+ }
+
+ Some(BoolNodeData {
+ target_node: const_.syntax().clone(),
+ name,
+ ty_annotation: const_.ty(),
+ initializer: const_.body(),
+ definition: Definition::Const(def),
+ })
+ } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) {
+ let def = ctx.sema.to_def(&static_)?;
+ if !def.ty(ctx.db()).is_bool() {
+ cov_mark::hit!(not_applicable_non_bool_static);
+ return None;
+ }
+
+ Some(BoolNodeData {
+ target_node: static_.syntax().clone(),
+ name,
+ ty_annotation: static_.ty(),
+ initializer: static_.body(),
+ definition: Definition::Static(def),
+ })
+ } else {
+ let field = name.syntax().parent().and_then(ast::RecordField::cast)?;
+ if field.name()? != name {
+ return None;
+ }
+
+ let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
+ let def = ctx.sema.to_def(&field)?;
+ if !def.ty(ctx.db()).is_bool() {
+ cov_mark::hit!(not_applicable_non_bool_field);
+ return None;
+ }
+ Some(BoolNodeData {
+ target_node: adt.syntax().clone(),
+ name,
+ ty_annotation: field.ty(),
+ initializer: None,
+ definition: Definition::Field(def),
+ })
+ }
+}
+
+fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
+ let expr_range = expr.syntax().text_range();
+ let enum_expr = bool_expr_to_enum_expr(expr);
+ edit.replace(expr_range, enum_expr.syntax().text())
+}
+
+/// Converts an expression of type `bool` to one of the new enum type.
+fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
+ let true_expr = make::expr_path(make::path_from_text("Bool::True"));
+ let false_expr = make::expr_path(make::path_from_text("Bool::False"));
+
+ if let ast::Expr::Literal(literal) = &expr {
+ match literal.kind() {
+ ast::LiteralKind::Bool(true) => true_expr,
+ ast::LiteralKind::Bool(false) => false_expr,
+ _ => expr,
+ }
+ } else {
+ make::expr_if(
+ expr,
+ make::tail_only_block_expr(true_expr),
+ Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))),
+ )
+ .into()
+ }
+}
+
+/// Replaces all usages of the target identifier, both when read and written to.
+fn replace_usages(
+ edit: &mut SourceChangeBuilder,
+ ctx: &AssistContext<'_>,
+ usages: UsageSearchResult,
+ target_definition: Definition,
+ target_module: &hir::Module,
+ delayed_mutations: &mut Vec<(ImportScope, ast::Path)>,
+) {
+ for (file_id, references) in usages {
+ edit.edit_file(file_id.file_id());
+
+ let refs_with_imports = augment_references_with_imports(ctx, references, target_module);
+
+ refs_with_imports.into_iter().rev().for_each(
+ |FileReferenceWithImport { range, name, import_data }| {
+ // replace the usages in patterns and expressions
+ if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
+ cov_mark::hit!(replaces_record_pat_shorthand);
+
+ let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
+ if let Some(def) = definition {
+ replace_usages(
+ edit,
+ ctx,
+ def.usages(&ctx.sema).all(),
+ target_definition,
+ target_module,
+ delayed_mutations,
+ )
+ }
+ } else if let Some(initializer) = find_assignment_usage(&name) {
+ cov_mark::hit!(replaces_assignment);
+
+ replace_bool_expr(edit, initializer);
+ } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
+ cov_mark::hit!(replaces_negation);
+
+ edit.replace(
+ prefix_expr.syntax().text_range(),
+ format!("{inner_expr} == Bool::False"),
+ );
+ } else if let Some((record_field, initializer)) = name
+ .as_name_ref()
+ .and_then(ast::RecordExprField::for_field_name)
+ .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
+ .and_then(|(got_field, _, _)| {
+ find_record_expr_usage(&name, got_field, target_definition)
+ })
+ {
+ cov_mark::hit!(replaces_record_expr);
+
+ let enum_expr = bool_expr_to_enum_expr(initializer);
+ utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
+ } else if let Some(pat) = find_record_pat_field_usage(&name) {
+ match pat {
+ ast::Pat::IdentPat(ident_pat) => {
+ cov_mark::hit!(replaces_record_pat);
+
+ let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
+ if let Some(def) = definition {
+ replace_usages(
+ edit,
+ ctx,
+ def.usages(&ctx.sema).all(),
+ target_definition,
+ target_module,
+ delayed_mutations,
+ )
+ }
+ }
+ ast::Pat::LiteralPat(literal_pat) => {
+ cov_mark::hit!(replaces_literal_pat);
+
+ if let Some(expr) = literal_pat.literal().and_then(|literal| {
+ literal.syntax().ancestors().find_map(ast::Expr::cast)
+ }) {
+ replace_bool_expr(edit, expr);
+ }
+ }
+ _ => (),
+ }
+ } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
+ edit.replace(ty_annotation.syntax().text_range(), "Bool");
+ replace_bool_expr(edit, initializer);
+ } else if let Some(receiver) = find_method_call_expr_usage(&name) {
+ edit.replace(
+ receiver.syntax().text_range(),
+ format!("({receiver} == Bool::True)"),
+ );
+ } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
+ // for any other usage in an expression, replace it with a check that it is the true variant
+ if let Some((record_field, expr)) =
+ name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
+ |record_field| record_field.expr().map(|expr| (record_field, expr)),
+ )
+ {
+ utils::replace_record_field_expr(
+ ctx,
+ edit,
+ record_field,
+ make::expr_bin_op(
+ expr,
+ ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
+ make::expr_path(make::path_from_text("Bool::True")),
+ ),
+ );
+ } else {
+ edit.replace(range, format!("{} == Bool::True", name.text()));
+ }
+ }
+
+ // add imports across modules where needed
+ if let Some((import_scope, path)) = import_data {
+ let scope = match import_scope {
+ ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
+ ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
+ ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
+ };
+ delayed_mutations.push((scope, path));
+ }
+ },
+ )
+ }
+}
+
+struct FileReferenceWithImport {
+ range: TextRange,
+ name: ast::NameLike,
+ import_data: Option<(ImportScope, ast::Path)>,
+}
+
+fn augment_references_with_imports(
+ ctx: &AssistContext<'_>,
+ references: Vec<FileReference>,
+ target_module: &hir::Module,
+) -> Vec<FileReferenceWithImport> {
+ let mut visited_modules = FxHashSet::default();
+
+ let cfg = ctx.config.import_path_config();
+
+ let edition = target_module.krate().edition(ctx.db());
+ references
+ .into_iter()
+ .filter_map(|FileReference { range, name, .. }| {
+ let name = name.into_name_like()?;
+ ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
+ })
+ .map(|(range, name, ref_module)| {
+ // if the referenced module is not the same as the target one and has not been seen before, add an import
+ let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
+ && !visited_modules.contains(&ref_module)
+ {
+ visited_modules.insert(ref_module);
+
+ let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
+ let path = ref_module
+ .find_use_path(
+ ctx.sema.db,
+ ModuleDef::Module(*target_module),
+ ctx.config.insert_use.prefix_kind,
+ cfg,
+ )
+ .map(|mod_path| {
+ make::path_concat(
+ mod_path_to_ast(&mod_path, edition),
+ make::path_from_text("Bool"),
+ )
+ });
+
+ import_scope.zip(path)
+ } else {
+ None
+ };
+
+ FileReferenceWithImport { range, name, import_data }
+ })
+ .collect()
+}
+
+fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
+ let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
+
+ if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
+ cov_mark::hit!(dont_assign_incorrect_ref);
+ return None;
+ }
+
+ if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
+ bin_expr.rhs()
+ } else {
+ None
+ }
+}
+
+fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
+ let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
+
+ if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
+ cov_mark::hit!(dont_overwrite_expression_inside_negation);
+ return None;
+ }
+
+ if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
+ let inner_expr = prefix_expr.expr()?;
+ Some((prefix_expr, inner_expr))
+ } else {
+ None
+ }
+}
+
+fn find_record_expr_usage(
+ name: &ast::NameLike,
+ got_field: hir::Field,
+ target_definition: Definition,
+) -> Option<(ast::RecordExprField, ast::Expr)> {
+ let name_ref = name.as_name_ref()?;
+ let record_field = ast::RecordExprField::for_field_name(name_ref)?;
+ let initializer = record_field.expr()?;
+
+ match target_definition {
+ Definition::Field(expected_field) if got_field == expected_field => {
+ Some((record_field, initializer))
+ }
+ _ => None,
+ }
+}
+
+fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
+ let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
+ let pat = record_pat_field.pat()?;
+
+ match pat {
+ ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
+ _ => None,
+ }
+}
+
+fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
+ let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
+ const_.syntax().parent().and_then(ast::AssocItemList::cast)?;
+
+ Some((const_.ty()?, const_.body()?))
+}
+
+fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
+ let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
+ let receiver = method_call.receiver()?;
+
+ if !receiver.syntax().descendants().contains(name.syntax()) {
+ return None;
+ }
+
+ Some(receiver)
+}
+
+/// Adds the definition of the new enum before the target node.
+fn add_enum_def(
+ edit: &mut SourceChangeBuilder,
+ ctx: &AssistContext<'_>,
+ usages: &UsageSearchResult,
+ target_node: SyntaxNode,
+ target_module: &hir::Module,
+) -> Option<()> {
+ let insert_before = node_to_insert_before(target_node);
+
+ if ctx
+ .sema
+ .scope(&insert_before)?
+ .module()
+ .scope(ctx.db(), Some(*target_module))
+ .iter()
+ .any(|(name, _)| name.as_str() == "Bool")
+ {
+ return None;
+ }
+
+ let make_enum_pub = usages
+ .iter()
+ .flat_map(|(_, refs)| refs)
+ .filter_map(|FileReference { name, .. }| {
+ let name = name.clone().into_name_like()?;
+ ctx.sema.scope(name.syntax()).map(|scope| scope.module())
+ })
+ .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
+ let enum_def = make_bool_enum(make_enum_pub);
+
+ let indent = IndentLevel::from_node(&insert_before);
+ enum_def.reindent_to(indent);
+
+ edit.insert(
+ insert_before.text_range().start(),
+ format!("{}\n\n{indent}", enum_def.syntax().text()),
+ );
+
+ Some(())
+}
+
+/// Finds where to put the new enum definition.
+/// Tries to find the ast node at the nearest module or at top-level, otherwise just
+/// returns the input node.
+fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
+ target_node
+ .ancestors()
+ .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
+ .filter(|it| ast::Item::can_cast(it.kind()))
+ .last()
+ .unwrap_or(target_node)
+}
+
+fn make_bool_enum(make_pub: bool) -> ast::Enum {
+ let enum_def = make::enum_(
+ if make_pub { Some(make::visibility_pub()) } else { None },
+ make::name("Bool"),
+ None,
+ None,
+ make::variant_list(vec![
+ make::variant(None, make::name("True"), None, None),
+ make::variant(None, make::name("False"), None, None),
+ ]),
+ )
+ .clone_for_update();
+
+ let derive_eq = make::attr_outer(make::meta_token_tree(
+ make::ext::ident_path("derive"),
+ make::token_tree(
+ T!['('],
+ vec![
+ NodeOrToken::Token(make::tokens::ident("PartialEq")),
+ NodeOrToken::Token(make::token(T![,])),
+ NodeOrToken::Token(make::tokens::single_space()),
+ NodeOrToken::Token(make::tokens::ident("Eq")),
+ ],
+ ),
+ ))
+ .clone_for_update();
+ enum_def.add_attr(derive_eq);
+
+ enum_def
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::tests::{check_assist, check_assist_not_applicable};
+
+ #[test]
+ fn parameter_with_first_param_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn function($0foo: bool, bar: bool) {
+ if foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: Bool, bar: bool) {
+ if foo == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn no_duplicate_enums() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, $0bar: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_last_param_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_middle_param_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool, baz: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool, baz: bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_closure_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let foo = |$0bar: bool| bar;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = |bar: Bool| bar == Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_with_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = true;
+
+ if foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::True;
+
+ if foo == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_with_usage_negated() {
+ cov_mark::check!(replaces_negation);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = true;
+
+ if !foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::True;
+
+ if foo == Bool::False {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_with_type_annotation() {
+ cov_mark::check!(replaces_ty_annotation);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo: bool = false;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo: Bool = Bool::False;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_with_non_literal_initializer() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = 1 == 2;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = if 1 == 2 { Bool::True } else { Bool::False };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_binexpr_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = false;
+ let bar = true;
+
+ if !foo && bar {
+ println!("foobar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::False;
+ let bar = true;
+
+ if foo == Bool::False && bar {
+ println!("foobar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_unop_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = true;
+
+ if *&foo {
+ println!("foobar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::True;
+
+ if *&foo == Bool::True {
+ println!("foobar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_assigned_later() {
+ cov_mark::check!(replaces_assignment);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo: bool;
+ foo = true;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo: Bool;
+ foo = Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_does_not_apply_recursively() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = true;
+ let bar = !foo;
+
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::True;
+ let bar = foo == Bool::False;
+
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_nested_in_negation() {
+ cov_mark::check!(dont_overwrite_expression_inside_negation);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ if !"foo".chars().any(|c| {
+ let $0foo = true;
+ foo
+ }) {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ if !"foo".chars().any(|c| {
+ let foo = Bool::True;
+ foo == Bool::True
+ }) {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_non_bool() {
+ cov_mark::check!(not_applicable_non_bool_local);
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let $0foo = 1;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_cursor_not_on_ident() {
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let foo = $0true;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_variable_non_ident_pat() {
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ let ($0foo, bar) = (true, false);
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_var_init_struct_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ foo: bool,
+}
+
+fn main() {
+ let $0foo = true;
+ let s = Foo { foo };
+}
+"#,
+ r#"
+struct Foo {
+ foo: bool,
+}
+
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = Bool::True;
+ let s = Foo { foo: foo == Bool::True };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn local_var_init_struct_usage_in_macro() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Struct {
+ boolean: bool,
+}
+
+macro_rules! identity {
+ ($body:expr) => {
+ $body
+ }
+}
+
+fn new() -> Struct {
+ let $0boolean = true;
+ identity![Struct { boolean }]
+}
+"#,
+ r#"
+struct Struct {
+ boolean: bool,
+}
+
+macro_rules! identity {
+ ($body:expr) => {
+ $body
+ }
+}
+
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn new() -> Struct {
+ let boolean = Bool::True;
+ identity![Struct { boolean: boolean == Bool::True }]
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_struct_basic() {
+ cov_mark::check!(replaces_record_expr);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0bar: bool,
+ baz: bool,
+}
+
+fn main() {
+ let foo = Foo { bar: true, baz: false };
+
+ if foo.bar {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ bar: Bool,
+ baz: bool,
+}
+
+fn main() {
+ let foo = Foo { bar: Bool::True, baz: false };
+
+ if foo.bar == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_basic() {
+ cov_mark::check!(replaces_record_pat);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: baz } = foo {
+ if baz {
+ println!("foo");
+ }
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: baz } = foo {
+ if baz == Bool::True {
+ println!("foo");
+ }
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_cross_file() {
+ // FIXME: The import is missing
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+//- /foo.rs
+pub enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn foo() {
+ let foo = Foo::Bar { bar: true };
+}
+
+//- /main.rs
+use foo::Foo;
+
+mod foo;
+
+fn main() {
+ let foo = Foo::Bar { bar: false };
+}
+"#,
+ r#"
+//- /foo.rs
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+pub enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn foo() {
+ let foo = Foo::Bar { bar: Bool::True };
+}
+
+//- /main.rs
+use foo::Foo;
+
+mod foo;
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::False };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_shorthand() {
+ cov_mark::check!(replaces_record_pat_shorthand);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ match foo {
+ Foo::Bar { bar } => {
+ if bar {
+ println!("foo");
+ }
+ }
+ _ => (),
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ match foo {
+ Foo::Bar { bar } => {
+ if bar == Bool::True {
+ println!("foo");
+ }
+ }
+ _ => (),
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_replaces_literal_patterns() {
+ cov_mark::check!(replaces_literal_pat);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: true } = foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: Bool::True } = foo {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_keeps_wildcard_patterns() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: _ } = foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: _ } = foo {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_union_basic() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+union Foo {
+ $0foo: bool,
+ bar: usize,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+
+ if unsafe { foo.foo } {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+union Foo {
+ foo: Bool,
+ bar: usize,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+
+ if unsafe { foo.foo == Bool::True } {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_negated() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0bar: bool,
+}
+
+fn main() {
+ let foo = Foo { bar: false };
+
+ if !foo.bar {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ bar: Bool,
+}
+
+fn main() {
+ let foo = Foo { bar: Bool::False };
+
+ if foo.bar == Bool::False {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_in_mod_properly_indented() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+mod foo {
+ struct Bar {
+ $0baz: bool,
+ }
+
+ impl Bar {
+ fn new(baz: bool) -> Self {
+ Self { baz }
+ }
+ }
+}
+"#,
+ r#"
+mod foo {
+ #[derive(PartialEq, Eq)]
+ enum Bool { True, False }
+
+ struct Bar {
+ baz: Bool,
+ }
+
+ impl Bar {
+ fn new(baz: bool) -> Self {
+ Self { baz: if baz { Bool::True } else { Bool::False } }
+ }
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_multiple_initializations() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0bar: bool,
+ baz: bool,
+}
+
+fn main() {
+ let foo1 = Foo { bar: true, baz: false };
+ let foo2 = Foo { bar: false, baz: false };
+
+ if foo1.bar && foo2.bar {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ bar: Bool,
+ baz: bool,
+}
+
+fn main() {
+ let foo1 = Foo { bar: Bool::True, baz: false };
+ let foo2 = Foo { bar: Bool::False, baz: false };
+
+ if foo1.bar == Bool::True && foo2.bar == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_assigned_to_another() {
+ cov_mark::check!(dont_assign_incorrect_ref);
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0foo: bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+ let mut bar = Bar { bar: true };
+
+ bar.bar = foo.foo;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ foo: Bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+ let mut bar = Bar { bar: true };
+
+ bar.bar = foo.foo == Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_initialized_with_other() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0foo: bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+ let bar = Bar { bar: foo.foo };
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ foo: Bool,
+}
+
+struct Bar {
+ bar: bool,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+ let bar = Bar { bar: foo.foo == Bool::True };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_method_chain_usage() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0bool: bool,
+}
+
+fn main() {
+ let foo = Foo { bool: true };
+
+ foo.bool.then(|| 2);
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Foo {
+ bool: Bool,
+}
+
+fn main() {
+ let foo = Foo { bool: Bool::True };
+
+ (foo.bool == Bool::True).then(|| 2);
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_in_macro() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+struct Struct {
+ $0boolean: bool,
+}
+
+fn boolean(x: Struct) {
+ let Struct { boolean } = x;
+}
+
+macro_rules! identity { ($body:expr) => { $body } }
+
+fn new() -> Struct {
+ identity!(Struct { boolean: true })
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+struct Struct {
+ boolean: Bool,
+}
+
+fn boolean(x: Struct) {
+ let Struct { boolean } = x;
+}
+
+macro_rules! identity { ($body:expr) => { $body } }
+
+fn new() -> Struct {
+ identity!(Struct { boolean: Bool::True })
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_non_bool() {
+ cov_mark::check!(not_applicable_non_bool_field);
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+struct Foo {
+ $0bar: usize,
+}
+
+fn main() {
+ let foo = Foo { bar: 1 };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_basic() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+const $0FOO: bool = false;
+
+fn main() {
+ if FOO {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+const FOO: Bool = Bool::False;
+
+fn main() {
+ if FOO == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_in_module() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ if foo::FOO {
+ println!("foo");
+ }
+}
+
+mod foo {
+ pub const $0FOO: bool = true;
+}
+"#,
+ r#"
+use foo::Bool;
+
+fn main() {
+ if foo::FOO == Bool::True {
+ println!("foo");
+ }
+}
+
+mod foo {
+ #[derive(PartialEq, Eq)]
+ pub enum Bool { True, False }
+
+ pub const FOO: Bool = Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_in_module_with_import() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+fn main() {
+ use foo::FOO;
+
+ if FOO {
+ println!("foo");
+ }
+}
+
+mod foo {
+ pub const $0FOO: bool = true;
+}
+"#,
+ r#"
+use foo::Bool;
+
+fn main() {
+ use foo::FOO;
+
+ if FOO == Bool::True {
+ println!("foo");
+ }
+}
+
+mod foo {
+ #[derive(PartialEq, Eq)]
+ pub enum Bool { True, False }
+
+ pub const FOO: Bool = Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_cross_file() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+//- /main.rs
+mod foo;
+
+fn main() {
+ if foo::FOO {
+ println!("foo");
+ }
+}
+
+//- /foo.rs
+pub const $0FOO: bool = true;
+"#,
+ r#"
+//- /main.rs
+use foo::Bool;
+
+mod foo;
+
+fn main() {
+ if foo::FOO == Bool::True {
+ println!("foo");
+ }
+}
+
+//- /foo.rs
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+pub const FOO: Bool = Bool::True;
+"#,
+ )
+ }
+
+ #[test]
+ fn const_cross_file_and_module() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+//- /main.rs
+mod foo;
+
+fn main() {
+ use foo::bar;
+
+ if bar::BAR {
+ println!("foo");
+ }
+}
+
+//- /foo.rs
+pub mod bar {
+ pub const $0BAR: bool = false;
+}
+"#,
+ r#"
+//- /main.rs
+use foo::bar::Bool;
+
+mod foo;
+
+fn main() {
+ use foo::bar;
+
+ if bar::BAR == Bool::True {
+ println!("foo");
+ }
+}
+
+//- /foo.rs
+pub mod bar {
+ #[derive(PartialEq, Eq)]
+ pub enum Bool { True, False }
+
+ pub const BAR: Bool = Bool::False;
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_in_impl_cross_file() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+impl Foo {
+ pub const $0BOOL: bool = true;
+}
+
+//- /foo.rs
+use crate::Foo;
+
+fn foo() -> bool {
+ Foo::BOOL
+}
+"#,
+ r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+impl Foo {
+ pub const BOOL: Bool = Bool::True;
+}
+
+//- /foo.rs
+use crate::{Bool, Foo};
+
+fn foo() -> bool {
+ Foo::BOOL == Bool::True
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_in_trait() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+trait Foo {
+ const $0BOOL: bool;
+}
+
+impl Foo for usize {
+ const BOOL: bool = true;
+}
+
+fn main() {
+ if <usize as Foo>::BOOL {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+trait Foo {
+ const BOOL: Bool;
+}
+
+impl Foo for usize {
+ const BOOL: Bool = Bool::True;
+}
+
+fn main() {
+ if <usize as Foo>::BOOL == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_non_bool() {
+ cov_mark::check!(not_applicable_non_bool_const);
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+const $0FOO: &str = "foo";
+
+fn main() {
+ println!("{FOO}");
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn static_basic() {
+ check_assist(
+ convert_bool_to_enum,
+ r#"
+static mut $0BOOL: bool = true;
+
+fn main() {
+ unsafe { BOOL = false };
+ if unsafe { BOOL } {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+static mut BOOL: Bool = Bool::True;
+
+fn main() {
+ unsafe { BOOL = Bool::False };
+ if unsafe { BOOL == Bool::True } {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn static_non_bool() {
+ cov_mark::check!(not_applicable_non_bool_static);
+ check_assist_not_applicable(
+ convert_bool_to_enum,
+ r#"
+static mut $0FOO: usize = 0;
+
+fn main() {
+ if unsafe { FOO } == 0 {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn not_applicable_to_other_names() {
+ check_assist_not_applicable(convert_bool_to_enum, "fn $0main() {}")
+ }
+}