Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/bool_to_enum.rs')
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs193
1 files changed, 160 insertions, 33 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 85b0b87d0c..3303a2dd3c 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -16,7 +16,7 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName,
},
- ted, AstNode, NodeOrToken, SyntaxNode, T,
+ match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T,
};
use text_edit::TextRange;
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
// ```
// ->
// ```
-// fn main() {
-// #[derive(PartialEq, Eq)]
-// enum Bool { True, False }
+// #[derive(PartialEq, Eq)]
+// enum Bool { True, False }
//
+// fn main() {
// let bool = Bool::True;
//
// if bool == Bool::True {
@@ -270,6 +270,10 @@ fn replace_usages(
}
_ => (),
}
+ } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
+ {
+ edit.replace(ty_annotation.syntax().text_range(), "Bool");
+ replace_bool_expr(edit, initializer);
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name
@@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
}
}
+fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
+ let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
+ if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
+ return None;
+ }
+
+ Some((const_.ty()?, const_.body()?))
+}
+
/// Adds the definition of the new enum before the target node.
fn add_enum_def(
edit: &mut SourceChangeBuilder,
@@ -430,11 +443,12 @@ fn add_enum_def(
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub);
- let indent = IndentLevel::from_node(&target_node);
+ let insert_before = node_to_insert_before(target_node);
+ let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent);
ted::insert_all(
- ted::Position::before(&edit.make_syntax_mut(target_node)),
+ ted::Position::before(&edit.make_syntax_mut(insert_before)),
vec![
enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@@ -442,6 +456,35 @@ fn add_enum_def(
);
}
+/// Finds where to put the new enum definition, at the nearest module or at top-level.
+fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode {
+ let mut ancestors = target_node.ancestors();
+
+ while let Some(ancestor) = ancestors.next() {
+ match_ast! {
+ match ancestor {
+ ast::Item(item) => {
+ if item
+ .syntax()
+ .parent()
+ .and_then(|item_list| item_list.parent())
+ .and_then(ast::Module::cast)
+ .is_some()
+ {
+ return ancestor;
+ }
+ },
+ ast::SourceFile(_) => break,
+ _ => (),
+ }
+ }
+
+ target_node = ancestor;
+ }
+
+ target_node
+}
+
fn make_bool_enum(make_pub: bool) -> ast::Enum {
let enum_def = make::enum_(
if make_pub { Some(make::visibility_pub()) } else { None },
@@ -491,10 +534,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = Bool::True;
if foo == Bool::True {
@@ -520,10 +563,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = Bool::True;
if foo == Bool::False {
@@ -545,10 +588,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo: Bool = Bool::False;
}
"#,
@@ -565,10 +608,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = if 1 == 2 { Bool::True } else { Bool::False };
}
"#,
@@ -590,10 +633,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = Bool::False;
let bar = true;
@@ -619,10 +662,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = Bool::True;
if *&foo == Bool::True {
@@ -645,10 +688,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo: Bool;
foo = Bool::True;
}
@@ -671,10 +714,10 @@ fn main() {
}
"#,
r#"
-fn main() {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+fn main() {
let foo = Bool::True;
let bar = foo == Bool::False;
@@ -702,11 +745,11 @@ fn main() {
}
"#,
r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
fn main() {
if !"foo".chars().any(|c| {
- #[derive(PartialEq, Eq)]
- enum Bool { True, False }
-
let foo = Bool::True;
foo == Bool::True
}) {
@@ -1446,6 +1489,90 @@ pub mod bar {
}
#[test]
+ fn const_in_impl_cross_file() {
+ check_assist(
+ bool_to_enum,
+ r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+impl Foo {
+ pub const $0BOOL: bool = true;
+}
+
+//- /foo.rs
+use crate::Foo;
+
+fn foo() -> bool {
+ Foo::BOOL
+}
+"#,
+ r#"
+//- /main.rs
+mod foo;
+
+struct Foo;
+
+#[derive(PartialEq, Eq)]
+pub enum Bool { True, False }
+
+impl Foo {
+ pub const BOOL: Bool = Bool::True;
+}
+
+//- /foo.rs
+use crate::{Foo, Bool};
+
+fn foo() -> bool {
+ Foo::BOOL == Bool::True
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn const_in_trait() {
+ check_assist(
+ bool_to_enum,
+ r#"
+trait Foo {
+ const $0BOOL: bool;
+}
+
+impl Foo for usize {
+ const BOOL: bool = true;
+}
+
+fn main() {
+ if <usize as Foo>::BOOL {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+trait Foo {
+ const BOOL: Bool;
+}
+
+impl Foo for usize {
+ const BOOL: Bool = Bool::True;
+}
+
+fn main() {
+ if <usize as Foo>::BOOL == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn const_non_bool() {
cov_mark::check!(not_applicable_non_bool_const);
check_assist_not_applicable(