Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/unify.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/unify.rs | 849 |
1 files changed, 446 insertions, 403 deletions
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index bb4782bd41..ec4b7ee85d 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -3,16 +3,21 @@ use std::{fmt, mem}; use chalk_ir::{ - CanonicalVarKind, FloatTy, IntTy, TyVariableKind, UniverseIndex, cast::Cast, - fold::TypeFoldable, interner::HasInterner, zip::Zip, + CanonicalVarKind, FloatTy, IntTy, TyVariableKind, cast::Cast, fold::TypeFoldable, + interner::HasInterner, }; -use chalk_solve::infer::ParameterEnaVariableExt; use either::Either; -use ena::unify::UnifyKey; use hir_def::{AdtId, lang_item::LangItem}; use hir_expand::name::Name; use intern::sym; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_next_trait_solver::solve::HasChanged; +use rustc_type_ir::{ + AliasRelationDirection, FloatVid, IntVid, TyVid, + inherent::{Span, Term as _}, + relate::{Relate, solver_relating::RelateExt}, + solve::{Certainty, NoSolution}, +}; use smallvec::SmallVec; use triomphe::Arc; @@ -24,14 +29,25 @@ use crate::{ TraitRef, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause, consteval::unknown_const, db::HirDatabase, - fold_generic_args, fold_tys_and_consts, to_chalk_trait_id, - traits::{FnTrait, NextTraitSolveResult}, + fold_generic_args, fold_tys_and_consts, + next_solver::{ + self, Binder, DbInterner, ParamEnvAnd, Predicate, PredicateKind, SolverDefIds, Term, + infer::{ + DbInternerInferExt, InferCtxt, canonical::canonicalizer::OriginalQueryValues, + snapshot::CombinedSnapshot, + }, + mapping::{ChalkToNextSolver, InferenceVarExt, NextSolverToChalk}, + }, + to_chalk_trait_id, + traits::{ + FnTrait, NextTraitSolveResult, next_trait_solve_canonical_in_ctxt, next_trait_solve_in_ctxt, + }, }; -impl InferenceContext<'_> { - pub(super) fn canonicalize<T>(&mut self, t: T) -> Canonical<T> +impl<'db> InferenceContext<'db> { + pub(super) fn canonicalize<T>(&mut self, t: T) -> rustc_type_ir::Canonical<DbInterner<'db>, T> where - T: TypeFoldable<Interner> + HasInterner<Interner = Interner>, + T: rustc_type_ir::TypeFoldable<DbInterner<'db>>, { self.table.canonicalize(t) } @@ -42,11 +58,11 @@ impl InferenceContext<'_> { ) -> SmallVec<[WhereClause; 4]> { self.table.resolve_obligations_as_possible(); - let root = self.table.var_unification_table.inference_var_root(self_ty); + let root = InferenceVar::from_vid(self.table.infer_ctxt.root_var(self_ty.to_vid())); let pending_obligations = mem::take(&mut self.table.pending_obligations); let obligations = pending_obligations .iter() - .filter_map(|obligation| match obligation.value.value.goal.data(Interner) { + .filter_map(|obligation| match obligation.to_chalk(self.table.interner).goal.data(Interner) { GoalData::DomainGoal(DomainGoal::Holds(clause)) => { let ty = match clause { WhereClause::AliasEq(AliasEq { @@ -59,18 +75,9 @@ impl InferenceContext<'_> { WhereClause::TypeOutlives(to) => to.ty.clone(), _ => return None, }; - - let uncanonical = - chalk_ir::Substitute::apply(&obligation.free_vars, ty, Interner); - if matches!( - self.resolve_ty_shallow(&uncanonical).kind(Interner), - TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root, - ) { - Some(chalk_ir::Substitute::apply( - &obligation.free_vars, - clause.clone(), - Interner, - )) + let ty = self.resolve_ty_shallow(&ty); + if matches!(ty.kind(Interner), TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root) { + Some(clause.clone()) } else { None } @@ -84,51 +91,6 @@ impl InferenceContext<'_> { } } -#[derive(Debug, Clone)] -pub(crate) struct Canonicalized<T> -where - T: HasInterner<Interner = Interner>, -{ - pub(crate) value: Canonical<T>, - free_vars: Vec<GenericArg>, -} - -impl<T: HasInterner<Interner = Interner>> Canonicalized<T> { - pub(crate) fn apply_solution( - &self, - ctx: &mut InferenceTable<'_>, - solution: Canonical<Substitution>, - ) { - // the solution may contain new variables, which we need to convert to new inference vars - let new_vars = Substitution::from_iter( - Interner, - solution.binders.iter(Interner).map(|k| match &k.kind { - VariableKind::Ty(TyVariableKind::General) => ctx.new_type_var().cast(Interner), - VariableKind::Ty(TyVariableKind::Integer) => ctx.new_integer_var().cast(Interner), - VariableKind::Ty(TyVariableKind::Float) => ctx.new_float_var().cast(Interner), - // Chalk can sometimes return new lifetime variables. We just replace them by errors - // for now. - VariableKind::Lifetime => ctx.new_lifetime_var().cast(Interner), - VariableKind::Const(ty) => ctx.new_const_var(ty.clone()).cast(Interner), - }), - ); - for (i, v) in solution.value.iter(Interner).enumerate() { - let var = &self.free_vars[i]; - if let Some(ty) = v.ty(Interner) { - // eagerly replace projections in the type; we may be getting types - // e.g. from where clauses where this hasn't happened yet - let ty = ctx.normalize_associated_types_in(new_vars.apply(ty.clone(), Interner)); - tracing::debug!("unifying {:?} {:?}", var, ty); - ctx.unify(var.assert_ty_ref(Interner), &ty); - } else { - let v = new_vars.apply(v.clone(), Interner); - tracing::debug!("try_unifying {:?} {:?}", var, v); - let _ = ctx.try_unify(var, &v); - } - } - } -} - /// Check if types unify. /// /// Note that we consider placeholder types to unify with everything. @@ -224,37 +186,36 @@ bitflags::bitflags! { } } -type ChalkInferenceTable = chalk_solve::infer::InferenceTable<Interner>; - #[derive(Clone)] pub(crate) struct InferenceTable<'a> { pub(crate) db: &'a dyn HirDatabase, + pub(crate) interner: DbInterner<'a>, pub(crate) trait_env: Arc<TraitEnvironment>, pub(crate) tait_coercion_table: Option<FxHashMap<OpaqueTyId, Ty>>, - var_unification_table: ChalkInferenceTable, - type_variable_table: SmallVec<[TypeVariableFlags; 16]>, - pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>, - /// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on - /// temporary allocations. - resolve_obligations_buffer: Vec<Canonicalized<InEnvironment<Goal>>>, + pub(crate) infer_ctxt: InferCtxt<'a>, + diverging_tys: FxHashSet<Ty>, + pending_obligations: Vec<next_solver::Goal<'a, next_solver::Predicate<'a>>>, } -pub(crate) struct InferenceTableSnapshot { - var_table_snapshot: chalk_solve::infer::InferenceSnapshot<Interner>, - type_variable_table: SmallVec<[TypeVariableFlags; 16]>, - pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>, +pub(crate) struct InferenceTableSnapshot<'a> { + ctxt_snapshot: CombinedSnapshot, + diverging_tys: FxHashSet<Ty>, + pending_obligations: Vec<next_solver::Goal<'a, next_solver::Predicate<'a>>>, } impl<'a> InferenceTable<'a> { pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc<TraitEnvironment>) -> Self { + let interner = DbInterner::new_with(db, Some(trait_env.krate), trait_env.block); InferenceTable { db, + interner, trait_env, tait_coercion_table: None, - var_unification_table: ChalkInferenceTable::new(), - type_variable_table: SmallVec::new(), + infer_ctxt: interner.infer_ctxt().build(rustc_type_ir::TypingMode::Analysis { + defining_opaque_types_and_generators: SolverDefIds::new_from_iter(interner, []), + }), + diverging_tys: FxHashSet::default(), pending_obligations: Vec::new(), - resolve_obligations_buffer: Vec::new(), } } @@ -265,29 +226,58 @@ impl<'a> InferenceTable<'a> { /// marked as diverging if necessary, so that resolving them gives the right /// result. pub(super) fn propagate_diverging_flag(&mut self) { - for i in 0..self.type_variable_table.len() { - if !self.type_variable_table[i].contains(TypeVariableFlags::DIVERGING) { - continue; + let mut new_tys = FxHashSet::default(); + for ty in self.diverging_tys.iter() { + match ty.kind(Interner) { + TyKind::InferenceVar(var, kind) => match kind { + TyVariableKind::General => { + let root = InferenceVar::from( + self.infer_ctxt.root_var(TyVid::from_u32(var.index())).as_u32(), + ); + if root.index() != var.index() { + new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); + } + } + TyVariableKind::Integer => { + let root = InferenceVar::from( + self.infer_ctxt + .inner + .borrow_mut() + .int_unification_table() + .find(IntVid::from_usize(var.index() as usize)) + .as_u32(), + ); + if root.index() != var.index() { + new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); + } + } + TyVariableKind::Float => { + let root = InferenceVar::from( + self.infer_ctxt + .inner + .borrow_mut() + .float_unification_table() + .find(FloatVid::from_usize(var.index() as usize)) + .as_u32(), + ); + if root.index() != var.index() { + new_tys.insert(TyKind::InferenceVar(root, *kind).intern(Interner)); + } + } + }, + _ => {} } - let v = InferenceVar::from(i as u32); - let root = self.var_unification_table.inference_var_root(v); - self.modify_type_variable_flag(root, |f| { - *f |= TypeVariableFlags::DIVERGING; - }); } + self.diverging_tys.extend(new_tys); } - pub(super) fn set_diverging(&mut self, iv: InferenceVar, diverging: bool) { - self.modify_type_variable_flag(iv, |f| { - f.set(TypeVariableFlags::DIVERGING, diverging); - }); + pub(super) fn set_diverging(&mut self, iv: InferenceVar, kind: TyVariableKind) { + self.diverging_tys.insert(TyKind::InferenceVar(iv, kind).intern(Interner)); } fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty { - let is_diverging = self - .type_variable_table - .get(iv.index() as usize) - .is_some_and(|data| data.contains(TypeVariableFlags::DIVERGING)); + let is_diverging = + self.diverging_tys.contains(&TyKind::InferenceVar(iv, kind).intern(Interner)); if is_diverging { return TyKind::Never.intern(Interner); } @@ -299,30 +289,14 @@ impl<'a> InferenceTable<'a> { .intern(Interner) } - pub(crate) fn canonicalize_with_free_vars<T>(&mut self, t: T) -> Canonicalized<T> - where - T: TypeFoldable<Interner> + HasInterner<Interner = Interner>, - { - // try to resolve obligations before canonicalizing, since this might - // result in new knowledge about variables - self.resolve_obligations_as_possible(); - let result = self.var_unification_table.canonicalize(Interner, t); - let free_vars = result - .free_vars - .into_iter() - .map(|free_var| free_var.to_generic_arg(Interner)) - .collect(); - Canonicalized { value: result.quantified, free_vars } - } - - pub(crate) fn canonicalize<T>(&mut self, t: T) -> Canonical<T> + pub(crate) fn canonicalize<T>(&mut self, t: T) -> rustc_type_ir::Canonical<DbInterner<'a>, T> where - T: TypeFoldable<Interner> + HasInterner<Interner = Interner>, + T: rustc_type_ir::TypeFoldable<DbInterner<'a>>, { // try to resolve obligations before canonicalizing, since this might // result in new knowledge about variables self.resolve_obligations_as_possible(); - self.var_unification_table.canonicalize(Interner, t).quantified + self.infer_ctxt.canonicalize_response(t) } /// Recurses through the given type, normalizing associated types mentioned @@ -348,6 +322,7 @@ impl<'a> InferenceTable<'a> { self.resolve_ty_shallow(&ty) } TyKind::AssociatedType(id, subst) => { + // return Either::Left(self.resolve_ty_shallow(&ty)); if ty.data(Interner).flags.intersects( chalk_ir::TypeFlags::HAS_TY_INFER | chalk_ir::TypeFlags::HAS_CT_INFER, @@ -370,49 +345,45 @@ impl<'a> InferenceTable<'a> { )), ); let in_env = InEnvironment::new(&self.trait_env.env, goal); + let goal = in_env.to_nextsolver(self.interner); + let goal = + ParamEnvAnd { param_env: goal.param_env, value: goal.predicate }; - let canonicalized = { + let (canonical_goal, orig_values) = { + let mut orig_values = OriginalQueryValues::default(); let result = - self.var_unification_table.canonicalize(Interner, in_env); - let free_vars = result - .free_vars - .into_iter() - .map(|free_var| free_var.to_generic_arg(Interner)) - .collect(); - Canonicalized { value: result.quantified, free_vars } + self.infer_ctxt.canonicalize_query(goal, &mut orig_values); + (result.canonical, orig_values) + }; + let canonical_goal = rustc_type_ir::Canonical { + max_universe: canonical_goal.max_universe, + variables: canonical_goal.variables, + value: crate::next_solver::Goal { + param_env: canonical_goal.value.param_env, + predicate: canonical_goal.value.value, + }, }; - let solution = self.db.trait_solve( - self.trait_env.krate, - self.trait_env.block, - canonicalized.value.clone(), + let solution = next_trait_solve_canonical_in_ctxt( + &self.infer_ctxt, + canonical_goal, ); if let NextTraitSolveResult::Certain(canonical_subst) = solution { - // This is not great :) But let's just assert this for now and come back to it later. - if canonical_subst.value.subst.len(Interner) != 1 { + let subst = self.instantiate_canonical(canonical_subst).subst; + if subst.len(Interner) != orig_values.var_values.len() { ty } else { - let normalized = canonical_subst.value.subst.as_slice(Interner) - [0] - .assert_ty_ref(Interner); - match normalized.kind(Interner) { - TyKind::Alias(AliasTy::Projection(proj_ty)) => { - if id == &proj_ty.associated_ty_id - && subst == &proj_ty.substitution - { - ty - } else { - normalized.clone() - } - } - TyKind::AssociatedType(new_id, new_subst) => { - if new_id == id && new_subst == subst { - ty + let target_ty = var.to_nextsolver(self.interner); + subst + .iter(Interner) + .zip(orig_values.var_values.iter()) + .find_map(|(new, orig)| { + if orig.ty() == Some(target_ty) { + Some(new.assert_ty_ref(Interner).clone()) } else { - normalized.clone() + None } - } - _ => normalized.clone(), - } + }) + .unwrap_or(ty) } } else { ty @@ -507,43 +478,32 @@ impl<'a> InferenceTable<'a> { pub(crate) fn normalize_projection_ty(&mut self, proj_ty: ProjectionTy) -> Ty { let var = self.new_type_var(); let alias_eq = AliasEq { alias: AliasTy::Projection(proj_ty), ty: var.clone() }; - let obligation = alias_eq.cast(Interner); - self.register_obligation(obligation); + let obligation: Goal = alias_eq.cast(Interner); + self.register_obligation(obligation.to_nextsolver(self.interner)); var } - fn modify_type_variable_flag<F>(&mut self, var: InferenceVar, cb: F) - where - F: FnOnce(&mut TypeVariableFlags), - { - let idx = var.index() as usize; - if self.type_variable_table.len() <= idx { - self.extend_type_variable_table(idx); - } - if let Some(f) = self.type_variable_table.get_mut(idx) { - cb(f); - } - } - fn extend_type_variable_table(&mut self, to_index: usize) { - let count = to_index - self.type_variable_table.len() + 1; - self.type_variable_table.extend(std::iter::repeat_n(TypeVariableFlags::default(), count)); - } - fn new_var(&mut self, kind: TyVariableKind, diverging: bool) -> Ty { - let var = self.var_unification_table.new_variable(UniverseIndex::ROOT); - // Chalk might have created some type variables for its own purposes that we don't know about... - self.extend_type_variable_table(var.index() as usize); - assert_eq!(var.index() as usize, self.type_variable_table.len() - 1); - let flags = self.type_variable_table.get_mut(var.index() as usize).unwrap(); + let var = match kind { + TyVariableKind::General => { + let var = self.infer_ctxt.next_ty_vid(); + InferenceVar::from(var.as_u32()) + } + TyVariableKind::Integer => { + let var = self.infer_ctxt.next_int_vid(); + InferenceVar::from(var.as_u32()) + } + TyVariableKind::Float => { + let var = self.infer_ctxt.next_float_vid(); + InferenceVar::from(var.as_u32()) + } + }; + + let ty = var.to_ty(Interner, kind); if diverging { - *flags |= TypeVariableFlags::DIVERGING; - } - if matches!(kind, TyVariableKind::Integer) { - *flags |= TypeVariableFlags::INTEGER; - } else if matches!(kind, TyVariableKind::Float) { - *flags |= TypeVariableFlags::FLOAT; + self.diverging_tys.insert(ty.clone()); } - var.to_ty_with_kind(Interner, kind) + ty } pub(crate) fn new_type_var(&mut self) -> Ty { @@ -563,12 +523,14 @@ impl<'a> InferenceTable<'a> { } pub(crate) fn new_const_var(&mut self, ty: Ty) -> Const { - let var = self.var_unification_table.new_variable(UniverseIndex::ROOT); + let var = self.infer_ctxt.next_const_vid(); + let var = InferenceVar::from(var.as_u32()); var.to_const(Interner, ty) } pub(crate) fn new_lifetime_var(&mut self) -> Lifetime { - let var = self.var_unification_table.new_variable(UniverseIndex::ROOT); + let var = self.infer_ctxt.next_region_vid(); + let var = InferenceVar::from(var.as_u32()); var.to_lifetime(Interner) } @@ -580,16 +542,18 @@ impl<'a> InferenceTable<'a> { where T: HasInterner<Interner = Interner> + TypeFoldable<Interner>, { - self.resolve_with_fallback_inner(&mut Vec::new(), t, &fallback) + self.resolve_with_fallback_inner(t, &fallback) } pub(crate) fn fresh_subst(&mut self, binders: &[CanonicalVarKind<Interner>]) -> Substitution { Substitution::from_iter( Interner, - binders.iter().map(|kind| { - let param_infer_var = - kind.map_ref(|&ui| self.var_unification_table.new_variable(ui)); - param_infer_var.to_generic_arg(Interner) + binders.iter().map(|kind| match &kind.kind { + chalk_ir::VariableKind::Ty(ty_variable_kind) => { + self.new_var(*ty_variable_kind, false).cast(Interner) + } + chalk_ir::VariableKind::Lifetime => self.new_lifetime_var().cast(Interner), + chalk_ir::VariableKind::Const(ty) => self.new_const_var(ty.clone()).cast(Interner), }), ) } @@ -602,15 +566,25 @@ impl<'a> InferenceTable<'a> { subst.apply(canonical.value, Interner) } + pub(crate) fn instantiate_canonical_ns<T>( + &mut self, + canonical: rustc_type_ir::Canonical<DbInterner<'a>, T>, + ) -> T + where + T: rustc_type_ir::TypeFoldable<DbInterner<'a>>, + { + self.infer_ctxt.instantiate_canonical(&canonical).0 + } + fn resolve_with_fallback_inner<T>( &mut self, - var_stack: &mut Vec<InferenceVar>, t: T, fallback: &dyn Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, ) -> T where T: HasInterner<Interner = Interner> + TypeFoldable<Interner>, { + let var_stack = &mut vec![]; t.fold_with( &mut resolve::Resolver { table: self, var_stack, fallback }, DebruijnIndex::INNERMOST, @@ -623,6 +597,7 @@ impl<'a> InferenceTable<'a> { { let t = self.resolve_with_fallback(t, &|_, _, d, _| d); let t = self.normalize_associated_types_in(t); + // let t = self.resolve_opaque_tys_in(t); // Resolve again, because maybe normalization inserted infer vars. self.resolve_with_fallback(t, &|_, _, d, _| d) } @@ -639,29 +614,26 @@ impl<'a> InferenceTable<'a> { let int_fallback = TyKind::Scalar(Scalar::Int(IntTy::I32)).intern(Interner); let float_fallback = TyKind::Scalar(Scalar::Float(FloatTy::F64)).intern(Interner); - let scalar_vars: Vec<_> = self - .type_variable_table - .iter() - .enumerate() - .filter_map(|(index, flags)| { - let kind = if flags.contains(TypeVariableFlags::INTEGER) { - TyVariableKind::Integer - } else if flags.contains(TypeVariableFlags::FLOAT) { - TyVariableKind::Float - } else { - return None; + let int_vars = self.infer_ctxt.inner.borrow_mut().int_unification_table().len(); + for v in 0..int_vars { + let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Integer); + let maybe_resolved = self.resolve_ty_shallow(&var); + if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) { + // I don't think we can ever unify these vars with float vars, but keep this here for now + let fallback = match kind { + TyVariableKind::Integer => &int_fallback, + TyVariableKind::Float => &float_fallback, + TyVariableKind::General => unreachable!(), }; - - // FIXME: This is not really the nicest way to get `InferenceVar`s. Can we get them - // without directly constructing them from `index`? - let var = InferenceVar::from(index as u32).to_ty(Interner, kind); - Some(var) - }) - .collect(); - - for var in scalar_vars { + self.unify(&var, fallback); + } + } + let float_vars = self.infer_ctxt.inner.borrow_mut().float_unification_table().len(); + for v in 0..float_vars { + let var = InferenceVar::from(v as u32).to_ty(Interner, TyVariableKind::Float); let maybe_resolved = self.resolve_ty_shallow(&var); if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) { + // I don't think we can ever unify these vars with float vars, but keep this here for now let fallback = match kind { TyVariableKind::Integer => &int_fallback, TyVariableKind::Float => &float_fallback, @@ -673,7 +645,11 @@ impl<'a> InferenceTable<'a> { } /// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that. - pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool { + pub(crate) fn unify<T: ChalkToNextSolver<'a, U>, U: Relate<DbInterner<'a>>>( + &mut self, + ty1: &T, + ty2: &T, + ) -> bool { let result = match self.try_unify(ty1, ty2) { Ok(r) => r, Err(_) => return false, @@ -683,58 +659,116 @@ impl<'a> InferenceTable<'a> { } /// Unify two relatable values (e.g. `Ty`) and check whether trait goals which arise from that could be fulfilled - pub(crate) fn unify_deeply<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool { + pub(crate) fn unify_deeply<T: ChalkToNextSolver<'a, U>, U: Relate<DbInterner<'a>>>( + &mut self, + ty1: &T, + ty2: &T, + ) -> bool { let result = match self.try_unify(ty1, ty2) { Ok(r) => r, Err(_) => return false, }; - result.goals.iter().all(|goal| { - let canonicalized = self.canonicalize_with_free_vars(goal.clone()); - self.try_resolve_obligation(&canonicalized).certain() + result.goals.into_iter().all(|goal| { + matches!(next_trait_solve_in_ctxt(&self.infer_ctxt, goal), Ok((_, Certainty::Yes))) }) } /// Unify two relatable values (e.g. `Ty`) and return new trait goals arising from it, so the /// caller needs to deal with them. - pub(crate) fn try_unify<T: ?Sized + Zip<Interner>>( + pub(crate) fn try_unify<T: ChalkToNextSolver<'a, U>, U: Relate<DbInterner<'a>>>( &mut self, t1: &T, t2: &T, - ) -> InferResult<()> { - match self.var_unification_table.relate( - Interner, - &self.db, - &self.trait_env.env, - chalk_ir::Variance::Invariant, - t1, - t2, - ) { - Ok(result) => Ok(InferOk { goals: result.goals, value: () }), - Err(chalk_ir::NoSolution) => Err(TypeError), + ) -> InferResult<'a, ()> { + let param_env = self.trait_env.env.to_nextsolver(self.interner); + let lhs = t1.to_nextsolver(self.interner); + let rhs = t2.to_nextsolver(self.interner); + let variance = rustc_type_ir::Variance::Invariant; + let span = crate::next_solver::Span::dummy(); + match self.infer_ctxt.relate(param_env, lhs, variance, rhs, span) { + Ok(goals) => Ok(InferOk { goals, value: () }), + Err(_) => Err(TypeError), } } /// If `ty` is a type variable with known type, returns that type; /// otherwise, return ty. + #[tracing::instrument(skip(self))] pub(crate) fn resolve_ty_shallow(&mut self, ty: &Ty) -> Ty { if !ty.data(Interner).flags.intersects(chalk_ir::TypeFlags::HAS_FREE_LOCAL_NAMES) { return ty.clone(); } + self.infer_ctxt + .resolve_vars_if_possible(ty.to_nextsolver(self.interner)) + .to_chalk(self.interner) + } + + pub(crate) fn resolve_vars_with_obligations<T>(&mut self, t: T) -> T + where + T: rustc_type_ir::TypeFoldable<DbInterner<'a>>, + { + use rustc_type_ir::TypeVisitableExt; + + if !t.has_non_region_infer() { + return t; + } + + let t = self.infer_ctxt.resolve_vars_if_possible(t); + + if !t.has_non_region_infer() { + return t; + } + + self.resolve_obligations_as_possible(); + self.infer_ctxt.resolve_vars_if_possible(t) + } + + pub(crate) fn structurally_resolve_type(&mut self, ty: &Ty) -> Ty { + if let TyKind::Alias(..) = ty.kind(Interner) { + self.structurally_normalize_ty(ty) + } else { + self.resolve_vars_with_obligations(ty.to_nextsolver(self.interner)) + .to_chalk(self.interner) + } + } + + fn structurally_normalize_ty(&mut self, ty: &Ty) -> Ty { + self.structurally_normalize_term(ty.to_nextsolver(self.interner).into()) + .expect_ty() + .to_chalk(self.interner) + } + + fn structurally_normalize_term(&mut self, term: Term<'a>) -> Term<'a> { + if term.to_alias_term().is_none() { + return term; + } + + let new_infer = self.infer_ctxt.next_term_var_of_kind(term); + + self.register_obligation(Predicate::new( + self.interner, + Binder::dummy(PredicateKind::AliasRelate( + term, + new_infer, + AliasRelationDirection::Equate, + )), + )); self.resolve_obligations_as_possible(); - self.var_unification_table.normalize_ty_shallow(Interner, ty).unwrap_or_else(|| ty.clone()) + let res = self.infer_ctxt.resolve_vars_if_possible(new_infer); + if res == new_infer { term } else { res } } - pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot { - let var_table_snapshot = self.var_unification_table.snapshot(); - let type_variable_table = self.type_variable_table.clone(); + pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot<'a> { + let ctxt_snapshot = self.infer_ctxt.start_snapshot(); + let diverging_tys = self.diverging_tys.clone(); let pending_obligations = self.pending_obligations.clone(); - InferenceTableSnapshot { var_table_snapshot, pending_obligations, type_variable_table } + InferenceTableSnapshot { ctxt_snapshot, pending_obligations, diverging_tys } } #[tracing::instrument(skip_all)] - pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot) { - self.var_unification_table.rollback_to(snapshot.var_table_snapshot); - self.type_variable_table = snapshot.type_variable_table; + pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot<'a>) { + self.infer_ctxt.rollback_to(snapshot.ctxt_snapshot); + self.diverging_tys = snapshot.diverging_tys; self.pending_obligations = snapshot.pending_obligations; } @@ -746,94 +780,95 @@ impl<'a> InferenceTable<'a> { result } + pub(crate) fn commit_if_ok<T, E>( + &mut self, + f: impl FnOnce(&mut InferenceTable<'_>) -> Result<T, E>, + ) -> Result<T, E> { + let snapshot = self.snapshot(); + let result = f(self); + match result { + Ok(_) => {} + Err(_) => { + self.rollback_to(snapshot); + } + } + result + } + /// Checks an obligation without registering it. Useful mostly to check /// whether a trait *might* be implemented before deciding to 'lock in' the /// choice (during e.g. method resolution or deref). #[tracing::instrument(level = "debug", skip(self))] pub(crate) fn try_obligation(&mut self, goal: Goal) -> NextTraitSolveResult { let in_env = InEnvironment::new(&self.trait_env.env, goal); - let canonicalized = self.canonicalize(in_env); + let canonicalized = self.canonicalize(in_env.to_nextsolver(self.interner)); - self.db.trait_solve(self.trait_env.krate, self.trait_env.block, canonicalized) + next_trait_solve_canonical_in_ctxt(&self.infer_ctxt, canonicalized) } - pub(crate) fn register_obligation(&mut self, goal: Goal) { - let in_env = InEnvironment::new(&self.trait_env.env, goal); - self.register_obligation_in_env(in_env) + #[tracing::instrument(level = "debug", skip(self))] + pub(crate) fn solve_obligation(&mut self, goal: Goal) -> Result<Certainty, NoSolution> { + let goal = InEnvironment::new(&self.trait_env.env, goal); + let goal = goal.to_nextsolver(self.interner); + let result = next_trait_solve_in_ctxt(&self.infer_ctxt, goal); + result.map(|m| m.1) + } + + pub(crate) fn register_obligation(&mut self, predicate: Predicate<'a>) { + let goal = next_solver::Goal { + param_env: self.trait_env.env.to_nextsolver(self.interner), + predicate, + }; + self.register_obligation_in_env(goal) } #[tracing::instrument(level = "debug", skip(self))] - fn register_obligation_in_env(&mut self, goal: InEnvironment<Goal>) { - match goal.goal.data(Interner) { - chalk_ir::GoalData::DomainGoal(chalk_ir::DomainGoal::Holds( - chalk_ir::WhereClause::AliasEq(chalk_ir::AliasEq { alias, ty }), - )) => { - if ty.inference_var(Interner).is_some() { - match alias { - chalk_ir::AliasTy::Opaque(opaque) => { - if self.unify( - &chalk_ir::TyKind::OpaqueType( - opaque.opaque_ty_id, - opaque.substitution.clone(), - ) - .intern(Interner), - ty, - ) { - return; - } - } - _ => {} - } - } + fn register_obligation_in_env( + &mut self, + goal: next_solver::Goal<'a, next_solver::Predicate<'a>>, + ) { + let result = next_trait_solve_in_ctxt(&self.infer_ctxt, goal); + tracing::debug!(?result); + match result { + Ok((_, Certainty::Yes)) => {} + Err(rustc_type_ir::solve::NoSolution) => {} + Ok((_, Certainty::Maybe(_))) => { + self.pending_obligations.push(goal); } - _ => {} - } - let canonicalized = { - let result = self.var_unification_table.canonicalize(Interner, goal); - let free_vars = result - .free_vars - .into_iter() - .map(|free_var| free_var.to_generic_arg(Interner)) - .collect(); - Canonicalized { value: result.quantified, free_vars } - }; - tracing::debug!(?canonicalized); - let solution = self.try_resolve_obligation(&canonicalized); - tracing::debug!(?solution); - if solution.uncertain() { - self.pending_obligations.push(canonicalized); } } - pub(crate) fn register_infer_ok<T>(&mut self, infer_ok: InferOk<T>) { + pub(crate) fn register_infer_ok<T>(&mut self, infer_ok: InferOk<'a, T>) { infer_ok.goals.into_iter().for_each(|goal| self.register_obligation_in_env(goal)); } pub(crate) fn resolve_obligations_as_possible(&mut self) { let _span = tracing::info_span!("resolve_obligations_as_possible").entered(); let mut changed = true; - let mut obligations = mem::take(&mut self.resolve_obligations_buffer); while mem::take(&mut changed) { - mem::swap(&mut self.pending_obligations, &mut obligations); - - for canonicalized in obligations.drain(..) { - tracing::debug!(obligation = ?canonicalized); - if !self.check_changed(&canonicalized) { - tracing::debug!("not changed"); - self.pending_obligations.push(canonicalized); - continue; + let mut obligations = mem::take(&mut self.pending_obligations); + + for goal in obligations.drain(..) { + tracing::debug!(obligation = ?goal); + + let result = next_trait_solve_in_ctxt(&self.infer_ctxt, goal); + let (has_changed, certainty) = match result { + Ok(result) => result, + Err(_) => { + continue; + } + }; + + if matches!(has_changed, HasChanged::Yes) { + changed = true; + } + + match certainty { + Certainty::Yes => {} + Certainty::Maybe(_) => self.pending_obligations.push(goal), } - changed = true; - let uncanonical = chalk_ir::Substitute::apply( - &canonicalized.free_vars, - canonicalized.value.value, - Interner, - ); - self.register_obligation_in_env(uncanonical); } } - self.resolve_obligations_buffer = obligations; - self.resolve_obligations_buffer.clear(); } pub(crate) fn fudge_inference<T: TypeFoldable<Interner>>( @@ -904,59 +939,6 @@ impl<'a> InferenceTable<'a> { .fold_with(&mut VarFudger { table: self, highest_known_var }, DebruijnIndex::INNERMOST) } - /// This checks whether any of the free variables in the `canonicalized` - /// have changed (either been unified with another variable, or with a - /// value). If this is not the case, we don't need to try to solve the goal - /// again -- it'll give the same result as last time. - fn check_changed(&mut self, canonicalized: &Canonicalized<InEnvironment<Goal>>) -> bool { - canonicalized.free_vars.iter().any(|var| { - let iv = match var.data(Interner) { - GenericArgData::Ty(ty) => ty.inference_var(Interner), - GenericArgData::Lifetime(lt) => lt.inference_var(Interner), - GenericArgData::Const(c) => c.inference_var(Interner), - } - .expect("free var is not inference var"); - if self.var_unification_table.probe_var(iv).is_some() { - return true; - } - let root = self.var_unification_table.inference_var_root(iv); - iv != root - }) - } - - #[tracing::instrument(level = "debug", skip(self))] - fn try_resolve_obligation( - &mut self, - canonicalized: &Canonicalized<InEnvironment<Goal>>, - ) -> NextTraitSolveResult { - let solution = self.db.trait_solve( - self.trait_env.krate, - self.trait_env.block, - canonicalized.value.clone(), - ); - - tracing::debug!(?solution, ?canonicalized); - match &solution { - NextTraitSolveResult::Certain(v) => { - canonicalized.apply_solution( - self, - Canonical { - binders: v.binders.clone(), - // FIXME handle constraints - value: v.value.subst.clone(), - }, - ); - } - // ...so, should think about how to get some actually get some guidance here - NextTraitSolveResult::Uncertain(v) => { - canonicalized.apply_solution(self, v.clone()); - } - NextTraitSolveResult::NoSolution => {} - } - - solution - } - pub(crate) fn callable_sig( &mut self, ty: &Ty, @@ -1014,33 +996,15 @@ impl<'a> InferenceTable<'a> { .fill_with_unknown() .build(); - let trait_env = self.trait_env.env.clone(); - let obligation = InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if !self - .db - .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) - .no_solution() - { - self.register_obligation(obligation.goal); + let goal: Goal = trait_ref.clone().cast(Interner); + if !self.try_obligation(goal.clone()).no_solution() { + self.register_obligation(goal.to_nextsolver(self.interner)); let return_ty = self.normalize_projection_ty(projection); for &fn_x in subtraits { let fn_x_trait = fn_x.get_id(self.db, krate)?; trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); - let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> = - InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if !self - .db - .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) - .no_solution() - { + let goal = trait_ref.clone().cast(Interner); + if !self.try_obligation(goal).no_solution() { return Some((fn_x, arg_tys, return_ty)); } } @@ -1074,7 +1038,7 @@ impl<'a> InferenceTable<'a> { match ty.kind(Interner) { TyKind::Error => self.new_type_var(), TyKind::InferenceVar(..) => { - let ty_resolved = self.resolve_ty_shallow(&ty); + let ty_resolved = self.structurally_resolve_type(&ty); if ty_resolved.is_unknown() { self.new_type_var() } else { ty } } _ => ty, @@ -1165,7 +1129,7 @@ impl<'a> InferenceTable<'a> { impl fmt::Debug for InferenceTable<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("InferenceTable").field("num_vars", &self.type_variable_table.len()).finish() + f.debug_struct("InferenceTable").finish() } } @@ -1174,11 +1138,19 @@ mod resolve { use crate::{ ConcreteConst, Const, ConstData, ConstScalar, ConstValue, DebruijnIndex, GenericArg, InferenceVar, Interner, Lifetime, Ty, TyVariableKind, VariableKind, + next_solver::mapping::NextSolverToChalk, }; use chalk_ir::{ cast::Cast, fold::{TypeFoldable, TypeFolder}, }; + use rustc_type_ir::{FloatVid, IntVid, TyVid}; + + #[derive(Copy, Clone, PartialEq, Eq)] + pub(super) enum VarKind { + Ty(TyVariableKind), + Const, + } #[derive(chalk_derive::FallibleTypeFolder)] #[has_interner(Interner)] @@ -1188,7 +1160,7 @@ mod resolve { F: Fn(InferenceVar, VariableKind, GenericArg, DebruijnIndex) -> GenericArg, > { pub(super) table: &'a mut InferenceTable<'b>, - pub(super) var_stack: &'a mut Vec<InferenceVar>, + pub(super) var_stack: &'a mut Vec<(InferenceVar, VarKind)>, pub(super) fallback: F, } impl<F> TypeFolder<Interner> for Resolver<'_, '_, F> @@ -1209,25 +1181,91 @@ mod resolve { kind: TyVariableKind, outer_binder: DebruijnIndex, ) -> Ty { - let var = self.table.var_unification_table.inference_var_root(var); - if self.var_stack.contains(&var) { - // recursive type - let default = self.table.fallback_value(var, kind).cast(Interner); - return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) - .assert_ty_ref(Interner) - .clone(); - } - if let Some(known_ty) = self.table.var_unification_table.probe_var(var) { - // known_ty may contain other variables that are known by now - self.var_stack.push(var); - let result = known_ty.fold_with(self, outer_binder); - self.var_stack.pop(); - result.assert_ty_ref(Interner).clone() - } else { - let default = self.table.fallback_value(var, kind).cast(Interner); - (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) - .assert_ty_ref(Interner) - .clone() + match kind { + TyVariableKind::General => { + let vid = self.table.infer_ctxt.root_var(TyVid::from(var.index())); + let var = InferenceVar::from(vid.as_u32()); + if self.var_stack.contains(&(var, VarKind::Ty(kind))) { + // recursive type + let default = self.table.fallback_value(var, kind).cast(Interner); + return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone(); + } + if let Ok(known_ty) = self.table.infer_ctxt.probe_ty_var(vid) { + let known_ty: Ty = known_ty.to_chalk(self.table.interner); + // known_ty may contain other variables that are known by now + self.var_stack.push((var, VarKind::Ty(kind))); + let result = known_ty.fold_with(self, outer_binder); + self.var_stack.pop(); + result + } else { + let default = self.table.fallback_value(var, kind).cast(Interner); + (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone() + } + } + TyVariableKind::Integer => { + let vid = self + .table + .infer_ctxt + .inner + .borrow_mut() + .int_unification_table() + .find(IntVid::from(var.index())); + let var = InferenceVar::from(vid.as_u32()); + if self.var_stack.contains(&(var, VarKind::Ty(kind))) { + // recursive type + let default = self.table.fallback_value(var, kind).cast(Interner); + return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone(); + } + if let Some(known_ty) = self.table.infer_ctxt.resolve_int_var(vid) { + let known_ty: Ty = known_ty.to_chalk(self.table.interner); + // known_ty may contain other variables that are known by now + self.var_stack.push((var, VarKind::Ty(kind))); + let result = known_ty.fold_with(self, outer_binder); + self.var_stack.pop(); + result + } else { + let default = self.table.fallback_value(var, kind).cast(Interner); + (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone() + } + } + TyVariableKind::Float => { + let vid = self + .table + .infer_ctxt + .inner + .borrow_mut() + .float_unification_table() + .find(FloatVid::from(var.index())); + let var = InferenceVar::from(vid.as_u32()); + if self.var_stack.contains(&(var, VarKind::Ty(kind))) { + // recursive type + let default = self.table.fallback_value(var, kind).cast(Interner); + return (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone(); + } + if let Some(known_ty) = self.table.infer_ctxt.resolve_float_var(vid) { + let known_ty: Ty = known_ty.to_chalk(self.table.interner); + // known_ty may contain other variables that are known by now + self.var_stack.push((var, VarKind::Ty(kind))); + let result = known_ty.fold_with(self, outer_binder); + self.var_stack.pop(); + result + } else { + let default = self.table.fallback_value(var, kind).cast(Interner); + (self.fallback)(var, VariableKind::Ty(kind), default, outer_binder) + .assert_ty_ref(Interner) + .clone() + } + } } } @@ -1237,25 +1275,30 @@ mod resolve { var: InferenceVar, outer_binder: DebruijnIndex, ) -> Const { - let var = self.table.var_unification_table.inference_var_root(var); + let vid = self + .table + .infer_ctxt + .root_const_var(rustc_type_ir::ConstVid::from_u32(var.index())); + let var = InferenceVar::from(vid.as_u32()); let default = ConstData { ty: ty.clone(), value: ConstValue::Concrete(ConcreteConst { interned: ConstScalar::Unknown }), } .intern(Interner) .cast(Interner); - if self.var_stack.contains(&var) { + if self.var_stack.contains(&(var, VarKind::Const)) { // recursive return (self.fallback)(var, VariableKind::Const(ty), default, outer_binder) .assert_const_ref(Interner) .clone(); } - if let Some(known_ty) = self.table.var_unification_table.probe_var(var) { + if let Ok(known_const) = self.table.infer_ctxt.probe_const_var(vid) { + let known_const: Const = known_const.to_chalk(self.table.interner); // known_ty may contain other variables that are known by now - self.var_stack.push(var); - let result = known_ty.fold_with(self, outer_binder); + self.var_stack.push((var, VarKind::Const)); + let result = known_const.fold_with(self, outer_binder); self.var_stack.pop(); - result.assert_const_ref(Interner).clone() + result } else { (self.fallback)(var, VariableKind::Const(ty), default, outer_binder) .assert_const_ref(Interner) |