Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs171
-rw-r--r--crates/syntax/src/ast/make.rs5
2 files changed, 148 insertions, 28 deletions
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index e3e928aecd..8efd58e2c3 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -11,7 +11,7 @@ use crate::{
ted::{self, Position},
AstNode, AstToken, Direction,
SyntaxKind::{ATTR, COMMENT, WHITESPACE},
- SyntaxNode,
+ SyntaxNode, SyntaxToken,
};
use super::HasName;
@@ -506,19 +506,7 @@ impl ast::RecordExprFieldList {
let position = match self.fields().last() {
Some(last_field) => {
- let comma = match last_field
- .syntax()
- .siblings_with_tokens(Direction::Next)
- .filter_map(|it| it.into_token())
- .find(|it| it.kind() == T![,])
- {
- Some(it) => it,
- None => {
- let comma = ast::make::token(T![,]);
- ted::insert(Position::after(last_field.syntax()), &comma);
- comma
- }
- };
+ let comma = get_or_insert_comma_after(last_field.syntax());
Position::after(comma)
}
None => match self.l_curly_token() {
@@ -579,19 +567,8 @@ impl ast::RecordPatFieldList {
let position = match self.fields().last() {
Some(last_field) => {
- let comma = match last_field
- .syntax()
- .siblings_with_tokens(Direction::Next)
- .filter_map(|it| it.into_token())
- .find(|it| it.kind() == T![,])
- {
- Some(it) => it,
- None => {
- let comma = ast::make::token(T![,]);
- ted::insert(Position::after(last_field.syntax()), &comma);
- comma
- }
- };
+ let syntax = last_field.syntax();
+ let comma = get_or_insert_comma_after(syntax);
Position::after(comma)
}
None => match self.l_curly_token() {
@@ -606,12 +583,53 @@ impl ast::RecordPatFieldList {
}
}
}
+
+fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken {
+ let comma = match syntax
+ .siblings_with_tokens(Direction::Next)
+ .filter_map(|it| it.into_token())
+ .find(|it| it.kind() == T![,])
+ {
+ Some(it) => it,
+ None => {
+ let comma = ast::make::token(T![,]);
+ ted::insert(Position::after(syntax), &comma);
+ comma
+ }
+ };
+ comma
+}
+
impl ast::StmtList {
pub fn push_front(&self, statement: ast::Stmt) {
ted::insert(Position::after(self.l_curly_token().unwrap()), statement.syntax());
}
}
+impl ast::VariantList {
+ pub fn add_variant(&self, variant: ast::Variant) {
+ let (indent, position) = match self.variants().last() {
+ Some(last_item) => (
+ IndentLevel::from_node(last_item.syntax()),
+ Position::after(get_or_insert_comma_after(last_item.syntax())),
+ ),
+ None => match self.l_curly_token() {
+ Some(l_curly) => {
+ normalize_ws_between_braces(self.syntax());
+ (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
+ }
+ None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
+ },
+ };
+ let elements: Vec<SyntaxElement<_>> = vec![
+ make::tokens::whitespace(&format!("{}{}", "\n", indent)).into(),
+ variant.syntax().clone().into(),
+ ast::make::token(T![,]).into(),
+ ];
+ ted::insert_all(position, elements);
+ }
+}
+
fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
let l = node
.children_with_tokens()
@@ -661,6 +679,9 @@ impl<N: AstNode + Clone> Indent for N {}
mod tests {
use std::fmt;
+ use stdx::trim_indent;
+ use test_utils::assert_eq_text;
+
use crate::SourceFile;
use super::*;
@@ -714,4 +735,100 @@ mod tests {
}",
);
}
+
+ #[test]
+ fn add_variant_to_empty_enum() {
+ let variant = make::variant(make::name("Bar"), None).clone_for_update();
+
+ check_add_variant(
+ r#"
+enum Foo {}
+"#,
+ r#"
+enum Foo {
+ Bar,
+}
+"#,
+ variant,
+ );
+ }
+
+ #[test]
+ fn add_variant_to_non_empty_enum() {
+ let variant = make::variant(make::name("Baz"), None).clone_for_update();
+
+ check_add_variant(
+ r#"
+enum Foo {
+ Bar,
+}
+"#,
+ r#"
+enum Foo {
+ Bar,
+ Baz,
+}
+"#,
+ variant,
+ );
+ }
+
+ #[test]
+ fn add_variant_with_tuple_field_list() {
+ let variant = make::variant(
+ make::name("Baz"),
+ Some(ast::FieldList::TupleFieldList(make::tuple_field_list(std::iter::once(
+ make::tuple_field(None, make::ty("bool")),
+ )))),
+ )
+ .clone_for_update();
+
+ check_add_variant(
+ r#"
+enum Foo {
+ Bar,
+}
+"#,
+ r#"
+enum Foo {
+ Bar,
+ Baz(bool),
+}
+"#,
+ variant,
+ );
+ }
+
+ #[test]
+ fn add_variant_with_record_field_list() {
+ let variant = make::variant(
+ make::name("Baz"),
+ Some(ast::FieldList::RecordFieldList(make::record_field_list(std::iter::once(
+ make::record_field(None, make::name("x"), make::ty("bool")),
+ )))),
+ )
+ .clone_for_update();
+
+ check_add_variant(
+ r#"
+enum Foo {
+ Bar,
+}
+"#,
+ r#"
+enum Foo {
+ Bar,
+ Baz { x: bool },
+}
+"#,
+ variant,
+ );
+ }
+
+ fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
+ let enum_ = ast_mut_from_text::<ast::Enum>(before);
+ enum_.variant_list().map(|it| it.add_variant(variant));
+ let after = enum_.to_string();
+ assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(&after.trim()));
+ }
}
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 5908dda8e6..037de876d4 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -745,7 +745,10 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
let field_list = match field_list {
None => String::new(),
- Some(it) => format!("{}", it),
+ Some(it) => match it {
+ ast::FieldList::RecordFieldList(record) => format!(" {}", record),
+ ast::FieldList::TupleFieldList(tuple) => format!("{}", tuple),
+ },
};
ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
}