Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/builtin_derive.rs')
| -rw-r--r-- | crates/hir-ty/src/builtin_derive.rs | 78 |
1 files changed, 51 insertions, 27 deletions
diff --git a/crates/hir-ty/src/builtin_derive.rs b/crates/hir-ty/src/builtin_derive.rs index 15d9634cfa..f3e67d01e5 100644 --- a/crates/hir-ty/src/builtin_derive.rs +++ b/crates/hir-ty/src/builtin_derive.rs @@ -20,16 +20,18 @@ use crate::{ GenericPredicates, db::HirDatabase, next_solver::{ - Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, StoredEarlyBinder, StoredTy, - TraitRef, Ty, TyKind, fold::fold_tys, generics::Generics, + Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, ParamEnv, StoredEarlyBinder, + StoredTy, TraitRef, Ty, TyKind, fold::fold_tys, generics::Generics, }, }; -fn fake_type_param(adt: AdtId) -> TypeParamId { +fn coerce_pointee_new_type_param(trait_id: TraitId) -> TypeParamId { // HACK: Fake the param. + // We cannot use a dummy param here, because it can leak into the IDE layer and that'll cause panics + // when e.g. trying to display it. So we use an existing param. TypeParamId::from_unchecked(TypeOrConstParamId { - parent: adt.into(), - local_id: la_arena::Idx::from_raw(la_arena::RawIdx::from_u32(u32::MAX)), + parent: trait_id.into(), + local_id: la_arena::Idx::from_raw(la_arena::RawIdx::from_u32(1)), }) } @@ -48,13 +50,35 @@ pub(crate) fn generics_of<'db>(interner: DbInterner<'db>, id: BuiltinDeriveImplI | BuiltinDeriveImplTrait::PartialEq => interner.generics_of(loc.adt.into()), BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => { let mut generics = interner.generics_of(loc.adt.into()); - generics.push_param(fake_type_param(loc.adt).into()); + let trait_id = loc + .trait_ + .get_id(interner.lang_items()) + .expect("we don't pass the impl to the solver if we can't resolve the trait"); + generics.push_param(coerce_pointee_new_type_param(trait_id).into()); generics } } } -pub(crate) fn impl_trait<'db>( +pub fn generic_params_count(db: &dyn HirDatabase, id: BuiltinDeriveImplId) -> usize { + let loc = id.loc(db); + let adt_params = GenericParams::new(db, loc.adt.into()); + let extra_params_count = match loc.trait_ { + BuiltinDeriveImplTrait::Copy + | BuiltinDeriveImplTrait::Clone + | BuiltinDeriveImplTrait::Default + | BuiltinDeriveImplTrait::Debug + | BuiltinDeriveImplTrait::Hash + | BuiltinDeriveImplTrait::Ord + | BuiltinDeriveImplTrait::PartialOrd + | BuiltinDeriveImplTrait::Eq + | BuiltinDeriveImplTrait::PartialEq => 0, + BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => 1, + }; + adt_params.len() + extra_params_count +} + +pub fn impl_trait<'db>( interner: DbInterner<'db>, id: BuiltinDeriveImplId, ) -> EarlyBinder<'db, TraitRef<'db>> { @@ -93,7 +117,7 @@ pub(crate) fn impl_trait<'db>( let args = GenericArgs::identity_for_item(interner, loc.adt.into()); let self_ty = Ty::new_adt(interner, loc.adt, args); let Some((pointee_param_idx, _, new_param_ty)) = - coerce_pointee_params(interner, loc, &generic_params) + coerce_pointee_params(interner, loc, &generic_params, trait_id) else { // Malformed derive. return EarlyBinder::bind(TraitRef::new( @@ -110,14 +134,15 @@ pub(crate) fn impl_trait<'db>( } #[salsa::tracked(returns(ref), unsafe(non_update_types))] -pub(crate) fn builtin_derive_predicates<'db>( - db: &'db dyn HirDatabase, - impl_: BuiltinDeriveImplId, -) -> GenericPredicates { +pub fn predicates<'db>(db: &'db dyn HirDatabase, impl_: BuiltinDeriveImplId) -> GenericPredicates { let loc = impl_.loc(db); let generic_params = GenericParams::new(db, loc.adt.into()); let interner = DbInterner::new_with(db, loc.module(db).krate(db)); let adt_predicates = GenericPredicates::query(db, loc.adt.into()); + let trait_id = loc + .trait_ + .get_id(interner.lang_items()) + .expect("we don't pass the impl to the solver if we can't resolve the trait"); match loc.trait_ { BuiltinDeriveImplTrait::Copy | BuiltinDeriveImplTrait::Clone @@ -127,7 +152,7 @@ pub(crate) fn builtin_derive_predicates<'db>( | BuiltinDeriveImplTrait::PartialOrd | BuiltinDeriveImplTrait::Eq | BuiltinDeriveImplTrait::PartialEq => { - simple_trait_predicates(interner, loc, &generic_params, adt_predicates) + simple_trait_predicates(interner, loc, &generic_params, adt_predicates, trait_id) } BuiltinDeriveImplTrait::Default => { if matches!(loc.adt, AdtId::EnumId(_)) { @@ -137,12 +162,12 @@ pub(crate) fn builtin_derive_predicates<'db>( .store(), )) } else { - simple_trait_predicates(interner, loc, &generic_params, adt_predicates) + simple_trait_predicates(interner, loc, &generic_params, adt_predicates, trait_id) } } BuiltinDeriveImplTrait::CoerceUnsized | BuiltinDeriveImplTrait::DispatchFromDyn => { let Some((pointee_param_idx, pointee_param_id, new_param_ty)) = - coerce_pointee_params(interner, loc, &generic_params) + coerce_pointee_params(interner, loc, &generic_params, trait_id) else { // Malformed derive. return GenericPredicates::from_explicit_own_predicates(StoredEarlyBinder::bind( @@ -181,6 +206,12 @@ pub(crate) fn builtin_derive_predicates<'db>( } } +/// Not cached in a query, currently used in `hir` only. If you need this in `hir-ty` consider introducing a query. +pub fn param_env<'db>(interner: DbInterner<'db>, id: BuiltinDeriveImplId) -> ParamEnv<'db> { + let predicates = predicates(interner.db, id); + crate::lower::param_env_from_predicates(interner, predicates) +} + struct MentionsPointee { pointee_param_idx: u32, } @@ -216,11 +247,8 @@ fn simple_trait_predicates<'db>( loc: &BuiltinDeriveImplLoc, generic_params: &GenericParams, adt_predicates: &GenericPredicates, + trait_id: TraitId, ) -> GenericPredicates { - let trait_id = loc - .trait_ - .get_id(interner.lang_items()) - .expect("we don't pass the impl to the solver if we can't resolve the trait"); let extra_predicates = generic_params .iter_type_or_consts() .filter(|(_, data)| matches!(data, TypeOrConstParamData::TypeParamData(_))) @@ -309,6 +337,7 @@ fn coerce_pointee_params<'db>( interner: DbInterner<'db>, loc: &BuiltinDeriveImplLoc, generic_params: &GenericParams, + trait_id: TraitId, ) -> Option<(u32, TypeParamId, Ty<'db>)> { let pointee_param = { if let Ok((pointee_param, _)) = generic_params @@ -339,7 +368,7 @@ fn coerce_pointee_params<'db>( let pointee_param_idx = pointee_param.into_raw().into_u32() + (generic_params.len_lifetimes() as u32); let new_param_idx = generic_params.len() as u32; - let new_param_id = fake_type_param(loc.adt); + let new_param_id = coerce_pointee_new_type_param(trait_id); let new_param_ty = Ty::new_param(interner, new_param_id, new_param_idx); Some((pointee_param_idx, pointee_param_id, new_param_ty)) } @@ -352,11 +381,7 @@ mod tests { use stdx::format_to; use test_fixture::WithFixture; - use crate::{ - builtin_derive::{builtin_derive_predicates, impl_trait}, - next_solver::DbInterner, - test_db::TestDB, - }; + use crate::{builtin_derive::impl_trait, next_solver::DbInterner, test_db::TestDB}; fn check_trait_refs(#[rust_analyzer::rust_fixture] ra_fixture: &str, expectation: Expect) { let db = TestDB::with_files(ra_fixture); @@ -384,8 +409,7 @@ mod tests { let mut predicates = String::new(); for (_, module) in def_map.modules() { for derive in module.scope.builtin_derive_impls() { - let preds = - builtin_derive_predicates(&db, derive).all_predicates().skip_binder(); + let preds = super::predicates(&db, derive).all_predicates().skip_binder(); format_to!( predicates, "{}\n\n", |