use ide_db::defs::{Definition, NameRefClass}; use syntax::{ AstNode, SyntaxNode, ast::{self, HasName, Name, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, syntax_editor::SyntaxEditor, }; use crate::{ AssistId, assist_context::{AssistContext, Assists}, }; // Assist: convert_match_to_let_else // // Converts let statement with match initializer to let-else statement. // // ``` // # //- minicore: option // fn foo(opt: Option<()>) { // let val$0 = match opt { // Some(it) => it, // None => return, // }; // } // ``` // -> // ``` // fn foo(opt: Option<()>) { // let Some(val) = opt else { return }; // } // ``` pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?; let pat = let_stmt.pat()?; if ctx.offset() > pat.syntax().text_range().end() { return None; } let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None }; let initializer_expr = initializer.expr()?; let (extracting_arm, diverging_arm) = find_arms(ctx, &initializer)?; if extracting_arm.guard().is_some() { cov_mark::hit!(extracting_arm_has_guard); return None; } let diverging_arm_expr = match diverging_arm.expr()?.dedent(1.into()) { ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => { block.to_string() } other => format!("{{ {other} }}"), }; let extracting_arm_pat = extracting_arm.pat()?; let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?; acc.add( AssistId::refactor_rewrite("convert_match_to_let_else"), "Convert match to let-else", let_stmt.syntax().text_range(), |builder| { let extracting_arm_pat = rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat); builder.replace( let_stmt.syntax().text_range(), format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"), ) }, ) } // Given a match expression, find extracting and diverging arms. fn find_arms( ctx: &AssistContext<'_>, match_expr: &ast::MatchExpr, ) -> Option<(ast::MatchArm, ast::MatchArm)> { let arms = match_expr.match_arm_list()?.arms().collect::>(); if arms.len() != 2 { return None; } let mut extracting = None; let mut diverging = None; for arm in arms { if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() { diverging = Some(arm); } else { extracting = Some(arm); } } match (extracting, diverging) { (Some(extracting), Some(diverging)) => Some((extracting, diverging)), _ => { cov_mark::hit!(non_diverging_match); None } } } // Given an extracting arm, find the extracted variable. fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option> { match arm.expr()? { ast::Expr::PathExpr(path) => { let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; match NameRefClass::classify(&ctx.sema, &name_ref)? { NameRefClass::Definition(Definition::Local(local), _) => { let source = local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name()); source.collect() } _ => None, } } _ => { cov_mark::hit!(extracting_arm_is_not_an_identity_expr); None } } } // Rename `extracted` with `binding` in `pat`. fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode { let syntax = pat.syntax().clone_subtree(); let mut editor = SyntaxEditor::new(syntax.clone()); let make = SyntaxFactory::with_mappings(); let extracted = extracted .iter() .map(|e| e.syntax().text_range() - pat.syntax().text_range().start()) .map(|r| syntax.covering_element(r)) .collect::>(); for extracted_syntax in extracted { // If `extracted` variable is a record field, we should rename it to `binding`, // otherwise we just need to replace `extracted` with `binding`. if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast) { if let Some(name_ref) = record_pat_field.field_name() { editor.replace( record_pat_field.syntax(), make.record_pat_field( make.name_ref(&name_ref.text()), binding.clone_for_update(), ) .syntax(), ); } } else { editor.replace(extracted_syntax, binding.syntax().clone_for_update()); } } editor.add_mappings(make.finish_with_mappings()); let new_node = editor.finish().new_root().clone(); if let Some(pat) = ast::Pat::cast(new_node.clone()) { pat.dedent(1.into()).syntax().clone() } else { new_node } } #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; use super::*; #[test] fn should_not_be_applicable_for_non_diverging_match() { cov_mark::check!(non_diverging_match); check_assist_not_applicable( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { let val$0 = match opt { Some(it) => it, None => (), }; } "#, ); } #[test] fn or_pattern_multiple_binding() { check_assist( convert_match_to_let_else, r#" //- minicore: option enum Foo { A(u32), B(u32), C(String), } fn foo(opt: Option) -> Result { let va$0lue = match opt { Some(Foo::A(it) | Foo::B(it)) => it, _ => return Err(()), }; } "#, r#" enum Foo { A(u32), B(u32), C(String), } fn foo(opt: Option) -> Result { let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) }; } "#, ); } #[test] fn indent_level() { check_assist( convert_match_to_let_else, r#" //- minicore: option enum Foo { A(u32), B(u32), C(String), } fn foo(opt: Option) -> Result { let mut state = 2; let va$0lue = match opt { Some( Foo::A(it) | Foo::B(it) ) => it, _ => { state = 3; return Err(()) }, }; } "#, r#" enum Foo { A(u32), B(u32), C(String), } fn foo(opt: Option) -> Result { let mut state = 2; let Some( Foo::A(value) | Foo::B(value) ) = opt else { state = 3; return Err(()) }; } "#, ); } #[test] fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() { cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2); check_assist_not_applicable( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option) { let val$0 = match opt { Some(it) => it + 1, None => return, }; } "#, ); check_assist_not_applicable( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { let val$0 = match opt { Some(it) => { let _ = 1 + 1; it }, None => return, }; } "#, ); } #[test] fn should_not_be_applicable_if_extracting_arm_has_guard() { cov_mark::check!(extracting_arm_has_guard); check_assist_not_applicable( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { let val$0 = match opt { Some(it) if 2 > 1 => it, None => return, }; } "#, ); } #[test] fn basic_pattern() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { let val$0 = match opt { Some(it) => it, None => return, }; } "#, r#" fn foo(opt: Option<()>) { let Some(val) = opt else { return }; } "#, ); } #[test] fn keeps_modifiers() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { let ref mut val$0 = match opt { Some(it) => it, None => return, }; } "#, r#" fn foo(opt: Option<()>) { let Some(ref mut val) = opt else { return }; } "#, ); } #[test] fn nested_pattern() { check_assist( convert_match_to_let_else, r#" //- minicore: option, result fn foo(opt: Option>) { let val$0 = match opt { Some(Ok(it)) => it, _ => return, }; } "#, r#" fn foo(opt: Option>) { let Some(Ok(val)) = opt else { return }; } "#, ); } #[test] fn works_with_any_diverging_block() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { loop { let val$0 = match opt { Some(it) => it, None => break, }; } } "#, r#" fn foo(opt: Option<()>) { loop { let Some(val) = opt else { break }; } } "#, ); check_assist( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option<()>) { loop { let val$0 = match opt { Some(it) => it, None => continue, }; } } "#, r#" fn foo(opt: Option<()>) { loop { let Some(val) = opt else { continue }; } } "#, ); check_assist( convert_match_to_let_else, r#" //- minicore: option fn panic() -> ! {} fn foo(opt: Option<()>) { loop { let val$0 = match opt { Some(it) => it, None => panic(), }; } } "#, r#" fn panic() -> ! {} fn foo(opt: Option<()>) { loop { let Some(val) = opt else { panic() }; } } "#, ); } #[test] fn struct_pattern() { check_assist( convert_match_to_let_else, r#" //- minicore: option struct Point { x: i32, y: i32, } fn foo(opt: Option) { let val$0 = match opt { Some(Point { x: 0, y }) => y, _ => return, }; } "#, r#" struct Point { x: i32, y: i32, } fn foo(opt: Option) { let Some(Point { x: 0, y: val }) = opt else { return }; } "#, ); } #[test] fn renames_whole_binding() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn foo(opt: Option) -> Option { let val$0 = match opt { it @ Some(42) => it, _ => return None, }; val } "#, r#" fn foo(opt: Option) -> Option { let val @ Some(42) = opt else { return None }; val } "#, ); } #[test] fn complex_pattern() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn f() { let (x, y)$0 = match Some((0, 1)) { Some(it) => it, None => return, }; } "#, r#" fn f() { let Some((x, y)) = Some((0, 1)) else { return }; } "#, ); } #[test] fn diverging_block() { check_assist( convert_match_to_let_else, r#" //- minicore: option fn f() { let x$0 = match Some(()) { Some(it) => it, None => {//comment println!("nope"); return }, }; } "#, r#" fn f() { let Some(x) = Some(()) else {//comment println!("nope"); return }; } "#, ); } }