Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/unify.rs')
-rw-r--r--crates/hir-ty/src/infer/unify.rs319
1 files changed, 261 insertions, 58 deletions
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs
index be4d370c24..4342375621 100644
--- a/crates/hir-ty/src/infer/unify.rs
+++ b/crates/hir-ty/src/infer/unify.rs
@@ -7,12 +7,13 @@ use hir_def::{ExpressionStoreOwnerId, GenericParamId, TraitId};
use rustc_hash::FxHashSet;
use rustc_type_ir::{
TyVid, TypeFoldable, TypeVisitableExt,
- inherent::{Const as _, GenericArg as _, IntoKind, Ty as _},
+ inherent::{GenericArg as _, IntoKind, Ty as _},
solve::Certainty,
};
use smallvec::SmallVec;
use crate::{
+ Span,
db::HirDatabase,
next_solver::{
Canonical, ClauseKind, Const, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs,
@@ -41,6 +42,10 @@ struct NestedObligationsForSelfTy<'a, 'db> {
impl<'a, 'db> ProofTreeVisitor<'db> for NestedObligationsForSelfTy<'a, 'db> {
type Result = ();
+ fn span(&self) -> Span {
+ self.root_cause.span()
+ }
+
fn config(&self) -> InspectConfig {
// Using an intentionally low depth to minimize the chance of future
// breaking changes in case we adapt the approach later on. This also
@@ -112,7 +117,7 @@ fn could_unify_impl<'db>(
let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis);
let cause = ObligationCause::dummy();
let at = infcx.at(&cause, env.param_env);
- let ((ty1_with_vars, ty2_with_vars), _) = infcx.instantiate_canonical(tys);
+ let ((ty1_with_vars, ty2_with_vars), _) = infcx.instantiate_canonical(Span::Dummy, tys);
let mut ctxt = ObligationCtxt::new(&infcx);
let can_unify = at
.eq(ty1_with_vars, ty2_with_vars)
@@ -245,12 +250,12 @@ impl<'db> InferenceTable<'db> {
self.diverging_type_vars.insert(ty);
}
- pub(crate) fn next_ty_var(&self) -> Ty<'db> {
- self.infer_ctxt.next_ty_var()
+ pub(crate) fn next_ty_var(&self, span: Span) -> Ty<'db> {
+ self.infer_ctxt.next_ty_var(span)
}
- pub(crate) fn next_const_var(&self) -> Const<'db> {
- self.infer_ctxt.next_const_var()
+ pub(crate) fn next_const_var(&self, span: Span) -> Const<'db> {
+ self.infer_ctxt.next_const_var(span)
}
pub(crate) fn next_int_var(&self) -> Ty<'db> {
@@ -261,31 +266,18 @@ impl<'db> InferenceTable<'db> {
self.infer_ctxt.next_float_var()
}
- pub(crate) fn new_maybe_never_var(&mut self) -> Ty<'db> {
- let var = self.next_ty_var();
+ pub(crate) fn new_maybe_never_var(&mut self, span: Span) -> Ty<'db> {
+ let var = self.next_ty_var(span);
self.set_diverging(var);
var
}
- pub(crate) fn next_region_var(&self) -> Region<'db> {
- self.infer_ctxt.next_region_var()
+ pub(crate) fn next_region_var(&self, span: Span) -> Region<'db> {
+ self.infer_ctxt.next_region_var(span)
}
- pub(crate) fn next_var_for_param(&self, id: GenericParamId) -> GenericArg<'db> {
- self.infer_ctxt.next_var_for_param(id)
- }
-
- pub(crate) fn resolve_completely<T>(&mut self, value: T) -> T
- where
- T: TypeFoldable<DbInterner<'db>>,
- {
- let value = self.infer_ctxt.resolve_vars_if_possible(value);
-
- let mut goals = vec![];
-
- // FIXME(next-solver): Handle `goals`.
-
- value.fold_with(&mut resolve_completely::Resolver::new(self, true, &mut goals))
+ pub(crate) fn var_for_def(&self, id: GenericParamId, span: Span) -> GenericArg<'db> {
+ self.infer_ctxt.var_for_def(id, span)
}
pub(crate) fn at<'a>(&'a self, cause: &'a ObligationCause) -> At<'a, 'db> {
@@ -319,8 +311,8 @@ impl<'db> InferenceTable<'db> {
}
/// Create a `GenericArgs` full of infer vars for `def`.
- pub(crate) fn fresh_args_for_item(&self, def: SolverDefId) -> GenericArgs<'db> {
- self.infer_ctxt.fresh_args_for_item(def)
+ pub(crate) fn fresh_args_for_item(&self, span: Span, def: SolverDefId) -> GenericArgs<'db> {
+ self.infer_ctxt.fresh_args_for_item(span, def)
}
/// Try to resolve `ty` to a structural type, normalizing aliases.
@@ -399,19 +391,21 @@ impl<'db> InferenceTable<'db> {
where
I: IntoIterator<Item = PredicateObligation<'db>>,
{
- obligations.into_iter().for_each(|obligation| {
- self.register_predicate(obligation);
- });
+ self.fulfillment_cx.register_predicate_obligations(&self.infer_ctxt, obligations);
}
/// checking later, during regionck, that `arg` is well-formed.
pub(crate) fn register_wf_obligation(&mut self, term: Term<'db>, cause: ObligationCause) {
- self.register_predicate(Obligation::new(
- self.interner(),
- cause,
- self.param_env,
- ClauseKind::WellFormed(term),
- ));
+ let _ = (term, cause);
+ // FIXME: We don't currently register an obligation here because we don't implement
+ // wf checking anyway and this function is currently often passed dummy spans, which could
+ // prevent reporting "type annotation needed" errors.
+ // self.register_predicate(Obligation::new(
+ // self.interner(),
+ // cause,
+ // self.param_env,
+ // ClauseKind::WellFormed(term),
+ // ));
}
/// Registers obligations that all `args` are well-formed.
@@ -421,34 +415,29 @@ impl<'db> InferenceTable<'db> {
}
}
- pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T
+ pub(super) fn insert_type_vars<T>(&mut self, ty: T, span: Span) -> T
where
T: TypeFoldable<DbInterner<'db>>,
{
- self.infer_ctxt.insert_type_vars(ty)
+ self.infer_ctxt.insert_type_vars(ty, span)
}
/// Whenever you lower a user-written type, you should call this.
- pub(crate) fn process_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
- self.process_remote_user_written_ty(ty)
- // FIXME: Register a well-formed obligation.
+ pub(crate) fn process_user_written_ty(&mut self, span: Span, ty: Ty<'db>) -> Ty<'db> {
+ let ty = self.insert_type_vars(ty, span);
+ self.try_structurally_resolve_type(ty)
}
/// The difference of this method from `process_user_written_ty()` is that this method doesn't register a well-formed obligation,
/// while `process_user_written_ty()` should (but doesn't currently).
pub(crate) fn process_remote_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
- let ty = self.insert_type_vars(ty);
+ let ty = self.insert_type_vars(ty, Span::Dummy);
// See https://github.com/rust-lang/rust/blob/cdb45c87e2cd43495379f7e867e3cc15dcee9f93/compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs#L487-L495:
// Even though the new solver only lazily normalizes usually, here we eagerly normalize so that not everything needs
// to normalize before inspecting the `TyKind`.
// FIXME(next-solver): We should not deeply normalize here, only shallowly.
self.try_structurally_resolve_type(ty)
}
-
- /// Replaces ConstScalar::Unknown by a new type var, so we can maybe still infer it.
- pub(super) fn insert_const_vars_shallow(&mut self, c: Const<'db>) -> Const<'db> {
- if c.is_ct_error() { self.next_const_var() } else { c }
- }
}
impl fmt::Debug for InferenceTable<'_> {
@@ -460,20 +449,234 @@ impl fmt::Debug for InferenceTable<'_> {
}
}
-mod resolve_completely {
- use rustc_type_ir::{DebruijnIndex, Flags, TypeFolder, TypeSuperFoldable};
+pub(super) mod resolve_completely {
+ use rustc_hash::FxHashSet;
+ use rustc_type_ir::{
+ DebruijnIndex, Flags, InferConst, InferTy, TypeFlags, TypeFoldable, TypeFolder,
+ TypeSuperFoldable, TypeVisitableExt, inherent::IntoKind,
+ };
+ use stdx::never;
+ use thin_vec::ThinVec;
use crate::{
- infer::unify::InferenceTable,
+ InferenceDiagnostic, Span,
+ infer::{TypeMismatch, unify::InferenceTable},
next_solver::{
- Const, DbInterner, Goal, Predicate, Region, Term, Ty,
+ Const, ConstKind, DbInterner, DefaultAny, GenericArg, Goal, Predicate, Region, Term,
+ TermKind, Ty, TyKind,
infer::{resolve::ReplaceInferWithError, traits::ObligationCause},
normalize::deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals,
},
};
+ pub(crate) struct WriteBackCtxt<'db> {
+ table: InferenceTable<'db>,
+ diagnostics: ThinVec<InferenceDiagnostic>,
+ has_errors: bool,
+ spans_emitted_type_must_be_known_for: FxHashSet<Span>,
+ types: &'db DefaultAny<'db>,
+ }
+
+ impl<'db> WriteBackCtxt<'db> {
+ pub(crate) fn new(
+ table: InferenceTable<'db>,
+ diagnostics: ThinVec<InferenceDiagnostic>,
+ vars_emitted_type_must_be_known_for: FxHashSet<Term<'db>>,
+ ) -> Self {
+ let spans_emitted_type_must_be_known_for = vars_emitted_type_must_be_known_for
+ .into_iter()
+ .filter_map(|term| match term.kind() {
+ TermKind::Ty(ty) => match ty.kind() {
+ TyKind::Infer(InferTy::TyVar(vid)) => {
+ Some(table.infer_ctxt.type_var_span(vid))
+ }
+ _ => None,
+ },
+ TermKind::Const(ct) => match ct.kind() {
+ ConstKind::Infer(InferConst::Var(vid)) => {
+ table.infer_ctxt.const_var_span(vid)
+ }
+ _ => None,
+ },
+ })
+ .collect();
+
+ Self {
+ types: table.interner().default_types(),
+ table,
+ diagnostics,
+ has_errors: false,
+ spans_emitted_type_must_be_known_for,
+ }
+ }
+
+ pub(crate) fn resolve_type_mismatch(&mut self, value_ref: &mut TypeMismatch) {
+ // Ignore diagnostics from type mismatches, which are diagnostics themselves.
+ // FIXME: We should make type mismatches just regular diagnostics.
+ let prev_diagnostics_len = self.diagnostics.len();
+ self.resolve_completely(value_ref);
+ self.diagnostics.truncate(prev_diagnostics_len);
+ }
+
+ pub(crate) fn resolve_completely<T>(&mut self, value_ref: &mut T)
+ where
+ T: TypeFoldable<DbInterner<'db>>,
+ {
+ self.resolve_completely_with_default(value_ref, value_ref.clone());
+ }
+
+ pub(crate) fn resolve_completely_with_default<T>(&mut self, value_ref: &mut T, default: T)
+ where
+ T: TypeFoldable<DbInterner<'db>>,
+ {
+ let value = std::mem::replace(value_ref, default);
+
+ let value = self.table.resolve_vars_if_possible(value);
+
+ let mut goals = vec![];
+
+ // FIXME(next-solver): Handle `goals`.
+
+ *value_ref = value.fold_with(&mut Resolver::new(self, true, &mut goals));
+ }
+
+ pub(crate) fn resolve_diagnostics(mut self) -> (ThinVec<InferenceDiagnostic>, bool) {
+ let has_errors = self.has_errors;
+
+ // Ignore diagnostics made from resolving diagnostics.
+ let mut diagnostics = std::mem::take(&mut self.diagnostics);
+ diagnostics.retain_mut(|diagnostic| {
+ self.resolve_completely(diagnostic);
+
+ if let InferenceDiagnostic::ExpectedFunction { found: ty, .. }
+ | InferenceDiagnostic::UnresolvedField { receiver: ty, .. }
+ | InferenceDiagnostic::UnresolvedMethodCall { receiver: ty, .. } = diagnostic
+ && ty.as_ref().references_non_lt_error()
+ {
+ false
+ } else {
+ true
+ }
+ });
+ diagnostics.shrink_to_fit();
+
+ (diagnostics, has_errors)
+ }
+ }
+
+ struct DiagnoseInferVars<'a, 'db> {
+ ctx: &'a mut WriteBackCtxt<'db>,
+ top_term: Term<'db>,
+ }
+
+ impl<'db> DiagnoseInferVars<'_, 'db> {
+ const TYPE_FLAGS: TypeFlags = TypeFlags::HAS_INFER.union(TypeFlags::HAS_NON_REGION_ERROR);
+
+ fn err_on_span(&mut self, span: Span) {
+ if !self.ctx.spans_emitted_type_must_be_known_for.insert(span) {
+ // Suppress duplicate diagnostics.
+ return;
+ }
+
+ if span.is_dummy() {
+ return;
+ }
+
+ // We have to be careful not to insert infer vars here, as we won't resolve this new diagnostic.
+ let top_term = self.top_term.fold_with(&mut ReplaceInferWithError::new(self.cx()));
+ self.ctx.diagnostics.push(InferenceDiagnostic::TypeMustBeKnown {
+ at_point: span,
+ top_term: Some(GenericArg::from(top_term).store()),
+ });
+ }
+ }
+
+ impl<'db> TypeFolder<DbInterner<'db>> for DiagnoseInferVars<'_, 'db> {
+ fn cx(&self) -> DbInterner<'db> {
+ self.ctx.table.interner()
+ }
+
+ fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
+ if !t.has_type_flags(Self::TYPE_FLAGS) {
+ return t;
+ }
+
+ match t.kind() {
+ TyKind::Error(_) => {
+ self.ctx.has_errors = true;
+ t
+ }
+ TyKind::Infer(infer_ty) => match infer_ty {
+ InferTy::TyVar(vid) => {
+ self.err_on_span(self.ctx.table.infer_ctxt.type_var_span(vid));
+ self.ctx.has_errors = true;
+ self.ctx.types.types.error
+ }
+ InferTy::IntVar(_) => {
+ never!("fallback should have resolved all int vars");
+ self.ctx.types.types.i32
+ }
+ InferTy::FloatVar(_) => {
+ never!("fallback should have resolved all float vars");
+ self.ctx.types.types.f64
+ }
+ InferTy::FreshTy(_) | InferTy::FreshIntTy(_) | InferTy::FreshFloatTy(_) => {
+ never!("should not have fresh infer vars outside of caching");
+ self.ctx.has_errors = true;
+ self.ctx.types.types.error
+ }
+ },
+ _ => t.super_fold_with(self),
+ }
+ }
+
+ fn fold_const(&mut self, c: Const<'db>) -> Const<'db> {
+ if !c.has_type_flags(Self::TYPE_FLAGS) {
+ return c;
+ }
+
+ match c.kind() {
+ ConstKind::Error(_) => {
+ self.ctx.has_errors = true;
+ c
+ }
+ ConstKind::Infer(infer_ct) => match infer_ct {
+ InferConst::Var(vid) => {
+ if let Some(span) = self.ctx.table.infer_ctxt.const_var_span(vid) {
+ self.err_on_span(span);
+ }
+ self.ctx.has_errors = true;
+ self.ctx.types.consts.error
+ }
+ InferConst::Fresh(_) => {
+ never!("should not have fresh infer vars outside of caching");
+ self.ctx.has_errors = true;
+ self.ctx.types.consts.error
+ }
+ },
+ _ => c.super_fold_with(self),
+ }
+ }
+
+ fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
+ if !p.has_type_flags(Self::TYPE_FLAGS) {
+ return p;
+ }
+ p.super_fold_with(self)
+ }
+
+ fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
+ if r.is_var() {
+ // For now, we don't error on regions.
+ self.ctx.types.regions.error
+ } else {
+ r
+ }
+ }
+ }
+
pub(super) struct Resolver<'a, 'db> {
- ctx: &'a mut InferenceTable<'db>,
+ ctx: &'a mut WriteBackCtxt<'db>,
/// Whether we should normalize, disabled when resolving predicates.
should_normalize: bool,
nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
@@ -481,7 +684,7 @@ mod resolve_completely {
impl<'a, 'db> Resolver<'a, 'db> {
pub(super) fn new(
- ctx: &'a mut InferenceTable<'db>,
+ ctx: &'a mut WriteBackCtxt<'db>,
should_normalize: bool,
nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
) -> Resolver<'a, 'db> {
@@ -498,7 +701,7 @@ mod resolve_completely {
{
let value = if self.should_normalize {
let cause = ObligationCause::new();
- let at = self.ctx.at(&cause);
+ let at = self.ctx.table.at(&cause);
let universes = vec![None; outer_exclusive_binder(value).as_usize()];
match deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals(
at, value, universes,
@@ -516,17 +719,17 @@ mod resolve_completely {
value
};
- value.fold_with(&mut ReplaceInferWithError::new(self.ctx.interner()))
+ value.fold_with(&mut DiagnoseInferVars { ctx: self.ctx, top_term: value.into() })
}
}
- impl<'cx, 'db> TypeFolder<DbInterner<'db>> for Resolver<'cx, 'db> {
+ impl<'db> TypeFolder<DbInterner<'db>> for Resolver<'_, 'db> {
fn cx(&self) -> DbInterner<'db> {
- self.ctx.interner()
+ self.ctx.table.interner()
}
fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
- if r.is_var() { Region::error(self.ctx.interner()) } else { r }
+ if r.is_var() { self.ctx.types.regions.error } else { r }
}
fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {