Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-expand/src/builtin_derive_macro.rs')
-rw-r--r--crates/hir-expand/src/builtin_derive_macro.rs154
1 files changed, 81 insertions, 73 deletions
diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs
index 54706943ac..3d1e272b90 100644
--- a/crates/hir-expand/src/builtin_derive_macro.rs
+++ b/crates/hir-expand/src/builtin_derive_macro.rs
@@ -4,17 +4,16 @@ use ::tt::Ident;
use base_db::{CrateOrigin, LangCrateOrigin};
use itertools::izip;
use mbe::TokenMap;
-use std::collections::HashSet;
+use rustc_hash::FxHashSet;
use stdx::never;
use tracing::debug;
-use crate::tt::{self, TokenId};
-use syntax::{
- ast::{
- self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
- HasTypeBounds, PathType,
- },
- match_ast,
+use crate::{
+ name::{AsName, Name},
+ tt::{self, TokenId},
+};
+use syntax::ast::{
+ self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
};
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@@ -195,39 +194,52 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
let (parsed, token_map) = mbe::token_tree_to_syntax_node(tt, mbe::TopEntryPoint::MacroItems);
let macro_items = ast::MacroItems::cast(parsed.syntax_node()).ok_or_else(|| {
debug!("derive node didn't parse");
- ExpandError::Other("invalid item definition".into())
+ ExpandError::other("invalid item definition")
})?;
let item = macro_items.items().next().ok_or_else(|| {
debug!("no module item parsed");
- ExpandError::Other("no item found".into())
+ ExpandError::other("no item found")
})?;
- let node = item.syntax();
- let (name, params, shape) = match_ast! {
- match node {
- ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
- ast::Enum(it) => {
- let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
- (
- it.name(),
- it.generic_param_list(),
- AdtShape::Enum {
- default_variant,
- variants: it.variant_list()
- .into_iter()
- .flat_map(|x| x.variants())
- .map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
- }
- )
- },
- ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
- _ => {
- debug!("unexpected node is {:?}", node);
- return Err(ExpandError::Other("expected struct, enum or union".into()))
- },
+ let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
+ debug!("expected adt, found: {:?}", item);
+ ExpandError::other("expected struct, enum or union")
+ })?;
+ let (name, generic_param_list, shape) = match &adt {
+ ast::Adt::Struct(it) => (
+ it.name(),
+ it.generic_param_list(),
+ AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
+ ),
+ ast::Adt::Enum(it) => {
+ let default_variant = it
+ .variant_list()
+ .into_iter()
+ .flat_map(|x| x.variants())
+ .position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
+ (
+ it.name(),
+ it.generic_param_list(),
+ AdtShape::Enum {
+ default_variant,
+ variants: it
+ .variant_list()
+ .into_iter()
+ .flat_map(|x| x.variants())
+ .map(|x| {
+ Ok((
+ name_to_token(&token_map, x.name())?,
+ VariantShape::from(x.field_list(), &token_map)?,
+ ))
+ })
+ .collect::<Result<_, ExpandError>>()?,
+ },
+ )
}
+ ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
};
- let mut param_type_set: HashSet<String> = HashSet::new();
- let param_types = params
+
+ let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
+ let param_types = generic_param_list
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
@@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
let this = param.name();
match this {
Some(x) => {
- param_type_set.insert(x.to_string());
+ param_type_set.insert(x.as_name());
mbe::syntax_node_to_token_tree(x.syntax()).0
}
None => tt::Subtree::empty(),
@@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
(name, ty, bounds)
})
.collect();
- let is_associated_type = |p: &PathType| {
- if let Some(p) = p.path() {
- if let Some(parent) = p.qualifier() {
- if let Some(x) = parent.segment() {
- if let Some(x) = x.path_type() {
- if let Some(x) = x.path() {
- if let Some(pname) = x.as_single_name_ref() {
- if param_type_set.contains(&pname.to_string()) {
- // <T as Trait>::Assoc
- return true;
- }
- }
- }
- }
- }
- if let Some(pname) = parent.as_single_name_ref() {
- if param_type_set.contains(&pname.to_string()) {
- // T::Assoc
- return true;
- }
- }
- }
- }
- false
+
+ // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
+ // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
+ // also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
+ // does not do that for some unknown reason.
+ //
+ // See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
+ // [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
+
+ // It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
+ // `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
+ // we should not inspect `ast::PathType`s in parameter bounds and where clauses.
+ let field_list = match adt {
+ ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
+ ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
+ ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
};
- let associated_types = node
- .descendants()
- .filter_map(PathType::cast)
- .filter(is_associated_type)
+ let associated_types = field_list
+ .into_iter()
+ .flat_map(|it| it.descendants())
+ .filter_map(ast::PathType::cast)
+ .filter_map(|p| {
+ let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
+ param_type_set.contains(&name).then_some(p)
+ })
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
- .collect::<Vec<_>>();
+ .collect();
let name_token = name_to_token(&token_map, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
}
@@ -297,7 +305,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Ident, ExpandError> {
let name = name.ok_or_else(|| {
debug!("parsed item has no name");
- ExpandError::Other("missing name".into())
+ ExpandError::other("missing name")
})?;
let name_token_id =
token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified);
@@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
/// }
/// ```
///
-/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
+/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(
tt: &tt::Subtree,
trait_path: tt::Subtree,
- trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
+ make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
};
- let trait_body = trait_body(&info);
+ let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
@@ -605,7 +613,7 @@ fn hash_expand(
span: tt::TokenId::unspecified(),
};
return quote! {
- fn hash<H: #krate::hash::Hasher>(&self, state: &mut H) {
+ fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
match #star self {}
}
};
@@ -613,7 +621,7 @@ fn hash_expand(
let arms = adt.shape.as_pattern(&adt.name).into_iter().zip(adt.shape.field_names()).map(
|(pat, names)| {
let expr = {
- let it = names.iter().map(|x| quote! { #x . hash(state); });
+ let it = names.iter().map(|x| quote! { #x . hash(ra_expand_state); });
quote! { {
##it
} }
@@ -625,8 +633,8 @@ fn hash_expand(
},
);
quote! {
- fn hash<H: #krate::hash::Hasher>(&self, state: &mut H) {
- #krate::mem::discriminant(self).hash(state);
+ fn hash<H: #krate::hash::Hasher>(&self, ra_expand_state: &mut H) {
+ #krate::mem::discriminant(self).hash(ra_expand_state);
match self {
##arms
}