Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
-rw-r--r--crates/hir-ty/src/infer.rs231
1 files changed, 180 insertions, 51 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 35d744e7d1..d14e9d6526 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -33,14 +33,15 @@ use std::{cell::OnceCell, convert::identity, iter};
use base_db::Crate;
use either::Either;
use hir_def::{
- AdtId, AssocItemId, ConstId, DefWithBodyId, FieldId, FunctionId, GenericDefId, GenericParamId,
- ItemContainerId, LocalFieldId, Lookup, TraitId, TupleFieldId, TupleId, TypeAliasId, VariantId,
- expr_store::{Body, ExpressionStore, HygieneId, path::Path},
+ AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwnerId, FieldId,
+ FunctionId, GenericDefId, GenericParamId, ItemContainerId, LocalFieldId, Lookup, TraitId,
+ TupleFieldId, TupleId, TypeAliasId, TypeOrConstParamId, VariantId,
+ expr_store::{Body, ExpressionStore, HygieneId, RootExprOrigin, path::Path},
hir::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, LabelId, PatId},
lang_item::LangItems,
layout::Integer,
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
- signatures::{ConstSignature, EnumSignature, StaticSignature},
+ signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature},
type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId},
};
use hir_expand::{mod_path::ModPath, name::Name};
@@ -104,19 +105,18 @@ pub fn infer_query_with_inspect<'db>(
) -> InferenceResult {
let _p = tracing::info_span!("infer_query").entered();
let resolver = def.resolver(db);
- let body = db.body(def);
- let mut ctx = InferenceContext::new(db, def, &body, resolver);
+ let body = Body::of(db, def);
+ let mut ctx =
+ InferenceContext::new(db, ExpressionStoreOwnerId::Body(def), &body.store, resolver);
if let Some(inspect) = inspect {
ctx.table.infer_ctxt.attach_obligation_inspector(inspect);
}
match def {
- DefWithBodyId::FunctionId(f) => {
- ctx.collect_fn(f);
- }
- DefWithBodyId::ConstId(c) => ctx.collect_const(c, &db.const_signature(c)),
- DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_signature(s)),
+ DefWithBodyId::FunctionId(f) => ctx.collect_fn(f, body.self_param, &body.params),
+ DefWithBodyId::ConstId(c) => ctx.collect_const(c, ConstSignature::of(db, c)),
+ DefWithBodyId::StaticId(s) => ctx.collect_static(StaticSignature::of(db, s)),
DefWithBodyId::VariantId(v) => {
ctx.return_ty = match EnumSignature::variant_body_type(db, v.lookup(db).parent) {
hir_def::layout::IntegerType::Pointer(signed) => match signed {
@@ -143,10 +143,113 @@ pub fn infer_query_with_inspect<'db>(
}
}
- ctx.infer_body();
+ ctx.infer_body(body.root_expr());
+
+ ctx.infer_mut_body(body.root_expr());
+
+ infer_finalize(ctx)
+}
+
+fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult {
+ InferenceResult {
+ has_errors: true,
+ ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed))
+ }
+}
+
+/// Infer types for all const expressions in an item's signature.
+///
+/// This handles const expressions that appear in type positions within a generic
+/// item's signature, such as array lengths (`[T; N]`) and const generic arguments
+/// (`Foo<{ expr }>`). Each root expression is inferred independently within
+/// a shared `InferenceContext`, accumulating results into a single `InferenceResult`.
+fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult {
+ let _p = tracing::info_span!("infer_signature_query").entered();
+ let store = ExpressionStore::of(db, def.into());
+ let mut roots = store.expr_roots_with_origins().peekable();
+ let Some(_) = roots.peek() else {
+ return InferenceResult::new(crate::next_solver::default_types(db).types.error);
+ };
+
+ let resolver = def.resolver(db);
+ let owner = ExpressionStoreOwnerId::Signature(def);
+
+ let mut ctx = InferenceContext::new(db, owner, store, resolver);
+
+ for (root_expr, origin) in roots {
+ let expected = match origin {
+ // Array lengths are always `usize`.
+ RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize),
+ // Const parameter default: look up the param's declared type.
+ RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns(
+ ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }),
+ )),
+ // Path const generic args: determining the expected type requires
+ // path resolution.
+ // FIXME
+ RootExprOrigin::GenericArgsPath => Expectation::None,
+ RootExprOrigin::BodyRoot => Expectation::None,
+ };
+ ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes);
+ }
+
+ infer_finalize(ctx)
+}
+
+fn infer_variant_fields_query(db: &dyn HirDatabase, def: VariantId) -> InferenceResult {
+ let _p = tracing::info_span!("infer_variant_fields_query").entered();
+ let store = ExpressionStore::of(db, def.into());
+ let mut roots = store.expr_roots_with_origins().peekable();
+ let Some(_) = roots.peek() else {
+ return InferenceResult::new(crate::next_solver::default_types(db).types.error);
+ };
+
+ let resolver = def.resolver(db);
+ let owner = ExpressionStoreOwnerId::VariantFields(def);
+
+ let mut ctx = InferenceContext::new(db, owner, store, resolver);
+
+ for (root_expr, origin) in roots {
+ let expected = match origin {
+ // Array lengths are always `usize`.
+ RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize),
+ // unreachable
+ RootExprOrigin::ConstParam(_) => Expectation::None,
+ // Path const generic args: determining the expected type requires
+ // path resolution.
+ // FIXME
+ RootExprOrigin::GenericArgsPath => Expectation::None,
+ RootExprOrigin::BodyRoot => Expectation::None,
+ };
+ ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes);
+ }
+
+ infer_finalize(ctx)
+}
+
+fn infer_signature_cycle_result(
+ db: &dyn HirDatabase,
+ _: salsa::Id,
+ _: GenericDefId,
+) -> InferenceResult {
+ InferenceResult {
+ has_errors: true,
+ ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed))
+ }
+}
- ctx.infer_mut_body();
+fn infer_variant_fields_cycle_result(
+ db: &dyn HirDatabase,
+ _: salsa::Id,
+ _: VariantId,
+) -> InferenceResult {
+ InferenceResult {
+ has_errors: true,
+ ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed))
+ }
+}
+fn infer_finalize(mut ctx: InferenceContext<'_, '_>) -> InferenceResult {
ctx.handle_opaque_type_uses();
ctx.type_inference_fallback();
@@ -171,14 +274,6 @@ pub fn infer_query_with_inspect<'db>(
ctx.resolve_all()
}
-
-fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> InferenceResult {
- InferenceResult {
- has_errors: true,
- ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed))
- }
-}
-
/// Binding modes inferred for patterns.
/// <https://doc.rust-lang.org/reference/patterns.html#binding-modes>
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
@@ -552,12 +647,39 @@ pub struct InferenceResult {
#[salsa::tracked]
impl InferenceResult {
#[salsa::tracked(returns(ref), cycle_result = infer_cycle_result)]
- pub fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult {
+ fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult {
infer_query(db, def)
}
+
+ /// Infer types for all const expressions in an item's signature.
+ ///
+ /// Returns an `InferenceResult` containing type information for array lengths,
+ /// const generic arguments, and other const expressions appearing in type
+ /// positions within the item's signature.
+ #[salsa::tracked(returns(ref), cycle_result = infer_signature_cycle_result)]
+ fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult {
+ infer_signature_query(db, def)
+ }
+
+ #[salsa::tracked(returns(ref), cycle_result = infer_variant_fields_cycle_result)]
+ fn for_variant_fields(db: &dyn HirDatabase, def: VariantId) -> InferenceResult {
+ infer_variant_fields_query(db, def)
+ }
}
impl InferenceResult {
+ pub fn of(db: &dyn HirDatabase, def: impl Into<ExpressionStoreOwnerId>) -> &InferenceResult {
+ match def.into() {
+ ExpressionStoreOwnerId::Signature(generic_def_id) => {
+ Self::for_signature(db, generic_def_id)
+ }
+ ExpressionStoreOwnerId::Body(def_with_body_id) => Self::for_body(db, def_with_body_id),
+ ExpressionStoreOwnerId::VariantFields(variant_id) => {
+ Self::for_variant_fields(db, variant_id)
+ }
+ }
+ }
+
fn new(error_ty: Ty<'_>) -> Self {
Self {
method_resolutions: Default::default(),
@@ -754,8 +876,8 @@ impl InferenceResult {
#[derive(Clone, Debug)]
pub(crate) struct InferenceContext<'body, 'db> {
pub(crate) db: &'db dyn HirDatabase,
- pub(crate) owner: DefWithBodyId,
- pub(crate) body: &'body Body,
+ pub(crate) owner: ExpressionStoreOwnerId,
+ pub(crate) store: &'body ExpressionStore,
/// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext
/// and resolve the path via its methods. This will ensure proper error reporting.
pub(crate) resolver: Resolver<'db>,
@@ -855,11 +977,21 @@ fn find_continuable<'a, 'db>(
impl<'body, 'db> InferenceContext<'body, 'db> {
fn new(
db: &'db dyn HirDatabase,
- owner: DefWithBodyId,
- body: &'body Body,
+ owner: ExpressionStoreOwnerId,
+ store: &'body ExpressionStore,
resolver: Resolver<'db>,
) -> Self {
- let trait_env = db.trait_environment_for_body(owner);
+ let trait_env = match owner {
+ ExpressionStoreOwnerId::Signature(generic_def_id) => {
+ db.trait_environment(ExpressionStoreOwnerId::from(generic_def_id))
+ }
+ ExpressionStoreOwnerId::Body(def_with_body_id) => {
+ db.trait_environment(ExpressionStoreOwnerId::Body(def_with_body_id))
+ }
+ ExpressionStoreOwnerId::VariantFields(variant_id) => {
+ db.trait_environment(ExpressionStoreOwnerId::VariantFields(variant_id))
+ }
+ };
let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), Some(owner));
let types = crate::next_solver::default_types(db);
InferenceContext {
@@ -878,13 +1010,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
return_coercion: None,
db,
owner,
- generic_def: match owner {
- DefWithBodyId::FunctionId(it) => it.into(),
- DefWithBodyId::StaticId(it) => it.into(),
- DefWithBodyId::ConstId(it) => it.into(),
- DefWithBodyId::VariantId(it) => it.lookup(db).parent.into(),
- },
- body,
+ generic_def: owner.generic_def(db),
+ store,
traits_in_scope: resolver.traits_in_scope(db),
resolver,
diverges: Diverges::Maybe,
@@ -908,7 +1035,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
fn target_features(&self) -> (&TargetFeatures<'db>, TargetFeatureIsSafeInTarget) {
let (target_features, target_feature_is_safe) = self.target_features.get_or_init(|| {
let target_features = match self.owner {
- DefWithBodyId::FunctionId(id) => TargetFeatures::from_fn(self.db, id),
+ ExpressionStoreOwnerId::Body(DefWithBodyId::FunctionId(id)) => {
+ TargetFeatures::from_fn(self.db, id)
+ }
_ => TargetFeatures::default(),
};
let target_feature_is_safe = match &self.krate().workspace_data(self.db).target {
@@ -1102,12 +1231,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
self.return_ty = return_ty;
}
- fn collect_fn(&mut self, func: FunctionId) {
- let data = self.db.function_signature(func);
+ fn collect_fn(&mut self, func: FunctionId, self_param: Option<BindingId>, params: &[PatId]) {
+ let data = FunctionSignature::of(self.db, func);
let mut param_tys = self.with_ty_lowering(
&data.store,
InferenceTyDiagnosticSource::Signature,
- LifetimeElisionKind::for_fn_params(&data),
+ LifetimeElisionKind::for_fn_params(data),
|ctx| data.params.iter().map(|&type_ref| ctx.lower_ty(type_ref)).collect::<Vec<_>>(),
);
@@ -1130,13 +1259,13 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
param_tys.push(va_list_ty);
}
let mut param_tys = param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var()));
- if let Some(self_param) = self.body.self_param
+ if let Some(self_param) = self_param
&& let Some(ty) = param_tys.next()
{
let ty = self.process_user_written_ty(ty);
self.write_binding_ty(self_param, ty);
}
- for (ty, pat) in param_tys.zip(&*self.body.params) {
+ for (ty, pat) in param_tys.zip(params) {
let ty = self.process_user_written_ty(ty);
self.infer_top_pat(*pat, ty, None);
@@ -1170,12 +1299,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
&self.table.infer_ctxt
}
- fn infer_body(&mut self) {
+ fn infer_body(&mut self, body_expr: ExprId) {
match self.return_coercion {
- Some(_) => self.infer_return(self.body.body_expr),
+ Some(_) => self.infer_return(body_expr),
None => {
_ = self.infer_expr_coerce(
- self.body.body_expr,
+ body_expr,
&Expectation::has_type(self.return_ty),
ExprIsRead::Yes,
)
@@ -1282,7 +1411,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
f: impl FnOnce(&mut TyLoweringContext<'db, '_>) -> R,
) -> R {
self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
f,
@@ -1324,7 +1453,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> {
self.make_ty(
type_ref,
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
)
@@ -1332,7 +1461,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_const(&mut self, const_ref: ConstRef, ty: Ty<'db>) -> Const<'db> {
let const_ = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_const(const_ref, ty),
@@ -1342,7 +1471,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> {
let const_ = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_path_as_const(path, ty),
@@ -1356,7 +1485,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_lifetime(&mut self, lifetime_ref: LifetimeRefId) -> Region<'db> {
let lt = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_lifetime(lifetime_ref),
@@ -1571,7 +1700,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
let mut ctx = TyLoweringContext::new(
self.db,
&self.resolver,
- &self.body.store,
+ self.store,
&self.diagnostics,
InferenceTyDiagnosticSource::Body,
self.generic_def,
@@ -1584,7 +1713,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
return (self.err_ty(), None);
};
match res {
- ResolveValueResult::ValueNs(value, _) => match value {
+ ResolveValueResult::ValueNs(value) => match value {
ValueNs::EnumVariantId(var) => {
let args = path_ctx.substs_from_path(var.into(), true, false);
drop(ctx);
@@ -1608,7 +1737,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
return (self.err_ty(), None);
}
},
- ResolveValueResult::Partial(typens, unresolved, _) => (typens, Some(unresolved)),
+ ResolveValueResult::Partial(typens, unresolved) => (typens, Some(unresolved)),
}
} else {
match path_ctx.resolve_path_in_type_ns() {