Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir_ty/src/infer.rs')
-rw-r--r--crates/hir_ty/src/infer.rs130
1 files changed, 68 insertions, 62 deletions
diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs
index eca6b3a076..1892e30a42 100644
--- a/crates/hir_ty/src/infer.rs
+++ b/crates/hir_ty/src/infer.rs
@@ -59,7 +59,8 @@ mod closure;
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
let _p = profile::span("infer_query");
let resolver = def.resolver(db.upcast());
- let mut ctx = InferenceContext::new(db, def, resolver);
+ let body = db.body(def);
+ let mut ctx = InferenceContext::new(db, def, &body, resolver);
match def {
DefWithBodyId::ConstId(c) => ctx.collect_const(&db.const_data(c)),
@@ -360,7 +361,7 @@ impl Index<PatId> for InferenceResult {
pub(crate) struct InferenceContext<'a> {
pub(crate) db: &'a dyn HirDatabase,
pub(crate) owner: DefWithBodyId,
- pub(crate) body: Arc<Body>,
+ pub(crate) body: &'a Body,
pub(crate) resolver: Resolver,
table: unify::InferenceTable<'a>,
trait_env: Arc<TraitEnvironment>,
@@ -394,7 +395,12 @@ fn find_breakable<'c>(
}
impl<'a> InferenceContext<'a> {
- fn new(db: &'a dyn HirDatabase, owner: DefWithBodyId, resolver: Resolver) -> Self {
+ fn new(
+ db: &'a dyn HirDatabase,
+ owner: DefWithBodyId,
+ body: &'a Body,
+ resolver: Resolver,
+ ) -> Self {
let krate = owner.module(db.upcast()).krate();
let trait_env = owner
.as_generic_def_id()
@@ -406,46 +412,76 @@ impl<'a> InferenceContext<'a> {
return_ty: TyKind::Error.intern(Interner), // set in collect_fn_signature
db,
owner,
- body: db.body(owner),
+ body,
resolver,
diverges: Diverges::Maybe,
breakables: Vec::new(),
}
}
- fn err_ty(&self) -> Ty {
- self.result.standard_types.unknown.clone()
- }
+ fn resolve_all(self) -> InferenceResult {
+ let InferenceContext { mut table, mut result, .. } = self;
- fn resolve_all(mut self) -> InferenceResult {
// FIXME resolve obligations as well (use Guidance if necessary)
- self.table.resolve_obligations_as_possible();
+ table.resolve_obligations_as_possible();
// make sure diverging type variables are marked as such
- self.table.propagate_diverging_flag();
- let mut result = std::mem::take(&mut self.result);
+ table.propagate_diverging_flag();
for ty in result.type_of_expr.values_mut() {
- *ty = self.table.resolve_completely(ty.clone());
+ *ty = table.resolve_completely(ty.clone());
}
for ty in result.type_of_pat.values_mut() {
- *ty = self.table.resolve_completely(ty.clone());
+ *ty = table.resolve_completely(ty.clone());
}
for mismatch in result.type_mismatches.values_mut() {
- mismatch.expected = self.table.resolve_completely(mismatch.expected.clone());
- mismatch.actual = self.table.resolve_completely(mismatch.actual.clone());
+ mismatch.expected = table.resolve_completely(mismatch.expected.clone());
+ mismatch.actual = table.resolve_completely(mismatch.actual.clone());
}
for (_, subst) in result.method_resolutions.values_mut() {
- *subst = self.table.resolve_completely(subst.clone());
+ *subst = table.resolve_completely(subst.clone());
}
for adjustment in result.expr_adjustments.values_mut().flatten() {
- adjustment.target = self.table.resolve_completely(adjustment.target.clone());
+ adjustment.target = table.resolve_completely(adjustment.target.clone());
}
for adjustment in result.pat_adjustments.values_mut().flatten() {
- adjustment.target = self.table.resolve_completely(adjustment.target.clone());
+ adjustment.target = table.resolve_completely(adjustment.target.clone());
}
result
}
+ fn collect_const(&mut self, data: &ConstData) {
+ self.return_ty = self.make_ty(&data.type_ref);
+ }
+
+ fn collect_static(&mut self, data: &StaticData) {
+ self.return_ty = self.make_ty(&data.type_ref);
+ }
+
+ fn collect_fn(&mut self, data: &FunctionData) {
+ let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver)
+ .with_impl_trait_mode(ImplTraitLoweringMode::Param);
+ let param_tys =
+ data.params.iter().map(|(_, type_ref)| ctx.lower_ty(type_ref)).collect::<Vec<_>>();
+ for (ty, pat) in param_tys.into_iter().zip(self.body.params.iter()) {
+ let ty = self.insert_type_vars(ty);
+ let ty = self.normalize_associated_types_in(ty);
+
+ self.infer_pat(*pat, &ty, BindingMode::default());
+ }
+ let error_ty = &TypeRef::Error;
+ let return_ty = if data.has_async_kw() {
+ data.async_ret_type.as_deref().unwrap_or(error_ty)
+ } else {
+ &*data.ret_type
+ };
+ let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Disallowed); // FIXME implement RPIT
+ self.return_ty = return_ty;
+ }
+
+ fn infer_body(&mut self) {
+ self.infer_expr_coerce(self.body.body_expr, &Expectation::has_type(self.return_ty.clone()));
+ }
+
fn write_expr_ty(&mut self, expr: ExprId, ty: Ty) {
self.result.type_of_expr.insert(expr, ty);
}
@@ -491,6 +527,10 @@ impl<'a> InferenceContext<'a> {
self.make_ty_with_mode(type_ref, ImplTraitLoweringMode::Disallowed)
}
+ fn err_ty(&self) -> Ty {
+ self.result.standard_types.unknown.clone()
+ }
+
/// Replaces ConstScalar::Unknown by a new type var, so we can maybe still infer it.
fn insert_const_vars_shallow(&mut self, c: Const) -> Const {
let data = c.data(Interner);
@@ -544,6 +584,16 @@ impl<'a> InferenceContext<'a> {
self.table.unify(ty1, ty2)
}
+ /// Recurses through the given type, normalizing associated types mentioned
+ /// in it by replacing them by type variables and registering obligations to
+ /// resolve later. This should be done once for every type we get from some
+ /// type annotation (e.g. from a let type annotation, field type or function
+ /// call). `make_ty` handles this already, but e.g. for field types we need
+ /// to do it as well.
+ fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
+ self.table.normalize_associated_types_in(ty)
+ }
+
fn resolve_ty_shallow(&mut self, ty: &Ty) -> Ty {
self.resolve_obligations_as_possible();
self.table.resolve_ty_shallow(ty)
@@ -586,16 +636,6 @@ impl<'a> InferenceContext<'a> {
}
}
- /// Recurses through the given type, normalizing associated types mentioned
- /// in it by replacing them by type variables and registering obligations to
- /// resolve later. This should be done once for every type we get from some
- /// type annotation (e.g. from a let type annotation, field type or function
- /// call). `make_ty` handles this already, but e.g. for field types we need
- /// to do it as well.
- fn normalize_associated_types_in(&mut self, ty: Ty) -> Ty {
- self.table.normalize_associated_types_in(ty)
- }
-
fn resolve_variant(&mut self, path: Option<&Path>, value_ns: bool) -> (Ty, Option<VariantId>) {
let path = match path {
Some(path) => path,
@@ -727,40 +767,6 @@ impl<'a> InferenceContext<'a> {
}
}
- fn collect_const(&mut self, data: &ConstData) {
- self.return_ty = self.make_ty(&data.type_ref);
- }
-
- fn collect_static(&mut self, data: &StaticData) {
- self.return_ty = self.make_ty(&data.type_ref);
- }
-
- fn collect_fn(&mut self, data: &FunctionData) {
- let body = Arc::clone(&self.body); // avoid borrow checker problem
- let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver)
- .with_impl_trait_mode(ImplTraitLoweringMode::Param);
- let param_tys =
- data.params.iter().map(|(_, type_ref)| ctx.lower_ty(type_ref)).collect::<Vec<_>>();
- for (ty, pat) in param_tys.into_iter().zip(body.params.iter()) {
- let ty = self.insert_type_vars(ty);
- let ty = self.normalize_associated_types_in(ty);
-
- self.infer_pat(*pat, &ty, BindingMode::default());
- }
- let error_ty = &TypeRef::Error;
- let return_ty = if data.has_async_kw() {
- data.async_ret_type.as_deref().unwrap_or(error_ty)
- } else {
- &*data.ret_type
- };
- let return_ty = self.make_ty_with_mode(return_ty, ImplTraitLoweringMode::Disallowed); // FIXME implement RPIT
- self.return_ty = return_ty;
- }
-
- fn infer_body(&mut self) {
- self.infer_expr_coerce(self.body.body_expr, &Expectation::has_type(self.return_ty.clone()));
- }
-
fn resolve_lang_item(&self, name: Name) -> Option<LangItemTarget> {
let krate = self.resolver.krate();
self.db.lang_item(krate, name.to_smol_str())