Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #17467 - winstxnhdw:bool-to-enum, r=Veykril
feat: add bool_to_enum assist for parameters ## Summary This PR adds parameter support for `bool_to_enum` assists. Essentially, the assist can now transform this: ```rs fn function($0foo: bool) { if foo { println!("foo"); } } ``` To this, ```rs #[derive(PartialEq, Eq)] enum Bool { True, False } fn function(foo: Bool) { if foo == Bool::True { println!("foo"); } } ``` Thanks to `@/davidbarsky` for the test skeleton (: Closes #17400
bors 2024-06-30
parent 098d699 · parent d468746 · commit 7e8f9c8
-rw-r--r--crates/ide-assists/src/handlers/bool_to_enum.rs129
1 files changed, 112 insertions, 17 deletions
diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs
index 2aeca0bae0..0aa23ccc84 100644
--- a/crates/ide-assists/src/handlers/bool_to_enum.rs
+++ b/crates/ide-assists/src/handlers/bool_to_enum.rs
@@ -1,3 +1,4 @@
+use either::Either;
use hir::{ImportPathConfig, ModuleDef};
use ide_db::{
assists::{AssistId, AssistKind},
@@ -97,27 +98,30 @@ struct BoolNodeData {
fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
let name: ast::Name = ctx.find_node_at_offset()?;
- if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) {
- let bind_pat = match let_stmt.pat()? {
- ast::Pat::IdentPat(pat) => pat,
- _ => {
- cov_mark::hit!(not_applicable_in_non_ident_pat);
- return None;
- }
- };
- let def = ctx.sema.to_def(&bind_pat)?;
+ if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
+ let def = ctx.sema.to_def(&ident_pat)?;
if !def.ty(ctx.db()).is_bool() {
cov_mark::hit!(not_applicable_non_bool_local);
return None;
}
- Some(BoolNodeData {
- target_node: let_stmt.syntax().clone(),
- name,
- ty_annotation: let_stmt.ty(),
- initializer: let_stmt.initializer(),
- definition: Definition::Local(def),
- })
+ let local_definition = Definition::Local(def);
+ match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
+ Either::Left(param) => Some(BoolNodeData {
+ target_node: param.syntax().clone(),
+ name,
+ ty_annotation: param.ty(),
+ initializer: None,
+ definition: local_definition,
+ }),
+ Either::Right(let_stmt) => Some(BoolNodeData {
+ target_node: let_stmt.syntax().clone(),
+ name,
+ ty_annotation: let_stmt.ty(),
+ initializer: let_stmt.initializer(),
+ definition: local_definition,
+ }),
+ }
} else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
let def = ctx.sema.to_def(&const_)?;
if !def.ty(ctx.db()).is_bool() {
@@ -525,6 +529,98 @@ mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
#[test]
+ fn parameter_with_first_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function($0foo: bool, bar: bool) {
+ if foo {
+ println!("foo");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: Bool, bar: bool) {
+ if foo == Bool::True {
+ println!("foo");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_last_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_middle_param_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn function(foo: bool, $0bar: bool, baz: bool) {
+ if bar {
+ println!("bar");
+ }
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn function(foo: bool, bar: Bool, baz: bool) {
+ if bar == Bool::True {
+ println!("bar");
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
+ fn parameter_with_closure_usage() {
+ check_assist(
+ bool_to_enum,
+ r#"
+fn main() {
+ let foo = |$0bar: bool| bar;
+}
+"#,
+ r#"
+#[derive(PartialEq, Eq)]
+enum Bool { True, False }
+
+fn main() {
+ let foo = |bar: Bool| bar == Bool::True;
+}
+"#,
+ )
+ }
+
+ #[test]
fn local_variable_with_usage() {
check_assist(
bool_to_enum,
@@ -791,7 +887,6 @@ fn main() {
#[test]
fn local_variable_non_ident_pat() {
- cov_mark::check!(not_applicable_in_non_ident_pat);
check_assist_not_applicable(
bool_to_enum,
r#"