Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/next_solver/infer/traits.rs')
| -rw-r--r-- | crates/hir-ty/src/next_solver/infer/traits.rs | 108 |
1 files changed, 86 insertions, 22 deletions
diff --git a/crates/hir-ty/src/next_solver/infer/traits.rs b/crates/hir-ty/src/next_solver/infer/traits.rs index f1df806ab3..68aa12d7bb 100644 --- a/crates/hir-ty/src/next_solver/infer/traits.rs +++ b/crates/hir-ty/src/next_solver/infer/traits.rs @@ -7,14 +7,16 @@ use std::{ hash::{Hash, Hasher}, }; +use rustc_type_ir::elaborate::Elaboratable; use rustc_type_ir::{ PredicatePolarity, Upcast, solve::{Certainty, NoSolution}, }; +use rustc_type_ir::{TypeFoldable, TypeVisitable}; use crate::next_solver::{ - Binder, DbInterner, Goal, ParamEnv, PolyTraitPredicate, Predicate, SolverDefId, TraitPredicate, - Ty, + Binder, Clause, DbInterner, Goal, ParamEnv, PolyTraitPredicate, Predicate, SolverDefId, Span, + TraitPredicate, Ty, }; use super::InferCtxt; @@ -29,24 +31,29 @@ use super::InferCtxt; /// only live for a short period of time. #[derive(Clone, Debug, PartialEq, Eq)] pub struct ObligationCause { - /// The ID of the fn body that triggered this obligation. This is - /// used for region obligations to determine the precise - /// environment in which the region obligation should be evaluated - /// (in particular, closures can add new assumptions). See the - /// field `region_obligations` of the `FulfillmentContext` for more - /// information. - pub body_id: Option<SolverDefId>, + // FIXME: This should contain an `ExprId`/`PatId` etc., and a cause code. But for now we + // don't report trait solving diagnostics, so this is irrelevant. + _private: (), } impl ObligationCause { + #[expect( + clippy::new_without_default, + reason = "`new` is temporary, eventually we will provide span etc. here" + )] #[inline] - pub fn new(body_id: SolverDefId) -> ObligationCause { - ObligationCause { body_id: Some(body_id) } + pub fn new() -> ObligationCause { + ObligationCause { _private: () } } - #[inline(always)] + #[inline] pub fn dummy() -> ObligationCause { - ObligationCause { body_id: None } + ObligationCause::new() + } + + #[inline] + pub fn misc() -> ObligationCause { + ObligationCause::new() } } @@ -75,6 +82,72 @@ pub struct Obligation<'db, T> { pub recursion_depth: usize, } +/// For [`Obligation`], a sub-obligation is combined with the current obligation's +/// param-env and cause code. +impl<'db> Elaboratable<DbInterner<'db>> for PredicateObligation<'db> { + fn predicate(&self) -> Predicate<'db> { + self.predicate + } + + fn child(&self, clause: Clause<'db>) -> Self { + Obligation { + cause: self.cause.clone(), + param_env: self.param_env, + recursion_depth: 0, + predicate: clause.as_predicate(), + } + } + + fn child_with_derived_cause( + &self, + clause: Clause<'db>, + span: Span, + parent_trait_pred: PolyTraitPredicate<'db>, + index: usize, + ) -> Self { + let cause = ObligationCause::new(); + Obligation { + cause, + param_env: self.param_env, + recursion_depth: 0, + predicate: clause.as_predicate(), + } + } +} + +impl<'db, T: TypeVisitable<DbInterner<'db>>> TypeVisitable<DbInterner<'db>> for Obligation<'db, T> { + fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>( + &self, + visitor: &mut V, + ) -> V::Result { + rustc_ast_ir::try_visit!(self.param_env.visit_with(visitor)); + self.predicate.visit_with(visitor) + } +} + +impl<'db, T: TypeFoldable<DbInterner<'db>>> TypeFoldable<DbInterner<'db>> for Obligation<'db, T> { + fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>( + self, + folder: &mut F, + ) -> Result<Self, F::Error> { + Ok(Obligation { + cause: self.cause.clone(), + param_env: self.param_env.try_fold_with(folder)?, + predicate: self.predicate.try_fold_with(folder)?, + recursion_depth: self.recursion_depth, + }) + } + + fn fold_with<F: rustc_type_ir::TypeFolder<DbInterner<'db>>>(self, folder: &mut F) -> Self { + Obligation { + cause: self.cause.clone(), + param_env: self.param_env.fold_with(folder), + predicate: self.predicate.fold_with(folder), + recursion_depth: self.recursion_depth, + } + } +} + impl<'db, T: Copy> Obligation<'db, T> { pub fn as_goal(&self) -> Goal<'db, T> { Goal { param_env: self.param_env, predicate: self.predicate } @@ -156,15 +229,6 @@ impl<'db, O> Obligation<'db, O> { Obligation { cause, param_env, recursion_depth, predicate } } - pub fn misc( - tcx: DbInterner<'db>, - body_id: SolverDefId, - param_env: ParamEnv<'db>, - trait_ref: impl Upcast<DbInterner<'db>, O>, - ) -> Obligation<'db, O> { - Obligation::new(tcx, ObligationCause::new(body_id), param_env, trait_ref) - } - pub fn with<P>( &self, tcx: DbInterner<'db>, |