use either::Either; use syntax::{ AstNode, algo::find_node_at_range, ast::{self, syntax_factory::SyntaxFactory}, syntax_editor::SyntaxEditor, }; use crate::{ AssistId, assist_context::{AssistContext, Assists}, }; // Assist: pull_assignment_up // // Extracts variable assignment to outside an if or match statement. // // ``` // fn main() { // let mut foo = 6; // // if true { // $0foo = 5; // } else { // foo = 4; // } // } // ``` // -> // ``` // fn main() { // let mut foo = 6; // // foo = if true { // 5 // } else { // 4 // }; // } // ``` pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let assign_expr = ctx.find_node_at_offset::()?; let op_kind = assign_expr.op_kind()?; if op_kind != (ast::BinaryOp::Assignment { op: None }) { cov_mark::hit!(test_cant_pull_non_assignments); return None; } let mut collector = AssignmentsCollector { sema: &ctx.sema, common_lhs: assign_expr.lhs()?, assignments: Vec::new(), }; let node: Either = ctx.find_node_at_offset()?; let tgt: ast::Expr = if let Either::Left(if_expr) = node { let if_expr = std::iter::successors(Some(if_expr), |it| { it.syntax().parent().and_then(ast::IfExpr::cast) }) .last()?; collector.collect_if(&if_expr)?; if_expr.into() } else if let Either::Right(match_expr) = node { collector.collect_match(&match_expr)?; match_expr.into() } else { return None; }; if let Some(parent) = tgt.syntax().parent() && matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) { return None; } let target = tgt.syntax().text_range(); let edit_tgt = tgt.syntax().clone_subtree(); let assignments: Vec<_> = collector .assignments .into_iter() .filter_map(|(stmt, rhs)| { Some(( find_node_at_range::( &edit_tgt, stmt.syntax().text_range() - target.start(), )?, find_node_at_range::( &edit_tgt, rhs.syntax().text_range() - target.start(), )?, )) }) .collect(); let mut editor = SyntaxEditor::new(edit_tgt); for (stmt, rhs) in assignments { let mut stmt = stmt.syntax().clone(); if let Some(parent) = stmt.parent() && ast::ExprStmt::cast(parent.clone()).is_some() { stmt = parent.clone(); } editor.replace(stmt, rhs.syntax()); } let new_tgt_root = editor.finish().new_root().clone(); let new_tgt = ast::Expr::cast(new_tgt_root)?; acc.add( AssistId::refactor_extract("pull_assignment_up"), "Pull assignment up", target, move |edit| { let make = SyntaxFactory::with_mappings(); let mut editor = edit.make_editor(tgt.syntax()); let assign_expr = make.expr_assignment(collector.common_lhs, new_tgt.clone()); let assign_stmt = make.expr_stmt(assign_expr.into()); editor.replace(tgt.syntax(), assign_stmt.syntax()); editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) } struct AssignmentsCollector<'a> { sema: &'a hir::Semantics<'a, ide_db::RootDatabase>, common_lhs: ast::Expr, assignments: Vec<(ast::BinExpr, ast::Expr)>, } impl AssignmentsCollector<'_> { fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> { for arm in match_expr.match_arm_list()?.arms() { match arm.expr()? { ast::Expr::BlockExpr(block) => self.collect_block(&block)?, ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?, _ => return None, } } Some(()) } fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> { let then_branch = if_expr.then_branch()?; self.collect_block(&then_branch)?; match if_expr.else_branch()? { ast::ElseBranch::Block(block) => self.collect_block(&block), ast::ElseBranch::IfExpr(expr) => { cov_mark::hit!(test_pull_assignment_up_chained_if); self.collect_if(&expr) } } } fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> { let last_expr = block.tail_expr().or_else(|| match block.statements().last()? { ast::Stmt::ExprStmt(stmt) => stmt.expr(), ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None, })?; if let ast::Expr::BinExpr(expr) = last_expr { return self.collect_expr(&expr); } None } fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> { if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None }) && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs) { self.assignments.push((expr.clone(), expr.rhs()?)); return Some(()); } None } } fn is_equivalent( sema: &hir::Semantics<'_, ide_db::RootDatabase>, expr0: &ast::Expr, expr1: &ast::Expr, ) -> bool { match (expr0, expr1) { (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => { cov_mark::hit!(test_pull_assignment_up_field_assignment); sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1) } (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => { let path0 = path0.path(); let path1 = path1.path(); if let (Some(path0), Some(path1)) = (path0, path1) { sema.resolve_path(&path0) == sema.resolve_path(&path1) } else { false } } (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1)) if prefix0.op_kind() == Some(ast::UnaryOp::Deref) && prefix1.op_kind() == Some(ast::UnaryOp::Deref) => { cov_mark::hit!(test_pull_assignment_up_deref); if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) { is_equivalent(sema, &prefix0, &prefix1) } else { false } } _ => false, } } #[cfg(test)] mod tests { use super::*; use crate::tests::{check_assist, check_assist_not_applicable}; #[test] fn test_pull_assignment_up_if() { check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { $0a = 2; } else { a = 3; } }"#, r#" fn foo() { let mut a = 1; a = if true { 2 } else { 3 }; }"#, ); } #[test] fn test_pull_assignment_up_inner_if() { check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { a = 2; } else if true { $0a = 3; } else { a = 4; } }"#, r#" fn foo() { let mut a = 1; a = if true { 2 } else if true { 3 } else { 4 }; }"#, ); } #[test] fn test_pull_assignment_up_match() { check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; match 1 { 1 => { $0a = 2; }, 2 => { a = 3; }, 3 => { a = 4; } } }"#, r#" fn foo() { let mut a = 1; a = match 1 { 1 => { 2 }, 2 => { 3 }, 3 => { 4 } }; }"#, ); } #[test] fn test_pull_assignment_up_match_in_if_expr() { check_assist( pull_assignment_up, r#" fn foo() { let x; if true { match true { true => $0x = 2, false => x = 3, } } }"#, r#" fn foo() { let x; if true { x = match true { true => 2, false => 3, }; } }"#, ); } #[test] fn test_pull_assignment_up_assignment_expressions() { check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; match 1 { 1 => { $0a = 2; }, 2 => a = 3, 3 => { a = 4 } } }"#, r#" fn foo() { let mut a = 1; a = match 1 { 1 => { 2 }, 2 => 3, 3 => { 4 } }; }"#, ); } #[test] fn test_pull_assignment_up_not_last_not_applicable() { check_assist_not_applicable( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { $0a = 2; b = a; } else { a = 3; } }"#, ) } #[test] fn test_pull_assignment_up_chained_if() { cov_mark::check!(test_pull_assignment_up_chained_if); check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { $0a = 2; } else if false { a = 3; } else { a = 4; } }"#, r#" fn foo() { let mut a = 1; a = if true { 2 } else if false { 3 } else { 4 }; }"#, ); } #[test] fn test_pull_assignment_up_retains_stmts() { check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { let b = 2; $0a = 2; } else { let b = 3; a = 3; } }"#, r#" fn foo() { let mut a = 1; a = if true { let b = 2; 2 } else { let b = 3; 3 }; }"#, ) } #[test] fn pull_assignment_up_let_stmt_not_applicable() { check_assist_not_applicable( pull_assignment_up, r#" fn foo() { let mut a = 1; let b = if true { $0a = 2 } else { a = 3 }; }"#, ) } #[test] fn pull_assignment_up_if_missing_assignment_not_applicable() { check_assist_not_applicable( pull_assignment_up, r#" fn foo() { let mut a = 1; if true { $0a = 2; } else {} }"#, ) } #[test] fn pull_assignment_up_match_missing_assignment_not_applicable() { check_assist_not_applicable( pull_assignment_up, r#" fn foo() { let mut a = 1; match 1 { 1 => { $0a = 2; }, 2 => { a = 3; }, 3 => {}, } }"#, ) } #[test] fn test_pull_assignment_up_field_assignment() { cov_mark::check!(test_pull_assignment_up_field_assignment); check_assist( pull_assignment_up, r#" struct A(usize); fn foo() { let mut a = A(1); if true { $0a.0 = 2; } else { a.0 = 3; } }"#, r#" struct A(usize); fn foo() { let mut a = A(1); a.0 = if true { 2 } else { 3 }; }"#, ) } #[test] fn test_pull_assignment_up_deref() { cov_mark::check!(test_pull_assignment_up_deref); check_assist( pull_assignment_up, r#" fn foo() { let mut a = 1; let b = &mut a; if true { $0*b = 2; } else { *b = 3; } } "#, r#" fn foo() { let mut a = 1; let b = &mut a; *b = if true { 2 } else { 3 }; } "#, ) } #[test] fn test_cant_pull_non_assignments() { cov_mark::check!(test_cant_pull_non_assignments); check_assist_not_applicable( pull_assignment_up, r#" fn foo() { let mut a = 1; let b = &mut a; if true { $0*b + 2; } else { *b + 3; } } "#, ) } }