Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-expand/src/builtin/derive_macro.rs')
-rw-r--r--crates/hir-expand/src/builtin/derive_macro.rs706
1 files changed, 628 insertions, 78 deletions
diff --git a/crates/hir-expand/src/builtin/derive_macro.rs b/crates/hir-expand/src/builtin/derive_macro.rs
index 7d3e8deaf0..4510a593af 100644
--- a/crates/hir-expand/src/builtin/derive_macro.rs
+++ b/crates/hir-expand/src/builtin/derive_macro.rs
@@ -1,9 +1,10 @@
//! Builtin derives.
use intern::sym;
-use itertools::izip;
+use itertools::{izip, Itertools};
+use parser::SyntaxKind;
use rustc_hash::FxHashSet;
-use span::{MacroCallId, Span};
+use span::{MacroCallId, Span, SyntaxContextId};
use stdx::never;
use syntax_bridge::DocCommentDesugarMode;
use tracing::debug;
@@ -16,8 +17,12 @@ use crate::{
span_map::ExpansionSpanMap,
tt, ExpandError, ExpandResult,
};
-use syntax::ast::{
- self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
+use syntax::{
+ ast::{
+ self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, FieldList, HasAttrs,
+ HasGenericArgs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
+ },
+ ted,
};
macro_rules! register_builtin {
@@ -28,7 +33,7 @@ macro_rules! register_builtin {
}
impl BuiltinDeriveExpander {
- pub fn expander(&self) -> fn(Span, &tt::Subtree) -> ExpandResult<tt::Subtree> {
+ pub fn expander(&self) -> fn(Span, &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
match *self {
$( BuiltinDeriveExpander::$trait => $expand, )*
}
@@ -50,9 +55,9 @@ impl BuiltinDeriveExpander {
&self,
db: &dyn ExpandDatabase,
id: MacroCallId,
- tt: &tt::Subtree,
+ tt: &tt::TopSubtree,
span: Span,
- ) -> ExpandResult<tt::Subtree> {
+ ) -> ExpandResult<tt::TopSubtree> {
let span = span_with_def_site_ctxt(db, span, id);
self.expander()(span, tt)
}
@@ -67,13 +72,15 @@ register_builtin! {
Ord => ord_expand,
PartialOrd => partial_ord_expand,
Eq => eq_expand,
- PartialEq => partial_eq_expand
+ PartialEq => partial_eq_expand,
+ CoercePointee => coerce_pointee_expand
}
pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> {
BuiltinDeriveExpander::find_by_name(ident)
}
+#[derive(Clone)]
enum VariantShape {
Struct(Vec<tt::Ident>),
Tuple(usize),
@@ -85,7 +92,7 @@ fn tuple_field_iterator(span: Span, n: usize) -> impl Iterator<Item = tt::Ident>
}
impl VariantShape {
- fn as_pattern(&self, path: tt::Subtree, span: Span) -> tt::Subtree {
+ fn as_pattern(&self, path: tt::TopSubtree, span: Span) -> tt::TopSubtree {
self.as_pattern_map(path, span, |it| quote!(span => #it))
}
@@ -99,10 +106,10 @@ impl VariantShape {
fn as_pattern_map(
&self,
- path: tt::Subtree,
+ path: tt::TopSubtree,
span: Span,
- field_map: impl Fn(&tt::Ident) -> tt::Subtree,
- ) -> tt::Subtree {
+ field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
+ ) -> tt::TopSubtree {
match self {
VariantShape::Struct(fields) => {
let fields = fields.iter().map(|it| {
@@ -147,6 +154,7 @@ impl VariantShape {
}
}
+#[derive(Clone)]
enum AdtShape {
Struct(VariantShape),
Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> },
@@ -154,7 +162,7 @@ enum AdtShape {
}
impl AdtShape {
- fn as_pattern(&self, span: Span, name: &tt::Ident) -> Vec<tt::Subtree> {
+ fn as_pattern(&self, span: Span, name: &tt::Ident) -> Vec<tt::TopSubtree> {
self.as_pattern_map(name, |it| quote!(span =>#it), span)
}
@@ -176,9 +184,9 @@ impl AdtShape {
fn as_pattern_map(
&self,
name: &tt::Ident,
- field_map: impl Fn(&tt::Ident) -> tt::Subtree,
+ field_map: impl Fn(&tt::Ident) -> tt::TopSubtree,
span: Span,
- ) -> Vec<tt::Subtree> {
+ ) -> Vec<tt::TopSubtree> {
match self {
AdtShape::Struct(s) => {
vec![s.as_pattern_map(quote! {span => #name }, span, field_map)]
@@ -197,30 +205,38 @@ impl AdtShape {
}
}
+#[derive(Clone)]
struct BasicAdtInfo {
name: tt::Ident,
shape: AdtShape,
/// first field is the name, and
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any
- param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
- where_clause: Vec<tt::Subtree>,
- associated_types: Vec<tt::Subtree>,
+ param_types: Vec<AdtParam>,
+ where_clause: Vec<tt::TopSubtree>,
+ associated_types: Vec<tt::TopSubtree>,
}
-fn parse_adt(tt: &tt::Subtree, call_site: Span) -> Result<BasicAdtInfo, ExpandError> {
- let (parsed, tm) = &syntax_bridge::token_tree_to_syntax_node(
- tt,
- syntax_bridge::TopEntryPoint::MacroItems,
- parser::Edition::CURRENT_FIXME,
- );
- let macro_items = ast::MacroItems::cast(parsed.syntax_node())
- .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
- let item =
- macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
- let adt = &ast::Adt::cast(item.syntax().clone())
- .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
- let (name, generic_param_list, where_clause, shape) = match adt {
+#[derive(Clone)]
+struct AdtParam {
+ name: tt::TopSubtree,
+ /// `None` if this is a type parameter.
+ const_ty: Option<tt::TopSubtree>,
+ bounds: Option<tt::TopSubtree>,
+}
+
+// FIXME: This whole thing needs a refactor. Each derive requires its special values, and the result is a mess.
+fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, ExpandError> {
+ let (adt, tm) = to_adt_syntax(tt, call_site)?;
+ parse_adt_from_syntax(&adt, &tm, call_site)
+}
+
+fn parse_adt_from_syntax(
+ adt: &ast::Adt,
+ tm: &span::SpanMap<SyntaxContextId>,
+ call_site: Span,
+) -> Result<BasicAdtInfo, ExpandError> {
+ let (name, generic_param_list, where_clause, shape) = match &adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
@@ -276,7 +292,7 @@ fn parse_adt(tt: &tt::Subtree, call_site: Span) -> Result<BasicAdtInfo, ExpandEr
)
}
None => {
- tt::Subtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
+ tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
}
}
};
@@ -291,7 +307,7 @@ fn parse_adt(tt: &tt::Subtree, call_site: Span) -> Result<BasicAdtInfo, ExpandEr
}),
ast::TypeOrConstParam::Const(_) => None,
};
- let ty = if let ast::TypeOrConstParam::Const(param) = param {
+ let const_ty = if let ast::TypeOrConstParam::Const(param) = param {
let ty = param
.ty()
.map(|ty| {
@@ -303,13 +319,13 @@ fn parse_adt(tt: &tt::Subtree, call_site: Span) -> Result<BasicAdtInfo, ExpandEr
)
})
.unwrap_or_else(|| {
- tt::Subtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
+ tt::TopSubtree::empty(::tt::DelimSpan { open: call_site, close: call_site })
});
Some(ty)
} else {
None
};
- (name, ty, bounds)
+ AdtParam { name, const_ty, bounds }
})
.collect();
@@ -365,6 +381,24 @@ fn parse_adt(tt: &tt::Subtree, call_site: Span) -> Result<BasicAdtInfo, ExpandEr
Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
}
+fn to_adt_syntax(
+ tt: &tt::TopSubtree,
+ call_site: Span,
+) -> Result<(ast::Adt, span::SpanMap<SyntaxContextId>), ExpandError> {
+ let (parsed, tm) = syntax_bridge::token_tree_to_syntax_node(
+ tt,
+ syntax_bridge::TopEntryPoint::MacroItems,
+ parser::Edition::CURRENT_FIXME,
+ );
+ let macro_items = ast::MacroItems::cast(parsed.syntax_node())
+ .ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
+ let item =
+ macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
+ let adt = ast::Adt::cast(item.syntax().clone())
+ .ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
+ Ok((adt, tm))
+}
+
fn name_to_token(
call_site: Span,
token_map: &ExpansionSpanMap,
@@ -413,59 +447,85 @@ fn name_to_token(
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(
invoc_span: Span,
- tt: &tt::Subtree,
- trait_path: tt::Subtree,
- make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
-) -> ExpandResult<tt::Subtree> {
+ tt: &tt::TopSubtree,
+ trait_path: tt::TopSubtree,
+ make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
+) -> ExpandResult<tt::TopSubtree> {
let info = match parse_adt(tt, invoc_span) {
Ok(info) => info,
Err(e) => {
return ExpandResult::new(
- tt::Subtree::empty(tt::DelimSpan { open: invoc_span, close: invoc_span }),
+ tt::TopSubtree::empty(tt::DelimSpan { open: invoc_span, close: invoc_span }),
e,
)
}
};
+ ExpandResult::ok(expand_simple_derive_with_parsed(
+ invoc_span,
+ info,
+ trait_path,
+ make_trait_body,
+ true,
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
+ ))
+}
+
+fn expand_simple_derive_with_parsed(
+ invoc_span: Span,
+ info: BasicAdtInfo,
+ trait_path: tt::TopSubtree,
+ make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
+ constrain_to_trait: bool,
+ extra_impl_params: tt::TopSubtree,
+) -> tt::TopSubtree {
let trait_body = make_trait_body(&info);
let mut where_block: Vec<_> =
info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
.into_iter()
- .map(|(ident, param_ty, bound)| {
- let ident_ = ident.clone();
- if let Some(b) = bound {
- let ident = ident.clone();
- where_block.push(quote! {invoc_span => #ident : #b , });
- }
- if let Some(ty) = param_ty {
- (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident_ , })
+ .map(|param| {
+ let ident = param.name;
+ if let Some(b) = param.bounds {
+ let ident2 = ident.clone();
+ where_block.push(quote! {invoc_span => #ident2 : #b , });
+ }
+ if let Some(ty) = param.const_ty {
+ let ident2 = ident.clone();
+ (quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , })
} else {
let bound = trait_path.clone();
- (quote! {invoc_span => #ident : #bound , }, quote! {invoc_span => #ident_ , })
+ let ident2 = ident.clone();
+ let param = if constrain_to_trait {
+ quote! {invoc_span => #ident : #bound , }
+ } else {
+ quote! {invoc_span => #ident , }
+ };
+ (param, quote! {invoc_span => #ident2 , })
}
})
.unzip();
- where_block.extend(info.associated_types.iter().map(|it| {
- let it = it.clone();
- let bound = trait_path.clone();
- quote! {invoc_span => #it : #bound , }
- }));
+ if constrain_to_trait {
+ where_block.extend(info.associated_types.iter().map(|it| {
+ let it = it.clone();
+ let bound = trait_path.clone();
+ quote! {invoc_span => #it : #bound , }
+ }));
+ }
let name = info.name;
- let expanded = quote! {invoc_span =>
- impl < ##params > #trait_path for #name < ##args > where ##where_block { #trait_body }
- };
- ExpandResult::ok(expanded)
+ quote! {invoc_span =>
+ impl < ##params #extra_impl_params > #trait_path for #name < ##args > where ##where_block { #trait_body }
+ }
}
-fn copy_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn copy_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::marker::Copy }, |_| quote! {span =>})
}
-fn clone_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn clone_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::clone::Clone }, |adt| {
if matches!(adt.shape, AdtShape::Union) {
@@ -505,18 +565,18 @@ fn clone_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
}
/// This function exists since `quote! {span => => }` doesn't work.
-fn fat_arrow(span: Span) -> tt::Subtree {
+fn fat_arrow(span: Span) -> tt::TopSubtree {
let eq = tt::Punct { char: '=', spacing: ::tt::Spacing::Joint, span };
quote! {span => #eq> }
}
/// This function exists since `quote! {span => && }` doesn't work.
-fn and_and(span: Span) -> tt::Subtree {
+fn and_and(span: Span) -> tt::TopSubtree {
let and = tt::Punct { char: '&', spacing: ::tt::Spacing::Joint, span };
quote! {span => #and& }
}
-fn default_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn default_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = &dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::default::Default }, |adt| {
let body = match &adt.shape {
@@ -555,7 +615,7 @@ fn default_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
})
}
-fn debug_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn debug_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = &dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::fmt::Debug }, |adt| {
let for_variant = |name: String, v: &VariantShape| match v {
@@ -627,7 +687,7 @@ fn debug_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
})
}
-fn hash_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn hash_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = &dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::hash::Hash }, |adt| {
if matches!(adt.shape, AdtShape::Union) {
@@ -674,12 +734,12 @@ fn hash_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
})
}
-fn eq_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn eq_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::cmp::Eq }, |_| quote! {span =>})
}
-fn partial_eq_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn partial_eq_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::cmp::PartialEq }, |adt| {
if matches!(adt.shape, AdtShape::Union) {
@@ -731,7 +791,7 @@ fn self_and_other_patterns(
adt: &BasicAdtInfo,
name: &tt::Ident,
span: Span,
-) -> (Vec<tt::Subtree>, Vec<tt::Subtree>) {
+) -> (Vec<tt::TopSubtree>, Vec<tt::TopSubtree>) {
let self_patterns = adt.shape.as_pattern_map(
name,
|it| {
@@ -751,16 +811,16 @@ fn self_and_other_patterns(
(self_patterns, other_patterns)
}
-fn ord_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn ord_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = &dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::cmp::Ord }, |adt| {
fn compare(
krate: &tt::Ident,
- left: tt::Subtree,
- right: tt::Subtree,
- rest: tt::Subtree,
+ left: tt::TopSubtree,
+ right: tt::TopSubtree,
+ rest: tt::TopSubtree,
span: Span,
- ) -> tt::Subtree {
+ ) -> tt::TopSubtree {
let fat_arrow1 = fat_arrow(span);
let fat_arrow2 = fat_arrow(span);
quote! {span =>
@@ -809,16 +869,16 @@ fn ord_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
})
}
-fn partial_ord_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
+fn partial_ord_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let krate = &dollar_crate(span);
expand_simple_derive(span, tt, quote! {span => #krate::cmp::PartialOrd }, |adt| {
fn compare(
krate: &tt::Ident,
- left: tt::Subtree,
- right: tt::Subtree,
- rest: tt::Subtree,
+ left: tt::TopSubtree,
+ right: tt::TopSubtree,
+ rest: tt::TopSubtree,
span: Span,
- ) -> tt::Subtree {
+ ) -> tt::TopSubtree {
let fat_arrow1 = fat_arrow(span);
let fat_arrow2 = fat_arrow(span);
quote! {span =>
@@ -871,3 +931,493 @@ fn partial_ord_expand(span: Span, tt: &tt::Subtree) -> ExpandResult<tt::Subtree>
}
})
}
+
+fn coerce_pointee_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
+ let (adt, _span_map) = match to_adt_syntax(tt, span) {
+ Ok(it) => it,
+ Err(err) => {
+ return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
+ }
+ };
+ let adt = adt.clone_for_update();
+ let ast::Adt::Struct(strukt) = &adt else {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"),
+ );
+ };
+ let has_at_least_one_field = strukt
+ .field_list()
+ .map(|it| match it {
+ ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(),
+ ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(),
+ })
+ .unwrap_or(false);
+ if !has_at_least_one_field {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ "`CoercePointee` can only be derived on `struct`s with at least one field",
+ ),
+ );
+ }
+ let is_repr_transparent = strukt.attrs().any(|attr| {
+ attr.as_simple_call().is_some_and(|(name, tt)| {
+ name == "repr"
+ && tt.syntax().children_with_tokens().any(|it| {
+ it.into_token().is_some_and(|it| {
+ it.kind() == SyntaxKind::IDENT && it.text() == "transparent"
+ })
+ })
+ })
+ });
+ if !is_repr_transparent {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ "`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`",
+ ),
+ );
+ }
+ let type_params = strukt
+ .generic_param_list()
+ .into_iter()
+ .flat_map(|generics| {
+ generics.generic_params().filter_map(|param| match param {
+ ast::GenericParam::TypeParam(param) => Some(param),
+ _ => None,
+ })
+ })
+ .collect_vec();
+ if type_params.is_empty() {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ "`CoercePointee` can only be derived on `struct`s that are generic over at least one type",
+ ),
+ );
+ }
+ let (pointee_param, pointee_param_idx) = if type_params.len() == 1 {
+ // Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such.
+ (type_params[0].clone(), 0)
+ } else {
+ let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| {
+ param.attrs().any(|attr| {
+ let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee");
+ if is_pointee {
+ // Remove the `#[pointee]` attribute so it won't be present in the generated
+ // impls (where we cannot resolve it).
+ ted::remove(attr.syntax());
+ }
+ is_pointee
+ })
+ });
+ match (pointees.next(), pointees.next()) {
+ (Some((pointee_idx, pointee)), None) => (pointee, pointee_idx),
+ (None, _) => {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ "exactly one generic type parameter must be marked \
+ as `#[pointee]` to derive `CoercePointee` traits",
+ ),
+ )
+ }
+ (Some(_), Some(_)) => {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ "only one type parameter can be marked as `#[pointee]` \
+ when deriving `CoercePointee` traits",
+ ),
+ )
+ }
+ }
+ };
+ let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name())
+ else {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(span, "invalid item"),
+ );
+ };
+
+ {
+ let mut pointee_has_maybe_sized_bound = false;
+ if let Some(bounds) = pointee_param.type_bound_list() {
+ pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound);
+ }
+ if let Some(where_clause) = strukt.where_clause() {
+ pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| {
+ let Some(ast::Type::PathType(ty)) = pred.ty() else { return false };
+ let is_not_pointee = ty.path().is_none_or(|path| {
+ let is_pointee = path
+ .as_single_name_ref()
+ .is_some_and(|name| name.text() == pointee_param_name.text());
+ !is_pointee
+ });
+ if is_not_pointee {
+ return false;
+ }
+ pred.type_bound_list()
+ .is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound))
+ })
+ }
+ if !pointee_has_maybe_sized_bound {
+ return ExpandResult::new(
+ tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
+ ExpandError::other(
+ span,
+ format!("`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"),
+ ),
+ );
+ }
+ }
+
+ const ADDED_PARAM: &str = "__S";
+
+ let where_clause = strukt.get_or_create_where_clause();
+
+ {
+ let mut new_predicates = Vec::new();
+
+ // # Rewrite generic parameter bounds
+ // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
+ // Example:
+ // ```
+ // struct<
+ // U: Trait<T>,
+ // #[pointee] T: Trait<T> + ?Sized,
+ // V: Trait<T>> ...
+ // ```
+ // ... generates this `impl` generic parameters
+ // ```
+ // impl<
+ // U: Trait<T>,
+ // T: Trait<T> + ?Sized,
+ // V: Trait<T>
+ // >
+ // where
+ // U: Trait<__S>,
+ // __S: Trait<__S> + ?Sized,
+ // V: Trait<__S> ...
+ // ```
+ for param in &type_params {
+ let Some(param_name) = param.name() else { continue };
+ if let Some(bounds) = param.type_bound_list() {
+ // If the target type is the pointee, duplicate the bound as whole.
+ // Otherwise, duplicate only bounds that mention the pointee.
+ let is_pointee = param_name.text() == pointee_param_name.text();
+ let new_bounds = bounds
+ .bounds()
+ .map(|bound| bound.clone_subtree().clone_for_update())
+ .filter(|bound| {
+ bound.ty().is_some_and(|ty| {
+ substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
+ || is_pointee
+ })
+ });
+ let new_bounds_target = if is_pointee {
+ make::name_ref(ADDED_PARAM)
+ } else {
+ make::name_ref(&param_name.text())
+ };
+ new_predicates.push(
+ make::where_pred(
+ make::ty_path(make::path_from_segments(
+ [make::path_segment(new_bounds_target)],
+ false,
+ )),
+ new_bounds,
+ )
+ .clone_for_update(),
+ );
+ }
+ }
+
+ // # Rewrite `where` clauses
+ //
+ // Move on to `where` clauses.
+ // Example:
+ // ```
+ // struct MyPointer<#[pointee] T, ..>
+ // where
+ // U: Trait<V> + Trait<T>,
+ // Companion<T>: Trait<T>,
+ // T: Trait<T> + ?Sized,
+ // { .. }
+ // ```
+ // ... will have a impl prelude like so
+ // ```
+ // impl<..> ..
+ // where
+ // U: Trait<V> + Trait<T>,
+ // U: Trait<__S>,
+ // Companion<T>: Trait<T>,
+ // Companion<__S>: Trait<__S>,
+ // T: Trait<T> + ?Sized,
+ // __S: Trait<__S> + ?Sized,
+ // ```
+ //
+ // We should also write a few new `where` bounds from `#[pointee] T` to `__S`
+ // as well as any bound that indirectly involves the `#[pointee] T` type.
+ for predicate in where_clause.predicates() {
+ let predicate = predicate.clone_subtree().clone_for_update();
+ let Some(pred_target) = predicate.ty() else { continue };
+
+ // If the target type references the pointee, duplicate the bound as whole.
+ // Otherwise, duplicate only bounds that mention the pointee.
+ if substitute_type_in_bound(
+ pred_target.clone(),
+ &pointee_param_name.text(),
+ ADDED_PARAM,
+ ) {
+ if let Some(bounds) = predicate.type_bound_list() {
+ for bound in bounds.bounds() {
+ if let Some(ty) = bound.ty() {
+ substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM);
+ }
+ }
+ }
+
+ new_predicates.push(predicate);
+ } else if let Some(bounds) = predicate.type_bound_list() {
+ let new_bounds = bounds
+ .bounds()
+ .map(|bound| bound.clone_subtree().clone_for_update())
+ .filter(|bound| {
+ bound.ty().is_some_and(|ty| {
+ substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
+ })
+ });
+ new_predicates.push(make::where_pred(pred_target, new_bounds).clone_for_update());
+ }
+ }
+
+ for new_predicate in new_predicates {
+ where_clause.add_predicate(new_predicate);
+ }
+ }
+
+ {
+ // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
+ //
+ // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
+ where_clause.add_predicate(
+ make::where_pred(
+ make::ty_path(make::path_from_segments(
+ [make::path_segment(make::name_ref(&pointee_param_name.text()))],
+ false,
+ )),
+ [make::type_bound(make::ty_path(make::path_from_segments(
+ [
+ make::path_segment(make::name_ref("core")),
+ make::path_segment(make::name_ref("marker")),
+ make::generic_ty_path_segment(
+ make::name_ref("Unsize"),
+ [make::type_arg(make::ty_path(make::path_from_segments(
+ [make::path_segment(make::name_ref(ADDED_PARAM))],
+ false,
+ )))
+ .into()],
+ ),
+ ],
+ true,
+ )))],
+ )
+ .clone_for_update(),
+ );
+ }
+
+ let self_for_traits = {
+ // Replace the `#[pointee]` with `__S`.
+ let mut type_param_idx = 0;
+ let self_params_for_traits = strukt
+ .generic_param_list()
+ .into_iter()
+ .flat_map(|params| params.generic_params())
+ .filter_map(|param| {
+ Some(match param {
+ ast::GenericParam::ConstParam(param) => {
+ ast::GenericArg::ConstArg(make::expr_const_value(&param.name()?.text()))
+ }
+ ast::GenericParam::LifetimeParam(param) => {
+ make::lifetime_arg(param.lifetime()?).into()
+ }
+ ast::GenericParam::TypeParam(param) => {
+ let name = if pointee_param_idx == type_param_idx {
+ make::name_ref(ADDED_PARAM)
+ } else {
+ make::name_ref(&param.name()?.text())
+ };
+ type_param_idx += 1;
+ make::type_arg(make::ty_path(make::path_from_segments(
+ [make::path_segment(name)],
+ false,
+ )))
+ .into()
+ }
+ })
+ });
+ let self_for_traits = make::path_from_segments(
+ [make::generic_ty_path_segment(
+ make::name_ref(&struct_name.text()),
+ self_params_for_traits,
+ )],
+ false,
+ )
+ .clone_for_update();
+ self_for_traits
+ };
+
+ let mut span_map = span::SpanMap::empty();
+ // One span for them all.
+ span_map.push(adt.syntax().text_range().end(), span);
+
+ let self_for_traits = syntax_bridge::syntax_node_to_token_tree(
+ self_for_traits.syntax(),
+ &span_map,
+ span,
+ DocCommentDesugarMode::ProcMacro,
+ );
+ let info = match parse_adt_from_syntax(&adt, &span_map, span) {
+ Ok(it) => it,
+ Err(err) => {
+ return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err)
+ }
+ };
+
+ let self_for_traits2 = self_for_traits.clone();
+ let krate = dollar_crate(span);
+ let krate2 = krate.clone();
+ let dispatch_from_dyn = expand_simple_derive_with_parsed(
+ span,
+ info.clone(),
+ quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> },
+ |_adt| quote! {span => },
+ false,
+ quote! {span => __S },
+ );
+ let coerce_unsized = expand_simple_derive_with_parsed(
+ span,
+ info,
+ quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> },
+ |_adt| quote! {span => },
+ false,
+ quote! {span => __S },
+ );
+ return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized });
+
+ fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool {
+ if bound.question_mark_token().is_none() {
+ return false;
+ }
+ let Some(ast::Type::PathType(ty)) = bound.ty() else {
+ return false;
+ };
+ let Some(path) = ty.path() else {
+ return false;
+ };
+ return segments_eq(&path, &["Sized"])
+ || segments_eq(&path, &["core", "marker", "Sized"])
+ || segments_eq(&path, &["std", "marker", "Sized"]);
+
+ fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool {
+ path.segments().zip_longest(expected.iter().copied()).all(|value| {
+ value.both().is_some_and(|(segment, expected)| {
+ segment.name_ref().is_some_and(|name| name.text() == expected)
+ })
+ })
+ }
+ }
+
+ /// Returns true if any substitution was performed.
+ fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool {
+ return match ty {
+ ast::Type::ArrayType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement),
+ ast::Type::FnPtrType(ty) => any_long(
+ ty.param_list()
+ .into_iter()
+ .flat_map(|params| params.params().filter_map(|param| param.ty()))
+ .chain(ty.ret_type().and_then(|it| it.ty())),
+ |ty| substitute_type_in_bound(ty, param_name, replacement),
+ ),
+ ast::Type::ForType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::ImplTraitType(ty) => {
+ go_bounds(ty.type_bound_list(), param_name, replacement)
+ }
+ ast::Type::ParenType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::PathType(ty) => ty.path().is_some_and(|path| {
+ if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) {
+ ted::replace(
+ path.syntax(),
+ make::path_from_segments(
+ [make::path_segment(make::name_ref(replacement))],
+ false,
+ )
+ .clone_for_update()
+ .syntax(),
+ );
+ return true;
+ }
+
+ any_long(
+ path.segments()
+ .filter_map(|segment| segment.generic_arg_list())
+ .flat_map(|it| it.generic_args())
+ .filter_map(|generic_arg| match generic_arg {
+ ast::GenericArg::TypeArg(ty) => ty.ty(),
+ _ => None,
+ }),
+ |ty| substitute_type_in_bound(ty, param_name, replacement),
+ )
+ }),
+ ast::Type::PtrType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::RefType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::SliceType(ty) => {
+ ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::TupleType(ty) => {
+ any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement))
+ }
+ ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false,
+ };
+
+ fn go_bounds(
+ bounds: Option<ast::TypeBoundList>,
+ param_name: &str,
+ replacement: &str,
+ ) -> bool {
+ bounds.is_some_and(|bounds| {
+ any_long(bounds.bounds(), |bound| {
+ bound
+ .ty()
+ .is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
+ })
+ })
+ }
+
+ /// Like [`Iterator::any()`], but not short-circuiting.
+ fn any_long<I: Iterator, F: FnMut(I::Item) -> bool>(iter: I, mut f: F) -> bool {
+ let mut result = false;
+ iter.for_each(|item| result |= f(item));
+ result
+ }
+ }
+}