Unnamed repository; edit this file 'description' to name the repository.
feat: add support for other ADT types and destructuring patterns
Ryan Mehri 2023-09-11
parent 7ba2e13 · commit 25b1b3e
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs494
1 files changed, 430 insertions, 64 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 784a0d3559..b9dbd6e98f 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -6,6 +6,7 @@ use ide_db::{
imports::insert_use::{insert_use, ImportScope},
search::{FileReference, UsageSearchResult},
source_change::SourceChangeBuilder,
+ FxHashSet,
};
use itertools::Itertools;
use syntax::{
@@ -17,6 +18,7 @@ use syntax::{
},
ted, AstNode, NodeOrToken, SyntaxNode, T,
};
+use text_edit::TextRange;
use crate::assist_context::{AssistContext, Assists};
@@ -52,7 +54,7 @@ use crate::assist_context::{AssistContext, Assists};
pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
find_bool_node(ctx)?;
- let target_module = ctx.sema.scope(&target_node)?.module();
+ let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
let target = name.syntax().text_range();
acc.add(
@@ -70,9 +72,8 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
}
let usages = definition.usages(&ctx.sema).all();
-
add_enum_def(edit, ctx, &usages, target_node, &target_module);
- replace_usages(edit, ctx, &usages, &target_module);
+ replace_usages(edit, ctx, &usages, definition, &target_module);
},
)
}
@@ -144,14 +145,14 @@ fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
return None;
}
- let strukt = field.syntax().ancestors().find_map(ast::Struct::cast)?;
+ let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
let def = ctx.sema.to_def(&field)?;
if !def.ty(ctx.db()).is_bool() {
cov_mark::hit!(not_applicable_non_bool_field);
return None;
}
Some(BoolNodeData {
- target_node: strukt.syntax().clone(),
+ target_node: adt.syntax().clone(),
name,
ty_annotation: field.ty(),
initializer: None,
@@ -192,78 +193,171 @@ fn replace_usages(
edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>,
usages: &UsageSearchResult,
+ target_definition: Definition,
target_module: &hir::Module,
) {
for (file_id, references) in usages.iter() {
edit.edit_file(*file_id);
- // add imports across modules where needed
- references
- .iter()
- .filter_map(|FileReference { name, .. }| {
- ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module()))
- })
- .unique_by(|name_and_module| name_and_module.1)
- .filter(|(_, module)| module != target_module)
- .filter_map(|(name, module)| {
- let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
- let mod_path = module.find_use_path_prefixed(
- ctx.sema.db,
- ModuleDef::Module(*target_module),
- ctx.config.insert_use.prefix_kind,
- ctx.config.prefer_no_std,
- );
- import_scope.zip(mod_path)
- })
- .for_each(|(import_scope, mod_path)| {
- let import_scope = match import_scope {
- ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
- ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
- ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
- };
- let path =
- make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool"));
- insert_use(&import_scope, path, &ctx.config.insert_use);
- });
-
- // replace the usages in expressions
- references
- .into_iter()
- .filter_map(|FileReference { range, name, .. }| match name {
- ast::NameLike::NameRef(name) => Some((*range, name)),
- _ => None,
- })
- .rev()
- .for_each(|(range, name_ref)| {
- if let Some(initializer) = find_assignment_usage(name_ref) {
+ let refs_with_imports =
+ augment_references_with_imports(edit, ctx, references, target_module);
+
+ refs_with_imports.into_iter().rev().for_each(
+ |FileReferenceWithImport { range, old_name, new_name, import_data }| {
+ // replace the usages in patterns and expressions
+ if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast)
+ {
+ cov_mark::hit!(replaces_record_pat_shorthand);
+
+ let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
+ if let Some(def) = definition {
+ replace_usages(
+ edit,
+ ctx,
+ &def.usages(&ctx.sema).all(),
+ target_definition,
+ target_module,
+ )
+ }
+ } else if let Some(initializer) = find_assignment_usage(&new_name) {
cov_mark::hit!(replaces_assignment);
replace_bool_expr(edit, initializer);
- } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(name_ref) {
+ } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) {
cov_mark::hit!(replaces_negation);
edit.replace(
prefix_expr.syntax().text_range(),
format!("{} == Bool::False", inner_expr),
);
- } else if let Some((record_field, initializer)) = find_record_expr_usage(name_ref) {
+ } else if let Some((record_field, initializer)) = old_name
+ .as_name_ref()
+ .and_then(ast::RecordExprField::for_field_name)
+ .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
+ .and_then(|(got_field, _, _)| {
+ find_record_expr_usage(&new_name, got_field, target_definition)
+ })
+ {
cov_mark::hit!(replaces_record_expr);
let record_field = edit.make_mut(record_field);
let enum_expr = bool_expr_to_enum_expr(initializer);
record_field.replace_expr(enum_expr);
- } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
+ } else if let Some(pat) = find_record_pat_field_usage(&old_name) {
+ match pat {
+ ast::Pat::IdentPat(ident_pat) => {
+ cov_mark::hit!(replaces_record_pat);
+
+ let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
+ if let Some(def) = definition {
+ replace_usages(
+ edit,
+ ctx,
+ &def.usages(&ctx.sema).all(),
+ target_definition,
+ target_module,
+ )
+ }
+ }
+ ast::Pat::LiteralPat(literal_pat) => {
+ cov_mark::hit!(replaces_literal_pat);
+
+ if let Some(expr) = literal_pat.literal().and_then(|literal| {
+ literal.syntax().ancestors().find_map(ast::Expr::cast)
+ }) {
+ replace_bool_expr(edit, expr);
+ }
+ }
+ _ => (),
+ }
+ } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant
- edit.replace(range, format!("{} == Bool::True", name_ref.text()));
+ if let Some((record_field, expr)) = new_name
+ .as_name_ref()
+ .and_then(ast::RecordExprField::for_field_name)
+ .and_then(|record_field| {
+ record_field.expr().map(|expr| (record_field, expr))
+ })
+ {
+ record_field.replace_expr(
+ make::expr_bin_op(
+ expr,
+ ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
+ make::expr_path(make::path_from_text("Bool::True")),
+ )
+ .clone_for_update(),
+ );
+ } else {
+ edit.replace(range, format!("{} == Bool::True", new_name.text()));
+ }
+ }
+
+ // add imports across modules where needed
+ if let Some((import_scope, path)) = import_data {
+ insert_use(&import_scope, path, &ctx.config.insert_use);
}
- })
+ },
+ )
}
}
-fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
- let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?;
+struct FileReferenceWithImport {
+ range: TextRange,
+ old_name: ast::NameLike,
+ new_name: ast::NameLike,
+ import_data: Option<(ImportScope, ast::Path)>,
+}
- if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) {
+fn augment_references_with_imports(
+ edit: &mut SourceChangeBuilder,
+ ctx: &AssistContext<'_>,
+ references: &[FileReference],
+ target_module: &hir::Module,
+) -> Vec<FileReferenceWithImport> {
+ let mut visited_modules = FxHashSet::default();
+
+ references
+ .iter()
+ .filter_map(|FileReference { range, name, .. }| {
+ ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module()))
+ })
+ .map(|(range, name, ref_module)| {
+ let old_name = name.clone();
+ let new_name = edit.make_mut(name.clone());
+
+ // if the referenced module is not the same as the target one and has not been seen before, add an import
+ let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
+ && !visited_modules.contains(&ref_module)
+ {
+ visited_modules.insert(ref_module);
+
+ let import_scope =
+ ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema);
+ let path = ref_module
+ .find_use_path_prefixed(
+ ctx.sema.db,
+ ModuleDef::Module(*target_module),
+ ctx.config.insert_use.prefix_kind,
+ ctx.config.prefer_no_std,
+ )
+ .map(|mod_path| {
+ make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool"))
+ });
+
+ import_scope.zip(path)
+ } else {
+ None
+ };
+
+ FileReferenceWithImport { range, old_name, new_name, import_data }
+ })
+ .collect()
+}
+
+fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
+ let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
+
+ if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
cov_mark::hit!(dont_assign_incorrect_ref);
return None;
}
@@ -275,8 +369,8 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
}
}
-fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> {
- let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
+fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
+ let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
cov_mark::hit!(dont_overwrite_expression_inside_negation);
@@ -291,15 +385,31 @@ fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::
}
}
-fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprField, ast::Expr)> {
- let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?;
+fn find_record_expr_usage(
+ name: &ast::NameLike,
+ got_field: hir::Field,
+ target_definition: Definition,
+) -> Option<(ast::RecordExprField, ast::Expr)> {
+ let name_ref = name.as_name_ref()?;
+ let record_field = ast::RecordExprField::for_field_name(name_ref)?;
let initializer = record_field.expr()?;
- if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) {
- Some((record_field, initializer))
- } else {
- cov_mark::hit!(dont_overwrite_wrong_record_field);
- None
+ if let Definition::Field(expected_field) = target_definition {
+ if got_field != expected_field {
+ return None;
+ }
+ }
+
+ Some((record_field, initializer))
+}
+
+fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
+ let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
+ let pat = record_pat_field.pat()?;
+
+ match pat {
+ ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
+ _ => None,
}
}
@@ -317,7 +427,7 @@ fn add_enum_def(
.filter_map(|FileReference { name, .. }| {
ctx.sema.scope(name.syntax()).map(|scope| scope.module())
})
- .any(|module| &module != target_module);
+ .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub);
let indent = IndentLevel::from_node(&target_node);
@@ -646,7 +756,7 @@ fn main() {
}
#[test]
- fn field_basic() {
+ fn field_struct_basic() {
cov_mark::check!(replaces_record_expr);
check_assist(
bool_to_enum,
@@ -685,6 +795,263 @@ fn main() {
}
#[test]
+ fn field_enum_basic() {
+ cov_mark::check!(replaces_record_pat);
+ check_assist(
+ bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: baz } = foo {
+ if baz {
+ println!("foo");
+ }
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: baz } = foo {
+ if baz == Bool::True {
+ println!("foo");
+ }
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_cross_file() {
+ check_assist(
+ bool_to_enum,
+ r#"
+//- /foo.rs
+pub enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn foo() {
+ let foo = Foo::Bar { bar: true };
+}
+
+//- /main.rs
+use foo::Foo;
+
+mod foo;
+
+fn main() {
+ let foo = Foo::Bar { bar: false };
+}
+"#,
+ r#"
+//- /foo.rs
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+pub enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn foo() {
+ let foo = Foo::Bar { bar: Bool::True };
+}
+
+//- /main.rs
+use foo::{Foo, Bool};
+
+mod foo;
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::False };
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_shorthand() {
+ cov_mark::check!(replaces_record_pat_shorthand);
+ check_assist(
+ bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ match foo {
+ Foo::Bar { bar } => {
+ if bar {
+ println!("foo");
+ }
+ }
+ _ => (),
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ match foo {
+ Foo::Bar { bar } => {
+ if bar == Bool::True {
+ println!("foo");
+ }
+ }
+ _ => (),
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_replaces_literal_patterns() {
+ cov_mark::check!(replaces_literal_pat);
+ check_assist(
+ bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: true } = foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: Bool::True } = foo {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_enum_keeps_wildcard_patterns() {
+ check_assist(
+ bool_to_enum,
+ r#"
+enum Foo {
+ Foo,
+ Bar { $0bar: bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: true };
+
+ if let Foo::Bar { bar: _ } = foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+enum Foo {
+ Foo,
+ Bar { bar: Bool },
+}
+
+fn main() {
+ let foo = Foo::Bar { bar: Bool::True };
+
+ if let Foo::Bar { bar: _ } = foo {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn field_union_basic() {
+ check_assist(
+ bool_to_enum,
+ r#"
+union Foo {
+ $0foo: bool,
+ bar: usize,
+}
+
+fn main() {
+ let foo = Foo { foo: true };
+
+ if unsafe { foo.foo } {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+union Foo {
+ foo: Bool,
+ bar: usize,
+}
+
+fn main() {
+ let foo = Foo { foo: Bool::True };
+
+ if unsafe { foo.foo == Bool::True } {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn field_negated() {
check_assist(
bool_to_enum,
@@ -841,7 +1208,6 @@ fn main() {
#[test]
fn field_initialized_with_other() {
- cov_mark::check!(dont_overwrite_wrong_record_field);
check_assist(
bool_to_enum,
r#"