Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/move_bounds.rs63
-rw-r--r--crates/syntax/src/syntax_editor.rs2
-rw-r--r--crates/syntax/src/syntax_editor/edits.rs101
3 files changed, 112 insertions, 54 deletions
diff --git a/crates/ide-assists/src/handlers/move_bounds.rs b/crates/ide-assists/src/handlers/move_bounds.rs
index 01a46c3334..79b8bd5d3d 100644
--- a/crates/ide-assists/src/handlers/move_bounds.rs
+++ b/crates/ide-assists/src/handlers/move_bounds.rs
@@ -1,8 +1,8 @@
use either::Either;
use syntax::{
- ast::{self, AstNode, HasGenericParams, HasName, HasTypeBounds, syntax_factory::SyntaxFactory},
+ ast::{self, AstNode, HasName, HasTypeBounds, syntax_factory::SyntaxFactory},
match_ast,
- syntax_editor::{Position, Removable},
+ syntax_editor::{GetOrCreateWhereClause, Removable},
};
use crate::{AssistContext, AssistId, Assists};
@@ -53,61 +53,18 @@ pub(crate) fn move_bounds_to_where_clause(
.filter_map(|param| build_predicate(param, &make))
.collect();
- let existing_where: Option<ast::WhereClause> = match_ast! {
+ match_ast! {
match (&parent) {
- ast::Fn(it) => it.where_clause(),
- ast::Trait(it) => it.where_clause(),
- ast::Impl(it) => it.where_clause(),
- ast::Enum(it) => it.where_clause(),
- ast::Struct(it) => it.where_clause(),
- ast::TypeAlias(it) => it.where_clause(),
- _ => None,
+ ast::Fn(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ ast::Trait(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ ast::Impl(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ ast::Enum(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ ast::Struct(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ ast::TypeAlias(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()),
+ _ => return,
}
};
- let all_preds = existing_where.iter().flat_map(|wc| wc.predicates()).chain(new_preds);
- let new_where = make.where_clause(all_preds);
-
- if let Some(existing) = &existing_where {
- edit.replace(existing.syntax(), new_where.syntax());
- } else {
- let pos: Option<Position> = match_ast! {
- match (&parent) {
- ast::Fn(it) => it.ret_type()
- .map(|t| Position::after(t.syntax()))
- .or_else(|| it.param_list().map(|t| Position::after(t.syntax()))),
- ast::Trait(it) => it.generic_param_list()
- .map(|t| Position::after(t.syntax()))
- .or_else(|| it.name().map(|t| Position::after(t.syntax()))),
- ast::Impl(it) => it.self_ty()
- .map(|t| Position::after(t.syntax())),
- ast::Enum(it) => it.generic_param_list()
- .map(|t| Position::after(t.syntax()))
- .or_else(|| it.name().map(|t| Position::after(t.syntax()))),
- ast::Struct(it) => it.field_list()
- .and_then(|fl| match fl {
- ast::FieldList::TupleFieldList(it) => {
- Some(Position::after(it.syntax()))
- }
- ast::FieldList::RecordFieldList(_) => None,
- })
- .or_else(|| it.generic_param_list()
- .map(|t| Position::after(t.syntax())))
- .or_else(|| it.name().map(|t| Position::after(t.syntax()))),
- ast::TypeAlias(it) => it.generic_param_list()
- .map(|t| Position::after(t.syntax()))
- .or_else(|| it.name().map(|t| Position::after(t.syntax()))),
- _ => None,
- }
- };
- if let Some(pos) = pos {
- edit.insert_all(
- pos,
- vec![make.whitespace(" ").into(), new_where.syntax().clone().into()],
- );
- }
- }
-
for generic_param in type_param_list.generic_params() {
let param: &dyn HasTypeBounds = match &generic_param {
ast::GenericParam::TypeParam(t) => t,
diff --git a/crates/syntax/src/syntax_editor.rs b/crates/syntax/src/syntax_editor.rs
index 5683d891be..e6937e4d0f 100644
--- a/crates/syntax/src/syntax_editor.rs
+++ b/crates/syntax/src/syntax_editor.rs
@@ -20,7 +20,7 @@ mod edit_algo;
mod edits;
mod mapping;
-pub use edits::Removable;
+pub use edits::{GetOrCreateWhereClause, Removable};
pub use mapping::{SyntaxMapping, SyntaxMappingBuilder};
#[derive(Debug)]
diff --git a/crates/syntax/src/syntax_editor/edits.rs b/crates/syntax/src/syntax_editor/edits.rs
index 9090f7c9eb..ad08928923 100644
--- a/crates/syntax/src/syntax_editor/edits.rs
+++ b/crates/syntax/src/syntax_editor/edits.rs
@@ -10,6 +10,107 @@ use crate::{
syntax_editor::{Position, SyntaxEditor},
};
+pub trait GetOrCreateWhereClause: ast::HasGenericParams {
+ fn where_clause_position(&self) -> Option<Position>;
+
+ fn get_or_create_where_clause(
+ &self,
+ editor: &mut SyntaxEditor,
+ make: &SyntaxFactory,
+ new_preds: impl Iterator<Item = ast::WherePred>,
+ ) {
+ let existing = self.where_clause();
+ let all_preds: Vec<_> =
+ existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect();
+ let new_where = make.where_clause(all_preds);
+
+ if let Some(existing) = &existing {
+ editor.replace(existing.syntax(), new_where.syntax());
+ } else if let Some(pos) = self.where_clause_position() {
+ editor.insert_all(
+ pos,
+ vec![make.whitespace(" ").into(), new_where.syntax().clone().into()],
+ );
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Fn {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.ret_type() {
+ Some(Position::after(ty.syntax()))
+ } else if let Some(param_list) = self.param_list() {
+ Some(Position::after(param_list.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Impl {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(ty) = self.self_ty() {
+ Some(Position::after(ty.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Trait {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::TypeAlias {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Struct {
+ fn where_clause_position(&self) -> Option<Position> {
+ let tfl = self.field_list().and_then(|fl| match fl {
+ ast::FieldList::RecordFieldList(_) => None,
+ ast::FieldList::TupleFieldList(it) => Some(it),
+ });
+ if let Some(tfl) = tfl {
+ Some(Position::after(tfl.syntax()))
+ } else if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
+impl GetOrCreateWhereClause for ast::Enum {
+ fn where_clause_position(&self) -> Option<Position> {
+ if let Some(gpl) = self.generic_param_list() {
+ Some(Position::after(gpl.syntax()))
+ } else if let Some(name) = self.name() {
+ Some(Position::after(name.syntax()))
+ } else {
+ Some(Position::last_child_of(self.syntax()))
+ }
+ }
+}
+
impl SyntaxEditor {
/// Adds a new generic param to the function using `SyntaxEditor`
pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) {