Unnamed repository; edit this file 'description' to name the repository.
| -rw-r--r-- | crates/ide-assists/src/handlers/move_bounds.rs | 63 | ||||
| -rw-r--r-- | crates/syntax/src/syntax_editor.rs | 2 | ||||
| -rw-r--r-- | crates/syntax/src/syntax_editor/edits.rs | 101 |
3 files changed, 112 insertions, 54 deletions
diff --git a/crates/ide-assists/src/handlers/move_bounds.rs b/crates/ide-assists/src/handlers/move_bounds.rs index 01a46c3334..79b8bd5d3d 100644 --- a/crates/ide-assists/src/handlers/move_bounds.rs +++ b/crates/ide-assists/src/handlers/move_bounds.rs @@ -1,8 +1,8 @@ use either::Either; use syntax::{ - ast::{self, AstNode, HasGenericParams, HasName, HasTypeBounds, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, HasName, HasTypeBounds, syntax_factory::SyntaxFactory}, match_ast, - syntax_editor::{Position, Removable}, + syntax_editor::{GetOrCreateWhereClause, Removable}, }; use crate::{AssistContext, AssistId, Assists}; @@ -53,61 +53,18 @@ pub(crate) fn move_bounds_to_where_clause( .filter_map(|param| build_predicate(param, &make)) .collect(); - let existing_where: Option<ast::WhereClause> = match_ast! { + match_ast! { match (&parent) { - ast::Fn(it) => it.where_clause(), - ast::Trait(it) => it.where_clause(), - ast::Impl(it) => it.where_clause(), - ast::Enum(it) => it.where_clause(), - ast::Struct(it) => it.where_clause(), - ast::TypeAlias(it) => it.where_clause(), - _ => None, + ast::Fn(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Trait(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Impl(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Enum(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Struct(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::TypeAlias(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + _ => return, } }; - let all_preds = existing_where.iter().flat_map(|wc| wc.predicates()).chain(new_preds); - let new_where = make.where_clause(all_preds); - - if let Some(existing) = &existing_where { - edit.replace(existing.syntax(), new_where.syntax()); - } else { - let pos: Option<Position> = match_ast! { - match (&parent) { - ast::Fn(it) => it.ret_type() - .map(|t| Position::after(t.syntax())) - .or_else(|| it.param_list().map(|t| Position::after(t.syntax()))), - ast::Trait(it) => it.generic_param_list() - .map(|t| Position::after(t.syntax())) - .or_else(|| it.name().map(|t| Position::after(t.syntax()))), - ast::Impl(it) => it.self_ty() - .map(|t| Position::after(t.syntax())), - ast::Enum(it) => it.generic_param_list() - .map(|t| Position::after(t.syntax())) - .or_else(|| it.name().map(|t| Position::after(t.syntax()))), - ast::Struct(it) => it.field_list() - .and_then(|fl| match fl { - ast::FieldList::TupleFieldList(it) => { - Some(Position::after(it.syntax())) - } - ast::FieldList::RecordFieldList(_) => None, - }) - .or_else(|| it.generic_param_list() - .map(|t| Position::after(t.syntax()))) - .or_else(|| it.name().map(|t| Position::after(t.syntax()))), - ast::TypeAlias(it) => it.generic_param_list() - .map(|t| Position::after(t.syntax())) - .or_else(|| it.name().map(|t| Position::after(t.syntax()))), - _ => None, - } - }; - if let Some(pos) = pos { - edit.insert_all( - pos, - vec![make.whitespace(" ").into(), new_where.syntax().clone().into()], - ); - } - } - for generic_param in type_param_list.generic_params() { let param: &dyn HasTypeBounds = match &generic_param { ast::GenericParam::TypeParam(t) => t, diff --git a/crates/syntax/src/syntax_editor.rs b/crates/syntax/src/syntax_editor.rs index 5683d891be..e6937e4d0f 100644 --- a/crates/syntax/src/syntax_editor.rs +++ b/crates/syntax/src/syntax_editor.rs @@ -20,7 +20,7 @@ mod edit_algo; mod edits; mod mapping; -pub use edits::Removable; +pub use edits::{GetOrCreateWhereClause, Removable}; pub use mapping::{SyntaxMapping, SyntaxMappingBuilder}; #[derive(Debug)] diff --git a/crates/syntax/src/syntax_editor/edits.rs b/crates/syntax/src/syntax_editor/edits.rs index 9090f7c9eb..ad08928923 100644 --- a/crates/syntax/src/syntax_editor/edits.rs +++ b/crates/syntax/src/syntax_editor/edits.rs @@ -10,6 +10,107 @@ use crate::{ syntax_editor::{Position, SyntaxEditor}, }; +pub trait GetOrCreateWhereClause: ast::HasGenericParams { + fn where_clause_position(&self) -> Option<Position>; + + fn get_or_create_where_clause( + &self, + editor: &mut SyntaxEditor, + make: &SyntaxFactory, + new_preds: impl Iterator<Item = ast::WherePred>, + ) { + let existing = self.where_clause(); + let all_preds: Vec<_> = + existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect(); + let new_where = make.where_clause(all_preds); + + if let Some(existing) = &existing { + editor.replace(existing.syntax(), new_where.syntax()); + } else if let Some(pos) = self.where_clause_position() { + editor.insert_all( + pos, + vec![make.whitespace(" ").into(), new_where.syntax().clone().into()], + ); + } + } +} + +impl GetOrCreateWhereClause for ast::Fn { + fn where_clause_position(&self) -> Option<Position> { + if let Some(ty) = self.ret_type() { + Some(Position::after(ty.syntax())) + } else if let Some(param_list) = self.param_list() { + Some(Position::after(param_list.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Impl { + fn where_clause_position(&self) -> Option<Position> { + if let Some(ty) = self.self_ty() { + Some(Position::after(ty.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Trait { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::TypeAlias { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Struct { + fn where_clause_position(&self) -> Option<Position> { + let tfl = self.field_list().and_then(|fl| match fl { + ast::FieldList::RecordFieldList(_) => None, + ast::FieldList::TupleFieldList(it) => Some(it), + }); + if let Some(tfl) = tfl { + Some(Position::after(tfl.syntax())) + } else if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + +impl GetOrCreateWhereClause for ast::Enum { + fn where_clause_position(&self) -> Option<Position> { + if let Some(gpl) = self.generic_param_list() { + Some(Position::after(gpl.syntax())) + } else if let Some(name) = self.name() { + Some(Position::after(name.syntax())) + } else { + Some(Position::last_child_of(self.syntax())) + } + } +} + impl SyntaxEditor { /// Adds a new generic param to the function using `SyntaxEditor` pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) { |