Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/bool_to_enum.rs')
| -rw-r--r-- | crates/ide-assists/src/handlers/bool_to_enum.rs | 190 |
1 files changed, 167 insertions, 23 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index c95e24693d..ab25e0167b 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -1,3 +1,4 @@ +use either::Either; use hir::{ImportPathConfig, ModuleDef}; use ide_db::{ assists::{AssistId, AssistKind}, @@ -76,7 +77,11 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let usages = definition.usages(&ctx.sema).all(); add_enum_def(edit, ctx, &usages, target_node, &target_module); - replace_usages(edit, ctx, usages, definition, &target_module); + let mut delayed_mutations = Vec::new(); + replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations); + for (scope, path) in delayed_mutations { + insert_use(&scope, path, &ctx.config.insert_use); + } }, ) } @@ -91,29 +96,32 @@ struct BoolNodeData { /// Attempts to find an appropriate node to apply the action to. fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> { - let name: ast::Name = ctx.find_node_at_offset()?; - - if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) { - let bind_pat = match let_stmt.pat()? { - ast::Pat::IdentPat(pat) => pat, - _ => { - cov_mark::hit!(not_applicable_in_non_ident_pat); - return None; - } - }; - let def = ctx.sema.to_def(&bind_pat)?; + let name = ctx.find_node_at_offset::<ast::Name>()?; + + if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) { + let def = ctx.sema.to_def(&ident_pat)?; if !def.ty(ctx.db()).is_bool() { cov_mark::hit!(not_applicable_non_bool_local); return None; } - Some(BoolNodeData { - target_node: let_stmt.syntax().clone(), - name, - ty_annotation: let_stmt.ty(), - initializer: let_stmt.initializer(), - definition: Definition::Local(def), - }) + let local_definition = Definition::Local(def); + match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? { + Either::Left(param) => Some(BoolNodeData { + target_node: param.syntax().clone(), + name, + ty_annotation: param.ty(), + initializer: None, + definition: local_definition, + }), + Either::Right(let_stmt) => Some(BoolNodeData { + target_node: let_stmt.syntax().clone(), + name, + ty_annotation: let_stmt.ty(), + initializer: let_stmt.initializer(), + definition: local_definition, + }), + } } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) { let def = ctx.sema.to_def(&const_)?; if !def.ty(ctx.db()).is_bool() { @@ -197,6 +205,7 @@ fn replace_usages( usages: UsageSearchResult, target_definition: Definition, target_module: &hir::Module, + delayed_mutations: &mut Vec<(ImportScope, ast::Path)>, ) { for (file_id, references) in usages { edit.edit_file(file_id); @@ -217,6 +226,7 @@ fn replace_usages( def.usages(&ctx.sema).all(), target_definition, target_module, + delayed_mutations, ) } } else if let Some(initializer) = find_assignment_usage(&name) { @@ -255,6 +265,7 @@ fn replace_usages( def.usages(&ctx.sema).all(), target_definition, target_module, + delayed_mutations, ) } } @@ -306,7 +317,7 @@ fn replace_usages( ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)), ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)), }; - insert_use(&scope, path, &ctx.config.insert_use); + delayed_mutations.push((scope, path)); } }, ) @@ -329,6 +340,7 @@ fn augment_references_with_imports( let cfg = ImportPathConfig { prefer_no_std: ctx.config.prefer_no_std, prefer_prelude: ctx.config.prefer_prelude, + prefer_absolute: ctx.config.prefer_absolute, }; references @@ -449,7 +461,20 @@ fn add_enum_def( usages: &UsageSearchResult, target_node: SyntaxNode, target_module: &hir::Module, -) { +) -> Option<()> { + let insert_before = node_to_insert_before(target_node); + + if ctx + .sema + .scope(&insert_before)? + .module() + .scope(ctx.db(), Some(*target_module)) + .iter() + .any(|(name, _)| name.as_str() == Some("Bool")) + { + return None; + } + let make_enum_pub = usages .iter() .flat_map(|(_, refs)| refs) @@ -460,7 +485,6 @@ fn add_enum_def( .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); - let insert_before = node_to_insert_before(target_node); let indent = IndentLevel::from_node(&insert_before); enum_def.reindent_to(indent); @@ -468,6 +492,8 @@ fn add_enum_def( insert_before.text_range().start(), format!("{}\n\n{indent}", enum_def.syntax().text()), ); + + Some(()) } /// Finds where to put the new enum definition. @@ -518,6 +544,125 @@ mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; #[test] + fn parameter_with_first_param_usage() { + check_assist( + bool_to_enum, + r#" +fn function($0foo: bool, bar: bool) { + if foo { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn function(foo: Bool, bar: bool) { + if foo == Bool::True { + println!("foo"); + } +} +"#, + ) + } + + #[test] + fn no_duplicate_enums() { + check_assist( + bool_to_enum, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn function(foo: bool, $0bar: bool) { + if bar { + println!("bar"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn function(foo: bool, bar: Bool) { + if bar == Bool::True { + println!("bar"); + } +} +"#, + ) + } + + #[test] + fn parameter_with_last_param_usage() { + check_assist( + bool_to_enum, + r#" +fn function(foo: bool, $0bar: bool) { + if bar { + println!("bar"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn function(foo: bool, bar: Bool) { + if bar == Bool::True { + println!("bar"); + } +} +"#, + ) + } + + #[test] + fn parameter_with_middle_param_usage() { + check_assist( + bool_to_enum, + r#" +fn function(foo: bool, $0bar: bool, baz: bool) { + if bar { + println!("bar"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn function(foo: bool, bar: Bool, baz: bool) { + if bar == Bool::True { + println!("bar"); + } +} +"#, + ) + } + + #[test] + fn parameter_with_closure_usage() { + check_assist( + bool_to_enum, + r#" +fn main() { + let foo = |$0bar: bool| bar; +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +fn main() { + let foo = |bar: Bool| bar == Bool::True; +} +"#, + ) + } + + #[test] fn local_variable_with_usage() { check_assist( bool_to_enum, @@ -784,7 +929,6 @@ fn main() { #[test] fn local_variable_non_ident_pat() { - cov_mark::check!(not_applicable_in_non_ident_pat); check_assist_not_applicable( bool_to_enum, r#" |