Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/bool_to_enum.rs')
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs190
1 files changed, 167 insertions, 23 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index c95e24693d..ab25e0167b 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -1,3 +1,4 @@
+use either::Either;
use hir::{ImportPathConfig, ModuleDef};
use ide_db::{
assists::{AssistId, AssistKind},
@@ -76,7 +77,11 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
let usages = definition.usages(&ctx.sema).all();
add_enum_def(edit, ctx, &usages, target_node, &target_module);
- replace_usages(edit, ctx, usages, definition, &target_module);
+ let mut delayed_mutations = Vec::new();
+ replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations);
+ for (scope, path) in delayed_mutations {
+ insert_use(&scope, path, &ctx.config.insert_use);
+ }
},
)
}
@@ -91,29 +96,32 @@ struct BoolNodeData {
/// Attempts to find an appropriate node to apply the action to.
fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
- let name: ast::Name = ctx.find_node_at_offset()?;
-
- if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) {
- let bind_pat = match let_stmt.pat()? {
- ast::Pat::IdentPat(pat) => pat,
- _ => {
- cov_mark::hit!(not_applicable_in_non_ident_pat);
- return None;
- }
- };
- let def = ctx.sema.to_def(&bind_pat)?;
+ let name = ctx.find_node_at_offset::<ast::Name>()?;
+
+ if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
+ let def = ctx.sema.to_def(&ident_pat)?;
if !def.ty(ctx.db()).is_bool() {
cov_mark::hit!(not_applicable_non_bool_local);
return None;
}
- Some(BoolNodeData {
- target_node: let_stmt.syntax().clone(),
- name,
- ty_annotation: let_stmt.ty(),
- initializer: let_stmt.initializer(),
- definition: Definition::Local(def),
- })
+ let local_definition = Definition::Local(def);
+ match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
+ Either::Left(param) => Some(BoolNodeData {
+ target_node: param.syntax().clone(),
+ name,
+ ty_annotation: param.ty(),
+ initializer: None,
+ definition: local_definition,
+ }),
+ Either::Right(let_stmt) => Some(BoolNodeData {
+ target_node: let_stmt.syntax().clone(),
+ name,
+ ty_annotation: let_stmt.ty(),
+ initializer: let_stmt.initializer(),
+ definition: local_definition,
+ }),
+ }
} else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
let def = ctx.sema.to_def(&const_)?;
if !def.ty(ctx.db()).is_bool() {
@@ -197,6 +205,7 @@ fn replace_usages(
usages: UsageSearchResult,
target_definition: Definition,
target_module: &hir::Module,
+ delayed_mutations: &mut Vec<(ImportScope, ast::Path)>,
) {
for (file_id, references) in usages {
edit.edit_file(file_id);
@@ -217,6 +226,7 @@ fn replace_usages(
def.usages(&ctx.sema).all(),
target_definition,
target_module,
+ delayed_mutations,
)
}
} else if let Some(initializer) = find_assignment_usage(&name) {
@@ -255,6 +265,7 @@ fn replace_usages(
def.usages(&ctx.sema).all(),
target_definition,
target_module,
+ delayed_mutations,
)
}
}
@@ -306,7 +317,7 @@ fn replace_usages(
ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
};
- insert_use(&scope, path, &ctx.config.insert_use);
+ delayed_mutations.push((scope, path));
}
},
)
@@ -329,6 +340,7 @@ fn augment_references_with_imports(
let cfg = ImportPathConfig {
prefer_no_std: ctx.config.prefer_no_std,
prefer_prelude: ctx.config.prefer_prelude,
+ prefer_absolute: ctx.config.prefer_absolute,
};
references
@@ -449,7 +461,20 @@ fn add_enum_def(
usages: &UsageSearchResult,
target_node: SyntaxNode,
target_module: &hir::Module,
-) {
+) -> Option<()> {
+ let insert_before = node_to_insert_before(target_node);
+
+ if ctx
+ .sema
+ .scope(&insert_before)?
+ .module()
+ .scope(ctx.db(), Some(*target_module))
+ .iter()
+ .any(|(name, _)| name.as_str() == Some("Bool"))
+ {
+ return None;
+ }
+
let make_enum_pub = usages
.iter()
.flat_map(|(_, refs)| refs)
@@ -460,7 +485,6 @@ fn add_enum_def(
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub);
- let insert_before = node_to_insert_before(target_node);
let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent);
@@ -468,6 +492,8 @@ fn add_enum_def(
insert_before.text_range().start(),
format!("{}\n\n{indent}", enum_def.syntax().text()),
);
+
+ Some(())
}
/// Finds where to put the new enum definition.
@@ -518,6 +544,125 @@ mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
#[test]
+ fn parameter_with_first_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function($0foo: bool, bar: bool) {
+ if foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: Bool, bar: bool) {
+ if foo == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn no_duplicate_enums() {
+ check_assist(
+ bool_to_enum,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, $0bar: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_last_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_middle_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool, baz: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool, baz: bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_closure_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn main() {
+ let foo = |$0bar: bool| bar;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = |bar: Bool| bar == Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
fn local_variable_with_usage() {
check_assist(
bool_to_enum,
@@ -784,7 +929,6 @@ fn main() {
#[test]
fn local_variable_non_ident_pat() {
- cov_mark::check!(not_applicable_in_non_ident_pat);
check_assist_not_applicable(
bool_to_enum,
r#"