Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/next_solver/fulfill.rs')
| -rw-r--r-- | crates/hir-ty/src/next_solver/fulfill.rs | 151 |
1 files changed, 133 insertions, 18 deletions
diff --git a/crates/hir-ty/src/next_solver/fulfill.rs b/crates/hir-ty/src/next_solver/fulfill.rs index 4258f4c7ac..a8183ab422 100644 --- a/crates/hir-ty/src/next_solver/fulfill.rs +++ b/crates/hir-ty/src/next_solver/fulfill.rs @@ -1,21 +1,28 @@ //! Fulfill loop for next-solver. -use std::marker::PhantomData; -use std::mem; -use std::ops::ControlFlow; -use std::vec::ExtractIf; - -use rustc_next_trait_solver::delegate::SolverDelegate; -use rustc_next_trait_solver::solve::{ - GoalEvaluation, GoalStalledOn, HasChanged, SolverDelegateEvalExt, +mod errors; + +use std::{marker::PhantomData, mem, ops::ControlFlow, vec::ExtractIf}; + +use rustc_hash::FxHashSet; +use rustc_next_trait_solver::{ + delegate::SolverDelegate, + solve::{GoalEvaluation, GoalStalledOn, HasChanged, SolverDelegateEvalExt}, +}; +use rustc_type_ir::{ + Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, + inherent::{IntoKind, Span as _}, + solve::{Certainty, NoSolution}, }; -use rustc_type_ir::Interner; -use rustc_type_ir::inherent::Span as _; -use rustc_type_ir::solve::{Certainty, NoSolution}; -use crate::next_solver::infer::InferCtxt; -use crate::next_solver::infer::traits::{PredicateObligation, PredicateObligations}; -use crate::next_solver::{DbInterner, SolverContext, Span, TypingMode}; +use crate::next_solver::{ + DbInterner, SolverContext, SolverDefId, Span, Ty, TyKind, TypingMode, + infer::{ + InferCtxt, + traits::{PredicateObligation, PredicateObligations}, + }, + inspect::ProofTreeVisitor, +}; type PendingObligations<'db> = Vec<(PredicateObligation<'db>, Option<GoalStalledOn<DbInterner<'db>>>)>; @@ -31,6 +38,7 @@ type PendingObligations<'db> = /// /// It is also likely that we want to use slightly different datastructures /// here as this will have to deal with far more root goals than `evaluate_all`. +#[derive(Debug, Clone)] pub struct FulfillmentCtxt<'db> { obligations: ObligationStorage<'db>, @@ -41,7 +49,7 @@ pub struct FulfillmentCtxt<'db> { usable_in_snapshot: usize, } -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] struct ObligationStorage<'db> { /// Obligations which resulted in an overflow in fulfillment itself. /// @@ -123,10 +131,21 @@ impl<'db> FulfillmentCtxt<'db> { infcx: &InferCtxt<'db>, obligation: PredicateObligation<'db>, ) { - assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots()); + // FIXME: See the comment in `select_where_possible()`. + // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots()); self.obligations.register(obligation, None); } + pub(crate) fn register_predicate_obligations( + &mut self, + infcx: &InferCtxt<'db>, + obligations: impl IntoIterator<Item = PredicateObligation<'db>>, + ) { + // FIXME: See the comment in `select_where_possible()`. + // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots()); + obligations.into_iter().for_each(|obligation| self.obligations.register(obligation, None)); + } + pub(crate) fn collect_remaining_errors( &mut self, infcx: &InferCtxt<'db>, @@ -143,7 +162,11 @@ impl<'db> FulfillmentCtxt<'db> { &mut self, infcx: &InferCtxt<'db>, ) -> Vec<NextSolverError<'db>> { - assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots()); + // FIXME(next-solver): We should bring this assertion back. Currently it panics because + // there are places which use `InferenceTable` and open a snapshot and register obligations + // and select. They should use a different `ObligationCtxt` instead. Then we'll be also able + // to not put the obligations queue in `InferenceTable`'s snapshots. + // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots()); let mut errors = Vec::new(); loop { let mut any_changed = false; @@ -216,9 +239,94 @@ impl<'db> FulfillmentCtxt<'db> { self.obligations.has_pending_obligations() } - fn pending_obligations(&self) -> PredicateObligations<'db> { + pub(crate) fn pending_obligations(&self) -> PredicateObligations<'db> { self.obligations.clone_pending() } + + pub(crate) fn drain_stalled_obligations_for_coroutines( + &mut self, + infcx: &InferCtxt<'db>, + ) -> PredicateObligations<'db> { + let stalled_coroutines = match infcx.typing_mode() { + TypingMode::Analysis { defining_opaque_types_and_generators } => { + defining_opaque_types_and_generators + } + TypingMode::Coherence + | TypingMode::Borrowck { defining_opaque_types: _ } + | TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ } + | TypingMode::PostAnalysis => return Default::default(), + }; + let stalled_coroutines = stalled_coroutines.inner(); + + if stalled_coroutines.is_empty() { + return Default::default(); + } + + self.obligations + .drain_pending(|obl| { + infcx.probe(|_| { + infcx + .visit_proof_tree( + obl.as_goal(), + &mut StalledOnCoroutines { + stalled_coroutines, + cache: Default::default(), + }, + ) + .is_break() + }) + }) + .into_iter() + .map(|(o, _)| o) + .collect() + } +} + +/// Detect if a goal is stalled on a coroutine that is owned by the current typeck root. +/// +/// This function can (erroneously) fail to detect a predicate, i.e. it doesn't need to +/// be complete. However, this will lead to ambiguity errors, so we want to make it +/// accurate. +/// +/// This function can be also return false positives, which will lead to poor diagnostics +/// so we want to keep this visitor *precise* too. +pub struct StalledOnCoroutines<'a, 'db> { + pub stalled_coroutines: &'a [SolverDefId], + pub cache: FxHashSet<Ty<'db>>, +} + +impl<'db> ProofTreeVisitor<'db> for StalledOnCoroutines<'_, 'db> { + type Result = ControlFlow<()>; + + fn visit_goal(&mut self, inspect_goal: &super::inspect::InspectGoal<'_, 'db>) -> Self::Result { + inspect_goal.goal().predicate.visit_with(self)?; + + if let Some(candidate) = inspect_goal.unique_applicable_candidate() { + candidate.visit_nested_no_probe(self) + } else { + ControlFlow::Continue(()) + } + } +} + +impl<'db> TypeVisitor<DbInterner<'db>> for StalledOnCoroutines<'_, 'db> { + type Result = ControlFlow<()>; + + fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result { + if !self.cache.insert(ty) { + return ControlFlow::Continue(()); + } + + if let TyKind::Coroutine(def_id, _) = ty.kind() + && self.stalled_coroutines.contains(&def_id.into()) + { + ControlFlow::Break(()) + } else if ty.has_coroutines() { + ty.super_visit_with(self) + } else { + ControlFlow::Continue(()) + } + } } #[derive(Debug)] @@ -227,3 +335,10 @@ pub enum NextSolverError<'db> { Ambiguity(PredicateObligation<'db>), Overflow(PredicateObligation<'db>), } + +impl NextSolverError<'_> { + #[inline] + pub fn is_true_error(&self) -> bool { + matches!(self, NextSolverError::TrueError(_)) + } +} |