Unnamed repository; edit this file 'description' to name the repository.
feat: add support for other ADT types and destructuring patterns
| -rw-r--r-- | crates/ide-assists/src/handlers/bool_to_enum.rs | 494 |
1 files changed, 430 insertions, 64 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 784a0d3559..b9dbd6e98f 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -6,6 +6,7 @@ use ide_db::{ imports::insert_use::{insert_use, ImportScope}, search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, + FxHashSet, }; use itertools::Itertools; use syntax::{ @@ -17,6 +18,7 @@ use syntax::{ }, ted, AstNode, NodeOrToken, SyntaxNode, T, }; +use text_edit::TextRange; use crate::assist_context::{AssistContext, Assists}; @@ -52,7 +54,7 @@ use crate::assist_context::{AssistContext, Assists}; pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = find_bool_node(ctx)?; - let target_module = ctx.sema.scope(&target_node)?.module(); + let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db()); let target = name.syntax().text_range(); acc.add( @@ -70,9 +72,8 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option } let usages = definition.usages(&ctx.sema).all(); - add_enum_def(edit, ctx, &usages, target_node, &target_module); - replace_usages(edit, ctx, &usages, &target_module); + replace_usages(edit, ctx, &usages, definition, &target_module); }, ) } @@ -144,14 +145,14 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> { return None; } - let strukt = field.syntax().ancestors().find_map(ast::Struct::cast)?; + let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?; let def = ctx.sema.to_def(&field)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_field); return None; } Some(BoolNodeData { - target_node: strukt.syntax().clone(), + target_node: adt.syntax().clone(), name, ty_annotation: field.ty(), initializer: None, @@ -192,78 +193,171 @@ fn replace_usages( edit: &mut SourceChangeBuilder, ctx: &AssistContext<'_>, usages: &UsageSearchResult, + target_definition: Definition, target_module: &hir::Module, ) { for (file_id, references) in usages.iter() { edit.edit_file(*file_id); - // add imports across modules where needed - references - .iter() - .filter_map(|FileReference { name, .. }| { - ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) - }) - .unique_by(|name_and_module| name_and_module.1) - .filter(|(_, module)| module != target_module) - .filter_map(|(name, module)| { - let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); - let mod_path = module.find_use_path_prefixed( - ctx.sema.db, - ModuleDef::Module(*target_module), - ctx.config.insert_use.prefix_kind, - ctx.config.prefer_no_std, - ); - import_scope.zip(mod_path) - }) - .for_each(|(import_scope, mod_path)| { - let import_scope = match import_scope { - ImportScope::File(it) => ImportScope::File(edit.make_mut(it)), - ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)), - ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)), - }; - let path = - make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")); - insert_use(&import_scope, path, &ctx.config.insert_use); - }); - - // replace the usages in expressions - references - .into_iter() - .filter_map(|FileReference { range, name, .. }| match name { - ast::NameLike::NameRef(name) => Some((*range, name)), - _ => None, - }) - .rev() - .for_each(|(range, name_ref)| { - if let Some(initializer) = find_assignment_usage(name_ref) { + let refs_with_imports = + augment_references_with_imports(edit, ctx, references, target_module); + + refs_with_imports.into_iter().rev().for_each( + |FileReferenceWithImport { range, old_name, new_name, import_data }| { + // replace the usages in patterns and expressions + if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast) + { + cov_mark::hit!(replaces_record_pat_shorthand); + + let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local); + if let Some(def) = definition { + replace_usages( + edit, + ctx, + &def.usages(&ctx.sema).all(), + target_definition, + target_module, + ) + } + } else if let Some(initializer) = find_assignment_usage(&new_name) { cov_mark::hit!(replaces_assignment); replace_bool_expr(edit, initializer); - } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(name_ref) { + } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) { cov_mark::hit!(replaces_negation); edit.replace( prefix_expr.syntax().text_range(), format!("{} == Bool::False", inner_expr), ); - } else if let Some((record_field, initializer)) = find_record_expr_usage(name_ref) { + } else if let Some((record_field, initializer)) = old_name + .as_name_ref() + .and_then(ast::RecordExprField::for_field_name) + .and_then(|record_field| ctx.sema.resolve_record_field(&record_field)) + .and_then(|(got_field, _, _)| { + find_record_expr_usage(&new_name, got_field, target_definition) + }) + { cov_mark::hit!(replaces_record_expr); let record_field = edit.make_mut(record_field); let enum_expr = bool_expr_to_enum_expr(initializer); record_field.replace_expr(enum_expr); - } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { + } else if let Some(pat) = find_record_pat_field_usage(&old_name) { + match pat { + ast::Pat::IdentPat(ident_pat) => { + cov_mark::hit!(replaces_record_pat); + + let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local); + if let Some(def) = definition { + replace_usages( + edit, + ctx, + &def.usages(&ctx.sema).all(), + target_definition, + target_module, + ) + } + } + ast::Pat::LiteralPat(literal_pat) => { + cov_mark::hit!(replaces_literal_pat); + + if let Some(expr) = literal_pat.literal().and_then(|literal| { + literal.syntax().ancestors().find_map(ast::Expr::cast) + }) { + replace_bool_expr(edit, expr); + } + } + _ => (), + } + } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant - edit.replace(range, format!("{} == Bool::True", name_ref.text())); + if let Some((record_field, expr)) = new_name + .as_name_ref() + .and_then(ast::RecordExprField::for_field_name) + .and_then(|record_field| { + record_field.expr().map(|expr| (record_field, expr)) + }) + { + record_field.replace_expr( + make::expr_bin_op( + expr, + ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }), + make::expr_path(make::path_from_text("Bool::True")), + ) + .clone_for_update(), + ); + } else { + edit.replace(range, format!("{} == Bool::True", new_name.text())); + } + } + + // add imports across modules where needed + if let Some((import_scope, path)) = import_data { + insert_use(&import_scope, path, &ctx.config.insert_use); } - }) + }, + ) } } -fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> { - let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; +struct FileReferenceWithImport { + range: TextRange, + old_name: ast::NameLike, + new_name: ast::NameLike, + import_data: Option<(ImportScope, ast::Path)>, +} - if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) { +fn augment_references_with_imports( + edit: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + references: &[FileReference], + target_module: &hir::Module, +) -> Vec<FileReferenceWithImport> { + let mut visited_modules = FxHashSet::default(); + + references + .iter() + .filter_map(|FileReference { range, name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module())) + }) + .map(|(range, name, ref_module)| { + let old_name = name.clone(); + let new_name = edit.make_mut(name.clone()); + + // if the referenced module is not the same as the target one and has not been seen before, add an import + let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module + && !visited_modules.contains(&ref_module) + { + visited_modules.insert(ref_module); + + let import_scope = + ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema); + let path = ref_module + .find_use_path_prefixed( + ctx.sema.db, + ModuleDef::Module(*target_module), + ctx.config.insert_use.prefix_kind, + ctx.config.prefer_no_std, + ) + .map(|mod_path| { + make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")) + }); + + import_scope.zip(path) + } else { + None + }; + + FileReferenceWithImport { range, old_name, new_name, import_data } + }) + .collect() +} + +fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> { + let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?; + + if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) { cov_mark::hit!(dont_assign_incorrect_ref); return None; } @@ -275,8 +369,8 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> { } } -fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { - let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; +fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> { + let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) { cov_mark::hit!(dont_overwrite_expression_inside_negation); @@ -291,15 +385,31 @@ fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast:: } } -fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprField, ast::Expr)> { - let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; +fn find_record_expr_usage( + name: &ast::NameLike, + got_field: hir::Field, + target_definition: Definition, +) -> Option<(ast::RecordExprField, ast::Expr)> { + let name_ref = name.as_name_ref()?; + let record_field = ast::RecordExprField::for_field_name(name_ref)?; let initializer = record_field.expr()?; - if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) { - Some((record_field, initializer)) - } else { - cov_mark::hit!(dont_overwrite_wrong_record_field); - None + if let Definition::Field(expected_field) = target_definition { + if got_field != expected_field { + return None; + } + } + + Some((record_field, initializer)) +} + +fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> { + let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?; + let pat = record_pat_field.pat()?; + + match pat { + ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat), + _ => None, } } @@ -317,7 +427,7 @@ fn add_enum_def( .filter_map(|FileReference { name, .. }| { ctx.sema.scope(name.syntax()).map(|scope| scope.module()) }) - .any(|module| &module != target_module); + .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); let indent = IndentLevel::from_node(&target_node); @@ -646,7 +756,7 @@ fn main() { } #[test] - fn field_basic() { + fn field_struct_basic() { cov_mark::check!(replaces_record_expr); check_assist( bool_to_enum, @@ -685,6 +795,263 @@ fn main() { } #[test] + fn field_enum_basic() { + cov_mark::check!(replaces_record_pat); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: baz } = foo { + if baz { + println!("foo"); + } + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: baz } = foo { + if baz == Bool::True { + println!("foo"); + } + } +} +"#, + ) + } + + #[test] + fn field_enum_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /foo.rs +pub enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn foo() { + let foo = Foo::Bar { bar: true }; +} + +//- /main.rs +use foo::Foo; + +mod foo; + +fn main() { + let foo = Foo::Bar { bar: false }; +} +"#, + r#" +//- /foo.rs +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +pub enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn foo() { + let foo = Foo::Bar { bar: Bool::True }; +} + +//- /main.rs +use foo::{Foo, Bool}; + +mod foo; + +fn main() { + let foo = Foo::Bar { bar: Bool::False }; +} +"#, + ) + } + + #[test] + fn field_enum_shorthand() { + cov_mark::check!(replaces_record_pat_shorthand); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + match foo { + Foo::Bar { bar } => { + if bar { + println!("foo"); + } + } + _ => (), + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + match foo { + Foo::Bar { bar } => { + if bar == Bool::True { + println!("foo"); + } + } + _ => (), + } +} +"#, + ) + } + + #[test] + fn field_enum_replaces_literal_patterns() { + cov_mark::check!(replaces_literal_pat); + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: true } = foo { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: Bool::True } = foo { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn field_enum_keeps_wildcard_patterns() { + check_assist( + bool_to_enum, + r#" +enum Foo { + Foo, + Bar { $0bar: bool }, +} + +fn main() { + let foo = Foo::Bar { bar: true }; + + if let Foo::Bar { bar: _ } = foo { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +enum Foo { + Foo, + Bar { bar: Bool }, +} + +fn main() { + let foo = Foo::Bar { bar: Bool::True }; + + if let Foo::Bar { bar: _ } = foo { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn field_union_basic() { + check_assist( + bool_to_enum, + r#" +union Foo { + $0foo: bool, + bar: usize, +} + +fn main() { + let foo = Foo { foo: true }; + + if unsafe { foo.foo } { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +union Foo { + foo: Bool, + bar: usize, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + + if unsafe { foo.foo == Bool::True } { + println!("foo"); + } +} +"#, + ) + } + + #[test] fn field_negated() { check_assist( bool_to_enum, @@ -841,7 +1208,6 @@ fn main() { #[test] fn field_initialized_with_other() { - cov_mark::check!(dont_overwrite_wrong_record_field); check_assist( bool_to_enum, r#" |