Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
| -rw-r--r-- | crates/hir-ty/src/infer.rs | 231 |
1 files changed, 180 insertions, 51 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 35d744e7d1..d14e9d6526 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -33,14 +33,15 @@ use std::{cell::OnceCell, convert::identity, iter}; use base_db::Crate; use either::Either; use hir_def::{ - AdtId, AssocItemId, ConstId, DefWithBodyId, FieldId, FunctionId, GenericDefId, GenericParamId, - ItemContainerId, LocalFieldId, Lookup, TraitId, TupleFieldId, TupleId, TypeAliasId, VariantId, - expr_store::{Body, ExpressionStore, HygieneId, path::Path}, + AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, + FunctionId, GenericDefId, GenericParamId, ItemContainerId, LocalFieldId, Lookup, TraitId, + TupleFieldId, TupleId, TypeAliasId, TypeOrConstParamId, VariantId, + expr_store::{Body, ExpressionStore, HygieneId, RootExprOrigin, path::Path}, hir::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, LabelId, PatId}, lang_item::LangItems, layout::Integer, resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs}, - signatures::{ConstSignature, EnumSignature, StaticSignature}, + signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId}, }; use hir_expand::{mod_path::ModPath, name::Name}; @@ -104,19 +105,18 @@ pub fn infer_query_with_inspect<'db>( ) -> InferenceResult { let _p = tracing::info_span!("infer_query").entered(); let resolver = def.resolver(db); - let body = db.body(def); - let mut ctx = InferenceContext::new(db, def, &body, resolver); + let body = Body::of(db, def); + let mut ctx = + InferenceContext::new(db, ExpressionStoreOwnerId::Body(def), &body.store, resolver); if let Some(inspect) = inspect { ctx.table.infer_ctxt.attach_obligation_inspector(inspect); } match def { - DefWithBodyId::FunctionId(f) => { - ctx.collect_fn(f); - } - DefWithBodyId::ConstId(c) => ctx.collect_const(c, &db.const_signature(c)), - DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_signature(s)), + DefWithBodyId::FunctionId(f) => ctx.collect_fn(f, body.self_param, &body.params), + DefWithBodyId::ConstId(c) => ctx.collect_const(c, ConstSignature::of(db, c)), + DefWithBodyId::StaticId(s) => ctx.collect_static(StaticSignature::of(db, s)), DefWithBodyId::VariantId(v) => { ctx.return_ty = match EnumSignature::variant_body_type(db, v.lookup(db).parent) { hir_def::layout::IntegerType::Pointer(signed) => match signed { @@ -143,10 +143,113 @@ pub fn infer_query_with_inspect<'db>( } } - ctx.infer_body(); + ctx.infer_body(body.root_expr()); + + ctx.infer_mut_body(body.root_expr()); + + infer_finalize(ctx) +} + +fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} + +/// Infer types for all const expressions in an item's signature. +/// +/// This handles const expressions that appear in type positions within a generic +/// item's signature, such as array lengths (`[T; N]`) and const generic arguments +/// (`Foo<{ expr }>`). Each root expression is inferred independently within +/// a shared `InferenceContext`, accumulating results into a single `InferenceResult`. +fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { + let _p = tracing::info_span!("infer_signature_query").entered(); + let store = ExpressionStore::of(db, def.into()); + let mut roots = store.expr_roots_with_origins().peekable(); + let Some(_) = roots.peek() else { + return InferenceResult::new(crate::next_solver::default_types(db).types.error); + }; + + let resolver = def.resolver(db); + let owner = ExpressionStoreOwnerId::Signature(def); + + let mut ctx = InferenceContext::new(db, owner, store, resolver); + + for (root_expr, origin) in roots { + let expected = match origin { + // Array lengths are always `usize`. + RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize), + // Const parameter default: look up the param's declared type. + RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns( + ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }), + )), + // Path const generic args: determining the expected type requires + // path resolution. + // FIXME + RootExprOrigin::GenericArgsPath => Expectation::None, + RootExprOrigin::BodyRoot => Expectation::None, + }; + ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes); + } + + infer_finalize(ctx) +} + +fn infer_variant_fields_query(db: &dyn HirDatabase, def: VariantId) -> InferenceResult { + let _p = tracing::info_span!("infer_variant_fields_query").entered(); + let store = ExpressionStore::of(db, def.into()); + let mut roots = store.expr_roots_with_origins().peekable(); + let Some(_) = roots.peek() else { + return InferenceResult::new(crate::next_solver::default_types(db).types.error); + }; + + let resolver = def.resolver(db); + let owner = ExpressionStoreOwnerId::VariantFields(def); + + let mut ctx = InferenceContext::new(db, owner, store, resolver); + + for (root_expr, origin) in roots { + let expected = match origin { + // Array lengths are always `usize`. + RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize), + // unreachable + RootExprOrigin::ConstParam(_) => Expectation::None, + // Path const generic args: determining the expected type requires + // path resolution. + // FIXME + RootExprOrigin::GenericArgsPath => Expectation::None, + RootExprOrigin::BodyRoot => Expectation::None, + }; + ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes); + } + + infer_finalize(ctx) +} + +fn infer_signature_cycle_result( + db: &dyn HirDatabase, + _: salsa::Id, + _: GenericDefId, +) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} - ctx.infer_mut_body(); +fn infer_variant_fields_cycle_result( + db: &dyn HirDatabase, + _: salsa::Id, + _: VariantId, +) -> InferenceResult { + InferenceResult { + has_errors: true, + ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) + } +} +fn infer_finalize(mut ctx: InferenceContext<'_, '_>) -> InferenceResult { ctx.handle_opaque_type_uses(); ctx.type_inference_fallback(); @@ -171,14 +274,6 @@ pub fn infer_query_with_inspect<'db>( ctx.resolve_all() } - -fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult { - InferenceResult { - has_errors: true, - ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed)) - } -} - /// Binding modes inferred for patterns. /// <https://doc.rust-lang.org/reference/patterns.html#binding-modes> #[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] @@ -552,12 +647,39 @@ pub struct InferenceResult { #[salsa::tracked] impl InferenceResult { #[salsa::tracked(returns(ref), cycle_result = infer_cycle_result)] - pub fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { + fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { infer_query(db, def) } + + /// Infer types for all const expressions in an item's signature. + /// + /// Returns an `InferenceResult` containing type information for array lengths, + /// const generic arguments, and other const expressions appearing in type + /// positions within the item's signature. + #[salsa::tracked(returns(ref), cycle_result = infer_signature_cycle_result)] + fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult { + infer_signature_query(db, def) + } + + #[salsa::tracked(returns(ref), cycle_result = infer_variant_fields_cycle_result)] + fn for_variant_fields(db: &dyn HirDatabase, def: VariantId) -> InferenceResult { + infer_variant_fields_query(db, def) + } } impl InferenceResult { + pub fn of(db: &dyn HirDatabase, def: impl Into<ExpressionStoreOwnerId>) -> &InferenceResult { + match def.into() { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + Self::for_signature(db, generic_def_id) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => Self::for_body(db, def_with_body_id), + ExpressionStoreOwnerId::VariantFields(variant_id) => { + Self::for_variant_fields(db, variant_id) + } + } + } + fn new(error_ty: Ty<'_>) -> Self { Self { method_resolutions: Default::default(), @@ -754,8 +876,8 @@ impl InferenceResult { #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'body, 'db> { pub(crate) db: &'db dyn HirDatabase, - pub(crate) owner: DefWithBodyId, - pub(crate) body: &'body Body, + pub(crate) owner: ExpressionStoreOwnerId, + pub(crate) store: &'body ExpressionStore, /// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext /// and resolve the path via its methods. This will ensure proper error reporting. pub(crate) resolver: Resolver<'db>, @@ -855,11 +977,21 @@ fn find_continuable<'a, 'db>( impl<'body, 'db> InferenceContext<'body, 'db> { fn new( db: &'db dyn HirDatabase, - owner: DefWithBodyId, - body: &'body Body, + owner: ExpressionStoreOwnerId, + store: &'body ExpressionStore, resolver: Resolver<'db>, ) -> Self { - let trait_env = db.trait_environment_for_body(owner); + let trait_env = match owner { + ExpressionStoreOwnerId::Signature(generic_def_id) => { + db.trait_environment(ExpressionStoreOwnerId::from(generic_def_id)) + } + ExpressionStoreOwnerId::Body(def_with_body_id) => { + db.trait_environment(ExpressionStoreOwnerId::Body(def_with_body_id)) + } + ExpressionStoreOwnerId::VariantFields(variant_id) => { + db.trait_environment(ExpressionStoreOwnerId::VariantFields(variant_id)) + } + }; let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), Some(owner)); let types = crate::next_solver::default_types(db); InferenceContext { @@ -878,13 +1010,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return_coercion: None, db, owner, - generic_def: match owner { - DefWithBodyId::FunctionId(it) => it.into(), - DefWithBodyId::StaticId(it) => it.into(), - DefWithBodyId::ConstId(it) => it.into(), - DefWithBodyId::VariantId(it) => it.lookup(db).parent.into(), - }, - body, + generic_def: owner.generic_def(db), + store, traits_in_scope: resolver.traits_in_scope(db), resolver, diverges: Diverges::Maybe, @@ -908,7 +1035,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> { fn target_features(&self) -> (&TargetFeatures<'db>, TargetFeatureIsSafeInTarget) { let (target_features, target_feature_is_safe) = self.target_features.get_or_init(|| { let target_features = match self.owner { - DefWithBodyId::FunctionId(id) => TargetFeatures::from_fn(self.db, id), + ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(id)) => { + TargetFeatures::from_fn(self.db, id) + } _ => TargetFeatures::default(), }; let target_feature_is_safe = match &self.krate().workspace_data(self.db).target { @@ -1102,12 +1231,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { self.return_ty = return_ty; } - fn collect_fn(&mut self, func: FunctionId) { - let data = self.db.function_signature(func); + fn collect_fn(&mut self, func: FunctionId, self_param: Option<BindingId>, params: &[PatId]) { + let data = FunctionSignature::of(self.db, func); let mut param_tys = self.with_ty_lowering( &data.store, InferenceTyDiagnosticSource::Signature, - LifetimeElisionKind::for_fn_params(&data), + LifetimeElisionKind::for_fn_params(data), |ctx| data.params.iter().map(|&type_ref| ctx.lower_ty(type_ref)).collect::<Vec<_>>(), ); @@ -1130,13 +1259,13 @@ impl<'body, 'db> InferenceContext<'body, 'db> { param_tys.push(va_list_ty); } let mut param_tys = param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var())); - if let Some(self_param) = self.body.self_param + if let Some(self_param) = self_param && let Some(ty) = param_tys.next() { let ty = self.process_user_written_ty(ty); self.write_binding_ty(self_param, ty); } - for (ty, pat) in param_tys.zip(&*self.body.params) { + for (ty, pat) in param_tys.zip(params) { let ty = self.process_user_written_ty(ty); self.infer_top_pat(*pat, ty, None); @@ -1170,12 +1299,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { &self.table.infer_ctxt } - fn infer_body(&mut self) { + fn infer_body(&mut self, body_expr: ExprId) { match self.return_coercion { - Some(_) => self.infer_return(self.body.body_expr), + Some(_) => self.infer_return(body_expr), None => { _ = self.infer_expr_coerce( - self.body.body_expr, + body_expr, &Expectation::has_type(self.return_ty), ExprIsRead::Yes, ) @@ -1282,7 +1411,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { f: impl FnOnce(&mut TyLoweringContext<'db, '_>) -> R, ) -> R { self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, f, @@ -1324,7 +1453,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> { self.make_ty( type_ref, - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, ) @@ -1332,7 +1461,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_const(&mut self, const_ref: ConstRef, ty: Ty<'db>) -> Const<'db> { let const_ = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_const(const_ref, ty), @@ -1342,7 +1471,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> { let const_ = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_path_as_const(path, ty), @@ -1356,7 +1485,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pub(crate) fn make_body_lifetime(&mut self, lifetime_ref: LifetimeRefId) -> Region<'db> { let lt = self.with_ty_lowering( - self.body, + self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_lifetime(lifetime_ref), @@ -1571,7 +1700,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { let mut ctx = TyLoweringContext::new( self.db, &self.resolver, - &self.body.store, + self.store, &self.diagnostics, InferenceTyDiagnosticSource::Body, self.generic_def, @@ -1584,7 +1713,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return (self.err_ty(), None); }; match res { - ResolveValueResult::ValueNs(value, _) => match value { + ResolveValueResult::ValueNs(value) => match value { ValueNs::EnumVariantId(var) => { let args = path_ctx.substs_from_path(var.into(), true, false); drop(ctx); @@ -1608,7 +1737,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return (self.err_ty(), None); } }, - ResolveValueResult::Partial(typens, unresolved, _) => (typens, Some(unresolved)), + ResolveValueResult::Partial(typens, unresolved) => (typens, Some(unresolved)), } } else { match path_ctx.resolve_path_in_type_ns() { |