Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/closure.rs | 2353 |
1 files changed, 682 insertions, 1671 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index c3029bf2b5..06f8307eb0 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -1,137 +1,214 @@ //! Inference of closure parameter types based on the closure's expected type. -use std::{cmp, convert::Infallible, mem, ops::ControlFlow}; +pub(crate) mod analysis; + +use std::{iter, mem, ops::ControlFlow}; -use chalk_ir::{ - BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, - cast::Cast, - fold::{FallibleTypeFolder, Shift, TypeFoldable}, - visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, -}; -use either::Either; use hir_def::{ - DefWithBodyId, FieldId, HasModule, TupleFieldId, TupleId, VariantId, - expr_store::path::Path, - hir::{ - Array, AsmOperand, BinaryOp, BindingId, CaptureBy, ClosureKind, Expr, ExprId, ExprOrPatId, - Pat, PatId, Statement, UnaryOp, - }, - item_tree::FieldsShape, + TraitId, + hir::{ClosureKind, ExprId, PatId}, lang_item::LangItem, - resolver::ValueNs, + type_ref::TypeRefId, +}; +use rustc_type_ir::{ + ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, + CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, + TypeVisitor, + inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Ty as _}, }; -use hir_def::{Lookup, type_ref::TypeRefId}; -use hir_expand::name::Name; -use intern::sym; -use rustc_hash::{FxHashMap, FxHashSet}; -use smallvec::{SmallVec, smallvec}; -use stdx::{format_to, never}; -use syntax::utils::is_raw_identifier; +use tracing::debug; use crate::{ - Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, - DynTyExt, FnAbi, FnPointer, FnSig, GenericArg, Interner, OpaqueTy, ProjectionTy, - ProjectionTyExt, Substitution, Ty, TyBuilder, TyExt, WhereClause, - db::{HirDatabase, InternedClosure, InternedCoroutine}, - error_lifetime, from_assoc_type_id, from_chalk_trait_id, from_placeholder_idx, - generics::Generics, - infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever}, - make_binders, - mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem}, - to_assoc_type_id, to_chalk_trait_id, + FnAbi, + db::{InternedClosure, InternedCoroutine}, + infer::{BreakableKind, Diverges, coerce::CoerceMany}, + next_solver::{ + AliasTy, Binder, BoundRegionKind, BoundVarKind, BoundVarKinds, ClauseKind, DbInterner, + ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, PolyProjectionPredicate, Predicate, + PredicateKind, SolverDefId, Ty, TyKind, + abi::Safety, + infer::{ + BoundRegionConversionTime, InferOk, InferResult, + traits::{ObligationCause, PredicateObligations}, + }, + util::explicit_item_bounds, + }, traits::FnTrait, - utils::{self, elaborate_clause_supertraits}, }; use super::{Expectation, InferenceContext}; #[derive(Debug)] -pub(super) struct ClosureSignature { - pub(super) ret_ty: Ty, - pub(super) expected_sig: FnPointer, +struct ClosureSignatures<'db> { + /// The signature users of the closure see. + bound_sig: PolyFnSig<'db>, + /// The signature within the function body. + /// This mostly differs in the sense that lifetimes are now early bound and any + /// opaque types from the signature expectation are overridden in case there are + /// explicit hidden types written by the user in the closure signature. + liberated_sig: FnSig<'db>, } -impl InferenceContext<'_> { +impl<'db> InferenceContext<'_, 'db> { pub(super) fn infer_closure( &mut self, - body: &ExprId, + body: ExprId, args: &[PatId], - ret_type: &Option<TypeRefId>, + ret_type: Option<TypeRefId>, arg_types: &[Option<TypeRefId>], closure_kind: ClosureKind, tgt_expr: ExprId, - expected: &Expectation, - ) -> Ty { + expected: &Expectation<'db>, + ) -> Ty<'db> { assert_eq!(args.len(), arg_types.len()); + let interner = self.interner(); let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { - Some(expected_ty) => self.deduce_closure_signature(&expected_ty, closure_kind), + Some(expected_ty) => self.deduce_closure_signature(expected_ty, closure_kind), None => (None, None), }; - let ClosureSignature { expected_sig: bound_sig, ret_ty: body_ret_ty } = - self.sig_of_closure(body, ret_type, arg_types, closure_kind, expected_sig); - let bound_sig = self.normalize_associated_types_in(bound_sig); - let sig_ty = TyKind::Function(bound_sig.clone()).intern(Interner); + let ClosureSignatures { bound_sig, liberated_sig } = + self.sig_of_closure(arg_types, ret_type, expected_sig); + let body_ret_ty = bound_sig.output().skip_binder(); + let sig_ty = Ty::new_fn_ptr(interner, bound_sig); + let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); + // FIXME: Make this an infer var and infer it later. + let tupled_upvars_ty = self.types.unit; let (id, ty, resume_yield_tys) = match closure_kind { ClosureKind::Coroutine(_) => { - let sig_tys = bound_sig.substitution.0.as_slice(Interner); - // FIXME: report error when there are more than 1 parameter. - let resume_ty = match sig_tys.first() { - // When `sig_tys.len() == 1` the first type is the return type, not the - // first parameter type. - Some(ty) if sig_tys.len() > 1 => ty.assert_ty_ref(Interner).clone(), - _ => self.result.standard_types.unit.clone(), + let yield_ty = self.table.next_ty_var(); + let resume_ty = liberated_sig.inputs().get(0).unwrap_or(self.types.unit); + + // FIXME: Infer the upvars later. + let parts = CoroutineArgsParts { + parent_args, + kind_ty: self.types.unit, + resume_ty, + yield_ty, + return_ty: body_ret_ty, + tupled_upvars_ty, }; - let yield_ty = self.table.new_type_var(); - - let subst = TyBuilder::subst_for_coroutine(self.db, self.owner) - .push(resume_ty.clone()) - .push(yield_ty.clone()) - .push(body_ret_ty.clone()) - .build(); let coroutine_id = self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); - let coroutine_ty = TyKind::Coroutine(coroutine_id, subst).intern(Interner); + let coroutine_ty = Ty::new_coroutine( + interner, + coroutine_id, + CoroutineArgs::new(interner, parts).args, + ); (None, coroutine_ty, Some((resume_ty, yield_ty))) } - ClosureKind::Closure | ClosureKind::Async => { - let closure_id = - self.db.intern_closure(InternedClosure(self.owner, tgt_expr)).into(); - let closure_ty = TyKind::Closure( - closure_id, - TyBuilder::subst_for_closure(self.db, self.owner, sig_ty.clone()), - ) - .intern(Interner); + ClosureKind::Closure => { + let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr)); + match expected_kind { + Some(kind) => { + self.result.closure_info.insert( + closure_id, + ( + Vec::new(), + match kind { + rustc_type_ir::ClosureKind::Fn => FnTrait::Fn, + rustc_type_ir::ClosureKind::FnMut => FnTrait::FnMut, + rustc_type_ir::ClosureKind::FnOnce => FnTrait::FnOnce, + }, + ), + ); + } + None => {} + }; + // FIXME: Infer the kind later if needed. + let parts = ClosureArgsParts { + parent_args, + closure_kind_ty: Ty::from_closure_kind( + interner, + expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), + ), + closure_sig_as_fn_ptr_ty: sig_ty, + tupled_upvars_ty, + }; + let closure_ty = Ty::new_closure( + interner, + closure_id.into(), + ClosureArgs::new(interner, parts).args, + ); self.deferred_closures.entry(closure_id).or_default(); self.add_current_closure_dependency(closure_id); (Some(closure_id), closure_ty, None) } - }; + ClosureKind::Async => { + // async closures always return the type ascribed after the `->` (if present), + // and yield `()`. + let bound_return_ty = bound_sig.skip_binder().output(); + let bound_yield_ty = self.types.unit; + // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. + let resume_ty = self.types.unit; + + // FIXME: Infer the kind later if needed. + let closure_kind_ty = Ty::from_closure_kind( + interner, + expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), + ); - // Eagerly try to relate the closure type with the expected - // type, otherwise we often won't have enough information to - // infer the body. - self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected, expected_kind); + // FIXME: Infer captures later. + // `for<'env> fn() -> ()`, for no captures. + let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + interner, + Binder::bind_with_vars( + interner.mk_fn_sig([], self.types.unit, false, Safety::Safe, FnAbi::Rust), + BoundVarKinds::new_from_iter( + interner, + [BoundVarKind::Region(BoundRegionKind::ClosureEnv)], + ), + ), + ); + let closure_args = CoroutineClosureArgs::new( + interner, + CoroutineClosureArgsParts { + parent_args, + closure_kind_ty, + signature_parts_ty: Ty::new_fn_ptr( + interner, + bound_sig.map_bound(|sig| { + interner.mk_fn_sig( + [ + resume_ty, + Ty::new_tup_from_iter(interner, sig.inputs().iter()), + ], + Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), + sig.c_variadic, + sig.safety, + sig.abi, + ) + }), + ), + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + }, + ); + + let coroutine_id = + self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); + (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None) + } + }; // Now go through the argument patterns - for (arg_pat, arg_ty) in args.iter().zip(bound_sig.substitution.0.as_slice(Interner).iter()) - { - self.infer_top_pat(*arg_pat, arg_ty.assert_ty_ref(Interner), None); + for (arg_pat, arg_ty) in args.iter().zip(bound_sig.skip_binder().inputs()) { + self.infer_top_pat(*arg_pat, arg_ty, None); } // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); let prev_closure = mem::replace(&mut self.current_closure, id); - let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty.clone()); + let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty); let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty)); let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { - this.infer_return(*body); + this.infer_return(body); }); self.diverges = prev_diverges; @@ -140,1707 +217,641 @@ impl InferenceContext<'_> { self.current_closure = prev_closure; self.resume_yield_tys = prev_resume_yield_tys; - self.table.normalize_associated_types_in(ty) + ty } - // This function handles both closures and coroutines. - pub(super) fn deduce_closure_type_from_expectations( - &mut self, - closure_expr: ExprId, - closure_ty: &Ty, - sig_ty: &Ty, - expectation: &Expectation, - expected_kind: Option<FnTrait>, - ) { - let expected_ty = match expectation.to_option(&mut self.table) { - Some(ty) => ty, - None => return, - }; - - match (closure_ty.kind(Interner), expected_kind) { - (TyKind::Closure(closure_id, _), Some(closure_kind)) => { - self.result - .closure_info - .entry(*closure_id) - .or_insert_with(|| (Vec::new(), closure_kind)); - } - _ => {} - } - - // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here. - let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty, CoerceNever::Yes); - - // Coroutines are not Fn* so return early. - if matches!(closure_ty.kind(Interner), TyKind::Coroutine(..)) { - return; + fn fn_trait_kind_from_def_id(&self, trait_id: TraitId) -> Option<rustc_type_ir::ClosureKind> { + let lang_item = self.db.lang_attr(trait_id.into())?; + match lang_item { + LangItem::Fn => Some(rustc_type_ir::ClosureKind::Fn), + LangItem::FnMut => Some(rustc_type_ir::ClosureKind::FnMut), + LangItem::FnOnce => Some(rustc_type_ir::ClosureKind::FnOnce), + _ => None, } + } - // Deduction based on the expected `dyn Fn` is done separately. - if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) { - if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) { - let expected_sig_ty = TyKind::Function(sig).intern(Interner); - - self.unify(sig_ty, &expected_sig_ty); - } + fn async_fn_trait_kind_from_def_id( + &self, + trait_id: TraitId, + ) -> Option<rustc_type_ir::ClosureKind> { + let lang_item = self.db.lang_attr(trait_id.into())?; + match lang_item { + LangItem::AsyncFn => Some(rustc_type_ir::ClosureKind::Fn), + LangItem::AsyncFnMut => Some(rustc_type_ir::ClosureKind::FnMut), + LangItem::AsyncFnOnce => Some(rustc_type_ir::ClosureKind::FnOnce), + _ => None, } } - // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`. - // Might need to port closure sig deductions too. - pub(super) fn deduce_closure_signature( + /// Given the expected type, figures out what it can about this closure we + /// are about to type check: + fn deduce_closure_signature( &mut self, - expected_ty: &Ty, + expected_ty: Ty<'db>, closure_kind: ClosureKind, - ) -> (Option<FnSubst<Interner>>, Option<FnTrait>) { - match expected_ty.kind(Interner) { - TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => { - let clauses = expected_ty.impl_trait_bounds(self.db).into_iter().flatten().map( - |b: chalk_ir::Binders<chalk_ir::WhereClause<Interner>>| { - b.into_value_and_skipped_binders().0 - }, - ); - self.deduce_closure_kind_from_predicate_clauses(expected_ty, clauses, closure_kind) - } - TyKind::Dyn(dyn_ty) => { - let sig = - dyn_ty.bounds.skip_binders().as_slice(Interner).iter().find_map(|bound| { - if let WhereClause::AliasEq(AliasEq { - alias: AliasTy::Projection(projection_ty), - ty: projected_ty, - }) = bound.skip_binders() - { - if let Some(sig) = self.deduce_sig_from_projection( - closure_kind, - projection_ty, - projected_ty, - ) { - return Some(sig); - } - } - None - }); - - let kind = dyn_ty.principal().and_then(|principal_trait_ref| { - self.fn_trait_kind_from_trait_id(from_chalk_trait_id( - principal_trait_ref.skip_binders().skip_binders().trait_id, - )) + ) -> (Option<PolyFnSig<'db>>, Option<rustc_type_ir::ClosureKind>) { + match expected_ty.kind() { + TyKind::Alias(rustc_type_ir::Opaque, AliasTy { def_id, args, .. }) => self + .deduce_closure_signature_from_predicates( + expected_ty, + closure_kind, + explicit_item_bounds(self.interner(), def_id) + .iter_instantiated(self.interner(), args) + .map(|clause| clause.as_predicate()), + ), + TyKind::Dynamic(object_type, ..) => { + let sig = object_type.projection_bounds().into_iter().find_map(|pb| { + let pb = pb.with_self_ty(self.interner(), Ty::new_unit(self.interner())); + self.deduce_sig_from_projection(closure_kind, pb) }); - + let kind = object_type + .principal_def_id() + .and_then(|did| self.fn_trait_kind_from_def_id(did.0)); (sig, kind) } - TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => { - let clauses = self.clauses_for_self_ty(*ty); - self.deduce_closure_kind_from_predicate_clauses( - expected_ty, - clauses.into_iter(), + TyKind::Infer(rustc_type_ir::TyVar(vid)) => self + .deduce_closure_signature_from_predicates( + Ty::new_var(self.interner(), self.table.infer_ctxt.root_var(vid)), closure_kind, - ) - } - TyKind::Function(fn_ptr) => match closure_kind { - ClosureKind::Closure => (Some(fn_ptr.substitution.clone()), Some(FnTrait::Fn)), - ClosureKind::Async | ClosureKind::Coroutine(_) => (None, None), + self.table.obligations_for_self_ty(vid).into_iter().map(|obl| obl.predicate), + ), + TyKind::FnPtr(sig_tys, hdr) => match closure_kind { + ClosureKind::Closure => { + let expected_sig = sig_tys.with(hdr); + (Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn)) + } + ClosureKind::Coroutine(_) | ClosureKind::Async => (None, None), }, _ => (None, None), } } - fn deduce_closure_kind_from_predicate_clauses( + fn deduce_closure_signature_from_predicates( &mut self, - expected_ty: &Ty, - clauses: impl DoubleEndedIterator<Item = WhereClause>, + expected_ty: Ty<'db>, closure_kind: ClosureKind, - ) -> (Option<FnSubst<Interner>>, Option<FnTrait>) { + predicates: impl DoubleEndedIterator<Item = Predicate<'db>>, + ) -> (Option<PolyFnSig<'db>>, Option<rustc_type_ir::ClosureKind>) { let mut expected_sig = None; let mut expected_kind = None; - for clause in elaborate_clause_supertraits(self.db, clauses.rev()) { - if expected_sig.is_none() { - if let WhereClause::AliasEq(AliasEq { - alias: AliasTy::Projection(projection), - ty, - }) = &clause - { - let inferred_sig = - self.deduce_sig_from_projection(closure_kind, projection, ty); - // Make sure that we didn't infer a signature that mentions itself. - // This can happen when we elaborate certain supertrait bounds that - // mention projections containing the `Self` type. See rust-lang/rust#105401. - struct MentionsTy<'a> { - expected_ty: &'a Ty, - } - impl TypeVisitor<Interner> for MentionsTy<'_> { - type BreakTy = (); - - fn interner(&self) -> Interner { - Interner - } + for pred in rustc_type_ir::elaborate::elaborate( + self.interner(), + // Reverse the obligations here, since `elaborate_*` uses a stack, + // and we want to keep inference generally in the same order of + // the registered obligations. + predicates.rev(), + ) + // We only care about self bounds + .filter_only_self() + { + debug!(?pred); + let bound_predicate = pred.kind(); + + // Given a Projection predicate, we can potentially infer + // the complete signature. + if expected_sig.is_none() + && let PredicateKind::Clause(ClauseKind::Projection(proj_predicate)) = + bound_predicate.skip_binder() + { + let inferred_sig = self.deduce_sig_from_projection( + closure_kind, + bound_predicate.rebind(proj_predicate), + ); - fn as_dyn( - &mut self, - ) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> - { - self - } + // Make sure that we didn't infer a signature that mentions itself. + // This can happen when we elaborate certain supertrait bounds that + // mention projections containing the `Self` type. See #105401. + struct MentionsTy<'db> { + expected_ty: Ty<'db>, + } + impl<'db> TypeVisitor<DbInterner<'db>> for MentionsTy<'db> { + type Result = ControlFlow<()>; - fn visit_ty( - &mut self, - t: &Ty, - db: chalk_ir::DebruijnIndex, - ) -> ControlFlow<()> { - if t == self.expected_ty { - ControlFlow::Break(()) - } else { - t.super_visit_with(self, db) - } + fn visit_ty(&mut self, t: Ty<'db>) -> Self::Result { + if t == self.expected_ty { + ControlFlow::Break(()) + } else { + t.super_visit_with(self) } } - if inferred_sig - .visit_with( - &mut MentionsTy { expected_ty }, - chalk_ir::DebruijnIndex::INNERMOST, - ) - .is_continue() - { - expected_sig = inferred_sig; - } } - } - let trait_id = match clause { - WhereClause::AliasEq(AliasEq { - alias: AliasTy::Projection(projection), .. - }) => projection.trait_(self.db), - WhereClause::Implemented(trait_ref) => from_chalk_trait_id(trait_ref.trait_id), - _ => continue, - }; - if let Some(closure_kind) = self.fn_trait_kind_from_trait_id(trait_id) { - // always use the closure kind that is more permissive. - match (expected_kind, closure_kind) { - (None, _) => expected_kind = Some(closure_kind), - (Some(FnTrait::FnMut), FnTrait::Fn) => expected_kind = Some(FnTrait::Fn), - (Some(FnTrait::FnOnce), FnTrait::Fn | FnTrait::FnMut) => { - expected_kind = Some(closure_kind) + // Don't infer a closure signature from a goal that names the closure type as this will + // (almost always) lead to occurs check errors later in type checking. + if let Some(inferred_sig) = inferred_sig { + // In the new solver it is difficult to explicitly normalize the inferred signature as we + // would have to manually handle universes and rewriting bound vars and placeholders back + // and forth. + // + // Instead we take advantage of the fact that we relating an inference variable with an alias + // will only instantiate the variable if the alias is rigid(*not quite). Concretely we: + // - Create some new variable `?sig` + // - Equate `?sig` with the unnormalized signature, e.g. `fn(<Foo<?x> as Trait>::Assoc)` + // - Depending on whether `<Foo<?x> as Trait>::Assoc` is rigid, ambiguous or normalizeable, + // we will either wind up with `?sig=<Foo<?x> as Trait>::Assoc/?y/ConcreteTy` respectively. + // + // *: In cases where there are ambiguous aliases in the signature that make use of bound vars + // they will wind up present in `?sig` even though they are non-rigid. + // + // This is a bit weird and means we may wind up discarding the goal due to it naming `expected_ty` + // even though the normalized form may not name `expected_ty`. However, this matches the existing + // behaviour of the old solver and would be technically a breaking change to fix. + let generalized_fnptr_sig = self.table.next_ty_var(); + let inferred_fnptr_sig = Ty::new_fn_ptr(self.interner(), inferred_sig); + // FIXME: Report diagnostics. + _ = self + .table + .infer_ctxt + .at(&ObligationCause::new(), self.table.trait_env.env) + .eq(inferred_fnptr_sig, generalized_fnptr_sig) + .map(|infer_ok| self.table.register_infer_ok(infer_ok)); + + let resolved_sig = + self.table.infer_ctxt.resolve_vars_if_possible(generalized_fnptr_sig); + + if resolved_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() { + expected_sig = Some(resolved_sig.fn_sig(self.interner())); } - _ => {} + } else if inferred_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() { + expected_sig = inferred_sig; } } - } - - (expected_sig, expected_kind) - } - - fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> { - // Search for a predicate like `<$self as FnX<Args>>::Output == Ret` - let fn_traits: SmallVec<[ChalkTraitId; 3]> = - utils::fn_traits(self.db, self.owner.module(self.db).krate()) - .map(to_chalk_trait_id) - .collect(); - - let self_ty = self.result.standard_types.unknown.clone(); - let bounds = dyn_ty.bounds.clone().substitute(Interner, &[self_ty.cast(Interner)]); - for bound in bounds.iter(Interner) { - // NOTE(skip_binders): the extracted types are rebound by the returned `FnPointer` - if let WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(projection), ty }) = - bound.skip_binders() - { - let assoc_data = - self.db.associated_ty_data(from_assoc_type_id(projection.associated_ty_id)); - if !fn_traits.contains(&assoc_data.trait_id) { - return None; + // Even if we can't infer the full signature, we may be able to + // infer the kind. This can occur when we elaborate a predicate + // like `F : Fn<A>`. Note that due to subtyping we could encounter + // many viable options, so pick the most restrictive. + let trait_def_id = match bound_predicate.skip_binder() { + PredicateKind::Clause(ClauseKind::Projection(data)) => { + Some(data.projection_term.trait_def_id(self.interner()).0) } + PredicateKind::Clause(ClauseKind::Trait(data)) => Some(data.def_id().0), + _ => None, + }; + + if let Some(trait_def_id) = trait_def_id { + let found_kind = match closure_kind { + ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id), + ClosureKind::Async => self + .async_fn_trait_kind_from_def_id(trait_def_id) + .or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)), + _ => None, + }; - // Skip `Self`, get the type argument. - let arg = projection.substitution.as_slice(Interner).get(1)?; - if let Some(subst) = arg.ty(Interner)?.as_tuple() { - let generic_args = subst.as_slice(Interner); - let mut sig_tys = Vec::with_capacity(generic_args.len() + 1); - for arg in generic_args { - sig_tys.push(arg.ty(Interner)?.clone()); + if let Some(found_kind) = found_kind { + // always use the closure kind that is more permissive. + match (expected_kind, found_kind) { + (None, _) => expected_kind = Some(found_kind), + ( + Some(rustc_type_ir::ClosureKind::FnMut), + rustc_type_ir::ClosureKind::Fn, + ) => expected_kind = Some(rustc_type_ir::ClosureKind::Fn), + ( + Some(rustc_type_ir::ClosureKind::FnOnce), + rustc_type_ir::ClosureKind::Fn | rustc_type_ir::ClosureKind::FnMut, + ) => expected_kind = Some(found_kind), + _ => {} } - sig_tys.push(ty.clone()); - - cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature); - return Some(FnPointer { - num_binders: bound.len(Interner), - sig: FnSig { - abi: FnAbi::RustCall, - safety: chalk_ir::Safety::Safe, - variadic: false, - }, - substitution: FnSubst(Substitution::from_iter(Interner, sig_tys)), - }); } } } - None + (expected_sig, expected_kind) } + /// Given a projection like "<F as Fn(X)>::Result == Y", we can deduce + /// everything we need to know about a closure or coroutine. + /// + /// The `cause_span` should be the span that caused us to + /// have this expected signature, or `None` if we can't readily + /// know that. fn deduce_sig_from_projection( &mut self, closure_kind: ClosureKind, - projection_ty: &ProjectionTy, - projected_ty: &Ty, - ) -> Option<FnSubst<Interner>> { - let container = - from_assoc_type_id(projection_ty.associated_ty_id).lookup(self.db).container; - let trait_ = match container { - hir_def::ItemContainerId::TraitId(trait_) => trait_, - _ => return None, - }; + projection: PolyProjectionPredicate<'db>, + ) -> Option<PolyFnSig<'db>> { + let SolverDefId::TypeAliasId(def_id) = projection.item_def_id() else { unreachable!() }; + let lang_item = self.db.lang_attr(def_id.into()); // For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits, // for closures and async closures, respectively. - let fn_trait_kind = self.fn_trait_kind_from_trait_id(trait_)?; - if !matches!(closure_kind, ClosureKind::Closure | ClosureKind::Async) { - return None; - } - if fn_trait_kind.is_async() { - // If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is, - // but we do know it must implement `Future<Output = X>`. - self.extract_async_fn_sig_from_projection(projection_ty, projected_ty) - } else { - self.extract_sig_from_projection(projection_ty, projected_ty) + match closure_kind { + ClosureKind::Closure if lang_item == Some(LangItem::FnOnceOutput) => { + self.extract_sig_from_projection(projection) + } + ClosureKind::Async if lang_item == Some(LangItem::AsyncFnOnceOutput) => { + self.extract_sig_from_projection(projection) + } + // It's possible we've passed the closure to a (somewhat out-of-fashion) + // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still + // guide inference here, since it's beneficial for the user. + ClosureKind::Async if lang_item == Some(LangItem::FnOnceOutput) => { + self.extract_sig_from_projection_and_future_bound(projection) + } + _ => None, } } + /// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args + /// and return type to infer a [`ty::PolyFnSig`] for the closure. fn extract_sig_from_projection( &self, - projection_ty: &ProjectionTy, - projected_ty: &Ty, - ) -> Option<FnSubst<Interner>> { - let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner); + projection: PolyProjectionPredicate<'db>, + ) -> Option<PolyFnSig<'db>> { + let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection); - let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else { - return None; - }; - - let ret_param_ty = projected_ty; - - Some(FnSubst(Substitution::from_iter( - Interner, - input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new( - Interner, - chalk_ir::GenericArgData::Ty(ret_param_ty.clone()), - ))), - ))) - } - - fn extract_async_fn_sig_from_projection( - &mut self, - projection_ty: &ProjectionTy, - projected_ty: &Ty, - ) -> Option<FnSubst<Interner>> { - let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner); + let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); + debug!(?arg_param_ty); - let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else { + let TyKind::Tuple(input_tys) = arg_param_ty.kind() else { return None; }; - let ret_param_future_output = projected_ty; - let ret_param_future = self.table.new_type_var(); - let future_output = - LangItem::FutureOutput.resolve_type_alias(self.db, self.resolver.krate())?; - let future_projection = crate::AliasTy::Projection(crate::ProjectionTy { - associated_ty_id: to_assoc_type_id(future_output), - substitution: Substitution::from1(Interner, ret_param_future.clone()), - }); - self.table.register_obligation( - crate::AliasEq { alias: future_projection, ty: ret_param_future_output.clone() } - .cast(Interner), - ); - - Some(FnSubst(Substitution::from_iter( - Interner, - input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new( - Interner, - chalk_ir::GenericArgData::Ty(ret_param_future), - ))), - ))) - } + // Since this is a return parameter type it is safe to unwrap. + let ret_param_ty = projection.skip_binder().term.expect_type(); + debug!(?ret_param_ty); - fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> { - FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?) + let sig = projection.rebind(self.interner().mk_fn_sig( + input_tys, + ret_param_ty, + false, + Safety::Safe, + FnAbi::Rust, + )); + + Some(sig) } - fn supplied_sig_of_closure( + /// When an async closure is passed to a function that has a "two-part" `Fn` + /// and `Future` trait bound, like: + /// + /// ```rust + /// use std::future::Future; + /// + /// fn not_exactly_an_async_closure<F, Fut>(_f: F) + /// where + /// F: FnOnce(String, u32) -> Fut, + /// Fut: Future<Output = i32>, + /// {} + /// ``` + /// + /// The we want to be able to extract the signature to guide inference in the async + /// closure. We will have two projection predicates registered in this case. First, + /// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is + /// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>` + /// projection. + /// + /// This function is actually best-effort with the return type; if we don't find a + /// `Future` projection, we still will return arguments that we extracted from the `FnOnce` + /// projection, and the output will be an unconstrained type variable instead. + fn extract_sig_from_projection_and_future_bound( &mut self, - body: &ExprId, - ret_type: &Option<TypeRefId>, - arg_types: &[Option<TypeRefId>], - closure_kind: ClosureKind, - ) -> ClosureSignature { - let mut sig_tys = Vec::with_capacity(arg_types.len() + 1); - - // collect explicitly written argument types - for arg_type in arg_types.iter() { - let arg_ty = match arg_type { - // FIXME: I think rustc actually lowers closure params with `LifetimeElisionKind::AnonymousCreateParameter` - // (but the return type with infer). - Some(type_ref) => self.make_body_ty(*type_ref), - None => self.table.new_type_var(), - }; - sig_tys.push(arg_ty); - } + projection: PolyProjectionPredicate<'db>, + ) -> Option<PolyFnSig<'db>> { + let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection); - // add return type - let ret_ty = match ret_type { - Some(type_ref) => self.make_body_ty(*type_ref), - None => self.table.new_type_var(), - }; - if let ClosureKind::Async = closure_kind { - sig_tys.push(self.lower_async_block_type_impl_trait(ret_ty.clone(), *body)); - } else { - sig_tys.push(ret_ty.clone()); - } + let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); + debug!(?arg_param_ty); - let expected_sig = FnPointer { - num_binders: 0, - sig: FnSig { abi: FnAbi::RustCall, safety: chalk_ir::Safety::Safe, variadic: false }, - substitution: FnSubst( - Substitution::from_iter(Interner, sig_tys.iter().cloned()).shifted_in(Interner), - ), + let TyKind::Tuple(input_tys) = arg_param_ty.kind() else { + return None; }; - ClosureSignature { ret_ty, expected_sig } - } + // If the return type is a type variable, look for bounds on it. + // We could theoretically support other kinds of return types here, + // but none of them would be useful, since async closures return + // concrete anonymous future types, and their futures are not coerced + // into any other type within the body of the async closure. + let TyKind::Infer(rustc_type_ir::TyVar(return_vid)) = + projection.skip_binder().term.expect_type().kind() + else { + return None; + }; - /// The return type is the signature of the closure, and the return type - /// *as represented inside the body* (so, for async closures, the `Output` ty) - pub(super) fn sig_of_closure( + // FIXME: We may want to elaborate here, though I assume this will be exceedingly rare. + let mut return_ty = None; + for bound in self.table.obligations_for_self_ty(return_vid) { + if let PredicateKind::Clause(ClauseKind::Projection(ret_projection)) = + bound.predicate.kind().skip_binder() + && let ret_projection = bound.predicate.kind().rebind(ret_projection) + && let Some(ret_projection) = ret_projection.no_bound_vars() + && let SolverDefId::TypeAliasId(assoc_type) = ret_projection.def_id() + && self.db.lang_attr(assoc_type.into()) == Some(LangItem::FutureOutput) + { + return_ty = Some(ret_projection.term.expect_type()); + break; + } + } + + // SUBTLE: If we didn't find a `Future<Output = ...>` bound for the return + // vid, we still want to attempt to provide inference guidance for the async + // closure's arguments. Instantiate a new vid to plug into the output type. + // + // You may be wondering, what if it's higher-ranked? Well, given that we + // found a type variable for the `FnOnce::Output` projection above, we know + // that the output can't mention any of the vars. + // + // Also note that we use a fresh var here for the signature since the signature + // records the output of the *future*, and `return_vid` above is the type + // variable of the future, not its output. + // + // FIXME: We probably should store this signature inference output in a way + // that does not misuse a `FnSig` type, but that can be done separately. + let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var()); + + let sig = projection.rebind(self.interner().mk_fn_sig( + input_tys, + return_ty, + false, + Safety::Safe, + FnAbi::Rust, + )); + + Some(sig) + } + + fn sig_of_closure( &mut self, - body: &ExprId, - ret_type: &Option<TypeRefId>, - arg_types: &[Option<TypeRefId>], - closure_kind: ClosureKind, - expected_sig: Option<FnSubst<Interner>>, - ) -> ClosureSignature { + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + expected_sig: Option<PolyFnSig<'db>>, + ) -> ClosureSignatures<'db> { if let Some(e) = expected_sig { - self.sig_of_closure_with_expectation(body, ret_type, arg_types, closure_kind, e) + self.sig_of_closure_with_expectation(decl_inputs, decl_output, e) } else { - self.sig_of_closure_no_expectation(body, ret_type, arg_types, closure_kind) + self.sig_of_closure_no_expectation(decl_inputs, decl_output) } } + /// If there is no expected signature, then we will convert the + /// types that the user gave into a signature. fn sig_of_closure_no_expectation( &mut self, - body: &ExprId, - ret_type: &Option<TypeRefId>, - arg_types: &[Option<TypeRefId>], - closure_kind: ClosureKind, - ) -> ClosureSignature { - self.supplied_sig_of_closure(body, ret_type, arg_types, closure_kind) - } + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + ) -> ClosureSignatures<'db> { + let bound_sig = self.supplied_sig_of_closure(decl_inputs, decl_output); - fn sig_of_closure_with_expectation( - &mut self, - body: &ExprId, - ret_type: &Option<TypeRefId>, - arg_types: &[Option<TypeRefId>], - closure_kind: ClosureKind, - expected_sig: FnSubst<Interner>, - ) -> ClosureSignature { - let expected_sig = FnPointer { - num_binders: 0, - sig: FnSig { abi: FnAbi::RustCall, safety: chalk_ir::Safety::Safe, variadic: false }, - substitution: expected_sig, - }; - - // If the expected signature does not match the actual arg types, - // then just return the expected signature - if expected_sig.substitution.0.len(Interner) != arg_types.len() + 1 { - let ret_ty = match ret_type { - Some(type_ref) => self.make_body_ty(*type_ref), - None => self.table.new_type_var(), - }; - return ClosureSignature { expected_sig, ret_ty }; - } - - self.merge_supplied_sig_with_expectation( - body, - ret_type, - arg_types, - closure_kind, - expected_sig, - ) - } - - fn merge_supplied_sig_with_expectation( - &mut self, - body: &ExprId, - ret_type: &Option<TypeRefId>, - arg_types: &[Option<TypeRefId>], - closure_kind: ClosureKind, - expected_sig: FnPointer, - ) -> ClosureSignature { - let supplied_sig = self.supplied_sig_of_closure(body, ret_type, arg_types, closure_kind); - - let snapshot = self.table.snapshot(); - if !self.table.unify(&expected_sig.substitution, &supplied_sig.expected_sig.substitution) { - self.table.rollback_to(snapshot); - } - - supplied_sig - } -} - -// The below functions handle capture and closure kind (Fn, FnMut, ..) - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) struct HirPlace { - pub(crate) local: BindingId, - pub(crate) projections: Vec<ProjectionElem<Infallible, Ty>>, -} - -impl HirPlace { - fn ty(&self, ctx: &mut InferenceContext<'_>) -> Ty { - let mut ty = ctx.table.resolve_completely(ctx.result[self.local].clone()); - for p in &self.projections { - ty = p.projected_ty( - ty, - ctx.db, - |_, _, _| { - unreachable!("Closure field only happens in MIR"); - }, - ctx.owner.module(ctx.db).krate(), - ); - } - ty - } - - fn capture_kind_of_truncated_place( - &self, - mut current_capture: CaptureKind, - len: usize, - ) -> CaptureKind { - if let CaptureKind::ByRef(BorrowKind::Mut { - kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow, - }) = current_capture - { - if self.projections[len..].contains(&ProjectionElem::Deref) { - current_capture = - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture }); - } - } - current_capture + self.closure_sigs(bound_sig) } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum CaptureKind { - ByRef(BorrowKind), - ByValue, -} -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CapturedItem { - pub(crate) place: HirPlace, - pub(crate) kind: CaptureKind, - /// The inner vec is the stacks; the outer vec is for each capture reference. + /// Invoked to compute the signature of a closure expression. This + /// combines any user-provided type annotations (e.g., `|x: u32| + /// -> u32 { .. }`) with the expected signature. /// - /// Even though we always report only the last span (i.e. the most inclusive span), - /// we need to keep them all, since when a closure occurs inside a closure, we - /// copy all captures of the inner closure to the outer closure, and then we may - /// truncate them, and we want the correct span to be reported. - span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, - pub(crate) ty: Binders<Ty>, -} - -impl CapturedItem { - pub fn local(&self) -> BindingId { - self.place.local - } - - /// Returns whether this place has any field (aka. non-deref) projections. - pub fn has_field_projections(&self) -> bool { - self.place.projections.iter().any(|it| !matches!(it, ProjectionElem::Deref)) - } - - pub fn ty(&self, subst: &Substitution) -> Ty { - self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst()) - } - - pub fn kind(&self) -> CaptureKind { - self.kind - } - - pub fn spans(&self) -> SmallVec<[MirSpan; 3]> { - self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect() - } - - /// Converts the place to a name that can be inserted into source code. - pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); - let mut result = body[self.place.local].name.as_str().to_owned(); - for proj in &self.place.projections { - match proj { - ProjectionElem::Deref => {} - ProjectionElem::Field(Either::Left(f)) => { - let variant_data = f.parent.fields(db); - match variant_data.shape { - FieldsShape::Record => { - result.push('_'); - result.push_str(variant_data.fields()[f.local_id].name.as_str()) - } - FieldsShape::Tuple => { - let index = - variant_data.fields().iter().position(|it| it.0 == f.local_id); - if let Some(index) = index { - format_to!(result, "_{index}"); - } - } - FieldsShape::Unit => {} - } - } - ProjectionElem::Field(Either::Right(f)) => format_to!(result, "_{}", f.index), - &ProjectionElem::ClosureField(field) => format_to!(result, "_{field}"), - ProjectionElem::Index(_) - | ProjectionElem::ConstantIndex { .. } - | ProjectionElem::Subslice { .. } - | ProjectionElem::OpaqueCast(_) => { - never!("Not happen in closure capture"); - continue; - } - } - } - if is_raw_identifier(&result, owner.module(db).krate().data(db).edition) { - result.insert_str(0, "r#"); - } - result - } - - pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); - let krate = owner.krate(db); - let edition = krate.data(db).edition; - let mut result = body[self.place.local].name.display(db, edition).to_string(); - for proj in &self.place.projections { - match proj { - // In source code autoderef kicks in. - ProjectionElem::Deref => {} - ProjectionElem::Field(Either::Left(f)) => { - let variant_data = f.parent.fields(db); - match variant_data.shape { - FieldsShape::Record => format_to!( - result, - ".{}", - variant_data.fields()[f.local_id].name.display(db, edition) - ), - FieldsShape::Tuple => format_to!( - result, - ".{}", - variant_data - .fields() - .iter() - .position(|it| it.0 == f.local_id) - .unwrap_or_default() - ), - FieldsShape::Unit => {} - } - } - ProjectionElem::Field(Either::Right(f)) => { - let field = f.index; - format_to!(result, ".{field}"); - } - &ProjectionElem::ClosureField(field) => { - format_to!(result, ".{field}"); - } - ProjectionElem::Index(_) - | ProjectionElem::ConstantIndex { .. } - | ProjectionElem::Subslice { .. } - | ProjectionElem::OpaqueCast(_) => { - never!("Not happen in closure capture"); - continue; - } - } - } - let final_derefs_count = self - .place - .projections - .iter() - .rev() - .take_while(|proj| matches!(proj, ProjectionElem::Deref)) - .count(); - result.insert_str(0, &"*".repeat(final_derefs_count)); - result - } - - pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { - let body = db.body(owner); - let krate = owner.krate(db); - let edition = krate.data(db).edition; - let mut result = body[self.place.local].name.display(db, edition).to_string(); - let mut field_need_paren = false; - for proj in &self.place.projections { - match proj { - ProjectionElem::Deref => { - result = format!("*{result}"); - field_need_paren = true; - } - ProjectionElem::Field(Either::Left(f)) => { - if field_need_paren { - result = format!("({result})"); - } - let variant_data = f.parent.fields(db); - let field = match variant_data.shape { - FieldsShape::Record => { - variant_data.fields()[f.local_id].name.as_str().to_owned() - } - FieldsShape::Tuple => variant_data - .fields() - .iter() - .position(|it| it.0 == f.local_id) - .unwrap_or_default() - .to_string(), - FieldsShape::Unit => "[missing field]".to_owned(), - }; - result = format!("{result}.{field}"); - field_need_paren = false; - } - ProjectionElem::Field(Either::Right(f)) => { - let field = f.index; - if field_need_paren { - result = format!("({result})"); - } - result = format!("{result}.{field}"); - field_need_paren = false; - } - &ProjectionElem::ClosureField(field) => { - if field_need_paren { - result = format!("({result})"); - } - result = format!("{result}.{field}"); - field_need_paren = false; - } - ProjectionElem::Index(_) - | ProjectionElem::ConstantIndex { .. } - | ProjectionElem::Subslice { .. } - | ProjectionElem::OpaqueCast(_) => { - never!("Not happen in closure capture"); - continue; - } - } - } - result - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct CapturedItemWithoutTy { - pub(crate) place: HirPlace, - pub(crate) kind: CaptureKind, - /// The inner vec is the stacks; the outer vec is for each capture reference. - pub(crate) span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, -} - -impl CapturedItemWithoutTy { - fn with_ty(self, ctx: &mut InferenceContext<'_>) -> CapturedItem { - let ty = self.place.ty(ctx); - let ty = match &self.kind { - CaptureKind::ByValue => ty, - CaptureKind::ByRef(bk) => { - let m = match bk { - BorrowKind::Mut { .. } => Mutability::Mut, - _ => Mutability::Not, - }; - TyKind::Ref(m, error_lifetime(), ty).intern(Interner) - } - }; - return CapturedItem { - place: self.place, - kind: self.kind, - span_stacks: self.span_stacks, - ty: replace_placeholder_with_binder(ctx, ty), - }; - - fn replace_placeholder_with_binder(ctx: &mut InferenceContext<'_>, ty: Ty) -> Binders<Ty> { - struct Filler<'a> { - db: &'a dyn HirDatabase, - generics: &'a Generics, - } - impl FallibleTypeFolder<Interner> for Filler<'_> { - type Error = (); - - fn as_dyn(&mut self) -> &mut dyn FallibleTypeFolder<Interner, Error = Self::Error> { - self - } - - fn interner(&self) -> Interner { - Interner - } - - fn try_fold_free_placeholder_const( - &mut self, - ty: chalk_ir::Ty<Interner>, - idx: chalk_ir::PlaceholderIndex, - outer_binder: DebruijnIndex, - ) -> Result<chalk_ir::Const<Interner>, Self::Error> { - let x = from_placeholder_idx(self.db, idx); - let Some(idx) = self.generics.type_or_const_param_idx(x) else { - return Err(()); - }; - Ok(BoundVar::new(outer_binder, idx).to_const(Interner, ty)) - } - - fn try_fold_free_placeholder_ty( - &mut self, - idx: chalk_ir::PlaceholderIndex, - outer_binder: DebruijnIndex, - ) -> std::result::Result<Ty, Self::Error> { - let x = from_placeholder_idx(self.db, idx); - let Some(idx) = self.generics.type_or_const_param_idx(x) else { - return Err(()); - }; - Ok(BoundVar::new(outer_binder, idx).to_ty(Interner)) - } - } - let filler = &mut Filler { db: ctx.db, generics: ctx.generics() }; - let result = ty.clone().try_fold_with(filler, DebruijnIndex::INNERMOST).unwrap_or(ty); - make_binders(ctx.db, filler.generics, result) - } - } -} - -impl InferenceContext<'_> { - fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option<HirPlace> { - let r = self.place_of_expr_without_adjust(tgt_expr)?; - let adjustments = - self.result.expr_adjustments.get(&tgt_expr).map(|it| &**it).unwrap_or_default(); - apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments) - } - - /// Pushes the span into `current_capture_span_stack`, *without clearing it first*. - fn path_place(&mut self, path: &Path, id: ExprOrPatId) -> Option<HirPlace> { - if path.type_anchor().is_some() { - return None; - } - let hygiene = self.body.expr_or_pat_path_hygiene(id); - self.resolver.resolve_path_in_value_ns_fully(self.db, path, hygiene).and_then(|result| { - match result { - ValueNs::LocalBinding(binding) => { - let mir_span = match id { - ExprOrPatId::ExprId(id) => MirSpan::ExprId(id), - ExprOrPatId::PatId(id) => MirSpan::PatId(id), - }; - self.current_capture_span_stack.push(mir_span); - Some(HirPlace { local: binding, projections: Vec::new() }) - } - _ => None, - } - }) - } + /// The approach is as follows: + /// + /// - Let `S` be the (higher-ranked) signature that we derive from the user's annotations. + /// - Let `E` be the (higher-ranked) signature that we derive from the expectations, if any. + /// - If we have no expectation `E`, then the signature of the closure is `S`. + /// - Otherwise, the signature of the closure is E. Moreover: + /// - Skolemize the late-bound regions in `E`, yielding `E'`. + /// - Instantiate all the late-bound regions bound in the closure within `S` + /// with fresh (existential) variables, yielding `S'` + /// - Require that `E' = S'` + /// - We could use some kind of subtyping relationship here, + /// I imagine, but equality is easier and works fine for + /// our purposes. + /// + /// The key intuition here is that the user's types must be valid + /// from "the inside" of the closure, but the expectation + /// ultimately drives the overall signature. + /// + /// # Examples + /// + /// ```ignore (illustrative) + /// fn with_closure<F>(_: F) + /// where F: Fn(&u32) -> &u32 { .. } + /// + /// with_closure(|x: &u32| { ... }) + /// ``` + /// + /// Here: + /// - E would be `fn(&u32) -> &u32`. + /// - S would be `fn(&u32) -> ?T` + /// - E' is `&'!0 u32 -> &'!0 u32` + /// - S' is `&'?0 u32 -> ?T` + /// + /// S' can be unified with E' with `['?0 = '!0, ?T = &'!10 u32]`. + /// + /// # Arguments + /// + /// - `expr_def_id`: the `LocalDefId` of the closure expression + /// - `decl`: the HIR declaration of the closure + /// - `body`: the body of the closure + /// - `expected_sig`: the expected signature (if any). Note that + /// this is missing a binder: that is, there may be late-bound + /// regions with depth 1, which are bound then by the closure. + fn sig_of_closure_with_expectation( + &mut self, + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + expected_sig: PolyFnSig<'db>, + ) -> ClosureSignatures<'db> { + // Watch out for some surprises and just ignore the + // expectation if things don't see to match up with what we + // expect. + if expected_sig.c_variadic() { + return self.sig_of_closure_no_expectation(decl_inputs, decl_output); + } else if expected_sig.skip_binder().inputs_and_output.len() != decl_inputs.len() + 1 { + return self + .sig_of_closure_with_mismatched_number_of_arguments(decl_inputs, decl_output); + } + + // Create a `PolyFnSig`. Note the oddity that late bound + // regions appearing free in `expected_sig` are now bound up + // in this binder we are creating. + assert!(!expected_sig.skip_binder().has_vars_bound_above(rustc_type_ir::INNERMOST)); + let bound_sig = expected_sig.map_bound(|sig| { + self.interner().mk_fn_sig( + sig.inputs(), + sig.output(), + sig.c_variadic, + Safety::Safe, + FnAbi::RustCall, + ) + }); - /// Changes `current_capture_span_stack` to contain the stack of spans for this expr. - fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option<HirPlace> { - self.current_capture_span_stack.clear(); - match &self.body[tgt_expr] { - Expr::Path(p) => { - let resolver_guard = - self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr); - let result = self.path_place(p, tgt_expr.into()); - self.resolver.reset_to_guard(resolver_guard); - return result; - } - Expr::Field { expr, name: _ } => { - let mut place = self.place_of_expr(*expr)?; - let field = self.result.field_resolution(tgt_expr)?; - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - place.projections.push(ProjectionElem::Field(field)); - return Some(place); - } - Expr::UnaryOp { expr, op: UnaryOp::Deref } => { - if matches!( - self.expr_ty_after_adjustments(*expr).kind(Interner), - TyKind::Ref(..) | TyKind::Raw(..) - ) { - let mut place = self.place_of_expr(*expr)?; - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - place.projections.push(ProjectionElem::Deref); - return Some(place); - } - } - _ => (), - } - None - } + // `deduce_expectations_from_expected_type` introduces + // late-bound lifetimes defined elsewhere, which we now + // anonymize away, so as not to confuse the user. + let bound_sig = self.interner().anonymize_bound_vars(bound_sig); - fn push_capture(&mut self, place: HirPlace, kind: CaptureKind) { - self.current_captures.push(CapturedItemWithoutTy { - place, - kind, - span_stacks: smallvec![self.current_capture_span_stack.iter().copied().collect()], - }); - } + let closure_sigs = self.closure_sigs(bound_sig); - fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) { - // The first span is the identifier, and it must always remain. - truncate_to += 1; - for span_stack in &mut capture.span_stacks { - let mut remained = truncate_to; - let mut actual_truncate_to = 0; - for &span in &*span_stack { - actual_truncate_to += 1; - if !span.is_ref_span(self.body) { - remained -= 1; - if remained == 0 { - break; - } - } - } - if actual_truncate_to < span_stack.len() - && span_stack[actual_truncate_to].is_ref_span(self.body) - { - // Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect. - actual_truncate_to += 1; - } - span_stack.truncate(actual_truncate_to); + // Up till this point, we have ignored the annotations that the user + // gave. This function will check that they unify successfully. + // Along the way, it also writes out entries for types that the user + // wrote into our typeck results, which are then later used by the privacy + // check. + match self.merge_supplied_sig_with_expectation(decl_inputs, decl_output, closure_sigs) { + Ok(infer_ok) => self.table.register_infer_ok(infer_ok), + Err(_) => self.sig_of_closure_no_expectation(decl_inputs, decl_output), } } - fn ref_expr(&mut self, expr: ExprId, place: Option<HirPlace>) { - if let Some(place) = place { - self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared)); - } - self.walk_expr(expr); - } + fn sig_of_closure_with_mismatched_number_of_arguments( + &mut self, + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + ) -> ClosureSignatures<'db> { + let error_sig = self.error_sig_of_closure(decl_inputs, decl_output); - fn add_capture(&mut self, place: HirPlace, kind: CaptureKind) { - if self.is_upvar(&place) { - self.push_capture(place, kind); - } + self.closure_sigs(error_sig) } - fn mutate_path_pat(&mut self, path: &Path, id: PatId) { - if let Some(place) = self.path_place(path, id.into()) { - self.add_capture( - place, - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), + /// Enforce the user's types against the expectation. See + /// `sig_of_closure_with_expectation` for details on the overall + /// strategy. + fn merge_supplied_sig_with_expectation( + &mut self, + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + mut expected_sigs: ClosureSignatures<'db>, + ) -> InferResult<'db, ClosureSignatures<'db>> { + // Get the signature S that the user gave. + // + // (See comment on `sig_of_closure_with_expectation` for the + // meaning of these letters.) + let supplied_sig = self.supplied_sig_of_closure(decl_inputs, decl_output); + + debug!(?supplied_sig); + + // FIXME(#45727): As discussed in [this comment][c1], naively + // forcing equality here actually results in suboptimal error + // messages in some cases. For now, if there would have been + // an obvious error, we fallback to declaring the type of the + // closure to be the one the user gave, which allows other + // error message code to trigger. + // + // However, I think [there is potential to do even better + // here][c2], since in *this* code we have the precise span of + // the type parameter in question in hand when we report the + // error. + // + // [c1]: https://github.com/rust-lang/rust/pull/45072#issuecomment-341089706 + // [c2]: https://github.com/rust-lang/rust/pull/45072#issuecomment-341096796 + self.table.commit_if_ok(|table| { + let mut all_obligations = PredicateObligations::new(); + let supplied_sig = table.infer_ctxt.instantiate_binder_with_fresh_vars( + BoundRegionConversionTime::FnCall, + supplied_sig, ); - self.current_capture_span_stack.pop(); // Remove the pattern span. - } - } - fn mutate_expr(&mut self, expr: ExprId, place: Option<HirPlace>) { - if let Some(place) = place { - self.add_capture( - place, - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), + // The liberated version of this signature should be a subtype + // of the liberated form of the expectation. + for (supplied_ty, expected_ty) in + iter::zip(supplied_sig.inputs(), expected_sigs.liberated_sig.inputs()) + { + // Check that E' = S'. + let cause = ObligationCause::new(); + let InferOk { value: (), obligations } = table + .infer_ctxt + .at(&cause, table.trait_env.env) + .eq(expected_ty, supplied_ty)?; + all_obligations.extend(obligations); + } + + let supplied_output_ty = supplied_sig.output(); + let cause = ObligationCause::new(); + let InferOk { value: (), obligations } = table + .infer_ctxt + .at(&cause, table.trait_env.env) + .eq(expected_sigs.liberated_sig.output(), supplied_output_ty)?; + all_obligations.extend(obligations); + + let inputs = supplied_sig + .inputs() + .into_iter() + .map(|ty| table.infer_ctxt.resolve_vars_if_possible(ty)); + + expected_sigs.liberated_sig = table.interner().mk_fn_sig( + inputs, + supplied_output_ty, + expected_sigs.liberated_sig.c_variadic, + Safety::Safe, + FnAbi::RustCall, ); - } - self.walk_expr(expr); - } - fn consume_expr(&mut self, expr: ExprId) { - if let Some(place) = self.place_of_expr(expr) { - self.consume_place(place); - } - self.walk_expr(expr); + Ok(InferOk { value: expected_sigs, obligations: all_obligations }) + }) } - fn consume_place(&mut self, place: HirPlace) { - if self.is_upvar(&place) { - let ty = place.ty(self); - let kind = if self.is_ty_copy(ty) { - CaptureKind::ByRef(BorrowKind::Shared) - } else { - CaptureKind::ByValue - }; - self.push_capture(place, kind); - } - } + /// If there is no expected signature, then we will convert the + /// types that the user gave into a signature. + /// + /// Also, record this closure signature for later. + fn supplied_sig_of_closure( + &mut self, + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + ) -> PolyFnSig<'db> { + let interner = self.interner(); - fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment]) { - if let Some((last, rest)) = adjustment.split_last() { - match &last.kind { - Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => { - self.walk_expr_with_adjust(tgt_expr, rest) - } - Adjust::Deref(Some(m)) => match m.0 { - Some(m) => { - self.ref_capture_with_adjusts(m, tgt_expr, rest); - } - None => unreachable!(), - }, - Adjust::Borrow(b) => { - self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest); - } + let supplied_return = match decl_output { + Some(output) => { + let output = self.make_body_ty(output); + self.process_user_written_ty(output) } - } else { - self.walk_expr_without_adjust(tgt_expr); - } - } - - fn ref_capture_with_adjusts(&mut self, m: Mutability, tgt_expr: ExprId, rest: &[Adjustment]) { - let capture_kind = match m { - Mutability::Mut => CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), - Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared), + None => self.table.next_ty_var(), }; - if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) { - if let Some(place) = - apply_adjusts_to_place(&mut self.current_capture_span_stack, place, rest) - { - self.add_capture(place, capture_kind); + // First, convert the types that the user supplied (if any). + let supplied_arguments = decl_inputs.iter().map(|&input| match input { + Some(input) => { + let input = self.make_body_ty(input); + self.process_user_written_ty(input) } - } - self.walk_expr_with_adjust(tgt_expr, rest); - } - - fn walk_expr(&mut self, tgt_expr: ExprId) { - if let Some(it) = self.result.expr_adjustments.get_mut(&tgt_expr) { - // FIXME: this take is completely unneeded, and just is here to make borrow checker - // happy. Remove it if you can. - let x_taken = mem::take(it); - self.walk_expr_with_adjust(tgt_expr, &x_taken); - *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken; - } else { - self.walk_expr_without_adjust(tgt_expr); - } - } - - fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) { - match &self.body[tgt_expr] { - Expr::OffsetOf(_) => (), - Expr::InlineAsm(e) => e.operands.iter().for_each(|(_, op)| match op { - AsmOperand::In { expr, .. } - | AsmOperand::Out { expr: Some(expr), .. } - | AsmOperand::InOut { expr, .. } => self.walk_expr_without_adjust(*expr), - AsmOperand::SplitInOut { in_expr, out_expr, .. } => { - self.walk_expr_without_adjust(*in_expr); - if let Some(out_expr) = out_expr { - self.walk_expr_without_adjust(*out_expr); - } - } - AsmOperand::Out { expr: None, .. } - | AsmOperand::Const(_) - | AsmOperand::Label(_) - | AsmOperand::Sym(_) => (), - }), - Expr::If { condition, then_branch, else_branch } => { - self.consume_expr(*condition); - self.consume_expr(*then_branch); - if let &Some(expr) = else_branch { - self.consume_expr(expr); - } - } - Expr::Async { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Block { statements, tail, .. } => { - for s in statements.iter() { - match s { - Statement::Let { pat, type_ref: _, initializer, else_branch } => { - if let Some(else_branch) = else_branch { - self.consume_expr(*else_branch); - } - if let Some(initializer) = initializer { - if else_branch.is_some() { - self.consume_expr(*initializer); - } else { - self.walk_expr(*initializer); - } - if let Some(place) = self.place_of_expr(*initializer) { - self.consume_with_pat(place, *pat); - } - } - } - Statement::Expr { expr, has_semi: _ } => { - self.consume_expr(*expr); - } - Statement::Item(_) => (), - } - } - if let Some(tail) = tail { - self.consume_expr(*tail); - } - } - Expr::Call { callee, args } => { - self.consume_expr(*callee); - self.consume_exprs(args.iter().copied()); - } - Expr::MethodCall { receiver, args, .. } => { - self.consume_expr(*receiver); - self.consume_exprs(args.iter().copied()); - } - Expr::Match { expr, arms } => { - for arm in arms.iter() { - self.consume_expr(arm.expr); - if let Some(guard) = arm.guard { - self.consume_expr(guard); - } - } - self.walk_expr(*expr); - if let Some(discr_place) = self.place_of_expr(*expr) { - if self.is_upvar(&discr_place) { - let mut capture_mode = None; - for arm in arms.iter() { - self.walk_pat(&mut capture_mode, arm.pat); - } - if let Some(c) = capture_mode { - self.push_capture(discr_place, c); - } - } - } - } - Expr::Break { expr, label: _ } - | Expr::Return { expr } - | Expr::Yield { expr } - | Expr::Yeet { expr } => { - if let &Some(expr) = expr { - self.consume_expr(expr); - } - } - &Expr::Become { expr } => { - self.consume_expr(expr); - } - Expr::RecordLit { fields, spread, .. } => { - if let &Some(expr) = spread { - self.consume_expr(expr); - } - self.consume_exprs(fields.iter().map(|it| it.expr)); - } - Expr::Field { expr, name: _ } => self.select_from_expr(*expr), - Expr::UnaryOp { expr, op: UnaryOp::Deref } => { - if matches!( - self.expr_ty_after_adjustments(*expr).kind(Interner), - TyKind::Ref(..) | TyKind::Raw(..) - ) { - self.select_from_expr(*expr); - } else if let Some((f, _)) = self.result.method_resolution(tgt_expr) { - let mutability = 'b: { - if let Some(deref_trait) = - self.resolve_lang_item(LangItem::DerefMut).and_then(|it| it.as_trait()) - { - if let Some(deref_fn) = deref_trait - .trait_items(self.db) - .method_by_name(&Name::new_symbol_root(sym::deref_mut)) - { - break 'b deref_fn == f; - } - } - false - }; - let place = self.place_of_expr(*expr); - if mutability { - self.mutate_expr(*expr, place); - } else { - self.ref_expr(*expr, place); - } - } else { - self.select_from_expr(*expr); - } - } - Expr::Let { pat, expr } => { - self.walk_expr(*expr); - if let Some(place) = self.place_of_expr(*expr) { - self.consume_with_pat(place, *pat); - } - } - Expr::UnaryOp { expr, op: _ } - | Expr::Array(Array::Repeat { initializer: expr, repeat: _ }) - | Expr::Await { expr } - | Expr::Loop { body: expr, label: _ } - | Expr::Box { expr } - | Expr::Cast { expr, type_ref: _ } => { - self.consume_expr(*expr); - } - Expr::Ref { expr, rawness: _, mutability } => { - // We need to do this before we push the span so the order will be correct. - let place = self.place_of_expr(*expr); - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - match mutability { - hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr, place), - hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr, place), - } - } - Expr::BinaryOp { lhs, rhs, op } => { - let Some(op) = op else { - return; - }; - if matches!(op, BinaryOp::Assignment { .. }) { - let place = self.place_of_expr(*lhs); - self.mutate_expr(*lhs, place); - self.consume_expr(*rhs); - return; - } - self.consume_expr(*lhs); - self.consume_expr(*rhs); - } - Expr::Range { lhs, rhs, range_type: _ } => { - if let &Some(expr) = lhs { - self.consume_expr(expr); - } - if let &Some(expr) = rhs { - self.consume_expr(expr); - } - } - Expr::Index { base, index } => { - self.select_from_expr(*base); - self.consume_expr(*index); - } - Expr::Closure { .. } => { - let ty = self.expr_ty(tgt_expr); - let TyKind::Closure(id, _) = ty.kind(Interner) else { - never!("closure type is always closure"); - return; - }; - let (captures, _) = - self.result.closure_info.get(id).expect( - "We sort closures, so we should always have data for inner closures", - ); - let mut cc = mem::take(&mut self.current_captures); - cc.extend(captures.iter().filter(|it| self.is_upvar(&it.place)).map(|it| { - CapturedItemWithoutTy { - place: it.place.clone(), - kind: it.kind, - span_stacks: it.span_stacks.clone(), - } - })); - self.current_captures = cc; - } - Expr::Array(Array::ElementList { elements: exprs }) | Expr::Tuple { exprs } => { - self.consume_exprs(exprs.iter().copied()) - } - &Expr::Assignment { target, value } => { - self.walk_expr(value); - let resolver_guard = - self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr); - match self.place_of_expr(value) { - Some(rhs_place) => { - self.inside_assignment = true; - self.consume_with_pat(rhs_place, target); - self.inside_assignment = false; - } - None => self.body.walk_pats(target, &mut |pat| match &self.body[pat] { - Pat::Path(path) => self.mutate_path_pat(path, pat), - &Pat::Expr(expr) => { - let place = self.place_of_expr(expr); - self.mutate_expr(expr, place); - } - _ => {} - }), - } - self.resolver.reset_to_guard(resolver_guard); - } - - Expr::Missing - | Expr::Continue { .. } - | Expr::Path(_) - | Expr::Literal(_) - | Expr::Const(_) - | Expr::Underscore => (), - } - } - - fn walk_pat(&mut self, result: &mut Option<CaptureKind>, pat: PatId) { - let mut update_result = |ck: CaptureKind| match result { - Some(r) => { - *r = cmp::max(*r, ck); - } - None => *result = Some(ck), - }; + None => self.table.next_ty_var(), + }); - self.walk_pat_inner( - pat, - &mut update_result, - BorrowKind::Mut { kind: MutBorrowKind::Default }, - ); + Binder::dummy(interner.mk_fn_sig( + supplied_arguments, + supplied_return, + false, + Safety::Safe, + FnAbi::RustCall, + )) } - fn walk_pat_inner( + /// Converts the types that the user supplied, in case that doing + /// so should yield an error, but returns back a signature where + /// all parameters are of type `ty::Error`. + fn error_sig_of_closure( &mut self, - p: PatId, - update_result: &mut impl FnMut(CaptureKind), - mut for_mut: BorrowKind, - ) { - match &self.body[p] { - Pat::Ref { .. } - | Pat::Box { .. } - | Pat::Missing - | Pat::Wild - | Pat::Tuple { .. } - | Pat::Expr(_) - | Pat::Or(_) => (), - Pat::TupleStruct { .. } | Pat::Record { .. } => { - if let Some(variant) = self.result.variant_resolution_for_pat(p) { - let adt = variant.adt_id(self.db); - let is_multivariant = match adt { - hir_def::AdtId::EnumId(e) => e.enum_variants(self.db).variants.len() != 1, - _ => false, - }; - if is_multivariant { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); - } - } - } - Pat::Slice { .. } - | Pat::ConstBlock(_) - | Pat::Path(_) - | Pat::Lit(_) - | Pat::Range { .. } => { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); - } - Pat::Bind { id, .. } => match self.result.binding_modes[p] { - crate::BindingMode::Move => { - if self.is_ty_copy(self.result.type_of_binding[*id].clone()) { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); - } else { - update_result(CaptureKind::ByValue); - } - } - crate::BindingMode::Ref(r) => match r { - Mutability::Mut => update_result(CaptureKind::ByRef(for_mut)), - Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)), - }, - }, - } - if self.result.pat_adjustments.get(&p).is_some_and(|it| !it.is_empty()) { - for_mut = BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture }; - } - self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut)); - } - - fn expr_ty(&self, expr: ExprId) -> Ty { - self.result[expr].clone() - } - - fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty { - let mut ty = None; - if let Some(it) = self.result.expr_adjustments.get(&e) { - if let Some(it) = it.last() { - ty = Some(it.target.clone()); - } - } - ty.unwrap_or_else(|| self.expr_ty(e)) - } - - fn is_upvar(&self, place: &HirPlace) -> bool { - if let Some(c) = self.current_closure { - let InternedClosure(_, root) = self.db.lookup_intern_closure(c.into()); - return self.body.is_binding_upvar(place.local, root); - } - false - } - - fn is_ty_copy(&mut self, ty: Ty) -> bool { - if let TyKind::Closure(id, _) = ty.kind(Interner) { - // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We - // should probably let chalk know which closures are copy, but I don't know how doing it - // without creating query cycles. - return self.result.closure_info.get(id).map(|it| it.1 == FnTrait::Fn).unwrap_or(true); - } - self.table.resolve_completely(ty).is_copy(self.db, self.owner) - } - - fn select_from_expr(&mut self, expr: ExprId) { - self.walk_expr(expr); - } - - fn restrict_precision_for_unsafe(&mut self) { - // FIXME: Borrow checker problems without this. - let mut current_captures = std::mem::take(&mut self.current_captures); - for capture in &mut current_captures { - let mut ty = self.table.resolve_completely(self.result[capture.place.local].clone()); - if ty.as_raw_ptr().is_some() || ty.is_union() { - capture.kind = CaptureKind::ByRef(BorrowKind::Shared); - self.truncate_capture_spans(capture, 0); - capture.place.projections.truncate(0); - continue; - } - for (i, p) in capture.place.projections.iter().enumerate() { - ty = p.projected_ty( - ty, - self.db, - |_, _, _| { - unreachable!("Closure field only happens in MIR"); - }, - self.owner.module(self.db).krate(), - ); - if ty.as_raw_ptr().is_some() || ty.is_union() { - capture.kind = CaptureKind::ByRef(BorrowKind::Shared); - self.truncate_capture_spans(capture, i + 1); - capture.place.projections.truncate(i + 1); - break; - } - } - } - self.current_captures = current_captures; - } + decl_inputs: &[Option<TypeRefId>], + decl_output: Option<TypeRefId>, + ) -> PolyFnSig<'db> { + let interner = self.interner(); + let err_ty = Ty::new_error(interner, ErrorGuaranteed); - fn adjust_for_move_closure(&mut self) { - // FIXME: Borrow checker won't allow without this. - let mut current_captures = std::mem::take(&mut self.current_captures); - for capture in &mut current_captures { - if let Some(first_deref) = - capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref) - { - self.truncate_capture_spans(capture, first_deref); - capture.place.projections.truncate(first_deref); - } - capture.kind = CaptureKind::ByValue; + if let Some(output) = decl_output { + self.make_body_ty(output); } - self.current_captures = current_captures; - } - - fn minimize_captures(&mut self) { - self.current_captures.sort_unstable_by_key(|it| it.place.projections.len()); - let mut hash_map = FxHashMap::<HirPlace, usize>::default(); - let result = mem::take(&mut self.current_captures); - for mut item in result { - let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] }; - let mut it = item.place.projections.iter(); - let prev_index = loop { - if let Some(k) = hash_map.get(&lookup_place) { - break Some(*k); - } - match it.next() { - Some(it) => { - lookup_place.projections.push(it.clone()); - } - None => break None, - } - }; - match prev_index { - Some(p) => { - let prev_projections_len = self.current_captures[p].place.projections.len(); - self.truncate_capture_spans(&mut item, prev_projections_len); - self.current_captures[p].span_stacks.extend(item.span_stacks); - let len = self.current_captures[p].place.projections.len(); - let kind_after_truncate = - item.place.capture_kind_of_truncated_place(item.kind, len); - self.current_captures[p].kind = - cmp::max(kind_after_truncate, self.current_captures[p].kind); - } - None => { - hash_map.insert(item.place.clone(), self.current_captures.len()); - self.current_captures.push(item); - } + let supplied_arguments = decl_inputs.iter().map(|&input| match input { + Some(input) => { + self.make_body_ty(input); + err_ty } - } - } - - fn consume_with_pat(&mut self, mut place: HirPlace, tgt_pat: PatId) { - let adjustments_count = - self.result.pat_adjustments.get(&tgt_pat).map(|it| it.len()).unwrap_or_default(); - place.projections.extend((0..adjustments_count).map(|_| ProjectionElem::Deref)); - self.current_capture_span_stack - .extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat))); - 'reset_span_stack: { - match &self.body[tgt_pat] { - Pat::Missing | Pat::Wild => (), - Pat::Tuple { args, ellipsis } => { - let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let field_count = match self.result[tgt_pat].kind(Interner) { - TyKind::Tuple(_, s) => s.len(Interner), - _ => break 'reset_span_stack, - }; - let fields = 0..field_count; - let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); - for (&arg, i) in it { - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(ProjectionElem::Field(Either::Right(TupleFieldId { - tuple: TupleId(!0), // dummy this, as its unused anyways - index: i as u32, - }))); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - Pat::Or(pats) => { - for pat in pats.iter() { - self.consume_with_pat(place.clone(), *pat); - } - } - Pat::Record { args, .. } => { - let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { - break 'reset_span_stack; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place) - } - VariantId::StructId(s) => { - let vd = s.fields(self.db); - for field_pat in args.iter() { - let arg = field_pat.pat; - let Some(local_id) = vd.field(&field_pat.name) else { - continue; - }; - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(ProjectionElem::Field(Either::Left(FieldId { - parent: variant, - local_id, - }))); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - } - } - Pat::Range { .. } | Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Lit(_) => { - self.consume_place(place) - } - Pat::Path(path) => { - if self.inside_assignment { - self.mutate_path_pat(path, tgt_pat); - } - self.consume_place(place); - } - &Pat::Bind { id, subpat: _ } => { - let mode = self.result.binding_modes[tgt_pat]; - let capture_kind = match mode { - BindingMode::Move => { - self.consume_place(place); - break 'reset_span_stack; - } - BindingMode::Ref(Mutability::Not) => BorrowKind::Shared, - BindingMode::Ref(Mutability::Mut) => { - BorrowKind::Mut { kind: MutBorrowKind::Default } - } - }; - self.current_capture_span_stack.push(MirSpan::BindingId(id)); - self.add_capture(place, CaptureKind::ByRef(capture_kind)); - self.current_capture_span_stack.pop(); - } - Pat::TupleStruct { path: _, args, ellipsis } => { - let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { - break 'reset_span_stack; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place) - } - VariantId::StructId(s) => { - let vd = s.fields(self.db); - let (al, ar) = - args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let fields = vd.fields().iter(); - let it = al - .iter() - .zip(fields.clone()) - .chain(ar.iter().rev().zip(fields.rev())); - for (&arg, (i, _)) in it { - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(ProjectionElem::Field(Either::Left(FieldId { - parent: variant, - local_id: i, - }))); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - } - } - Pat::Ref { pat, mutability: _ } => { - self.current_capture_span_stack.push(MirSpan::PatId(tgt_pat)); - place.projections.push(ProjectionElem::Deref); - self.consume_with_pat(place, *pat); - self.current_capture_span_stack.pop(); - } - Pat::Box { .. } => (), // not supported - &Pat::Expr(expr) => { - self.consume_place(place); - let pat_capture_span_stack = mem::take(&mut self.current_capture_span_stack); - let old_inside_assignment = mem::replace(&mut self.inside_assignment, false); - let lhs_place = self.place_of_expr(expr); - self.mutate_expr(expr, lhs_place); - self.inside_assignment = old_inside_assignment; - self.current_capture_span_stack = pat_capture_span_stack; - } - } - } - self.current_capture_span_stack - .truncate(self.current_capture_span_stack.len() - adjustments_count); - } - - fn consume_exprs(&mut self, exprs: impl Iterator<Item = ExprId>) { - for expr in exprs { - self.consume_expr(expr); - } - } - - fn closure_kind(&self) -> FnTrait { - let mut r = FnTrait::Fn; - for it in &self.current_captures { - r = cmp::min( - r, - match &it.kind { - CaptureKind::ByRef(BorrowKind::Mut { .. }) => FnTrait::FnMut, - CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn, - CaptureKind::ByValue => FnTrait::FnOnce, - }, - ) - } - r - } + None => err_ty, + }); - fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait { - let InternedClosure(_, root) = self.db.lookup_intern_closure(closure.into()); - self.current_closure = Some(closure); - let Expr::Closure { body, capture_by, .. } = &self.body[root] else { - unreachable!("Closure expression id is always closure"); - }; - self.consume_expr(*body); - for item in &self.current_captures { - if matches!( - item.kind, - CaptureKind::ByRef(BorrowKind::Mut { - kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow - }) - ) && !item.place.projections.contains(&ProjectionElem::Deref) - { - // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in - // MIR. I didn't do that due duplicate diagnostics. - self.result.mutated_bindings_in_closure.insert(item.place.local); - } - } - self.restrict_precision_for_unsafe(); - // `closure_kind` should be done before adjust_for_move_closure - // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does. - // rustc also does diagnostics here if the latter is not a subtype of the former. - let closure_kind = self - .result - .closure_info - .get(&closure) - .map_or_else(|| self.closure_kind(), |info| info.1); - match capture_by { - CaptureBy::Value => self.adjust_for_move_closure(), - CaptureBy::Ref => (), - } - self.minimize_captures(); - self.strip_captures_ref_span(); - let result = mem::take(&mut self.current_captures); - let captures = result.into_iter().map(|it| it.with_ty(self)).collect::<Vec<_>>(); - self.result.closure_info.insert(closure, (captures, closure_kind)); - closure_kind - } + let result = Binder::dummy(interner.mk_fn_sig( + supplied_arguments, + err_ty, + false, + Safety::Safe, + FnAbi::RustCall, + )); - fn strip_captures_ref_span(&mut self) { - // FIXME: Borrow checker won't allow without this. - let mut captures = std::mem::take(&mut self.current_captures); - for capture in &mut captures { - if matches!(capture.kind, CaptureKind::ByValue) { - for span_stack in &mut capture.span_stacks { - if span_stack[span_stack.len() - 1].is_ref_span(self.body) { - span_stack.truncate(span_stack.len() - 1); - } - } - } - } - self.current_captures = captures; - } + debug!("supplied_sig_of_closure: result={:?}", result); - pub(crate) fn infer_closures(&mut self) { - let deferred_closures = self.sort_closures(); - for (closure, exprs) in deferred_closures.into_iter().rev() { - self.current_captures = vec![]; - let kind = self.analyze_closure(closure); - - for (derefed_callee, callee_ty, params, expr) in exprs { - if let &Expr::Call { callee, .. } = &self.body[expr] { - let mut adjustments = - self.result.expr_adjustments.remove(&callee).unwrap_or_default().into_vec(); - self.write_fn_trait_method_resolution( - kind, - &derefed_callee, - &mut adjustments, - &callee_ty, - ¶ms, - expr, - ); - self.result.expr_adjustments.insert(callee, adjustments.into_boxed_slice()); - } - } - } - } - - /// We want to analyze some closures before others, to have a correct analysis: - /// * We should analyze nested closures before the parent, since the parent should capture some of - /// the things that its children captures. - /// * If a closure calls another closure, we need to analyze the callee, to find out how we should - /// capture it (e.g. by move for FnOnce) - /// - /// These dependencies are collected in the main inference. We do a topological sort in this function. It - /// will consume the `deferred_closures` field and return its content in a sorted vector. - fn sort_closures(&mut self) -> Vec<(ClosureId, Vec<(Ty, Ty, Vec<Ty>, ExprId)>)> { - let mut deferred_closures = mem::take(&mut self.deferred_closures); - let mut dependents_count: FxHashMap<ClosureId, usize> = - deferred_closures.keys().map(|it| (*it, 0)).collect(); - for deps in self.closure_dependencies.values() { - for dep in deps { - *dependents_count.entry(*dep).or_default() += 1; - } - } - let mut queue: Vec<_> = - deferred_closures.keys().copied().filter(|it| dependents_count[it] == 0).collect(); - let mut result = vec![]; - while let Some(it) = queue.pop() { - if let Some(d) = deferred_closures.remove(&it) { - result.push((it, d)); - } - for dep in self.closure_dependencies.get(&it).into_iter().flat_map(|it| it.iter()) { - let cnt = dependents_count.get_mut(dep).unwrap(); - *cnt -= 1; - if *cnt == 0 { - queue.push(*dep); - } - } - } - assert!(deferred_closures.is_empty(), "we should have analyzed all closures"); result } - pub(super) fn add_current_closure_dependency(&mut self, dep: ClosureId) { - if let Some(c) = self.current_closure { - if !dep_creates_cycle(&self.closure_dependencies, &mut FxHashSet::default(), c, dep) { - self.closure_dependencies.entry(c).or_default().push(dep); - } - } - - fn dep_creates_cycle( - closure_dependencies: &FxHashMap<ClosureId, Vec<ClosureId>>, - visited: &mut FxHashSet<ClosureId>, - from: ClosureId, - to: ClosureId, - ) -> bool { - if !visited.insert(from) { - return false; - } - - if from == to { - return true; - } - - if let Some(deps) = closure_dependencies.get(&to) { - for dep in deps { - if dep_creates_cycle(closure_dependencies, visited, from, *dep) { - return true; - } - } - } - - false - } - } -} - -/// Call this only when the last span in the stack isn't a split. -fn apply_adjusts_to_place( - current_capture_span_stack: &mut Vec<MirSpan>, - mut r: HirPlace, - adjustments: &[Adjustment], -) -> Option<HirPlace> { - let span = *current_capture_span_stack.last().expect("empty capture span stack"); - for adj in adjustments { - match &adj.kind { - Adjust::Deref(None) => { - current_capture_span_stack.push(span); - r.projections.push(ProjectionElem::Deref); - } - _ => return None, - } + fn closure_sigs(&self, bound_sig: PolyFnSig<'db>) -> ClosureSignatures<'db> { + let liberated_sig = bound_sig.skip_binder(); + // FIXME: When we lower HRTB we'll need to actually liberate regions here. + ClosureSignatures { bound_sig, liberated_sig } } - Some(r) } |