Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/syntax/src/syntax_editor/edits.rs')
-rw-r--r--crates/syntax/src/syntax_editor/edits.rs101
1 files changed, 101 insertions, 0 deletions
diff --git a/crates/syntax/src/syntax_editor/edits.rs b/crates/syntax/src/syntax_editor/edits.rs
index 9090f7c9eb..ad08928923 100644
--- a/crates/syntax/src/syntax_editor/edits.rs
+++ b/crates/syntax/src/syntax_editor/edits.rs
@@ -10,6 +10,107 @@ use crate::{
syntax_editor::{Position, SyntaxEditor},
};
+pub trait GetOrCreateWhereClause: ast::HasGenericParams {
+ fn where_clause_position(&self) -> Option<Position>;
+
+ fn get_or_create_where_clause(
+ &self,
+ editor: &mut SyntaxEditor,
+ make: &SyntaxFactory,
+ new_preds: impl Iterator<Item = ast::WherePred>,
+ ) {
+ let existing = self.where_clause();
+ let all_preds: Vec<_> =
+ existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect();
+ let new_where = make.where_clause(all_preds);
+
+ if let Some(existing) = &existing {
+ editor.replace(existing.syntax(), new_where.syntax());
+ } else if let Some(pos) = self.where_clause_position() {
+ editor.insert_all(
+ pos,
+ vec![make.whitespace(" ").into(), new_where.syntax().clone().into()],
+ );
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Fn {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.ret_type() {
+ Some(Position::after(ty.syntax()))
+ } else if let Some(param_list) = self.param_list() {
+ Some(Position::after(param_list.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Impl {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.self_ty() {
+ Some(Position::after(ty.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Trait {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::TypeAlias {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Struct {
+ fn where_clause_position(&self) -> Option<Position> {
+ let tfl = self.field_list().and_then(|fl| match fl {
+ ast::FieldList::RecordFieldList(_) => None,
+ ast::FieldList::TupleFieldList(it) => Some(it),
+ });
+ if let Some(tfl) = tfl {
+ Some(Position::after(tfl.syntax()))
+ } else if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Enum {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
impl SyntaxEditor {
/// Adds a new generic param to the function using `SyntaxEditor`
pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {