Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #14291 - HKalbasi:master, r=HKalbasi
fix multiple definition binding in match to let-else fix #14290
bors 2023-03-09
parent 38e9a11 · parent 811190b · commit 8e404f4
-rw-r--r--crates/ide-assists/src/handlers/convert_match_to_let_else.rs85
1 files changed, 63 insertions, 22 deletions
diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
index 745a870ab6..7f2c01772b 100644
--- a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
+++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
@@ -1,6 +1,6 @@
use ide_db::defs::{Definition, NameRefClass};
use syntax::{
- ast::{self, HasName},
+ ast::{self, HasName, Name},
ted, AstNode, SyntaxNode,
};
@@ -48,7 +48,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
other => format!("{{ {other} }}"),
};
let extracting_arm_pat = extracting_arm.pat()?;
- let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
+ let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
acc.add(
AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
@@ -56,7 +56,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
let_stmt.syntax().text_range(),
|builder| {
let extracting_arm_pat =
- rename_variable(&extracting_arm_pat, extracted_variable, binding);
+ rename_variable(&extracting_arm_pat, &extracted_variable_positions, binding);
builder.replace(
let_stmt.syntax().text_range(),
format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
@@ -95,14 +95,15 @@ fn find_arms(
}
// Given an extracting arm, find the extracted variable.
-fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
+fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
match arm.expr()? {
ast::Expr::PathExpr(path) => {
let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
match NameRefClass::classify(&ctx.sema, &name_ref)? {
NameRefClass::Definition(Definition::Local(local)) => {
- let source = local.primary_source(ctx.db()).into_ident_pat()?;
- Some(source.name()?)
+ let source =
+ local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
+ source.collect()
}
_ => None,
}
@@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
}
// Rename `extracted` with `binding` in `pat`.
-fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::Pat) -> SyntaxNode {
+fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
let syntax = pat.syntax().clone_for_update();
- let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
-
- // If `extracted` variable is a record field, we should rename it to `binding`,
- // otherwise we just need to replace `extracted` with `binding`.
-
- if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
- {
- if let Some(name_ref) = record_pat_field.field_name() {
- ted::replace(
- record_pat_field.syntax(),
- ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding)
+ let extracted = extracted
+ .iter()
+ .map(|e| syntax.covering_element(e.syntax().text_range()))
+ .collect::<Vec<_>>();
+ for extracted_syntax in extracted {
+ // If `extracted` variable is a record field, we should rename it to `binding`,
+ // otherwise we just need to replace `extracted` with `binding`.
+
+ if let Some(record_pat_field) =
+ extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
+ {
+ if let Some(name_ref) = record_pat_field.field_name() {
+ ted::replace(
+ record_pat_field.syntax(),
+ ast::make::record_pat_field(
+ ast::make::name_ref(&name_ref.text()),
+ binding.clone(),
+ )
.syntax()
.clone_for_update(),
- );
+ );
+ }
+ } else {
+ ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
}
- } else {
- ted::replace(extracted_syntax, binding.syntax().clone_for_update());
}
-
syntax
}
@@ -163,6 +171,39 @@ fn foo(opt: Option<()>) {
}
#[test]
+ fn or_pattern_multiple_binding() {
+ check_assist(
+ convert_match_to_let_else,
+ r#"
+//- minicore: option
+enum Foo {
+ A(u32),
+ B(u32),
+ C(String),
+}
+
+fn foo(opt: Option<Foo>) -> Result<u32, ()> {
+ let va$0lue = match opt {
+ Some(Foo::A(it) | Foo::B(it)) => it,
+ _ => return Err(()),
+ };
+}
+ "#,
+ r#"
+enum Foo {
+ A(u32),
+ B(u32),
+ C(String),
+}
+
+fn foo(opt: Option<Foo>) -> Result<u32, ()> {
+ let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
+}
+ "#,
+ );
+ }
+
+ #[test]
fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
check_assist_not_applicable(