Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs29
-rw-r--r--crates/ide-assists/src/utils/gen_trait_fn_body.rs17
2 files changed, 43 insertions, 3 deletions
diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
index a1ca286121..2854701c08 100644
--- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
+++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs
@@ -907,6 +907,33 @@ impl PartialEq for Foo {
}
#[test]
+ fn add_custom_impl_partial_eq_single_variant_tuple_enum() {
+ check_assist(
+ replace_derive_with_manual_impl,
+ r#"
+//- minicore: eq, derive
+#[derive(Partial$0Eq)]
+enum Foo {
+ Bar(String),
+}
+"#,
+ r#"
+enum Foo {
+ Bar(String),
+}
+
+impl PartialEq for Foo {
+ $0fn eq(&self, other: &Self) -> bool {
+ match (self, other) {
+ (Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
+ }
+ }
+}
+"#,
+ )
+ }
+
+ #[test]
fn add_custom_impl_partial_eq_partial_tuple_enum() {
check_assist(
replace_derive_with_manual_impl,
@@ -959,7 +986,7 @@ impl PartialEq for Foo {
match (self, other) {
(Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
(Self::Baz(l0), Self::Baz(r0)) => l0 == r0,
- _ => core::mem::discriminant(self) == core::mem::discriminant(other),
+ _ => false,
}
}
}
diff --git a/crates/ide-assists/src/utils/gen_trait_fn_body.rs b/crates/ide-assists/src/utils/gen_trait_fn_body.rs
index 287001af84..f32e5ce97d 100644
--- a/crates/ide-assists/src/utils/gen_trait_fn_body.rs
+++ b/crates/ide-assists/src/utils/gen_trait_fn_body.rs
@@ -439,8 +439,10 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let eq_check =
make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
+ let mut n_cases = 0;
let mut arms = vec![];
for variant in enum_.variant_list()?.variants() {
+ n_cases += 1;
match variant.field_list() {
// => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
Some(ast::FieldList::RecordFieldList(list)) => {
@@ -514,8 +516,19 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let expr = match arms.len() {
0 => eq_check,
- _ => {
- arms.push(make::match_arm(Some(make::wildcard_pat().into()), None, eq_check));
+ arms_len => {
+ // Generate the fallback arm when this enum has >1 variants.
+ // The fallback arm will be `_ => false,` if we've already gone through every case where the variants of self and other match,
+ // and `_ => std::mem::discriminant(self) == std::mem::discriminant(other),` otherwise.
+ if n_cases > 1 {
+ let lhs = make::wildcard_pat().into();
+ let rhs = if arms_len == n_cases {
+ make::expr_literal("false").into()
+ } else {
+ eq_check
+ };
+ arms.push(make::match_arm(Some(lhs), None, rhs));
+ }
let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));