Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
-rw-r--r--crates/hir-ty/src/infer.rs132
1 files changed, 98 insertions, 34 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 991acda14b..9d78f5de9e 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -33,9 +33,10 @@ use std::{cell::OnceCell, convert::identity, iter};
use base_db::Crate;
use either::Either;
use hir_def::{
- AdtId, AssocItemId, ConstId, DefWithBodyId, FieldId, FunctionId, GenericDefId, GenericParamId,
- ItemContainerId, LocalFieldId, Lookup, TraitId, TupleFieldId, TupleId, TypeAliasId, VariantId,
- expr_store::{Body, ExpressionStore, HygieneId, path::Path},
+ AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwner, FieldId,
+ FunctionId, GenericDefId, GenericParamId, ItemContainerId, LocalFieldId, Lookup, TraitId,
+ TupleFieldId, TupleId, TypeAliasId, TypeOrConstParamId, VariantId,
+ expr_store::{ConstExprOrigin, ExpressionStore, HygieneId, path::Path},
hir::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, LabelId, PatId},
lang_item::LangItems,
layout::Integer,
@@ -105,16 +106,14 @@ pub fn infer_query_with_inspect<'db>(
let _p = tracing::info_span!("infer_query").entered();
let resolver = def.resolver(db);
let body = db.body(def);
- let mut ctx = InferenceContext::new(db, def, &body, resolver);
+ let mut ctx = InferenceContext::new(db, ExpressionStoreOwner::Body(def), &body.store, resolver);
if let Some(inspect) = inspect {
ctx.table.infer_ctxt.attach_obligation_inspector(inspect);
}
match def {
- DefWithBodyId::FunctionId(f) => {
- ctx.collect_fn(f);
- }
+ DefWithBodyId::FunctionId(f) => ctx.collect_fn(f, body.self_param, &body.params),
DefWithBodyId::ConstId(c) => ctx.collect_const(c, &db.const_signature(c)),
DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_signature(s)),
DefWithBodyId::VariantId(v) => {
@@ -143,10 +142,14 @@ pub fn infer_query_with_inspect<'db>(
}
}
- ctx.infer_body();
+ ctx.infer_body(body.body_expr);
+
+ ctx.infer_mut_body(body.body_expr);
- ctx.infer_mut_body();
+ finalize_infer(ctx)
+}
+fn finalize_infer(mut ctx: InferenceContext<'_, '_>) -> InferenceResult {
ctx.handle_opaque_type_uses();
ctx.type_inference_fallback();
@@ -179,6 +182,55 @@ fn infer_cycle_result(db: &dyn HirDatabase, _: salsa::Id, _: DefWithBodyId) -> I
}
}
+/// Infer types for all const expressions in an item's signature.
+///
+/// This handles const expressions that appear in type positions within a generic
+/// item's signature, such as array lengths (`[T; N]`) and const generic arguments
+/// (`Foo<{ expr }>`). Each root expression is inferred independently within
+/// a shared `InferenceContext`, accumulating results into a single `InferenceResult`.
+fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult {
+ let _p = tracing::info_span!("infer_signature_query").entered();
+ let (_, store) = db.generic_params_and_store(def);
+ let mut roots = store.signature_const_expr_roots_with_origins().peekable();
+ let Some(_) = roots.peek() else {
+ return InferenceResult::new(crate::next_solver::default_types(db).types.error);
+ };
+
+ let resolver = def.resolver(db);
+ let owner = ExpressionStoreOwner::Signature(def);
+
+ let mut ctx = InferenceContext::new(db, owner, &store, resolver);
+
+ for (root_expr, origin) in roots {
+ let expected = match origin {
+ // Array lengths are always `usize`.
+ ConstExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize),
+ // Const parameter default: look up the param's declared type.
+ ConstExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns(
+ ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }),
+ )),
+ // Path const generic args: determining the expected type requires
+ // path resolution.
+ // FIXME
+ ConstExprOrigin::GenericArgsPath => Expectation::None,
+ };
+ ctx.infer_expr(root_expr, &expected, ExprIsRead::Yes);
+ }
+
+ finalize_infer(ctx)
+}
+
+fn infer_signature_cycle_result(
+ db: &dyn HirDatabase,
+ _: salsa::Id,
+ _: GenericDefId,
+) -> InferenceResult {
+ InferenceResult {
+ has_errors: true,
+ ..InferenceResult::new(Ty::new_error(DbInterner::new_no_crate(db), ErrorGuaranteed))
+ }
+}
+
/// Binding modes inferred for patterns.
/// <https://doc.rust-lang.org/reference/patterns.html#binding-modes>
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
@@ -555,6 +607,16 @@ impl InferenceResult {
pub fn for_body(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult {
infer_query(db, def)
}
+
+ /// Infer types for all const expressions in an item's signature.
+ ///
+ /// Returns an `InferenceResult` containing type information for array lengths,
+ /// const generic arguments, and other const expressions appearing in type
+ /// positions within the item's signature.
+ #[salsa::tracked(returns(ref), cycle_result = infer_signature_cycle_result)]
+ pub fn for_signature(db: &dyn HirDatabase, def: GenericDefId) -> InferenceResult {
+ infer_signature_query(db, def)
+ }
}
impl InferenceResult {
@@ -754,8 +816,8 @@ impl InferenceResult {
#[derive(Clone, Debug)]
pub(crate) struct InferenceContext<'body, 'db> {
pub(crate) db: &'db dyn HirDatabase,
- pub(crate) owner: DefWithBodyId,
- pub(crate) body: &'body Body,
+ pub(crate) owner: ExpressionStoreOwner,
+ pub(crate) store: &'body ExpressionStore,
/// Generally you should not resolve things via this resolver. Instead create a TyLoweringContext
/// and resolve the path via its methods. This will ensure proper error reporting.
pub(crate) resolver: Resolver<'db>,
@@ -855,11 +917,16 @@ fn find_continuable<'a, 'db>(
impl<'body, 'db> InferenceContext<'body, 'db> {
fn new(
db: &'db dyn HirDatabase,
- owner: DefWithBodyId,
- body: &'body Body,
+ owner: ExpressionStoreOwner,
+ store: &'body ExpressionStore,
resolver: Resolver<'db>,
) -> Self {
- let trait_env = db.trait_environment_for_body(owner);
+ let trait_env = match owner {
+ ExpressionStoreOwner::Signature(generic_def_id) => db.trait_environment(generic_def_id),
+ ExpressionStoreOwner::Body(def_with_body_id) => {
+ db.trait_environment_for_body(def_with_body_id)
+ }
+ };
let table = unify::InferenceTable::new(db, trait_env, resolver.krate(), Some(owner));
let types = crate::next_solver::default_types(db);
InferenceContext {
@@ -878,13 +945,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
return_coercion: None,
db,
owner,
- generic_def: match owner {
- DefWithBodyId::FunctionId(it) => it.into(),
- DefWithBodyId::StaticId(it) => it.into(),
- DefWithBodyId::ConstId(it) => it.into(),
- DefWithBodyId::VariantId(it) => it.lookup(db).parent.into(),
- },
- body,
+ generic_def: owner.generic_def(db),
+ store,
traits_in_scope: resolver.traits_in_scope(db),
resolver,
diverges: Diverges::Maybe,
@@ -908,7 +970,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
fn target_features(&self) -> (&TargetFeatures<'db>, TargetFeatureIsSafeInTarget) {
let (target_features, target_feature_is_safe) = self.target_features.get_or_init(|| {
let target_features = match self.owner {
- DefWithBodyId::FunctionId(id) => TargetFeatures::from_fn(self.db, id),
+ ExpressionStoreOwner::Body(DefWithBodyId::FunctionId(id)) => {
+ TargetFeatures::from_fn(self.db, id)
+ }
_ => TargetFeatures::default(),
};
let target_feature_is_safe = match &self.krate().workspace_data(self.db).target {
@@ -1102,7 +1166,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
self.return_ty = return_ty;
}
- fn collect_fn(&mut self, func: FunctionId) {
+ fn collect_fn(&mut self, func: FunctionId, self_param: Option<BindingId>, params: &[PatId]) {
let data = self.db.function_signature(func);
let mut param_tys = self.with_ty_lowering(
&data.store,
@@ -1130,13 +1194,13 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
param_tys.push(va_list_ty);
}
let mut param_tys = param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var()));
- if let Some(self_param) = self.body.self_param
+ if let Some(self_param) = self_param
&& let Some(ty) = param_tys.next()
{
let ty = self.process_user_written_ty(ty);
self.write_binding_ty(self_param, ty);
}
- for (ty, pat) in param_tys.zip(&*self.body.params) {
+ for (ty, pat) in param_tys.zip(params) {
let ty = self.process_user_written_ty(ty);
self.infer_top_pat(*pat, ty, None);
@@ -1170,12 +1234,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
&self.table.infer_ctxt
}
- fn infer_body(&mut self) {
+ fn infer_body(&mut self, body_expr: ExprId) {
match self.return_coercion {
- Some(_) => self.infer_return(self.body.body_expr),
+ Some(_) => self.infer_return(body_expr),
None => {
_ = self.infer_expr_coerce(
- self.body.body_expr,
+ body_expr,
&Expectation::has_type(self.return_ty),
ExprIsRead::Yes,
)
@@ -1282,7 +1346,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
f: impl FnOnce(&mut TyLoweringContext<'db, '_>) -> R,
) -> R {
self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
f,
@@ -1324,7 +1388,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> {
self.make_ty(
type_ref,
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
)
@@ -1332,7 +1396,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_const(&mut self, const_ref: ConstRef, ty: Ty<'db>) -> Const<'db> {
let const_ = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_const(const_ref, ty),
@@ -1342,7 +1406,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> {
let const_ = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_path_as_const(path, ty),
@@ -1356,7 +1420,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pub(crate) fn make_body_lifetime(&mut self, lifetime_ref: LifetimeRefId) -> Region<'db> {
let lt = self.with_ty_lowering(
- self.body,
+ self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_lifetime(lifetime_ref),
@@ -1571,7 +1635,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
let mut ctx = TyLoweringContext::new(
self.db,
&self.resolver,
- &self.body.store,
+ self.store,
&self.diagnostics,
InferenceTyDiagnosticSource::Body,
self.generic_def,