Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/next_solver/generic_arg.rs')
| -rw-r--r-- | crates/hir-ty/src/next_solver/generic_arg.rs | 453 |
1 files changed, 339 insertions, 114 deletions
diff --git a/crates/hir-ty/src/next_solver/generic_arg.rs b/crates/hir-ty/src/next_solver/generic_arg.rs index dedd6a1a6d..72cf2f9f07 100644 --- a/crates/hir-ty/src/next_solver/generic_arg.rs +++ b/crates/hir-ty/src/next_solver/generic_arg.rs @@ -1,40 +1,225 @@ -//! Things related to generic args in the next-trait-solver. +//! Things related to generic args in the next-trait-solver (`GenericArg`, `GenericArgs`, `Term`). +//! +//! Implementations of `GenericArg` and `Term` are pointer-tagged instead of an enum (rustc does +//! the same). This is done to save memory (which also helps speed) - one `GenericArg` is a machine +//! word instead of two, while matching on it is basically as cheap. The implementation for both +//! `GenericArg` and `Term` is shared in [`GenericArgImpl`]. This both simplifies the implementation, +//! as well as enables a noop conversion from `Term` to `GenericArg`. + +use std::{hint::unreachable_unchecked, marker::PhantomData, ptr::NonNull}; use hir_def::{GenericDefId, GenericParamId}; -use macros::{TypeFoldable, TypeVisitable}; +use intern::InternedRef; use rustc_type_ir::{ - ClosureArgs, CollectAndApply, ConstVid, CoroutineArgs, CoroutineClosureArgs, FnSigTys, - GenericArgKind, Interner, TermKind, TyKind, TyVid, Variance, + ClosureArgs, ConstVid, CoroutineArgs, CoroutineClosureArgs, FallibleTypeFolder, + GenericTypeVisitable, Interner, TyVid, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor, + Variance, inherent::{GenericArg as _, GenericsOf, IntoKind, SliceLike, Term as _, Ty as _}, relate::{Relate, VarianceDiagInfo}, + walk::TypeWalker, }; use smallvec::SmallVec; -use crate::next_solver::{PolyFnSig, interned_vec_db}; +use crate::next_solver::{ + ConstInterned, RegionInterned, TyInterned, impl_foldable_for_interned_slice, interned_slice, +}; use super::{ - Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, Tys, + Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, generics::Generics, }; -#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)] -pub enum GenericArg<'db> { - Ty(Ty<'db>), - Lifetime(Region<'db>), - Const(Const<'db>), +pub type GenericArgKind<'db> = rustc_type_ir::GenericArgKind<DbInterner<'db>>; +pub type TermKind<'db> = rustc_type_ir::TermKind<DbInterner<'db>>; + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +struct GenericArgImpl<'db> { + /// # Invariant + /// + /// Contains an [`InternedRef`] of a [`Ty`], [`Const`] or [`Region`], bit-tagged as per the consts below. + ptr: NonNull<()>, + _marker: PhantomData<(Ty<'db>, Const<'db>, Region<'db>)>, +} + +// SAFETY: We essentially own the `Ty`, `Const` or `Region`, and they are `Send + Sync`. +unsafe impl Send for GenericArgImpl<'_> {} +unsafe impl Sync for GenericArgImpl<'_> {} + +impl<'db> GenericArgImpl<'db> { + const KIND_MASK: usize = 0b11; + const PTR_MASK: usize = !Self::KIND_MASK; + const TY_TAG: usize = 0b00; + const CONST_TAG: usize = 0b01; + const REGION_TAG: usize = 0b10; + + #[inline] + fn new_ty(ty: Ty<'db>) -> Self { + Self { + // SAFETY: We create it from an `InternedRef`, and it's never null. + ptr: unsafe { + NonNull::new_unchecked( + ty.interned + .as_raw() + .cast::<()>() + .cast_mut() + .map_addr(|addr| addr | Self::TY_TAG), + ) + }, + _marker: PhantomData, + } + } + + #[inline] + fn new_const(ty: Const<'db>) -> Self { + Self { + // SAFETY: We create it from an `InternedRef`, and it's never null. + ptr: unsafe { + NonNull::new_unchecked( + ty.interned + .as_raw() + .cast::<()>() + .cast_mut() + .map_addr(|addr| addr | Self::CONST_TAG), + ) + }, + _marker: PhantomData, + } + } + + #[inline] + fn new_region(ty: Region<'db>) -> Self { + Self { + // SAFETY: We create it from an `InternedRef`, and it's never null. + ptr: unsafe { + NonNull::new_unchecked( + ty.interned + .as_raw() + .cast::<()>() + .cast_mut() + .map_addr(|addr| addr | Self::REGION_TAG), + ) + }, + _marker: PhantomData, + } + } + + #[inline] + fn kind(self) -> GenericArgKind<'db> { + let ptr = self.ptr.as_ptr().map_addr(|addr| addr & Self::PTR_MASK); + // SAFETY: We can only be created from a `Ty`, a `Const` or a `Region`, and the tag will match. + unsafe { + match self.ptr.addr().get() & Self::KIND_MASK { + Self::TY_TAG => GenericArgKind::Type(Ty { + interned: InternedRef::from_raw(ptr.cast::<TyInterned>()), + }), + Self::CONST_TAG => GenericArgKind::Const(Const { + interned: InternedRef::from_raw(ptr.cast::<ConstInterned>()), + }), + Self::REGION_TAG => GenericArgKind::Lifetime(Region { + interned: InternedRef::from_raw(ptr.cast::<RegionInterned>()), + }), + _ => unreachable_unchecked(), + } + } + } + + #[inline] + fn term_kind(self) -> TermKind<'db> { + let ptr = self.ptr.as_ptr().map_addr(|addr| addr & Self::PTR_MASK); + // SAFETY: We can only be created from a `Ty`, a `Const` or a `Region`, and the tag will match. + // It is the caller's responsibility (encapsulated within this module) to only call this with + // `Term`, which cannot be constructed from a `Region`. + unsafe { + match self.ptr.addr().get() & Self::KIND_MASK { + Self::TY_TAG => { + TermKind::Ty(Ty { interned: InternedRef::from_raw(ptr.cast::<TyInterned>()) }) + } + Self::CONST_TAG => TermKind::Const(Const { + interned: InternedRef::from_raw(ptr.cast::<ConstInterned>()), + }), + _ => unreachable_unchecked(), + } + } + } +} + +#[derive(PartialEq, Eq, Hash)] +pub struct StoredGenericArg { + ptr: GenericArgImpl<'static>, +} + +impl Clone for StoredGenericArg { + #[inline] + fn clone(&self) -> Self { + match self.ptr.kind() { + GenericArgKind::Lifetime(it) => std::mem::forget(it.interned.to_owned()), + GenericArgKind::Type(it) => std::mem::forget(it.interned.to_owned()), + GenericArgKind::Const(it) => std::mem::forget(it.interned.to_owned()), + } + Self { ptr: self.ptr } + } +} + +impl Drop for StoredGenericArg { + #[inline] + fn drop(&mut self) { + unsafe { + match self.ptr.kind() { + GenericArgKind::Lifetime(it) => it.interned.decrement_refcount(), + GenericArgKind::Type(it) => it.interned.decrement_refcount(), + GenericArgKind::Const(it) => it.interned.decrement_refcount(), + } + } + } +} + +impl StoredGenericArg { + #[inline] + fn new(value: GenericArg<'_>) -> Self { + let result = Self { ptr: GenericArgImpl { ptr: value.ptr.ptr, _marker: PhantomData } }; + // Increase refcount. + std::mem::forget(result.clone()); + result + } + + #[inline] + pub fn as_ref<'db>(&self) -> GenericArg<'db> { + GenericArg { ptr: self.ptr } + } +} + +impl std::fmt::Debug for StoredGenericArg { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_ref().fmt(f) + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct GenericArg<'db> { + ptr: GenericArgImpl<'db>, } impl<'db> std::fmt::Debug for GenericArg<'db> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Ty(t) => std::fmt::Debug::fmt(t, f), - Self::Lifetime(r) => std::fmt::Debug::fmt(r, f), - Self::Const(c) => std::fmt::Debug::fmt(c, f), + match self.kind() { + GenericArgKind::Type(t) => std::fmt::Debug::fmt(&t, f), + GenericArgKind::Lifetime(r) => std::fmt::Debug::fmt(&r, f), + GenericArgKind::Const(c) => std::fmt::Debug::fmt(&c, f), } } } impl<'db> GenericArg<'db> { + #[inline] + pub fn store(self) -> StoredGenericArg { + StoredGenericArg::new(self) + } + + #[inline] + pub fn kind(self) -> GenericArgKind<'db> { + self.ptr.kind() + } + pub fn ty(self) -> Option<Ty<'db>> { match self.kind() { GenericArgKind::Type(ty) => Some(ty), @@ -65,8 +250,8 @@ impl<'db> GenericArg<'db> { #[inline] pub(crate) fn expect_region(self) -> Region<'db> { - match self { - GenericArg::Lifetime(region) => region, + match self.kind() { + GenericArgKind::Lifetime(region) => region, _ => panic!("expected a region, got {self:?}"), } } @@ -78,33 +263,40 @@ impl<'db> GenericArg<'db> { GenericParamId::LifetimeParamId(_) => Region::error(interner).into(), } } + + #[inline] + pub fn walk(self) -> TypeWalker<DbInterner<'db>> { + TypeWalker::new(self) + } } impl<'db> From<Term<'db>> for GenericArg<'db> { + #[inline] fn from(value: Term<'db>) -> Self { - match value { - Term::Ty(ty) => GenericArg::Ty(ty), - Term::Const(c) => GenericArg::Const(c), - } + GenericArg { ptr: value.ptr } } } -#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)] -pub enum Term<'db> { - Ty(Ty<'db>), - Const(Const<'db>), +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Term<'db> { + ptr: GenericArgImpl<'db>, } impl<'db> std::fmt::Debug for Term<'db> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Ty(t) => std::fmt::Debug::fmt(t, f), - Self::Const(c) => std::fmt::Debug::fmt(c, f), + match self.kind() { + TermKind::Ty(t) => std::fmt::Debug::fmt(&t, f), + TermKind::Const(c) => std::fmt::Debug::fmt(&c, f), } } } impl<'db> Term<'db> { + #[inline] + pub fn kind(self) -> TermKind<'db> { + self.ptr.term_kind() + } + pub fn expect_type(&self) -> Ty<'db> { self.as_type().expect("expected a type, but found a const") } @@ -118,31 +310,108 @@ impl<'db> Term<'db> { } impl<'db> From<Ty<'db>> for GenericArg<'db> { + #[inline] fn from(value: Ty<'db>) -> Self { - Self::Ty(value) + GenericArg { ptr: GenericArgImpl::new_ty(value) } } } impl<'db> From<Region<'db>> for GenericArg<'db> { + #[inline] fn from(value: Region<'db>) -> Self { - Self::Lifetime(value) + GenericArg { ptr: GenericArgImpl::new_region(value) } } } impl<'db> From<Const<'db>> for GenericArg<'db> { + #[inline] fn from(value: Const<'db>) -> Self { - Self::Const(value) + GenericArg { ptr: GenericArgImpl::new_const(value) } } } impl<'db> IntoKind for GenericArg<'db> { - type Kind = GenericArgKind<DbInterner<'db>>; + type Kind = GenericArgKind<'db>; + #[inline] fn kind(self) -> Self::Kind { - match self { - GenericArg::Ty(ty) => GenericArgKind::Type(ty), - GenericArg::Lifetime(region) => GenericArgKind::Lifetime(region), - GenericArg::Const(c) => GenericArgKind::Const(c), + self.ptr.kind() + } +} + +impl<'db, V> GenericTypeVisitable<V> for GenericArg<'db> +where + GenericArgKind<'db>: GenericTypeVisitable<V>, +{ + fn generic_visit_with(&self, visitor: &mut V) { + self.kind().generic_visit_with(visitor); + } +} + +impl<'db, V> GenericTypeVisitable<V> for Term<'db> +where + TermKind<'db>: GenericTypeVisitable<V>, +{ + fn generic_visit_with(&self, visitor: &mut V) { + self.kind().generic_visit_with(visitor); + } +} + +impl<'db> TypeVisitable<DbInterner<'db>> for GenericArg<'db> { + fn visit_with<V: TypeVisitor<DbInterner<'db>>>(&self, visitor: &mut V) -> V::Result { + match self.kind() { + GenericArgKind::Lifetime(it) => it.visit_with(visitor), + GenericArgKind::Type(it) => it.visit_with(visitor), + GenericArgKind::Const(it) => it.visit_with(visitor), + } + } +} + +impl<'db> TypeVisitable<DbInterner<'db>> for Term<'db> { + fn visit_with<V: TypeVisitor<DbInterner<'db>>>(&self, visitor: &mut V) -> V::Result { + match self.kind() { + TermKind::Ty(it) => it.visit_with(visitor), + TermKind::Const(it) => it.visit_with(visitor), + } + } +} + +impl<'db> TypeFoldable<DbInterner<'db>> for GenericArg<'db> { + fn try_fold_with<F: FallibleTypeFolder<DbInterner<'db>>>( + self, + folder: &mut F, + ) -> Result<Self, F::Error> { + Ok(match self.kind() { + GenericArgKind::Lifetime(it) => it.try_fold_with(folder)?.into(), + GenericArgKind::Type(it) => it.try_fold_with(folder)?.into(), + GenericArgKind::Const(it) => it.try_fold_with(folder)?.into(), + }) + } + + fn fold_with<F: TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self { + match self.kind() { + GenericArgKind::Lifetime(it) => it.fold_with(folder).into(), + GenericArgKind::Type(it) => it.fold_with(folder).into(), + GenericArgKind::Const(it) => it.fold_with(folder).into(), + } + } +} + +impl<'db> TypeFoldable<DbInterner<'db>> for Term<'db> { + fn try_fold_with<F: FallibleTypeFolder<DbInterner<'db>>>( + self, + folder: &mut F, + ) -> Result<Self, F::Error> { + Ok(match self.kind() { + TermKind::Ty(it) => it.try_fold_with(folder)?.into(), + TermKind::Const(it) => it.try_fold_with(folder)?.into(), + }) + } + + fn fold_with<F: TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self { + match self.kind() { + TermKind::Ty(it) => it.fold_with(folder).into(), + TermKind::Const(it) => it.fold_with(folder).into(), } } } @@ -176,7 +445,15 @@ impl<'db> Relate<DbInterner<'db>> for GenericArg<'db> { } } -interned_vec_db!(GenericArgs, GenericArg); +interned_slice!( + GenericArgsStorage, + GenericArgs, + StoredGenericArgs, + generic_args, + GenericArg<'db>, + GenericArg<'static>, +); +impl_foldable_for_interned_slice!(GenericArgs); impl<'db> rustc_type_ir::inherent::GenericArg<DbInterner<'db>> for GenericArg<'db> {} @@ -196,6 +473,11 @@ impl<'db> GenericArgs<'db> { { let defs = interner.generics_of(def_id); let count = defs.count(); + + if count == 0 { + return Default::default(); + } + let mut args = SmallVec::with_capacity(count); Self::fill_item(&mut args, interner, defs, &mut mk_kind); interner.mk_args(&args) @@ -283,34 +565,6 @@ impl<'db> GenericArgs<'db> { } } - pub fn closure_sig_untupled(self) -> PolyFnSig<'db> { - let TyKind::FnPtr(inputs_and_output, hdr) = - self.split_closure_args_untupled().closure_sig_as_fn_ptr_ty.kind() - else { - unreachable!("not a function pointer") - }; - inputs_and_output.with(hdr) - } - - /// A "sensible" `.split_closure_args()`, where the arguments are not in a tuple. - pub fn split_closure_args_untupled(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> { - // FIXME: should use `ClosureSubst` when possible - match self.inner().as_slice() { - [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => { - let interner = DbInterner::conjure(); - rustc_type_ir::ClosureArgsParts { - parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), - closure_sig_as_fn_ptr_ty: sig_ty.expect_ty(), - closure_kind_ty: closure_kind_ty.expect_ty(), - tupled_upvars_ty: tupled_upvars_ty.expect_ty(), - } - } - _ => { - unreachable!("unexpected closure sig"); - } - } - } - pub fn types(self) -> impl Iterator<Item = Ty<'db>> { self.iter().filter_map(|it| it.as_type()) } @@ -330,8 +584,8 @@ impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> { a: Self, b: Self, ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> { - let interner = relation.cx(); - CollectAndApply::collect_and_apply( + GenericArgs::new_from_iter( + relation.cx(), std::iter::zip(a.iter(), b.iter()).map(|(a, b)| { relation.relate_with_variance( Variance::Invariant, @@ -340,7 +594,6 @@ impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> { b, ) }), - |g| GenericArgs::new_from_iter(interner, g.iter().cloned()), ) } } @@ -386,54 +639,30 @@ impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs< }) } fn type_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Ty { - self.inner() - .get(i) + self.get(i) .and_then(|g| g.as_type()) .unwrap_or_else(|| Ty::new_error(DbInterner::conjure(), ErrorGuaranteed)) } fn region_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Region { - self.inner() - .get(i) + self.get(i) .and_then(|g| g.as_region()) .unwrap_or_else(|| Region::error(DbInterner::conjure())) } fn const_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Const { - self.inner() - .get(i) + self.get(i) .and_then(|g| g.as_const()) .unwrap_or_else(|| Const::error(DbInterner::conjure())) } fn split_closure_args(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> { // FIXME: should use `ClosureSubst` when possible - match self.inner().as_slice() { + match self.as_slice() { [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => { - let interner = DbInterner::conjure(); - // This is stupid, but the next solver expects the first input to actually be a tuple - let sig_ty = match sig_ty.expect_ty().kind() { - TyKind::FnPtr(sig_tys, header) => Ty::new( - interner, - TyKind::FnPtr( - sig_tys.map_bound(|s| { - let inputs = Ty::new_tup_from_iter(interner, s.inputs().iter()); - let output = s.output(); - FnSigTys { - inputs_and_output: Tys::new_from_iter( - interner, - [inputs, output], - ), - } - }), - header, - ), - ), - _ => unreachable!("sig_ty should be last"), - }; rustc_type_ir::ClosureArgsParts { - parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), - closure_sig_as_fn_ptr_ty: sig_ty, + parent_args, + closure_sig_as_fn_ptr_ty: sig_ty.expect_ty(), closure_kind_ty: closure_kind_ty.expect_ty(), tupled_upvars_ty: tupled_upvars_ty.expect_ty(), } @@ -447,7 +676,7 @@ impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs< fn split_coroutine_closure_args( self, ) -> rustc_type_ir::CoroutineClosureArgsParts<DbInterner<'db>> { - match self.inner().as_slice() { + match self.as_slice() { [ parent_args @ .., closure_kind_ty, @@ -455,10 +684,7 @@ impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs< tupled_upvars_ty, coroutine_captures_by_ref_ty, ] => rustc_type_ir::CoroutineClosureArgsParts { - parent_args: GenericArgs::new_from_iter( - DbInterner::conjure(), - parent_args.iter().cloned(), - ), + parent_args, closure_kind_ty: closure_kind_ty.expect_ty(), signature_parts_ty: signature_parts_ty.expect_ty(), tupled_upvars_ty: tupled_upvars_ty.expect_ty(), @@ -469,11 +695,10 @@ impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs< } fn split_coroutine_args(self) -> rustc_type_ir::CoroutineArgsParts<DbInterner<'db>> { - let interner = DbInterner::conjure(); - match self.inner().as_slice() { + match self.as_slice() { [parent_args @ .., kind_ty, resume_ty, yield_ty, return_ty, tupled_upvars_ty] => { rustc_type_ir::CoroutineArgsParts { - parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), + parent_args, kind_ty: kind_ty.expect_ty(), resume_ty: resume_ty.expect_ty(), yield_ty: yield_ty.expect_ty(), @@ -507,25 +732,25 @@ pub fn error_for_param_kind<'db>(id: GenericParamId, interner: DbInterner<'db>) } impl<'db> IntoKind for Term<'db> { - type Kind = TermKind<DbInterner<'db>>; + type Kind = TermKind<'db>; + #[inline] fn kind(self) -> Self::Kind { - match self { - Term::Ty(ty) => TermKind::Ty(ty), - Term::Const(c) => TermKind::Const(c), - } + self.ptr.term_kind() } } impl<'db> From<Ty<'db>> for Term<'db> { + #[inline] fn from(value: Ty<'db>) -> Self { - Self::Ty(value) + Term { ptr: GenericArgImpl::new_ty(value) } } } impl<'db> From<Const<'db>> for Term<'db> { + #[inline] fn from(value: Const<'db>) -> Self { - Self::Const(value) + Term { ptr: GenericArgImpl::new_const(value) } } } @@ -572,7 +797,7 @@ impl From<ConstVid> for TermVid { impl<'db> DbInterner<'db> { pub(super) fn mk_args(self, args: &[GenericArg<'db>]) -> GenericArgs<'db> { - GenericArgs::new_from_iter(self, args.iter().cloned()) + GenericArgs::new_from_slice(args) } pub(super) fn mk_args_from_iter<I, T>(self, iter: I) -> T::Output |