Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/unify.rs')
-rw-r--r--crates/hir-ty/src/infer/unify.rs332
1 files changed, 167 insertions, 165 deletions
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs
index dd7e77ba8c..108cf5b1a2 100644
--- a/crates/hir-ty/src/infer/unify.rs
+++ b/crates/hir-ty/src/infer/unify.rs
@@ -3,8 +3,7 @@
use std::fmt;
use chalk_ir::{
- CanonicalVarKind, FloatTy, IntTy, TyVariableKind, cast::Cast, fold::TypeFoldable,
- interner::HasInterner,
+ CanonicalVarKind, TyVariableKind, cast::Cast, fold::TypeFoldable, interner::HasInterner,
};
use either::Either;
use hir_def::{AdtId, lang_item::LangItem};
@@ -12,7 +11,7 @@ use hir_expand::name::Name;
use intern::sym;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_type_ir::{
- FloatVid, IntVid, TyVid, TypeVisitableExt, UpcastFrom,
+ TyVid, TypeVisitableExt, UpcastFrom,
inherent::{IntoKind, Span, Term as _, Ty as _},
relate::{Relate, solver_relating::RelateExt},
solve::{Certainty, GoalSource},
@@ -23,8 +22,8 @@ use triomphe::Arc;
use super::{InferResult, InferenceContext, TypeError};
use crate::{
AliasTy, BoundVar, Canonical, Const, ConstValue, DebruijnIndex, GenericArg, GenericArgData,
- InferenceVar, Interner, Lifetime, OpaqueTyId, ProjectionTy, Scalar, Substitution,
- TraitEnvironment, Ty, TyExt, TyKind, VariableKind,
+ InferenceVar, Interner, Lifetime, OpaqueTyId, ProjectionTy, Substitution, TraitEnvironment, Ty,
+ TyExt, TyKind, VariableKind,
consteval::unknown_const,
db::HirDatabase,
fold_generic_args, fold_tys_and_consts,
@@ -143,7 +142,6 @@ pub fn could_unify_deeply(
let ty1_with_vars = table.normalize_associated_types_in(ty1_with_vars);
let ty2_with_vars = table.normalize_associated_types_in(ty2_with_vars);
table.select_obligations_where_possible();
- table.propagate_diverging_flag();
let ty1_with_vars = table.resolve_completely(ty1_with_vars);
let ty2_with_vars = table.resolve_completely(ty2_with_vars);
table.unify_deeply(&ty1_with_vars, &ty2_with_vars)
@@ -170,13 +168,19 @@ pub(crate) fn unify(
GenericArgData::Const(c) => c.inference_var(Interner),
} == Some(iv))
};
- let fallback = |iv, kind, default, binder| match kind {
- chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv)
- .map_or(default, |i| BoundVar::new(binder, i).to_ty(Interner).cast(Interner)),
- chalk_ir::VariableKind::Lifetime => find_var(iv)
- .map_or(default, |i| BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner)),
- chalk_ir::VariableKind::Const(ty) => find_var(iv)
- .map_or(default, |i| BoundVar::new(binder, i).to_const(Interner, ty).cast(Interner)),
+ let fallback = |iv, kind, binder| match kind {
+ chalk_ir::VariableKind::Ty(_ty_kind) => find_var(iv).map_or_else(
+ || TyKind::Error.intern(Interner).cast(Interner),
+ |i| BoundVar::new(binder, i).to_ty(Interner).cast(Interner),
+ ),
+ chalk_ir::VariableKind::Lifetime => find_var(iv).map_or_else(
+ || crate::error_lifetime().cast(Interner),
+ |i| BoundVar::new(binder, i).to_lifetime(Interner).cast(Interner),
+ ),
+ chalk_ir::VariableKind::Const(ty) => find_var(iv).map_or_else(
+ || crate::unknown_const(ty.clone()).cast(Interner),
+ |i| BoundVar::new(binder, i).to_const(Interner, ty.clone()).cast(Interner),
+ ),
};
Some(Substitution::from_iter(
Interner,
@@ -215,14 +219,13 @@ pub(crate) struct InferenceTable<'db> {
pub(crate) trait_env: Arc<TraitEnvironment<'db>>,
pub(crate) tait_coercion_table: Option<FxHashMap<OpaqueTyId, Ty>>,
pub(crate) infer_ctxt: InferCtxt<'db>,
- diverging_tys: FxHashSet<Ty>,
pub(super) fulfillment_cx: FulfillmentCtxt<'db>,
+ pub(super) diverging_type_vars: FxHashSet<crate::next_solver::Ty<'db>>,
}
pub(crate) struct InferenceTableSnapshot<'db> {
ctxt_snapshot: CombinedSnapshot,
obligations: FulfillmentCtxt<'db>,
- diverging_tys: FxHashSet<Ty>,
}
impl<'db> InferenceTable<'db> {
@@ -238,7 +241,7 @@ impl<'db> InferenceTable<'db> {
tait_coercion_table: None,
fulfillment_cx: FulfillmentCtxt::new(&infer_ctxt),
infer_ctxt,
- diverging_tys: FxHashSet::default(),
+ diverging_type_vars: FxHashSet::default(),
}
}
@@ -321,74 +324,8 @@ impl<'db> InferenceTable<'db> {
}
}
- /// Chalk doesn't know about the `diverging` flag, so when it unifies two
- /// type variables of which one is diverging, the chosen root might not be
- /// diverging and we have no way of marking it as such at that time. This
- /// function goes through all type variables and make sure their root is
- /// marked as diverging if necessary, so that resolving them gives the right
- /// result.
- pub(super) fn propagate_diverging_flag(&mut self) {
- let mut new_tys = FxHashSet::default();
- for ty in self.diverging_tys.iter() {
- match ty.kind(Interner) {
- TyKind::InferenceVar(var, kind) => match kind {
- TyVariableKind::General => {
- let root = InferenceVar::from(
- self.infer_ctxt.root_var(TyVid::from_u32(var.index())).as_u32(),
- );
- if root.index() != var.index() {
- new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner));
- }
- }
- TyVariableKind::Integer => {
- let root = InferenceVar::from(
- self.infer_ctxt
- .inner
- .borrow_mut()
- .int_unification_table()
- .find(IntVid::from_usize(var.index() as usize))
- .as_u32(),
- );
- if root.index() != var.index() {
- new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner));
- }
- }
- TyVariableKind::Float => {
- let root = InferenceVar::from(
- self.infer_ctxt
- .inner
- .borrow_mut()
- .float_unification_table()
- .find(FloatVid::from_usize(var.index() as usize))
- .as_u32(),
- );
- if root.index() != var.index() {
- new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner));
- }
- }
- },
- _ => {}
- }
- }
- self.diverging_tys.extend(new_tys);
- }
-
- pub(super) fn set_diverging(&mut self, iv: InferenceVar, kind: TyVariableKind) {
- self.diverging_tys.insert(TyKind::InferenceVar(iv, kind).intern(Interner));
- }
-
- fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty {
- let is_diverging =
- self.diverging_tys.contains(&TyKind::InferenceVar(iv, kind).intern(Interner));
- if is_diverging {
- return TyKind::Never.intern(Interner);
- }
- match kind {
- TyVariableKind::General => TyKind::Error,
- TyVariableKind::Integer => TyKind::Scalar(Scalar::Int(IntTy::I32)),
- TyVariableKind::Float => TyKind::Scalar(Scalar::Float(FloatTy::F64)),
- }
- .intern(Interner)
+ pub(super) fn set_diverging(&mut self, ty: crate::next_solver::Ty<'db>) {
+ self.diverging_type_vars.insert(ty);
}
pub(crate) fn canonicalize<T>(&mut self, t: T) -> rustc_type_ir::Canonical<DbInterner<'db>, T>
@@ -529,7 +466,7 @@ impl<'db> InferenceTable<'db> {
let ty = var.to_ty(Interner, kind);
if diverging {
- self.diverging_tys.insert(ty.clone());
+ self.diverging_type_vars.insert(ty.to_nextsolver(self.interner));
}
ty
}
@@ -573,7 +510,7 @@ impl<'db> InferenceTable<'db> {
pub(crate) fn resolve_with_fallback<T>(
&mut self,
t: T,
- fallback: &dyn Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg,
+ fallback: &dyn Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg,
) -> T
where
T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
@@ -615,7 +552,7 @@ impl<'db> InferenceTable<'db> {
fn resolve_with_fallback_inner<T>(
&mut self,
t: T,
- fallback: &dyn Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg,
+ fallback: &dyn Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg,
) -> T
where
T: HasInterner<Interner = Interner> + TypeFoldable<Interner>,
@@ -632,53 +569,15 @@ impl<'db> InferenceTable<'db> {
T: HasInterner<Interner = Interner> + TypeFoldable<Interner> + ChalkToNextSolver<'db, U>,
U: NextSolverToChalk<'db, T> + rustc_type_ir::TypeFoldable<DbInterner<'db>>,
{
- let t = self.resolve_with_fallback(t, &|_, _, d, _| d);
- let t = self.normalize_associated_types_in(t);
- // let t = self.resolve_opaque_tys_in(t);
- // Resolve again, because maybe normalization inserted infer vars.
- self.resolve_with_fallback(t, &|_, _, d, _| d)
- }
+ let value = t.to_nextsolver(self.interner);
+ let value = self.infer_ctxt.resolve_vars_if_possible(value);
- /// Apply a fallback to unresolved scalar types. Integer type variables and float type
- /// variables are replaced with i32 and f64, respectively.
- ///
- /// This method is only intended to be called just before returning inference results (i.e. in
- /// `InferenceContext::resolve_all()`).
- ///
- /// FIXME: This method currently doesn't apply fallback to unconstrained general type variables
- /// whereas rustc replaces them with `()` or `!`.
- pub(super) fn fallback_if_possible(&mut self) {
- let int_fallback = TyKind::Scalar(Scalar::Int(IntTy::I32)).intern(Interner);
- let float_fallback = TyKind::Scalar(Scalar::Float(FloatTy::F64)).intern(Interner);
-
- let int_vars = self.infer_ctxt.inner.borrow_mut().int_unification_table().len();
- for v in 0..int_vars {
- let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Integer);
- let maybe_resolved = self.resolve_ty_shallow(&var);
- if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) {
- // I don't think we can ever unify these vars with float vars, but keep this here for now
- let fallback = match kind {
- TyVariableKind::Integer => &int_fallback,
- TyVariableKind::Float => &float_fallback,
- TyVariableKind::General => unreachable!(),
- };
- self.unify(&var, fallback);
- }
- }
- let float_vars = self.infer_ctxt.inner.borrow_mut().float_unification_table().len();
- for v in 0..float_vars {
- let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Float);
- let maybe_resolved = self.resolve_ty_shallow(&var);
- if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) {
- // I don't think we can ever unify these vars with float vars, but keep this here for now
- let fallback = match kind {
- TyVariableKind::Integer => &int_fallback,
- TyVariableKind::Float => &float_fallback,
- TyVariableKind::General => unreachable!(),
- };
- self.unify(&var, fallback);
- }
- }
+ let mut goals = vec![];
+ let value = value.fold_with(&mut resolve_completely::Resolver::new(self, true, &mut goals));
+
+ // FIXME(next-solver): Handle `goals`.
+
+ value.to_chalk(self.interner)
}
/// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that.
@@ -829,15 +728,13 @@ impl<'db> InferenceTable<'db> {
pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot<'db> {
let ctxt_snapshot = self.infer_ctxt.start_snapshot();
- let diverging_tys = self.diverging_tys.clone();
let obligations = self.fulfillment_cx.clone();
- InferenceTableSnapshot { ctxt_snapshot, diverging_tys, obligations }
+ InferenceTableSnapshot { ctxt_snapshot, obligations }
}
#[tracing::instrument(skip_all)]
pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot<'db>) {
self.infer_ctxt.rollback_to(snapshot.ctxt_snapshot);
- self.diverging_tys = snapshot.diverging_tys;
self.fulfillment_cx = snapshot.obligations;
}
@@ -1166,14 +1063,10 @@ impl fmt::Debug for InferenceTable<'_> {
mod resolve {
use super::InferenceTable;
use crate::{
- ConcreteConst, Const, ConstData, ConstScalar, ConstValue, DebruijnIndex, GenericArg,
- InferenceVar, Interner, Lifetime, Ty, TyVariableKind, VariableKind,
- next_solver::mapping::NextSolverToChalk,
- };
- use chalk_ir::{
- cast::Cast,
- fold::{TypeFoldable, TypeFolder},
+ Const, DebruijnIndex, GenericArg, InferenceVar, Interner, Lifetime, Ty, TyVariableKind,
+ VariableKind, next_solver::mapping::NextSolverToChalk,
};
+ use chalk_ir::fold::{TypeFoldable, TypeFolder};
use rustc_type_ir::{FloatVid, IntVid, TyVid};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@@ -1187,7 +1080,7 @@ mod resolve {
pub(super) struct Resolver<
'a,
'b,
- F: Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg,
+ F: Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg,
> {
pub(super) table: &'a mut InferenceTable<'b>,
pub(super) var_stack: &'a mut Vec<(InferenceVar, VarKind)>,
@@ -1195,7 +1088,7 @@ mod resolve {
}
impl<F> TypeFolder<Interner> for Resolver<'_, '_, F>
where
- F: Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg,
+ F: Fn(InferenceVar, VariableKind, DebruijnIndex) -> GenericArg,
{
fn as_dyn(&mut self) -> &mut dyn TypeFolder<Interner> {
self
@@ -1217,8 +1110,7 @@ mod resolve {
let var = InferenceVar::from(vid.as_u32());
if self.var_stack.contains(&(var, VarKind::Ty(kind))) {
// recursive type
- let default = self.table.fallback_value(var, kind).cast(Interner);
- return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ return (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone();
}
@@ -1230,8 +1122,7 @@ mod resolve {
self.var_stack.pop();
result
} else {
- let default = self.table.fallback_value(var, kind).cast(Interner);
- (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone()
}
@@ -1247,8 +1138,7 @@ mod resolve {
let var = InferenceVar::from(vid.as_u32());
if self.var_stack.contains(&(var, VarKind::Ty(kind))) {
// recursive type
- let default = self.table.fallback_value(var, kind).cast(Interner);
- return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ return (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone();
}
@@ -1260,8 +1150,7 @@ mod resolve {
self.var_stack.pop();
result
} else {
- let default = self.table.fallback_value(var, kind).cast(Interner);
- (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone()
}
@@ -1277,8 +1166,7 @@ mod resolve {
let var = InferenceVar::from(vid.as_u32());
if self.var_stack.contains(&(var, VarKind::Ty(kind))) {
// recursive type
- let default = self.table.fallback_value(var, kind).cast(Interner);
- return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ return (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone();
}
@@ -1290,8 +1178,7 @@ mod resolve {
self.var_stack.pop();
result
} else {
- let default = self.table.fallback_value(var, kind).cast(Interner);
- (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder)
+ (self.fallback)(var, VariableKind::Ty(kind), outer_binder)
.assert_ty_ref(Interner)
.clone()
}
@@ -1310,15 +1197,9 @@ mod resolve {
.infer_ctxt
.root_const_var(rustc_type_ir::ConstVid::from_u32(var.index()));
let var = InferenceVar::from(vid.as_u32());
- let default = ConstData {
- ty: ty.clone(),
- value: ConstValue::Concrete(ConcreteConst { interned: ConstScalar::Unknown }),
- }
- .intern(Interner)
- .cast(Interner);
if self.var_stack.contains(&(var, VarKind::Const)) {
// recursive
- return (self.fallback)(var, VariableKind::Const(ty), default, outer_binder)
+ return (self.fallback)(var, VariableKind::Const(ty), outer_binder)
.assert_const_ref(Interner)
.clone();
}
@@ -1330,7 +1211,7 @@ mod resolve {
self.var_stack.pop();
result
} else {
- (self.fallback)(var, VariableKind::Const(ty), default, outer_binder)
+ (self.fallback)(var, VariableKind::Const(ty), outer_binder)
.assert_const_ref(Interner)
.clone()
}
@@ -1349,3 +1230,124 @@ mod resolve {
}
}
}
+
+mod resolve_completely {
+ use rustc_type_ir::{
+ DebruijnIndex, Flags, TypeFolder, TypeSuperFoldable,
+ inherent::{Const as _, Ty as _},
+ };
+
+ use crate::next_solver::Region;
+ use crate::{
+ infer::unify::InferenceTable,
+ next_solver::{
+ Const, DbInterner, ErrorGuaranteed, Goal, Predicate, Term, Ty,
+ infer::traits::ObligationCause,
+ normalize::deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals,
+ },
+ };
+
+ pub(super) struct Resolver<'a, 'db> {
+ ctx: &'a mut InferenceTable<'db>,
+ /// Whether we should normalize, disabled when resolving predicates.
+ should_normalize: bool,
+ nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
+ }
+
+ impl<'a, 'db> Resolver<'a, 'db> {
+ pub(super) fn new(
+ ctx: &'a mut InferenceTable<'db>,
+ should_normalize: bool,
+ nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
+ ) -> Resolver<'a, 'db> {
+ Resolver { ctx, nested_goals, should_normalize }
+ }
+
+ fn handle_term<T>(
+ &mut self,
+ value: T,
+ outer_exclusive_binder: impl FnOnce(T) -> DebruijnIndex,
+ ) -> T
+ where
+ T: Into<Term<'db>> + TypeSuperFoldable<DbInterner<'db>> + Copy,
+ {
+ let value = if self.should_normalize {
+ let cause = ObligationCause::new();
+ let at = self.ctx.infer_ctxt.at(&cause, self.ctx.trait_env.env);
+ let universes = vec![None; outer_exclusive_binder(value).as_usize()];
+ match deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals(
+ at, value, universes,
+ ) {
+ Ok((value, goals)) => {
+ self.nested_goals.extend(goals);
+ value
+ }
+ Err(_errors) => {
+ // FIXME: Report the error.
+ value
+ }
+ }
+ } else {
+ value
+ };
+
+ value.fold_with(&mut ReplaceInferWithError { interner: self.ctx.interner })
+ }
+ }
+
+ impl<'cx, 'db> TypeFolder<DbInterner<'db>> for Resolver<'cx, 'db> {
+ fn cx(&self) -> DbInterner<'db> {
+ self.ctx.interner
+ }
+
+ fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
+ if r.is_var() { Region::error(self.ctx.interner) } else { r }
+ }
+
+ fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
+ self.handle_term(ty, |it| it.outer_exclusive_binder())
+ }
+
+ fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
+ self.handle_term(ct, |it| it.outer_exclusive_binder())
+ }
+
+ fn fold_predicate(&mut self, predicate: Predicate<'db>) -> Predicate<'db> {
+ assert!(
+ !self.should_normalize,
+ "normalizing predicates in writeback is not generally sound"
+ );
+ predicate.super_fold_with(self)
+ }
+ }
+
+ struct ReplaceInferWithError<'db> {
+ interner: DbInterner<'db>,
+ }
+
+ impl<'db> TypeFolder<DbInterner<'db>> for ReplaceInferWithError<'db> {
+ fn cx(&self) -> DbInterner<'db> {
+ self.interner
+ }
+
+ fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
+ if t.is_infer() {
+ Ty::new_error(self.interner, ErrorGuaranteed)
+ } else {
+ t.super_fold_with(self)
+ }
+ }
+
+ fn fold_const(&mut self, c: Const<'db>) -> Const<'db> {
+ if c.is_ct_infer() {
+ Const::new_error(self.interner, ErrorGuaranteed)
+ } else {
+ c.super_fold_with(self)
+ }
+ }
+
+ fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
+ if r.is_var() { Region::error(self.interner) } else { r }
+ }
+ }
+}