Unnamed repository; edit this file 'description' to name the repository.
| -rw-r--r-- | crates/hir-expand/src/builtin/derive_macro.rs | 83 | ||||
| -rw-r--r-- | crates/syntax/src/ast/edit_in_place.rs | 239 |
2 files changed, 41 insertions, 281 deletions
diff --git a/crates/hir-expand/src/builtin/derive_macro.rs b/crates/hir-expand/src/builtin/derive_macro.rs index f208203c93..bfd7dffb05 100644 --- a/crates/hir-expand/src/builtin/derive_macro.rs +++ b/crates/hir-expand/src/builtin/derive_macro.rs @@ -22,8 +22,9 @@ use crate::{ use syntax::{ ast::{ self, AstNode, FieldList, HasAttrs, HasGenericArgs, HasGenericParams, HasModuleItem, - HasName, HasTypeBounds, edit_in_place::GenericParamsOwnerEdit, make, + HasName, HasTypeBounds, make, syntax_factory::SyntaxFactory, }, + syntax_editor::{GetOrCreateWhereClause, SyntaxEditor}, ted, }; @@ -1150,11 +1151,9 @@ fn coerce_pointee_expand( const ADDED_PARAM: &str = "__S"; - let where_clause = strukt.get_or_create_where_clause(); + let mut new_predicates: Vec<ast::WherePred> = Vec::new(); { - let mut new_predicates = Vec::new(); - // # Rewrite generic parameter bounds // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]` // Example: @@ -1196,16 +1195,13 @@ fn coerce_pointee_expand( } else { make::name_ref(¶m_name.text()) }; - new_predicates.push( - make::where_pred( - Either::Right(make::ty_path(make::path_from_segments( - [make::path_segment(new_bounds_target)], - false, - ))), - new_bounds, - ) - .clone_for_update(), - ); + new_predicates.push(make::where_pred( + Either::Right(make::ty_path(make::path_from_segments( + [make::path_segment(new_bounds_target)], + false, + ))), + new_bounds, + )); } } @@ -1235,7 +1231,7 @@ fn coerce_pointee_expand( // // We should also write a few new `where` bounds from `#[pointee] T` to `__S` // as well as any bound that indirectly involves the `#[pointee] T` type. - for predicate in where_clause.predicates() { + for predicate in strukt.where_clause().into_iter().flat_map(|wc| wc.predicates()) { let predicate = predicate.clone_subtree().clone_for_update(); let Some(pred_target) = predicate.ty() else { continue }; @@ -1269,42 +1265,43 @@ fn coerce_pointee_expand( ); } } - - for new_predicate in new_predicates { - where_clause.add_predicate(new_predicate); - } } { // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location // // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. - where_clause.add_predicate( - make::where_pred( - Either::Right(make::ty_path(make::path_from_segments( - [make::path_segment(make::name_ref(&pointee_param_name.text()))], - false, - ))), - [make::type_bound(make::ty_path(make::path_from_segments( - [ - make::path_segment(make::name_ref("core")), - make::path_segment(make::name_ref("marker")), - make::generic_ty_path_segment( - make::name_ref("Unsize"), - [make::type_arg(make::ty_path(make::path_from_segments( - [make::path_segment(make::name_ref(ADDED_PARAM))], - false, - ))) - .into()], - ), - ], - true, - )))], - ) - .clone_for_update(), - ); + new_predicates.push(make::where_pred( + Either::Right(make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(&pointee_param_name.text()))], + false, + ))), + [make::type_bound(make::ty_path(make::path_from_segments( + [ + make::path_segment(make::name_ref("core")), + make::path_segment(make::name_ref("marker")), + make::generic_ty_path_segment( + make::name_ref("Unsize"), + [make::type_arg(make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(ADDED_PARAM))], + false, + ))) + .into()], + ), + ], + true, + )))], + )); } + let (mut editor, strukt) = SyntaxEditor::with_ast_node(strukt); + let make = SyntaxFactory::with_mappings(); + strukt.get_or_create_where_clause(&mut editor, &make, new_predicates.into_iter()); + editor.add_mappings(make.finish_with_mappings()); + let edit = editor.finish(); + let strukt = ast::Struct::cast(edit.new_root().clone()).unwrap(); + let adt = ast::Adt::Struct(strukt.clone()); + let self_for_traits = { // Replace the `#[pointee]` with `__S`. let mut type_param_idx = 0; diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs index 7f59ae4213..c6affec427 100644 --- a/crates/syntax/src/ast/edit_in_place.rs +++ b/crates/syntax/src/ast/edit_in_place.rs @@ -9,221 +9,13 @@ use crate::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, SyntaxNode, SyntaxToken, algo::{self, neighbor}, - ast::{self, HasGenericParams, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, + ast::{self, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, syntax_editor::{Position, SyntaxEditor}, ted, }; use super::{GenericParam, HasName}; -pub trait GenericParamsOwnerEdit: ast::HasGenericParams { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList; - fn get_or_create_where_clause(&self) -> ast::WhereClause; -} - -impl GenericParamsOwnerEdit for ast::Fn { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(fn_token) = self.fn_token() { - ted::Position::after(fn_token) - } else if let Some(param_list) = self.param_list() { - ted::Position::before(param_list.syntax) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = if let Some(ty) = self.ret_type() { - ted::Position::after(ty.syntax()) - } else if let Some(param_list) = self.param_list() { - ted::Position::after(param_list.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Impl { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = match self.impl_token() { - Some(imp_token) => ted::Position::after(imp_token), - None => ted::Position::last_child_of(self.syntax()), - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match self.assoc_item_list() { - Some(items) => ted::Position::before(items.syntax()), - None => ted::Position::last_child_of(self.syntax()), - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Trait { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(trait_token) = self.trait_token() { - ted::Position::after(trait_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match (self.assoc_item_list(), self.semicolon_token()) { - (Some(items), _) => ted::Position::before(items.syntax()), - (_, Some(tok)) => ted::Position::before(tok), - (None, None) => ted::Position::last_child_of(self.syntax()), - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::TypeAlias { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(trait_token) = self.type_token() { - ted::Position::after(trait_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match self.eq_token() { - Some(tok) => ted::Position::before(tok), - None => match self.semicolon_token() { - Some(tok) => ted::Position::before(tok), - None => ted::Position::last_child_of(self.syntax()), - }, - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Struct { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(struct_token) = self.struct_token() { - ted::Position::after(struct_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let tfl = self.field_list().and_then(|fl| match fl { - ast::FieldList::RecordFieldList(_) => None, - ast::FieldList::TupleFieldList(it) => Some(it), - }); - let position = if let Some(tfl) = tfl { - ted::Position::after(tfl.syntax()) - } else if let Some(gpl) = self.generic_param_list() { - ted::Position::after(gpl.syntax()) - } else if let Some(name) = self.name() { - ted::Position::after(name.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Enum { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(enum_token) = self.enum_token() { - ted::Position::after(enum_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = if let Some(gpl) = self.generic_param_list() { - ted::Position::after(gpl.syntax()) - } else if let Some(name) = self.name() { - ted::Position::after(name.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -fn create_where_clause(position: ted::Position) { - let where_clause = make::where_clause(empty()).clone_for_update(); - ted::insert(position, where_clause.syntax()); -} - -fn create_generic_param_list(position: ted::Position) -> ast::GenericParamList { - let gpl = make::generic_param_list(empty()).clone_for_update(); - ted::insert_raw(position, gpl.syntax()); - gpl -} - pub trait AttrsOwnerEdit: ast::HasAttrs { fn remove_attrs_and_docs(&self) { remove_attrs_and_docs(self.syntax()); @@ -879,8 +671,6 @@ impl<N: AstNode + Clone> Indent for N {} #[cfg(test)] mod tests { - use std::fmt; - use parser::Edition; use crate::SourceFile; @@ -893,33 +683,6 @@ mod tests { } #[test] - fn test_create_generic_param_list() { - fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) { - let gpl_owner = ast_mut_from_text::<N>(before); - gpl_owner.get_or_create_generic_param_list(); - assert_eq!(gpl_owner.to_string(), after); - } - - check_create_gpl::<ast::Fn>("fn foo", "fn foo<>"); - check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}"); - - check_create_gpl::<ast::Impl>("impl", "impl<>"); - check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}"); - check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}"); - - check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>"); - check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}"); - - check_create_gpl::<ast::Struct>("struct A", "struct A<>"); - check_create_gpl::<ast::Struct>("struct A;", "struct A<>;"); - check_create_gpl::<ast::Struct>("struct A();", "struct A<>();"); - check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}"); - - check_create_gpl::<ast::Enum>("enum E", "enum E<>"); - check_create_gpl::<ast::Enum>("enum E {", "enum E<> {"); - } - - #[test] fn test_increase_indent() { let arm_list = ast_mut_from_text::<ast::Fn>( "fn foo() { |