Unnamed repository; edit this file 'description' to name the repository.
Port closure kind deduction logic from rustc
Shoyu Vanilla 2024-02-26
parent 7637141 · commit e36a31b
-rw-r--r--crates/hir-ty/src/infer/closure.rs96
-rw-r--r--crates/hir-ty/src/infer/unify.rs73
-rw-r--r--crates/hir-ty/src/utils.rs46
3 files changed, 205 insertions, 10 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 000248c2ce..62740f354c 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
use chalk_ir::{
cast::Cast,
fold::{FallibleTypeFolder, TypeFoldable},
- AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
+ BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind,
};
use either::Either;
use hir_def::{
@@ -22,13 +22,14 @@ use stdx::never;
use crate::{
db::{HirDatabase, InternedClosure},
- from_placeholder_idx, make_binders,
+ from_chalk_trait_id, from_placeholder_idx, make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
static_lifetime, to_chalk_trait_id,
traits::FnTrait,
- utils::{self, generics, Generics},
- Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer,
- FnSig, Interner, Substitution, Ty, TyExt,
+ utils::{self, elaborate_clause_supertraits, generics, Generics},
+ Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy,
+ DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty,
+ TyExt, WhereClause,
};
use super::{Expectation, InferenceContext};
@@ -47,6 +48,15 @@ impl InferenceContext<'_> {
None => return,
};
+ if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) {
+ if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) {
+ self.result
+ .closure_info
+ .entry(*closure_id)
+ .or_insert_with(|| (Vec::new(), closure_kind));
+ }
+ }
+
// Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty);
@@ -65,6 +75,60 @@ impl InferenceContext<'_> {
}
}
+ // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
+ // Might need to port closure sig deductions too.
+ fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> {
+ match expected_ty.kind(Interner) {
+ TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => {
+ let clauses = expected_ty
+ .impl_trait_bounds(self.db)
+ .into_iter()
+ .flatten()
+ .map(|b| b.into_value_and_skipped_binders().0);
+ self.deduce_closure_kind_from_predicate_clauses(clauses)
+ }
+ TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| {
+ self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id))
+ }),
+ TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => {
+ let clauses = self.clauses_for_self_ty(*ty);
+ self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter())
+ }
+ TyKind::Function(_) => Some(FnTrait::Fn),
+ _ => None,
+ }
+ }
+
+ fn deduce_closure_kind_from_predicate_clauses(
+ &self,
+ clauses: impl DoubleEndedIterator<Item = WhereClause>,
+ ) -> Option<FnTrait> {
+ let mut expected_kind = None;
+
+ for clause in elaborate_clause_supertraits(self.db, clauses.rev()) {
+ let trait_id = match clause {
+ WhereClause::AliasEq(AliasEq {
+ alias: AliasTy::Projection(projection), ..
+ }) => Some(projection.trait_(self.db)),
+ WhereClause::Implemented(trait_ref) => {
+ Some(from_chalk_trait_id(trait_ref.trait_id))
+ }
+ _ => None,
+ };
+ if let Some(closure_kind) =
+ trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id))
+ {
+ // `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
+ expected_kind = Some(
+ expected_kind
+ .map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)),
+ );
+ }
+ }
+
+ expected_kind
+ }
+
fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> {
// Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
@@ -111,6 +175,18 @@ impl InferenceContext<'_> {
None
}
+
+ fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
+ utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate())
+ .enumerate()
+ .find_map(|(i, t)| (t == trait_id).then_some(i))
+ .map(|i| match i {
+ 0 => FnTrait::Fn,
+ 1 => FnTrait::FnMut,
+ 2 => FnTrait::FnOnce,
+ _ => unreachable!(),
+ })
+ }
}
// The below functions handle capture and closure kind (Fn, FnMut, ..)
@@ -962,8 +1038,14 @@ impl InferenceContext<'_> {
}
}
self.restrict_precision_for_unsafe();
- // closure_kind should be done before adjust_for_move_closure
- let closure_kind = self.closure_kind();
+ // `closure_kind` should be done before adjust_for_move_closure
+ // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
+ // rustc also does diagnostics here if the latter is not a subtype of the former.
+ let closure_kind = self
+ .result
+ .closure_info
+ .get(&closure)
+ .map_or_else(|| self.closure_kind(), |info| info.1);
match capture_by {
CaptureBy::Value => self.adjust_for_move_closure(),
CaptureBy::Ref => (),
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs
index 709760b64f..d24e938a54 100644
--- a/crates/hir-ty/src/infer/unify.rs
+++ b/crates/hir-ty/src/infer/unify.rs
@@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt;
use either::Either;
use ena::unify::UnifyKey;
use hir_expand::name;
+use smallvec::SmallVec;
use triomphe::Arc;
use super::{InferOk, InferResult, InferenceContext, TypeError};
use crate::{
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
- DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
- Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
- TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
+ DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment,
+ InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution,
+ Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause,
};
impl InferenceContext<'_> {
@@ -31,6 +32,72 @@ impl InferenceContext<'_> {
{
self.table.canonicalize(t)
}
+
+ pub(super) fn clauses_for_self_ty(
+ &mut self,
+ self_ty: InferenceVar,
+ ) -> SmallVec<[WhereClause; 4]> {
+ self.table.resolve_obligations_as_possible();
+
+ let root = self.table.var_unification_table.inference_var_root(self_ty);
+ let pending_obligations = mem::take(&mut self.table.pending_obligations);
+ let obligations = pending_obligations
+ .iter()
+ .filter_map(|obligation| match obligation.value.value.goal.data(Interner) {
+ GoalData::DomainGoal(DomainGoal::Holds(
+ clause @ WhereClause::AliasEq(AliasEq {
+ alias: AliasTy::Projection(projection),
+ ..
+ }),
+ )) => {
+ let projection_self = projection.self_type_parameter(self.db);
+ let uncanonical = chalk_ir::Substitute::apply(
+ &obligation.free_vars,
+ projection_self,
+ Interner,
+ );
+ if matches!(
+ self.resolve_ty_shallow(&uncanonical).kind(Interner),
+ TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
+ ) {
+ Some(chalk_ir::Substitute::apply(
+ &obligation.free_vars,
+ clause.clone(),
+ Interner,
+ ))
+ } else {
+ None
+ }
+ }
+ GoalData::DomainGoal(DomainGoal::Holds(
+ clause @ WhereClause::Implemented(trait_ref),
+ )) => {
+ let trait_ref_self = trait_ref.self_type_parameter(Interner);
+ let uncanonical = chalk_ir::Substitute::apply(
+ &obligation.free_vars,
+ trait_ref_self,
+ Interner,
+ );
+ if matches!(
+ self.resolve_ty_shallow(&uncanonical).kind(Interner),
+ TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root,
+ ) {
+ Some(chalk_ir::Substitute::apply(
+ &obligation.free_vars,
+ clause.clone(),
+ Interner,
+ ))
+ } else {
+ None
+ }
+ }
+ _ => None,
+ })
+ .collect();
+ self.table.pending_obligations = pending_obligations;
+
+ obligations
+ }
}
#[derive(Debug, Clone)]
diff --git a/crates/hir-ty/src/utils.rs b/crates/hir-ty/src/utils.rs
index c150314138..8bd57820d2 100644
--- a/crates/hir-ty/src/utils.rs
+++ b/crates/hir-ty/src/utils.rs
@@ -112,6 +112,52 @@ impl Iterator for SuperTraits<'_> {
}
}
+pub(super) fn elaborate_clause_supertraits(
+ db: &dyn HirDatabase,
+ clauses: impl Iterator<Item = WhereClause>,
+) -> ClauseElaborator<'_> {
+ let mut elaborator = ClauseElaborator { db, stack: Vec::new(), seen: FxHashSet::default() };
+ elaborator.extend_deduped(clauses);
+
+ elaborator
+}
+
+pub(super) struct ClauseElaborator<'a> {
+ db: &'a dyn HirDatabase,
+ stack: Vec<WhereClause>,
+ seen: FxHashSet<WhereClause>,
+}
+
+impl<'a> ClauseElaborator<'a> {
+ fn extend_deduped(&mut self, clauses: impl IntoIterator<Item = WhereClause>) {
+ self.stack.extend(clauses.into_iter().filter(|c| self.seen.insert(c.clone())))
+ }
+
+ fn elaborate_supertrait(&mut self, clause: &WhereClause) {
+ if let WhereClause::Implemented(trait_ref) = clause {
+ direct_super_trait_refs(self.db, trait_ref, |t| {
+ let clause = WhereClause::Implemented(t);
+ if self.seen.insert(clause.clone()) {
+ self.stack.push(clause);
+ }
+ });
+ }
+ }
+}
+
+impl Iterator for ClauseElaborator<'_> {
+ type Item = WhereClause;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if let Some(next) = self.stack.pop() {
+ self.elaborate_supertrait(&next);
+ Some(next)
+ } else {
+ None
+ }
+ }
+}
+
fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) {
let resolver = trait_.resolver(db);
let generic_params = db.generic_params(trait_.into());