Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/next_solver/generic_arg.rs')
| -rw-r--r-- | crates/hir-ty/src/next_solver/generic_arg.rs | 577 |
1 files changed, 577 insertions, 0 deletions
diff --git a/crates/hir-ty/src/next_solver/generic_arg.rs b/crates/hir-ty/src/next_solver/generic_arg.rs new file mode 100644 index 0000000000..90bd44aee8 --- /dev/null +++ b/crates/hir-ty/src/next_solver/generic_arg.rs @@ -0,0 +1,577 @@ +//! Things related to generic args in the next-trait-solver. + +use hir_def::{GenericDefId, GenericParamId}; +use macros::{TypeFoldable, TypeVisitable}; +use rustc_type_ir::{ + ClosureArgs, CollectAndApply, ConstVid, CoroutineArgs, CoroutineClosureArgs, FnSigTys, + GenericArgKind, Interner, TermKind, TyKind, TyVid, Variance, + inherent::{GenericArg as _, GenericsOf, IntoKind, SliceLike, Term as _, Ty as _}, + relate::{Relate, VarianceDiagInfo}, +}; +use smallvec::SmallVec; + +use crate::next_solver::{PolyFnSig, interned_vec_db}; + +use super::{ + Const, DbInterner, EarlyParamRegion, ErrorGuaranteed, ParamConst, Region, SolverDefId, Ty, Tys, + generics::Generics, +}; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)] +pub enum GenericArg<'db> { + Ty(Ty<'db>), + Lifetime(Region<'db>), + Const(Const<'db>), +} + +impl<'db> std::fmt::Debug for GenericArg<'db> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ty(t) => std::fmt::Debug::fmt(t, f), + Self::Lifetime(r) => std::fmt::Debug::fmt(r, f), + Self::Const(c) => std::fmt::Debug::fmt(c, f), + } + } +} + +impl<'db> GenericArg<'db> { + pub fn ty(self) -> Option<Ty<'db>> { + match self.kind() { + GenericArgKind::Type(ty) => Some(ty), + _ => None, + } + } + + pub fn expect_ty(self) -> Ty<'db> { + match self.kind() { + GenericArgKind::Type(ty) => ty, + _ => panic!("Expected ty, got {self:?}"), + } + } + + pub fn konst(self) -> Option<Const<'db>> { + match self.kind() { + GenericArgKind::Const(konst) => Some(konst), + _ => None, + } + } + + pub fn region(self) -> Option<Region<'db>> { + match self.kind() { + GenericArgKind::Lifetime(r) => Some(r), + _ => None, + } + } + + pub fn error_from_id(interner: DbInterner<'db>, id: GenericParamId) -> GenericArg<'db> { + match id { + GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(), + GenericParamId::ConstParamId(_) => Const::error(interner).into(), + GenericParamId::LifetimeParamId(_) => Region::error(interner).into(), + } + } +} + +impl<'db> From<Term<'db>> for GenericArg<'db> { + fn from(value: Term<'db>) -> Self { + match value { + Term::Ty(ty) => GenericArg::Ty(ty), + Term::Const(c) => GenericArg::Const(c), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, TypeVisitable, TypeFoldable)] +pub enum Term<'db> { + Ty(Ty<'db>), + Const(Const<'db>), +} + +impl<'db> std::fmt::Debug for Term<'db> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ty(t) => std::fmt::Debug::fmt(t, f), + Self::Const(c) => std::fmt::Debug::fmt(c, f), + } + } +} + +impl<'db> Term<'db> { + pub fn expect_type(&self) -> Ty<'db> { + self.as_type().expect("expected a type, but found a const") + } + + pub fn is_trivially_wf(&self, tcx: DbInterner<'db>) -> bool { + match self.kind() { + TermKind::Ty(ty) => ty.is_trivially_wf(tcx), + TermKind::Const(ct) => ct.is_trivially_wf(), + } + } +} + +impl<'db> From<Ty<'db>> for GenericArg<'db> { + fn from(value: Ty<'db>) -> Self { + Self::Ty(value) + } +} + +impl<'db> From<Region<'db>> for GenericArg<'db> { + fn from(value: Region<'db>) -> Self { + Self::Lifetime(value) + } +} + +impl<'db> From<Const<'db>> for GenericArg<'db> { + fn from(value: Const<'db>) -> Self { + Self::Const(value) + } +} + +impl<'db> IntoKind for GenericArg<'db> { + type Kind = GenericArgKind<DbInterner<'db>>; + + fn kind(self) -> Self::Kind { + match self { + GenericArg::Ty(ty) => GenericArgKind::Type(ty), + GenericArg::Lifetime(region) => GenericArgKind::Lifetime(region), + GenericArg::Const(c) => GenericArgKind::Const(c), + } + } +} + +impl<'db> Relate<DbInterner<'db>> for GenericArg<'db> { + fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>( + relation: &mut R, + a: Self, + b: Self, + ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> { + match (a.kind(), b.kind()) { + (GenericArgKind::Lifetime(a_lt), GenericArgKind::Lifetime(b_lt)) => { + Ok(relation.relate(a_lt, b_lt)?.into()) + } + (GenericArgKind::Type(a_ty), GenericArgKind::Type(b_ty)) => { + Ok(relation.relate(a_ty, b_ty)?.into()) + } + (GenericArgKind::Const(a_ct), GenericArgKind::Const(b_ct)) => { + Ok(relation.relate(a_ct, b_ct)?.into()) + } + (GenericArgKind::Lifetime(unpacked), x) => { + unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x) + } + (GenericArgKind::Type(unpacked), x) => { + unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x) + } + (GenericArgKind::Const(unpacked), x) => { + unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x) + } + } + } +} + +interned_vec_db!(GenericArgs, GenericArg); + +impl<'db> rustc_type_ir::inherent::GenericArg<DbInterner<'db>> for GenericArg<'db> {} + +impl<'db> GenericArgs<'db> { + /// Creates an `GenericArgs` for generic parameter definitions, + /// by calling closures to obtain each kind. + /// The closures get to observe the `GenericArgs` as they're + /// being built, which can be used to correctly + /// replace defaults of generic parameters. + pub fn for_item<F>( + interner: DbInterner<'db>, + def_id: SolverDefId, + mut mk_kind: F, + ) -> GenericArgs<'db> + where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + let defs = interner.generics_of(def_id); + let count = defs.count(); + let mut args = SmallVec::with_capacity(count); + Self::fill_item(&mut args, interner, defs, &mut mk_kind); + interner.mk_args(&args) + } + + /// Creates an all-error `GenericArgs`. + pub fn error_for_item(interner: DbInterner<'db>, def_id: SolverDefId) -> GenericArgs<'db> { + GenericArgs::for_item(interner, def_id, |_, id, _| GenericArg::error_from_id(interner, id)) + } + + /// Like `for_item`, but prefers the default of a parameter if it has any. + pub fn for_item_with_defaults<F>( + interner: DbInterner<'db>, + def_id: GenericDefId, + mut fallback: F, + ) -> GenericArgs<'db> + where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + let defaults = interner.db.generic_defaults(def_id); + Self::for_item(interner, def_id.into(), |idx, id, prev| match defaults.get(idx as usize) { + Some(default) => default.instantiate(interner, prev), + None => fallback(idx, id, prev), + }) + } + + /// Like `for_item()`, but calls first uses the args from `first`. + pub fn fill_rest<F>( + interner: DbInterner<'db>, + def_id: SolverDefId, + first: impl IntoIterator<Item = GenericArg<'db>>, + mut fallback: F, + ) -> GenericArgs<'db> + where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + let mut iter = first.into_iter(); + Self::for_item(interner, def_id, |idx, id, prev| { + iter.next().unwrap_or_else(|| fallback(idx, id, prev)) + }) + } + + /// Appends default param values to `first` if needed. Params without default will call `fallback()`. + pub fn fill_with_defaults<F>( + interner: DbInterner<'db>, + def_id: GenericDefId, + first: impl IntoIterator<Item = GenericArg<'db>>, + mut fallback: F, + ) -> GenericArgs<'db> + where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + let defaults = interner.db.generic_defaults(def_id); + Self::fill_rest(interner, def_id.into(), first, |idx, id, prev| { + defaults + .get(idx as usize) + .map(|default| default.instantiate(interner, prev)) + .unwrap_or_else(|| fallback(idx, id, prev)) + }) + } + + fn fill_item<F>( + args: &mut SmallVec<[GenericArg<'db>; 8]>, + interner: DbInterner<'_>, + defs: Generics, + mk_kind: &mut F, + ) where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + if let Some(def_id) = defs.parent { + let parent_defs = interner.generics_of(def_id.into()); + Self::fill_item(args, interner, parent_defs, mk_kind); + } + Self::fill_single(args, &defs, mk_kind); + } + + fn fill_single<F>(args: &mut SmallVec<[GenericArg<'db>; 8]>, defs: &Generics, mk_kind: &mut F) + where + F: FnMut(u32, GenericParamId, &[GenericArg<'db>]) -> GenericArg<'db>, + { + args.reserve(defs.own_params.len()); + for param in &defs.own_params { + let kind = mk_kind(args.len() as u32, param.id, args); + args.push(kind); + } + } + + pub fn closure_sig_untupled(self) -> PolyFnSig<'db> { + let TyKind::FnPtr(inputs_and_output, hdr) = + self.split_closure_args_untupled().closure_sig_as_fn_ptr_ty.kind() + else { + unreachable!("not a function pointer") + }; + inputs_and_output.with(hdr) + } + + /// A "sensible" `.split_closure_args()`, where the arguments are not in a tuple. + pub fn split_closure_args_untupled(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> { + // FIXME: should use `ClosureSubst` when possible + match self.inner().as_slice() { + [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => { + let interner = DbInterner::conjure(); + rustc_type_ir::ClosureArgsParts { + parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), + closure_sig_as_fn_ptr_ty: sig_ty.expect_ty(), + closure_kind_ty: closure_kind_ty.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + } + } + _ => { + unreachable!("unexpected closure sig"); + } + } + } + + pub fn types(self) -> impl Iterator<Item = Ty<'db>> { + self.iter().filter_map(|it| it.as_type()) + } + + pub fn consts(self) -> impl Iterator<Item = Const<'db>> { + self.iter().filter_map(|it| it.as_const()) + } + + pub fn regions(self) -> impl Iterator<Item = Region<'db>> { + self.iter().filter_map(|it| it.as_region()) + } +} + +impl<'db> rustc_type_ir::relate::Relate<DbInterner<'db>> for GenericArgs<'db> { + fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>( + relation: &mut R, + a: Self, + b: Self, + ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> { + let interner = relation.cx(); + CollectAndApply::collect_and_apply( + std::iter::zip(a.iter(), b.iter()).map(|(a, b)| { + relation.relate_with_variance( + Variance::Invariant, + VarianceDiagInfo::default(), + a, + b, + ) + }), + |g| GenericArgs::new_from_iter(interner, g.iter().cloned()), + ) + } +} + +impl<'db> rustc_type_ir::inherent::GenericArgs<DbInterner<'db>> for GenericArgs<'db> { + fn as_closure(self) -> ClosureArgs<DbInterner<'db>> { + ClosureArgs { args: self } + } + fn as_coroutine(self) -> CoroutineArgs<DbInterner<'db>> { + CoroutineArgs { args: self } + } + fn as_coroutine_closure(self) -> CoroutineClosureArgs<DbInterner<'db>> { + CoroutineClosureArgs { args: self } + } + fn rebase_onto( + self, + interner: DbInterner<'db>, + source_def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId, + target: <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs, + ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs { + let defs = interner.generics_of(source_def_id); + interner.mk_args_from_iter(target.iter().chain(self.iter().skip(defs.count()))) + } + + fn identity_for_item( + interner: DbInterner<'db>, + def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId, + ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs { + Self::for_item(interner, def_id, |index, kind, _| mk_param(interner, index, kind)) + } + + fn extend_with_error( + interner: DbInterner<'db>, + def_id: <DbInterner<'db> as rustc_type_ir::Interner>::DefId, + original_args: &[<DbInterner<'db> as rustc_type_ir::Interner>::GenericArg], + ) -> <DbInterner<'db> as rustc_type_ir::Interner>::GenericArgs { + Self::for_item(interner, def_id, |index, kind, _| { + if let Some(arg) = original_args.get(index as usize) { + *arg + } else { + error_for_param_kind(kind, interner) + } + }) + } + fn type_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Ty { + self.inner() + .get(i) + .and_then(|g| g.as_type()) + .unwrap_or_else(|| Ty::new_error(DbInterner::conjure(), ErrorGuaranteed)) + } + + fn region_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Region { + self.inner() + .get(i) + .and_then(|g| g.as_region()) + .unwrap_or_else(|| Region::error(DbInterner::conjure())) + } + + fn const_at(self, i: usize) -> <DbInterner<'db> as rustc_type_ir::Interner>::Const { + self.inner() + .get(i) + .and_then(|g| g.as_const()) + .unwrap_or_else(|| Const::error(DbInterner::conjure())) + } + + fn split_closure_args(self) -> rustc_type_ir::ClosureArgsParts<DbInterner<'db>> { + // FIXME: should use `ClosureSubst` when possible + match self.inner().as_slice() { + [parent_args @ .., closure_kind_ty, sig_ty, tupled_upvars_ty] => { + let interner = DbInterner::conjure(); + // This is stupid, but the next solver expects the first input to actually be a tuple + let sig_ty = match sig_ty.expect_ty().kind() { + TyKind::FnPtr(sig_tys, header) => Ty::new( + interner, + TyKind::FnPtr( + sig_tys.map_bound(|s| { + let inputs = Ty::new_tup_from_iter(interner, s.inputs().iter()); + let output = s.output(); + FnSigTys { + inputs_and_output: Tys::new_from_iter( + interner, + [inputs, output], + ), + } + }), + header, + ), + ), + _ => unreachable!("sig_ty should be last"), + }; + rustc_type_ir::ClosureArgsParts { + parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), + closure_sig_as_fn_ptr_ty: sig_ty, + closure_kind_ty: closure_kind_ty.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + } + } + _ => { + unreachable!("unexpected closure sig"); + } + } + } + + fn split_coroutine_closure_args( + self, + ) -> rustc_type_ir::CoroutineClosureArgsParts<DbInterner<'db>> { + match self.inner().as_slice() { + [ + parent_args @ .., + closure_kind_ty, + signature_parts_ty, + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + ] => rustc_type_ir::CoroutineClosureArgsParts { + parent_args: GenericArgs::new_from_iter( + DbInterner::conjure(), + parent_args.iter().cloned(), + ), + closure_kind_ty: closure_kind_ty.expect_ty(), + signature_parts_ty: signature_parts_ty.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + coroutine_captures_by_ref_ty: coroutine_captures_by_ref_ty.expect_ty(), + }, + _ => panic!("GenericArgs were likely not for a CoroutineClosure."), + } + } + + fn split_coroutine_args(self) -> rustc_type_ir::CoroutineArgsParts<DbInterner<'db>> { + let interner = DbInterner::conjure(); + match self.inner().as_slice() { + [parent_args @ .., kind_ty, resume_ty, yield_ty, return_ty, tupled_upvars_ty] => { + rustc_type_ir::CoroutineArgsParts { + parent_args: GenericArgs::new_from_iter(interner, parent_args.iter().cloned()), + kind_ty: kind_ty.expect_ty(), + resume_ty: resume_ty.expect_ty(), + yield_ty: yield_ty.expect_ty(), + return_ty: return_ty.expect_ty(), + tupled_upvars_ty: tupled_upvars_ty.expect_ty(), + } + } + _ => panic!("GenericArgs were likely not for a Coroutine."), + } + } +} + +pub fn mk_param<'db>(interner: DbInterner<'db>, index: u32, id: GenericParamId) -> GenericArg<'db> { + match id { + GenericParamId::LifetimeParamId(id) => { + Region::new_early_param(interner, EarlyParamRegion { index, id }).into() + } + GenericParamId::TypeParamId(id) => Ty::new_param(interner, id, index).into(), + GenericParamId::ConstParamId(id) => { + Const::new_param(interner, ParamConst { index, id }).into() + } + } +} + +pub fn error_for_param_kind<'db>(id: GenericParamId, interner: DbInterner<'db>) -> GenericArg<'db> { + match id { + GenericParamId::LifetimeParamId(_) => Region::error(interner).into(), + GenericParamId::TypeParamId(_) => Ty::new_error(interner, ErrorGuaranteed).into(), + GenericParamId::ConstParamId(_) => Const::error(interner).into(), + } +} + +impl<'db> IntoKind for Term<'db> { + type Kind = TermKind<DbInterner<'db>>; + + fn kind(self) -> Self::Kind { + match self { + Term::Ty(ty) => TermKind::Ty(ty), + Term::Const(c) => TermKind::Const(c), + } + } +} + +impl<'db> From<Ty<'db>> for Term<'db> { + fn from(value: Ty<'db>) -> Self { + Self::Ty(value) + } +} + +impl<'db> From<Const<'db>> for Term<'db> { + fn from(value: Const<'db>) -> Self { + Self::Const(value) + } +} + +impl<'db> Relate<DbInterner<'db>> for Term<'db> { + fn relate<R: rustc_type_ir::relate::TypeRelation<DbInterner<'db>>>( + relation: &mut R, + a: Self, + b: Self, + ) -> rustc_type_ir::relate::RelateResult<DbInterner<'db>, Self> { + match (a.kind(), b.kind()) { + (TermKind::Ty(a_ty), TermKind::Ty(b_ty)) => Ok(relation.relate(a_ty, b_ty)?.into()), + (TermKind::Const(a_ct), TermKind::Const(b_ct)) => { + Ok(relation.relate(a_ct, b_ct)?.into()) + } + (TermKind::Ty(unpacked), x) => { + unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x) + } + (TermKind::Const(unpacked), x) => { + unreachable!("impossible case reached: can't relate: {:?} with {:?}", unpacked, x) + } + } + } +} + +impl<'db> rustc_type_ir::inherent::Term<DbInterner<'db>> for Term<'db> {} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub enum TermVid { + Ty(TyVid), + Const(ConstVid), +} + +impl From<TyVid> for TermVid { + fn from(value: TyVid) -> Self { + TermVid::Ty(value) + } +} + +impl From<ConstVid> for TermVid { + fn from(value: ConstVid) -> Self { + TermVid::Const(value) + } +} + +impl<'db> DbInterner<'db> { + pub(super) fn mk_args(self, args: &[GenericArg<'db>]) -> GenericArgs<'db> { + GenericArgs::new_from_iter(self, args.iter().cloned()) + } + + pub(super) fn mk_args_from_iter<I, T>(self, iter: I) -> T::Output + where + I: Iterator<Item = T>, + T: rustc_type_ir::CollectAndApply<GenericArg<'db>, GenericArgs<'db>>, + { + T::collect_and_apply(iter, |xs| self.mk_args(xs)) + } +} |