Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
| -rw-r--r-- | crates/hir-ty/src/infer.rs | 107 |
1 files changed, 59 insertions, 48 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 9d78f5de9e..52edbc899f 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -33,15 +33,15 @@ use std::{cell::OnceCell, convert::identity, iter}; use base_db::Crate; use either::Either; use hir_def::{ - AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwner, FieldId, + AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, FunctionId, GenericDefId, GenericParamId, ItemContainerId, LocalFieldId, Lookup, TraitId, TupleFieldId, TupleId, TypeAliasId, TypeOrConstParamId, VariantId, - expr_store::{ConstExprOrigin, ExpressionStore, HygieneId, path::Path}, + expr_store::{Body, ConstExprOrigin, ExpressionStore, HygieneId, path::Path}, hir::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, LabelId, PatId}, lang_item::LangItems, layout::Integer, resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs}, - signatures::{ConstSignature, EnumSignature, StaticSignature}, + signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId}, }; use hir_expand::{mod_path::ModPath, name::Name}; @@ -105,8 +105,9 @@ pub fn infer_query_with_inspect<'db>( ) -> InferenceResult { let _p = tracing::info_span!("infer_query").entered(); let resolver = def.resolver(db); - let body = db.body(def); - let mut ctx = InferenceContext::new(db, ExpressionStoreOwner::Body(def), &body.store, resolver); + let body = Body::of(db, def); + let mut ctx = + InferenceContext::new(db, ExpressionStoreOwnerId::Body(def), &body.store, resolver); if let Some(inspect) = inspect { ctx.table.infer_ctxt.attach_obligation_inspector(inspect); @@ -114,8 +115,8 @@ pub fn infer_query_with_inspect<'db>( match def { DefWithBodyId::FunctionId(f) => ctx.collect_fn(f, body.self_param, &body.params), - DefWithBodyId::ConstId(c) => ctx.collect_const(c, &db.const_signature(c)), - DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_signature(s)), + DefWithBodyId::ConstId(c) => ctx.collect_const(c, ConstSignature::of(db, c)), + DefWithBodyId::StaticId(s) => ctx.collect_static(StaticSignature::of(db, s)), DefWithBodyId::VariantId(v) => { ctx.return_ty = match EnumSignature::variant_body_type(db, v.lookup(db).parent) { hir_def::layout::IntegerType::Pointer(signed) => match signed { @@ -146,33 +147,7 @@ pub fn infer_query_with_inspect<'db>( ctx.infer_mut_body(body.body_expr); - finalize_infer(ctx) -} - -fn finalize_infer(mut ctx: InferenceContext<'_, '_>) -> InferenceResult { - ctx.handle_opaque_type_uses(); - - ctx.type_inference_fallback(); - - // Comment from rustc: - // Even though coercion casts provide type hints, we check casts after fallback for - // backwards compatibility. This makes fallback a stronger type hint than a cast coercion. - let cast_checks = std::mem::take(&mut ctx.deferred_cast_checks); - for mut cast in cast_checks.into_iter() { - if let Err(diag) = cast.check(&mut ctx) { - ctx.diagnostics.push(diag); - } - } - - ctx.table.select_obligations_where_possible(); - - ctx.infer_closures(); - - ctx.table.select_obligations_where_possible(); - - ctx.handle_opaque_type_uses(); - - ctx.resolve_all() + infer_finalize(ctx) } fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult { @@ -190,16 +165,16 @@ fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> I /// a shared `InferenceContext`, accumulating results into a single `InferenceResult`. fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { let _p = tracing::info_span!("infer_signature_query").entered(); - let (_, store) = db.generic_params_and_store(def); + let store = ExpressionStore::of(db, def.into()); let mut roots = store.signature_const_expr_roots_with_origins().peekable(); let Some(_) = roots.peek() else { return InferenceResult::new(crate::next_solver::default_types(db).types.error); }; let resolver = def.resolver(db); - let owner = ExpressionStoreOwner::Signature(def); + let owner = ExpressionStoreOwnerId::Signature(def); - let mut ctx = InferenceContext::new(db, owner, &store, resolver); + let mut ctx = InferenceContext::new(db, owner, store, resolver); for (root_expr, origin) in roots { let expected = match origin { @@ -217,7 +192,7 @@ fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceRe ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes); } - finalize_infer(ctx) + infer_finalize(ctx) } fn infer_signature_cycle_result( @@ -231,6 +206,31 @@ fn infer_signature_cycle_result( } } +fn infer_finalize(mut ctx: InferenceContext<'_, '_>) -> InferenceResult { + ctx.handle_opaque_type_uses(); + + ctx.type_inference_fallback(); + + // Comment from rustc: + // Even though coercion casts provide type hints, we check casts after fallback for + // backwards compatibility. This makes fallback a stronger type hint than a cast coercion. + let cast_checks = std::mem::take(&mut ctx.deferred_cast_checks); + for mut cast in cast_checks.into_iter() { + if let Err(diag) = cast.check(&mut ctx) { + ctx.diagnostics.push(diag); + } + } + + ctx.table.select_obligations_where_possible(); + + ctx.infer_closures(); + + ctx.table.select_obligations_where_possible(); + + ctx.handle_opaque_type_uses(); + + ctx.resolve_all() +} /// Binding modes inferred for patterns. /// <https://doc.rust-lang.org/reference/patterns.html#binding-modes> #[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] @@ -604,7 +604,7 @@ pub struct InferenceResult { #[salsa::tracked] impl InferenceResult { #[salsa::tracked(returns(ref), cycle_result = infer_cycle_result)] - pub fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { + fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { infer_query(db, def) } @@ -614,12 +614,21 @@ impl InferenceResult { /// const generic arguments, and other const expressions appearing in type /// positions within the item's signature. #[salsa::tracked(returns(ref), cycle_result = infer_signature_cycle_result)] - pub fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { + fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { infer_signature_query(db, def) } } impl InferenceResult { + pub fn of(db: &dyn HirDatabase, def: impl Into<ExpressionStoreOwnerId>) -> &InferenceResult { + match def.into() { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + Self::for_signature(db, generic_def_id) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Self::for_body(db, def_with_body_id), + } + } + fn new(error_ty: Ty<'_>) -> Self { Self { method_resolutions: Default::default(), @@ -816,7 +825,7 @@ impl InferenceResult { #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'body, 'db> { pub(crate) db: &'db dyn HirDatabase, - pub(crate) owner: ExpressionStoreOwner, + pub(crate) owner: ExpressionStoreOwnerId, pub(crate) store: &'body ExpressionStore, /// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext /// and resolve the path via its methods. This will ensure proper error reporting. @@ -917,14 +926,16 @@ fn find_continuable<'a, 'db>( impl<'body, 'db> InferenceContext<'body, 'db> { fn new( db: &'db dyn HirDatabase, - owner: ExpressionStoreOwner, + owner: ExpressionStoreOwnerId, store: &'body ExpressionStore, resolver: Resolver<'db>, ) -> Self { let trait_env = match owner { - ExpressionStoreOwner::Signature(generic_def_id) => db.trait_environment(generic_def_id), - ExpressionStoreOwner::Body(def_with_body_id) => { - db.trait_environment_for_body(def_with_body_id) + ExpressionStoreOwnerId::Signature(generic_def_id) => { + db.trait_environment(ExpressionStoreOwnerId::from(generic_def_id)) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => { + db.trait_environment(ExpressionStoreOwnerId::Body(def_with_body_id)) } }; let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), Some(owner)); @@ -970,7 +981,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { fn target_features(&self) -> (&TargetFeatures<'db>, TargetFeatureIsSafeInTarget) { let (target_features, target_feature_is_safe) = self.target_features.get_or_init(|| { let target_features = match self.owner { - ExpressionStoreOwner::Body(DefWithBodyId::FunctionId(id)) => { + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(id)) => { TargetFeatures::from_fn(self.db, id) } _ => TargetFeatures::default(), @@ -1167,11 +1178,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> { } fn collect_fn(&mut self, func: FunctionId, self_param: Option<BindingId>, params: &[PatId]) { - let data = self.db.function_signature(func); + let data = FunctionSignature::of(self.db, func); let mut param_tys = self.with_ty_lowering( &data.store, InferenceTyDiagnosticSource::Signature, - LifetimeElisionKind::for_fn_params(&data), + LifetimeElisionKind::for_fn_params(data), |ctx| data.params.iter().map(|&type_ref| ctx.lower_ty(type_ref)).collect::<Vec<_>>(), ); |