Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #21744 from A4-Tacks/add-match-arms-inc-edit
fix: keep comments for 'Fill match arms'
6 files changed, 153 insertions, 47 deletions
diff --git a/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/crates/ide-assists/src/handlers/add_missing_match_arms.rs index 9571270758..1749db1e61 100644 --- a/crates/ide-assists/src/handlers/add_missing_match_arms.rs +++ b/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -6,10 +6,11 @@ use ide_db::RootDatabase; use ide_db::syntax_helpers::suggest_name; use ide_db::{famous_defs::FamousDefs, helpers::mod_path_to_ast}; use itertools::Itertools; -use syntax::ToSmolStr; -use syntax::ast::edit::{AstNodeEdit, IndentLevel}; +use syntax::ast::edit::IndentLevel; use syntax::ast::syntax_factory::SyntaxFactory; use syntax::ast::{self, AstNode, MatchArmList, MatchExpr, Pat, make}; +use syntax::syntax_editor::{Position, SyntaxEditor}; +use syntax::{SyntaxKind, SyntaxNode, ToSmolStr}; use crate::{AssistContext, AssistId, Assists, utils}; @@ -223,32 +224,13 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) // having any hidden variants means that we need a catch-all arm needs_catch_all_arm |= has_hidden_variants; - let missing_arms = missing_pats + let mut missing_arms = missing_pats .filter(|(_, hidden)| { // filter out hidden patterns because they're handled by the catch-all arm !hidden }) - .map(|(pat, _)| make.match_arm(pat, None, utils::expr_fill_default(ctx.config))); - - let mut arms: Vec<_> = match_arm_list - .arms() - .filter(|arm| { - if matches!(arm.pat(), Some(ast::Pat::WildcardPat(_))) { - if arm.expr().is_none_or(is_empty_expr) { - false - } else { - cov_mark::hit!(add_missing_match_arms_empty_expr); - true - } - } else { - true - } - }) - .map(|arm| arm.reset_indent().indent(IndentLevel(1))) - .collect(); - - let first_new_arm_idx = arms.len(); - arms.extend(missing_arms); + .map(|(pat, _)| make.match_arm(pat, None, utils::expr_fill_default(ctx.config))) + .collect::<Vec<_>>(); if needs_catch_all_arm && !has_catch_all_arm { cov_mark::hit!(added_wildcard_pattern); @@ -257,13 +239,11 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) None, utils::expr_fill_default(ctx.config), ); - arms.push(arm); + missing_arms.push(arm); } - let new_match_arm_list = make.match_arm_list(arms); - // FIXME: Hack for syntax trees not having great support for macros - // Just replace the element that the original range came from + // Just edit the element that the original range came from let old_place = { // Find the original element let file = ctx.sema.parse(arm_list_range.file_id); @@ -280,25 +260,27 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) }; let mut editor = builder.make_editor(&old_place); - let new_match_arm_list = new_match_arm_list.indent(IndentLevel::from_node(&old_place)); - editor.replace(old_place, new_match_arm_list.syntax()); + let mut arms_edit = ArmsEdit { match_arm_list, place: old_place, last_arm: None }; + + arms_edit.remove_wildcard_arms(ctx, &mut editor); + arms_edit.add_comma_after_last_arm(ctx, &make, &mut editor); + arms_edit.append_arms(&missing_arms, &make, &mut editor); if let Some(cap) = ctx.config.snippet_cap { - if let Some(it) = new_match_arm_list - .arms() - .nth(first_new_arm_idx) + if let Some(it) = missing_arms + .first() .and_then(|arm| arm.syntax().descendants().find_map(ast::WildcardPat::cast)) { editor.add_annotation(it.syntax(), builder.make_placeholder_snippet(cap)); } - for arm in new_match_arm_list.arms().skip(first_new_arm_idx) { + for arm in &missing_arms { if let Some(expr) = arm.expr() { editor.add_annotation(expr.syntax(), builder.make_placeholder_snippet(cap)); } } - if let Some(arm) = new_match_arm_list.arms().skip(first_new_arm_idx).last() { + if let Some(arm) = missing_arms.last() { editor.add_annotation(arm.syntax(), builder.make_tabstop_after(cap)); } } @@ -357,6 +339,101 @@ fn cursor_at_trivial_match_arm_list( None } +struct ArmsEdit { + match_arm_list: MatchArmList, + place: SyntaxNode, + last_arm: Option<ast::MatchArm>, +} + +impl ArmsEdit { + fn remove_wildcard_arms(&mut self, ctx: &AssistContext<'_>, editor: &mut SyntaxEditor) { + for arm in self.match_arm_list.arms() { + if !matches!(arm.pat(), Some(Pat::WildcardPat(_))) { + self.last_arm = Some(arm); + continue; + } + if !arm.expr().is_none_or(is_empty_expr) { + cov_mark::hit!(add_missing_match_arms_empty_expr); + self.last_arm = Some(arm); + continue; + } + let Some(range) = self.cover_edit_range(ctx, &arm) else { continue }; + + let prev = match range.start() { + syntax::NodeOrToken::Node(node) => { + node.first_token().and_then(|it| it.prev_token()) + } + syntax::NodeOrToken::Token(tok) => tok.prev_token(), + }; + if let Some(prev) = prev + && prev.kind() == SyntaxKind::WHITESPACE + { + editor.delete(prev); + } + + editor.delete_all(range); + } + } + + fn append_arms(&self, arms: &[ast::MatchArm], make: &SyntaxFactory, editor: &mut SyntaxEditor) { + let Some(mut before) = self.place.last_token() else { + stdx::never!("match arm list not contain any token"); + return; + }; + if let Some(prev) = before.prev_token() + && prev.kind() == SyntaxKind::WHITESPACE + { + before = prev; + } + let open_curly = + !self.place.text().contains_char('\n') || before.kind() == SyntaxKind::WHITESPACE; + let indent = IndentLevel::from_node(&self.place); + let arm_indent = indent + 1; + let indent = make.whitespace(&format!("\n{indent}")); + let arm_indent = make.whitespace(&format!("\n{arm_indent}")); + let elements = arms + .iter() + .flat_map(|arm| [arm_indent.clone().into(), arm.syntax().clone().into()]) + .chain(open_curly.then(|| indent.clone().into())) + .collect(); + + if before.kind() == SyntaxKind::WHITESPACE { + editor.replace_with_many(before, elements); + } else { + editor.insert_all(Position::before(before), elements); + } + } + + fn add_comma_after_last_arm( + &self, + ctx: &AssistContext<'_>, + make: &SyntaxFactory, + editor: &mut SyntaxEditor, + ) { + if let Some(last_arm) = &self.last_arm + && last_arm.comma_token().is_none() + && last_arm.expr().is_none_or(|it| !it.is_block_like()) + && let Some(range) = self.cover_edit_range(ctx, last_arm) + { + editor.insert(Position::after(range.end()), make.token(syntax::T![,])); + } + } + + fn cover_edit_range( + &self, + ctx: &AssistContext<'_>, + node: &impl AstNode, + ) -> Option<std::ops::RangeInclusive<syntax::SyntaxElement>> { + let range = ctx.sema.original_range_opt(node.syntax())?; + + if !self.place.text_range().contains_range(range.range) { + return None; + } + + Some(utils::cover_edit_range(&self.place, range.range)) + } +} + fn is_variant_missing(existing_pats: &[Pat], var: &Pat) -> bool { !existing_pats.iter().any(|pat| does_pat_match_variant(pat, var)) } @@ -1734,7 +1811,7 @@ enum Test { fn foo(t: Test) { m!(match t { - Test::A=>(), + Test::A => (), Test::B => ${1:todo!()}, Test::C => ${2:todo!()},$0 }); @@ -2173,6 +2250,35 @@ fn foo(t: E) { } #[test] + fn keep_comments() { + check_assist( + add_missing_match_arms, + r#" +enum E { A, B, C } + +fn foo(t: E) -> i32 { + match $0t { + // variant a + E::A => 2 + // comment on end + } +}"#, + r#" +enum E { A, B, C } + +fn foo(t: E) -> i32 { + match t { + // variant a + E::A => 2, + // comment on end + E::B => ${1:todo!()}, + E::C => ${2:todo!()},$0 + } +}"#, + ); + } + + #[test] fn not_applicable_when_match_arm_list_cannot_be_upmapped() { check_assist_not_applicable( add_missing_match_arms, diff --git a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs index 42fceb8533..4dd2036c02 100644 --- a/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs +++ b/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs @@ -242,7 +242,7 @@ where { let make = SyntaxFactory::without_mappings(); let orig = ctx.sema.original_range_opt(field_list.syntax())?; - let list_range = cover_edit_range(source, orig.range); + let list_range = cover_edit_range(source.syntax(), orig.range); let l_curly = match list_range.start() { NodeOrToken::Node(node) => node.first_token()?, @@ -265,7 +265,7 @@ where for name_ref in fields(&field_list) { let Some(orig) = ctx.sema.original_range_opt(name_ref.syntax()) else { continue }; - let name_range = cover_edit_range(source, orig.range); + let name_range = cover_edit_range(source.syntax(), orig.range); if let Some(colon) = next_non_trivia_token(name_range.end().clone()) && colon.kind() == T![:] @@ -306,7 +306,7 @@ fn edit_field_references( // Only edit the field reference if it's part of a `.field` access if name_ref.syntax().parent().and_then(ast::FieldExpr::cast).is_some() { edit.replace_all( - cover_edit_range(&source, r.range), + cover_edit_range(source.syntax(), r.range), vec![make.name_ref(&index.to_string()).syntax().clone().into()], ); } diff --git a/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs b/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs index f1eae83866..270467b14f 100644 --- a/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs +++ b/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs @@ -191,7 +191,7 @@ fn process_struct_name_reference( full_path, generate_record_pat_list(&tuple_struct_pat, names), ); - editor.replace_all(cover_edit_range(source, range), vec![new.syntax().clone().into()]); + editor.replace_all(cover_edit_range(source.syntax(), range), vec![new.syntax().clone().into()]); }, ast::PathExpr(path_expr) => { let call_expr = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; @@ -207,7 +207,7 @@ fn process_struct_name_reference( let mut first_insert = vec![]; for (expr, name) in arg_list.args().zip(names) { let range = ctx.sema.original_range_opt(expr.syntax())?.range; - let place = cover_edit_range(source, range); + let place = cover_edit_range(source.syntax(), range); let elements = vec![ make.name_ref(&name.text()).syntax().clone().into(), make.token(T![:]).into(), @@ -236,7 +236,7 @@ fn process_delimiter( first_insert: Vec<syntax::SyntaxElement>, ) { let Some(range) = ctx.sema.original_range_opt(list.syntax()) else { return }; - let place = cover_edit_range(source, range.range); + let place = cover_edit_range(source.syntax(), range.range); let l_paren = match place.start() { syntax::NodeOrToken::Node(node) => node.first_token(), @@ -290,7 +290,7 @@ fn edit_field_references( && let Some(original) = ctx.sema.original_range_opt(name_ref.syntax()) { editor.replace_all( - cover_edit_range(&source, original.range), + cover_edit_range(source.syntax(), original.range), vec![name.syntax().clone().into()], ); } diff --git a/crates/ide-assists/src/handlers/destructure_struct_binding.rs b/crates/ide-assists/src/handlers/destructure_struct_binding.rs index 3f42696fa3..ec4a83b642 100644 --- a/crates/ide-assists/src/handlers/destructure_struct_binding.rs +++ b/crates/ide-assists/src/handlers/destructure_struct_binding.rs @@ -358,7 +358,7 @@ fn update_usages( data: &StructEditData, field_names: &FxHashMap<SmolStr, SmolStr>, ) { - let source = ctx.source_file(); + let source = ctx.source_file().syntax(); let make = SyntaxFactory::with_mappings(); let edits = data .usages diff --git a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs index 583ba42bf5..23c11b258c 100644 --- a/crates/ide-assists/src/handlers/destructure_tuple_binding.rs +++ b/crates/ide-assists/src/handlers/destructure_tuple_binding.rs @@ -326,7 +326,7 @@ impl EditTupleUsage { } EditTupleUsage::ReplaceExpr(target_expr, replace_with) => { if let Some(range) = ctx.sema.original_range_opt(target_expr.syntax()) { - let source = ctx.source_file(); + let source = ctx.source_file().syntax(); syntax_editor.replace_all( cover_edit_range(source, range.range), vec![replace_with.syntax().clone().into()], diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs index 07811fb6f0..a85a89efb4 100644 --- a/crates/ide-assists/src/utils.rs +++ b/crates/ide-assists/src/utils.rs @@ -1416,10 +1416,10 @@ pub(crate) fn cover_let_chain(mut expr: ast::Expr, range: TextRange) -> Option<a } pub(crate) fn cover_edit_range( - source: &impl AstNode, + source: &SyntaxNode, range: TextRange, ) -> std::ops::RangeInclusive<syntax::SyntaxElement> { - let node = match source.syntax().covering_element(range) { + let node = match source.covering_element(range) { NodeOrToken::Node(node) => node, NodeOrToken::Token(t) => t.parent().unwrap(), }; |