Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-expand/src/builtin_derive_macro.rs')
| -rw-r--r-- | crates/hir-expand/src/builtin_derive_macro.rs | 131 |
1 files changed, 111 insertions, 20 deletions
diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs index 5c1a75132e..7e753663c0 100644 --- a/crates/hir-expand/src/builtin_derive_macro.rs +++ b/crates/hir-expand/src/builtin_derive_macro.rs @@ -1,11 +1,12 @@ //! Builtin derives. use base_db::{CrateOrigin, LangCrateOrigin}; +use std::collections::HashSet; use tracing::debug; use crate::tt::{self, TokenId}; use syntax::{ - ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName}, + ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, PathType}, match_ast, }; @@ -60,8 +61,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> struct BasicAdtInfo { name: tt::Ident, - /// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param. - param_types: Vec<Option<tt::Subtree>>, + /// first field is the name, and + /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param. + /// third fields is where bounds, if any + param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>, + associated_types: Vec<tt::Subtree>, } fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> { @@ -86,18 +90,28 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> { }, } }; - let name = name.ok_or_else(|| { - debug!("parsed item has no name"); - ExpandError::Other("missing name".into()) - })?; - let name_token_id = - token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified); - let name_token = tt::Ident { span: name_token_id, text: name.text().into() }; + let mut param_type_set: HashSet<String> = HashSet::new(); let param_types = params .into_iter() .flat_map(|param_list| param_list.type_or_const_params()) .map(|param| { - if let ast::TypeOrConstParam::Const(param) = param { + let name = { + let this = param.name(); + match this { + Some(x) => { + param_type_set.insert(x.to_string()); + mbe::syntax_node_to_token_tree(x.syntax()).0 + } + None => tt::Subtree::empty(), + } + }; + let bounds = match ¶m { + ast::TypeOrConstParam::Type(x) => { + x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0) + } + ast::TypeOrConstParam::Const(_) => None, + }; + let ty = if let ast::TypeOrConstParam::Const(param) = param { let ty = param .ty() .map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0) @@ -105,27 +119,97 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> { Some(ty) } else { None - } + }; + (name, ty, bounds) }) .collect(); - Ok(BasicAdtInfo { name: name_token, param_types }) + let is_associated_type = |p: &PathType| { + if let Some(p) = p.path() { + if let Some(parent) = p.qualifier() { + if let Some(x) = parent.segment() { + if let Some(x) = x.path_type() { + if let Some(x) = x.path() { + if let Some(pname) = x.as_single_name_ref() { + if param_type_set.contains(&pname.to_string()) { + // <T as Trait>::Assoc + return true; + } + } + } + } + } + if let Some(pname) = parent.as_single_name_ref() { + if param_type_set.contains(&pname.to_string()) { + // T::Assoc + return true; + } + } + } + } + false + }; + let associated_types = node + .descendants() + .filter_map(PathType::cast) + .filter(is_associated_type) + .map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0) + .collect::<Vec<_>>(); + let name = name.ok_or_else(|| { + debug!("parsed item has no name"); + ExpandError::Other("missing name".into()) + })?; + let name_token_id = + token_map.token_by_range(name.syntax().text_range()).unwrap_or_else(TokenId::unspecified); + let name_token = tt::Ident { span: name_token_id, text: name.text().into() }; + Ok(BasicAdtInfo { name: name_token, param_types, associated_types }) } +/// Given that we are deriving a trait `DerivedTrait` for a type like: +/// +/// ```ignore (only-for-syntax-highlight) +/// struct Struct<'a, ..., 'z, A, B: DeclaredTrait, C, ..., Z> where C: WhereTrait { +/// a: A, +/// b: B::Item, +/// b1: <B as DeclaredTrait>::Item, +/// c1: <C as WhereTrait>::Item, +/// c2: Option<<C as WhereTrait>::Item>, +/// ... +/// } +/// ``` +/// +/// create an impl like: +/// +/// ```ignore (only-for-syntax-highlight) +/// impl<'a, ..., 'z, A, B: DeclaredTrait, C, ... Z> where +/// C: WhereTrait, +/// A: DerivedTrait + B1 + ... + BN, +/// B: DerivedTrait + B1 + ... + BN, +/// C: DerivedTrait + B1 + ... + BN, +/// B::Item: DerivedTrait + B1 + ... + BN, +/// <C as WhereTrait>::Item: DerivedTrait + B1 + ... + BN, +/// ... +/// { +/// ... +/// } +/// ``` +/// +/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and +/// therefore does not get bound by the derived trait. fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> { let info = match parse_adt(tt) { Ok(info) => info, Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e), }; + let mut where_block = vec![]; let (params, args): (Vec<_>, Vec<_>) = info .param_types .into_iter() - .enumerate() - .map(|(idx, param_ty)| { - let ident = tt::Leaf::Ident(tt::Ident { - span: tt::TokenId::unspecified(), - text: format!("T{idx}").into(), - }); + .map(|(ident, param_ty, bound)| { let ident_ = ident.clone(); + if let Some(b) = bound { + let ident = ident.clone(); + where_block.push(quote! { #ident : #b , }); + } if let Some(ty) = param_ty { (quote! { const #ident : #ty , }, quote! { #ident_ , }) } else { @@ -134,9 +218,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu } }) .unzip(); + + where_block.extend(info.associated_types.iter().map(|x| { + let x = x.clone(); + let bound = trait_path.clone(); + quote! { #x : #bound , } + })); + let name = info.name; let expanded = quote! { - impl < ##params > #trait_path for #name < ##args > {} + impl < ##params > #trait_path for #name < ##args > where ##where_block {} }; ExpandResult::ok(expanded) } |