Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-expand/src/builtin/derive_macro.rs')
| -rw-r--r-- | crates/hir-expand/src/builtin/derive_macro.rs | 86 |
1 files changed, 78 insertions, 8 deletions
diff --git a/crates/hir-expand/src/builtin/derive_macro.rs b/crates/hir-expand/src/builtin/derive_macro.rs index 28b6812139..f8fb700d55 100644 --- a/crates/hir-expand/src/builtin/derive_macro.rs +++ b/crates/hir-expand/src/builtin/derive_macro.rs @@ -80,9 +80,15 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> BuiltinDeriveExpander::find_by_name(ident) } +#[derive(Clone, Copy)] +enum HasDefault { + Yes, + No, +} + #[derive(Clone)] enum VariantShape { - Struct(Vec<tt::Ident>), + Struct(Vec<(tt::Ident, HasDefault)>), Tuple(usize), Unit, } @@ -98,7 +104,7 @@ impl VariantShape { fn field_names(&self, span: Span) -> Vec<tt::Ident> { match self { - VariantShape::Struct(s) => s.clone(), + VariantShape::Struct(s) => s.iter().map(|(ident, _)| ident.clone()).collect(), VariantShape::Tuple(n) => tuple_field_iterator(span, *n).collect(), VariantShape::Unit => vec![], } @@ -112,7 +118,7 @@ impl VariantShape { ) -> tt::TopSubtree { match self { VariantShape::Struct(fields) => { - let fields = fields.iter().map(|it| { + let fields = fields.iter().map(|(it, _)| { let mapped = field_map(it); quote! {span => #it : #mapped , } }); @@ -135,6 +141,63 @@ impl VariantShape { } } + fn default_expand( + &self, + path: tt::TopSubtree, + span: Span, + field_map: impl Fn(&tt::Ident) -> tt::TopSubtree, + ) -> tt::TopSubtree { + match self { + VariantShape::Struct(fields) => { + let contains_default = fields.iter().any(|it| matches!(it.1, HasDefault::Yes)); + let fields = fields + .iter() + .filter_map(|(it, has_default)| match has_default { + HasDefault::Yes => None, + HasDefault::No => Some(it), + }) + .map(|it| { + let mapped = field_map(it); + quote! {span => #it : #mapped , } + }); + if contains_default { + let mut double_dots = + tt::TopSubtreeBuilder::new(tt::Delimiter::invisible_spanned(span)); + double_dots.push(tt::Leaf::Punct(tt::Punct { + char: '.', + spacing: tt::Spacing::Joint, + span, + })); + double_dots.push(tt::Leaf::Punct(tt::Punct { + char: '.', + spacing: tt::Spacing::Alone, + span, + })); + let double_dots = double_dots.build(); + quote! {span => + #path { ##fields #double_dots } + } + } else { + quote! {span => + #path { ##fields } + } + } + } + &VariantShape::Tuple(n) => { + let fields = tuple_field_iterator(span, n).map(|it| { + let mapped = field_map(&it); + quote! {span => + #mapped , + } + }); + quote! {span => + #path ( ##fields ) + } + } + VariantShape::Unit => path, + } + } + fn from( call_site: Span, tm: &ExpansionSpanMap, @@ -144,8 +207,15 @@ impl VariantShape { None => VariantShape::Unit, Some(FieldList::RecordFieldList(it)) => VariantShape::Struct( it.fields() - .map(|it| it.name()) - .map(|it| name_to_token(call_site, tm, it)) + .map(|it| { + ( + it.name(), + if it.expr().is_some() { HasDefault::Yes } else { HasDefault::No }, + ) + }) + .map(|(it, has_default)| { + name_to_token(call_site, tm, it).map(|ident| (ident, has_default)) + }) .collect::<Result<_, _>>()?, ), Some(FieldList::TupleFieldList(it)) => VariantShape::Tuple(it.fields().count()), @@ -601,7 +671,7 @@ fn default_expand( let body = match &adt.shape { AdtShape::Struct(fields) => { let name = &adt.name; - fields.as_pattern_map( + fields.default_expand( quote!(span =>#name), span, |_| quote!(span =>#krate::default::Default::default()), @@ -611,7 +681,7 @@ fn default_expand( if let Some(d) = default_variant { let (name, fields) = &variants[*d]; let adt_name = &adt.name; - fields.as_pattern_map( + fields.default_expand( quote!(span =>#adt_name :: #name), span, |_| quote!(span =>#krate::default::Default::default()), @@ -643,7 +713,7 @@ fn debug_expand( expand_simple_derive(db, span, tt, quote! {span => #krate::fmt::Debug }, |adt| { let for_variant = |name: String, v: &VariantShape| match v { VariantShape::Struct(fields) => { - let for_fields = fields.iter().map(|it| { + let for_fields = fields.iter().map(|(it, _)| { let x_string = it.to_string(); quote! {span => .field(#x_string, & #it) |