Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/syntax/src/syntax_editor/edits.rs')
-rw-r--r--crates/syntax/src/syntax_editor/edits.rs105
1 files changed, 103 insertions, 2 deletions
diff --git a/crates/syntax/src/syntax_editor/edits.rs b/crates/syntax/src/syntax_editor/edits.rs
index 9090f7c9eb..44f0a8038e 100644
--- a/crates/syntax/src/syntax_editor/edits.rs
+++ b/crates/syntax/src/syntax_editor/edits.rs
@@ -10,6 +10,107 @@ use crate::{
syntax_editor::{Position, SyntaxEditor},
};
+pub trait GetOrCreateWhereClause: ast::HasGenericParams {
+ fn where_clause_position(&self) -> Option<Position>;
+
+ fn get_or_create_where_clause(
+ &self,
+ editor: &mut SyntaxEditor,
+ make: &SyntaxFactory,
+ new_preds: impl Iterator<Item = ast::WherePred>,
+ ) {
+ let existing = self.where_clause();
+ let all_preds: Vec<_> =
+ existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect();
+ let new_where = make.where_clause(all_preds);
+
+ if let Some(existing) = &existing {
+ editor.replace(existing.syntax(), new_where.syntax());
+ } else if let Some(pos) = self.where_clause_position() {
+ editor.insert_all(
+ pos,
+ vec![make.whitespace(" ").into(), new_where.syntax().clone().into()],
+ );
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Fn {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.ret_type() {
+ Some(Position::after(ty.syntax()))
+ } else if let Some(param_list) = self.param_list() {
+ Some(Position::after(param_list.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Impl {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.self_ty() {
+ Some(Position::after(ty.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Trait {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::TypeAlias {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Struct {
+ fn where_clause_position(&self) -> Option<Position> {
+ let tfl = self.field_list().and_then(|fl| match fl {
+ ast::FieldList::RecordFieldList(_) => None,
+ ast::FieldList::TupleFieldList(it) => Some(it),
+ });
+ if let Some(tfl) = tfl {
+ Some(Position::after(tfl.syntax()))
+ } else if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Enum {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
impl SyntaxEditor {
/// Adds a new generic param to the function using `SyntaxEditor`
pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {
@@ -109,7 +210,7 @@ impl ast::AssocItemList {
normalize_ws_between_braces(editor, self.syntax());
(IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n")
}
- None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
+ None => (IndentLevel::zero(), Position::last_child_of(self.syntax()), "\n"),
},
};
@@ -141,7 +242,7 @@ impl ast::VariantList {
normalize_ws_between_braces(editor, self.syntax());
(IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
}
- None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
+ None => (IndentLevel::zero(), Position::last_child_of(self.syntax())),
},
};
let elements: Vec<SyntaxElement> = vec![