Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/closure.rs | 456 |
1 files changed, 346 insertions, 110 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 2207bc37e8..ab111736d5 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -5,25 +5,27 @@ pub(crate) mod analysis; use std::{iter, mem, ops::ControlFlow}; use hir_def::{ - TraitId, - hir::{ClosureKind, CoroutineSource, ExprId, PatId}, + AdtId, TraitId, + hir::{ClosureKind, CoroutineKind, CoroutineSource, ExprId, PatId}, type_ref::TypeRefId, }; +use rustc_abi::ExternAbi; use rustc_type_ir::{ - ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, - CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, - TypeVisitor, + AliasTyKind, ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, + CoroutineClosureArgs, CoroutineClosureArgsParts, InferTy, Interner, TypeSuperVisitable, + TypeVisitable, TypeVisitableExt, TypeVisitor, inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, Ty as _}, }; -use tracing::debug; +use tracing::{debug, instrument}; use crate::{ - FnAbi, + Span, db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId}, - infer::{BreakableKind, Diverges, coerce::CoerceMany}, + infer::{BreakableKind, Diverges, coerce::CoerceMany, pat::PatOrigin}, next_solver::{ - AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, - PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind, + AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArg, PolyFnSig, + PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, TermId, Ty, TyKind, + Unnormalized, abi::Safety, infer::{ BoundRegionConversionTime, InferOk, InferResult, @@ -46,6 +48,22 @@ struct ClosureSignatures<'db> { } impl<'db> InferenceContext<'_, 'db> { + fn poll_option_ty(&mut self, item_ty: Ty<'db>) -> Ty<'db> { + let interner = self.interner(); + + let (Some(option), Some(poll)) = (self.lang_items.Option, self.lang_items.Poll) else { + return self.types.types.error; + }; + + let option_ty = Ty::new_adt( + interner, + AdtId::EnumId(option), + interner.mk_args(&[GenericArg::from(item_ty)]), + ); + + Ty::new_adt(interner, AdtId::EnumId(poll), interner.mk_args(&[GenericArg::from(option_ty)])) + } + pub(super) fn infer_closure( &mut self, body: ExprId, @@ -62,23 +80,31 @@ impl<'db> InferenceContext<'_, 'db> { // It's always helpful for inference if we know the kind of // closure sooner rather than later, so first examine the expected // type, and see if can glean a closure kind from there. - let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { + let (expected_sig, expected_kind) = match expected.to_option(&self.table) { Some(ty) => { - let ty = self.table.try_structurally_resolve_type(ty); - self.deduce_closure_signature(ty, closure_kind) + let ty = self.table.try_structurally_resolve_type(closure_expr.into(), ty); + self.deduce_closure_signature(closure_expr, ty, closure_kind) } None => (None, None), }; - let ClosureSignatures { bound_sig, mut liberated_sig } = - self.sig_of_closure(arg_types, ret_type, expected_sig); + let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure( + closure_expr, + args, + arg_types, + ret_type, + expected_sig, + closure_kind, + ); debug!(?bound_sig, ?liberated_sig); - let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); + let parent_args = self.identity_args(); - let tupled_upvars_ty = self.table.next_ty_var(); + let tupled_upvars_ty = self.table.next_ty_var(closure_expr.into()); + let closure_loc = + InternedClosure { owner: self.owner, expr: closure_expr, kind: closure_kind }; // FIXME: We could probably actually just unify this further -- // instead of having a `FnSig` and a `Option<CoroutineTypes>`, // we can have a `ClosureSignature { Coroutine { .. }, Closure { .. } }`, @@ -91,9 +117,9 @@ impl<'db> InferenceContext<'_, 'db> { interner.mk_fn_sig( [Ty::new_tup(interner, sig.inputs())], sig.output(), - sig.c_variadic, - sig.safety, - sig.abi, + sig.c_variadic(), + sig.safety(), + sig.abi(), ) }); @@ -103,7 +129,7 @@ impl<'db> InferenceContext<'_, 'db> { Some(kind) => Ty::from_closure_kind(interner, kind), // Create a type variable (for now) to represent the closure kind. // It will be unified during the upvar inference phase (`upvar.rs`) - None => self.table.next_ty_var(), + None => self.table.next_ty_var(closure_expr.into()), }; let closure_args = ClosureArgs::new( @@ -116,15 +142,26 @@ impl<'db> InferenceContext<'_, 'db> { }, ); - let closure_id = - InternedClosureId::new(self.db, InternedClosure(self.owner, closure_expr)); + let closure_id = InternedClosureId::new(self.db, closure_loc); (Ty::new_closure(interner, closure_id.into(), closure_args.args), None) } - ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => { + ClosureKind::OldCoroutine(_) | ClosureKind::Coroutine { .. } => { let yield_ty = match closure_kind { - ClosureKind::Coroutine(_) => self.table.next_ty_var(), - ClosureKind::AsyncBlock { .. } => self.types.types.unit, + ClosureKind::OldCoroutine(_) + | ClosureKind::Coroutine { kind: CoroutineKind::Gen, .. } => { + let yield_ty = self.table.next_ty_var(closure_expr.into()); + self.require_type_is_sized(yield_ty, closure_expr.into()); + yield_ty + } + ClosureKind::Coroutine { kind: CoroutineKind::Async, .. } => { + self.types.types.unit + } + ClosureKind::Coroutine { kind: CoroutineKind::AsyncGen, .. } => { + let yield_ty = self.table.next_ty_var(closure_expr.into()); + self.require_type_is_sized(yield_ty, closure_expr.into()); + self.poll_option_ty(yield_ty) + } _ => unreachable!(), }; @@ -137,8 +174,8 @@ impl<'db> InferenceContext<'_, 'db> { // later during upvar analysis. Regular coroutines always have the kind // ty of `().` let kind_ty = match closure_kind { - ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => { - self.table.next_ty_var() + ClosureKind::Coroutine { source: CoroutineSource::Closure, .. } => { + self.table.next_ty_var(closure_expr.into()) } _ => self.types.types.unit, }; @@ -155,31 +192,39 @@ impl<'db> InferenceContext<'_, 'db> { }, ); - let coroutine_id = - InternedCoroutineId::new(self.db, InternedClosure(self.owner, closure_expr)); + let coroutine_id = InternedCoroutineId::new(self.db, closure_loc); ( Ty::new_coroutine(interner, coroutine_id.into(), coroutine_args.args), Some((resume_ty, yield_ty)), ) } - ClosureKind::AsyncClosure => { - // async closures always return the type ascribed after the `->` (if present), - // and yield `()`. - let (bound_return_ty, bound_yield_ty) = - (bound_sig.skip_binder().output(), self.types.types.unit); + ClosureKind::CoroutineClosure(coroutine_kind) => { + let (bound_return_ty, bound_yield_ty) = match coroutine_kind { + CoroutineKind::Gen => { + (self.types.types.unit, self.table.next_ty_var(closure_expr.into())) + } + CoroutineKind::Async => { + (bound_sig.skip_binder().output(), self.types.types.unit) + } + CoroutineKind::AsyncGen => { + let yield_ty = self.table.next_ty_var(closure_expr.into()); + (self.types.types.unit, self.poll_option_ty(yield_ty)) + } + }; + // Compute all of the variables that will be used to populate the coroutine. - let resume_ty = self.table.next_ty_var(); + let resume_ty = self.table.next_ty_var(closure_expr.into()); let closure_kind_ty = match expected_kind { Some(kind) => Ty::from_closure_kind(interner, kind), // Create a type variable (for now) to represent the closure kind. // It will be unified during the upvar inference phase (`upvar.rs`) - None => self.table.next_ty_var(), + None => self.table.next_ty_var(closure_expr.into()), }; - let coroutine_captures_by_ref_ty = self.table.next_ty_var(); + let coroutine_captures_by_ref_ty = self.table.next_ty_var(closure_expr.into()); let closure_args = CoroutineClosureArgs::new( interner, @@ -198,9 +243,9 @@ impl<'db> InferenceContext<'_, 'db> { ), ], Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), - sig.c_variadic, - sig.safety, - sig.abi, + sig.c_variadic(), + sig.safety(), + sig.abi(), ) }), ), @@ -214,15 +259,12 @@ impl<'db> InferenceContext<'_, 'db> { // Create a type variable (for now) to represent the closure kind. // It will be unified during the upvar inference phase (`upvar.rs`) - None => self.table.next_ty_var(), + None => self.table.next_ty_var(closure_expr.into()), }; - let coroutine_upvars_ty = self.table.next_ty_var(); + let coroutine_upvars_ty = self.table.next_ty_var(closure_expr.into()); - let coroutine_closure_id = InternedCoroutineClosureId::new( - self.db, - InternedClosure(self.owner, closure_expr), - ); + let coroutine_closure_id = InternedCoroutineClosureId::new(self.db, closure_loc); // We need to turn the liberated signature that we got from HIR, which // looks something like `|Args...| -> T`, into a signature that is suitable @@ -245,9 +287,9 @@ impl<'db> InferenceContext<'_, 'db> { liberated_sig = interner.mk_fn_sig( liberated_sig.inputs().iter().copied(), coroutine_output_ty, - liberated_sig.c_variadic, - liberated_sig.safety, - liberated_sig.abi, + liberated_sig.c_variadic(), + liberated_sig.safety(), + liberated_sig.abi(), ); ( @@ -263,7 +305,7 @@ impl<'db> InferenceContext<'_, 'db> { // Now go through the argument patterns for (arg_pat, arg_ty) in args.iter().zip(bound_sig.skip_binder().inputs()) { - self.infer_top_pat(*arg_pat, *arg_ty, None); + self.infer_top_pat(*arg_pat, *arg_ty, PatOrigin::Param); } // FIXME: lift these out into a struct @@ -316,24 +358,27 @@ impl<'db> InferenceContext<'_, 'db> { /// are about to type check: fn deduce_closure_signature( &mut self, + closure_expr: ExprId, expected_ty: Ty<'db>, closure_kind: ClosureKind, ) -> (Option<PolyFnSig<'db>>, Option<rustc_type_ir::ClosureKind>) { match expected_ty.kind() { TyKind::Alias(AliasTy { kind: rustc_type_ir::Opaque { def_id }, args, .. }) => self .deduce_closure_signature_from_predicates( + closure_expr, expected_ty, closure_kind, def_id - .expect_opaque_ty() + .0 .predicates(self.db) .iter_instantiated_copied(self.interner(), args.as_slice()) + .map(Unnormalized::skip_norm_wip) .map(|clause| clause.as_predicate()), ), TyKind::Dynamic(object_type, ..) => { let sig = object_type.projection_bounds().into_iter().find_map(|pb| { let pb = pb.with_self_ty(self.interner(), Ty::new_unit(self.interner())); - self.deduce_sig_from_projection(closure_kind, pb) + self.deduce_sig_from_projection(closure_expr, closure_kind, pb) }); let kind = object_type .principal_def_id() @@ -342,6 +387,7 @@ impl<'db> InferenceContext<'_, 'db> { } TyKind::Infer(rustc_type_ir::TyVar(vid)) => self .deduce_closure_signature_from_predicates( + closure_expr, Ty::new_var(self.interner(), self.table.infer_ctxt.root_var(vid)), closure_kind, self.table.obligations_for_self_ty(vid).into_iter().map(|obl| obl.predicate), @@ -351,9 +397,9 @@ impl<'db> InferenceContext<'_, 'db> { let expected_sig = sig_tys.with(hdr); (Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn)) } - ClosureKind::Coroutine(_) - | ClosureKind::AsyncClosure - | ClosureKind::AsyncBlock { .. } => (None, None), + ClosureKind::OldCoroutine(_) + | ClosureKind::Coroutine { .. } + | ClosureKind::CoroutineClosure(_) => (None, None), }, _ => (None, None), } @@ -361,6 +407,7 @@ impl<'db> InferenceContext<'_, 'db> { fn deduce_closure_signature_from_predicates( &mut self, + closure_expr: ExprId, expected_ty: Ty<'db>, closure_kind: ClosureKind, predicates: impl DoubleEndedIterator<Item = Predicate<'db>>, @@ -388,6 +435,7 @@ impl<'db> InferenceContext<'_, 'db> { bound_predicate.skip_binder() { let inferred_sig = self.deduce_sig_from_projection( + closure_expr, closure_kind, bound_predicate.rebind(proj_predicate), ); @@ -430,18 +478,17 @@ impl<'db> InferenceContext<'_, 'db> { // This is a bit weird and means we may wind up discarding the goal due to it naming `expected_ty` // even though the normalized form may not name `expected_ty`. However, this matches the existing // behaviour of the old solver and would be technically a breaking change to fix. - let generalized_fnptr_sig = self.table.next_ty_var(); + let generalized_fnptr_sig = self.table.next_ty_var(closure_expr.into()); let inferred_fnptr_sig = Ty::new_fn_ptr(self.interner(), inferred_sig); // FIXME: Report diagnostics. _ = self .table .infer_ctxt - .at(&ObligationCause::new(), self.table.param_env) + .at(&ObligationCause::new(closure_expr), self.table.param_env) .eq(inferred_fnptr_sig, generalized_fnptr_sig) .map(|infer_ok| self.table.register_infer_ok(infer_ok)); - let resolved_sig = - self.table.infer_ctxt.resolve_vars_if_possible(generalized_fnptr_sig); + let resolved_sig = self.resolve_vars_if_possible(generalized_fnptr_sig); if resolved_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() { expected_sig = Some(resolved_sig.fn_sig(self.interner())); @@ -465,8 +512,10 @@ impl<'db> InferenceContext<'_, 'db> { if let Some(trait_def_id) = trait_def_id { let found_kind = match closure_kind { - ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id), - ClosureKind::AsyncClosure => self + ClosureKind::Closure | ClosureKind::CoroutineClosure(CoroutineKind::Gen) => { + self.fn_trait_kind_from_def_id(trait_def_id) + } + ClosureKind::CoroutineClosure(CoroutineKind::Async) => self .async_fn_trait_kind_from_def_id(trait_def_id) .or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)), _ => None, @@ -501,6 +550,7 @@ impl<'db> InferenceContext<'_, 'db> { /// know that. fn deduce_sig_from_projection( &mut self, + closure_expr: ExprId, closure_kind: ClosureKind, projection: PolyProjectionPredicate<'db>, ) -> Option<PolyFnSig<'db>> { @@ -512,14 +562,18 @@ impl<'db> InferenceContext<'_, 'db> { ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection(projection) } - ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { + ClosureKind::CoroutineClosure(CoroutineKind::Async) + if Some(def_id) == self.lang_items.AsyncFnOnceOutput => + { self.extract_sig_from_projection(projection) } // It's possible we've passed the closure to a (somewhat out-of-fashion) // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still // guide inference here, since it's beneficial for the user. - ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.FnOnceOutput => { - self.extract_sig_from_projection_and_future_bound(projection) + ClosureKind::CoroutineClosure(CoroutineKind::Async) + if Some(def_id) == self.lang_items.FnOnceOutput => + { + self.extract_sig_from_projection_and_future_bound(closure_expr, projection) } _ => None, } @@ -531,7 +585,7 @@ impl<'db> InferenceContext<'_, 'db> { &self, projection: PolyProjectionPredicate<'db>, ) -> Option<PolyFnSig<'db>> { - let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection); + let projection = self.resolve_vars_if_possible(projection); let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); debug!(?arg_param_ty); @@ -574,9 +628,10 @@ impl<'db> InferenceContext<'_, 'db> { /// projection, and the output will be an unconstrained type variable instead. fn extract_sig_from_projection_and_future_bound( &mut self, + closure_expr: ExprId, projection: PolyProjectionPredicate<'db>, ) -> Option<PolyFnSig<'db>> { - let projection = self.table.infer_ctxt.resolve_vars_if_possible(projection); + let projection = self.resolve_vars_if_possible(projection); let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1); debug!(?arg_param_ty); @@ -603,7 +658,7 @@ impl<'db> InferenceContext<'_, 'db> { bound.predicate.kind().skip_binder() && let ret_projection = bound.predicate.kind().rebind(ret_projection) && let Some(ret_projection) = ret_projection.no_bound_vars() - && let SolverDefId::TypeAliasId(assoc_type) = ret_projection.def_id() + && let TermId::TypeAliasId(assoc_type) = ret_projection.def_id().0 && Some(assoc_type) == self.lang_items.FutureOutput { return_ty = Some(ret_projection.term.expect_type()); @@ -625,7 +680,7 @@ impl<'db> InferenceContext<'_, 'db> { // // FIXME: We probably should store this signature inference output in a way // that does not misuse a `FnSig` type, but that can be done separately. - let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var()); + let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var(closure_expr.into())); let sig = projection.rebind(self.interner().mk_fn_sig_safe_rust_abi(input_tys, return_ty)); @@ -634,14 +689,29 @@ impl<'db> InferenceContext<'_, 'db> { fn sig_of_closure( &mut self, - decl_inputs: &[Option<TypeRefId>], - decl_output: Option<TypeRefId>, + closure_expr: ExprId, + decl_inputs: &[PatId], + decl_input_tys: &[Option<TypeRefId>], + decl_output_ty: Option<TypeRefId>, expected_sig: Option<PolyFnSig<'db>>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { if let Some(e) = expected_sig { - self.sig_of_closure_with_expectation(decl_inputs, decl_output, e) + self.sig_of_closure_with_expectation( + closure_expr, + decl_inputs, + decl_input_tys, + decl_output_ty, + e, + closure_kind, + ) } else { - self.sig_of_closure_no_expectation(decl_inputs, decl_output) + self.sig_of_closure_no_expectation( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ) } } @@ -649,10 +719,13 @@ impl<'db> InferenceContext<'_, 'db> { /// types that the user gave into a signature. fn sig_of_closure_no_expectation( &mut self, + closure_expr: ExprId, decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { - let bound_sig = self.supplied_sig_of_closure(decl_inputs, decl_output); + let bound_sig = + self.supplied_sig_of_closure(closure_expr, decl_inputs, decl_output, closure_kind); self.closure_sigs(bound_sig) } @@ -706,18 +779,28 @@ impl<'db> InferenceContext<'_, 'db> { /// regions with depth 1, which are bound then by the closure. fn sig_of_closure_with_expectation( &mut self, - decl_inputs: &[Option<TypeRefId>], - decl_output: Option<TypeRefId>, + closure_expr: ExprId, + decl_inputs: &[PatId], + decl_input_tys: &[Option<TypeRefId>], + decl_output_ty: Option<TypeRefId>, expected_sig: PolyFnSig<'db>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { // Watch out for some surprises and just ignore the // expectation if things don't see to match up with what we // expect. if expected_sig.c_variadic() { - return self.sig_of_closure_no_expectation(decl_inputs, decl_output); - } else if expected_sig.skip_binder().inputs_and_output.len() != decl_inputs.len() + 1 { - return self - .sig_of_closure_with_mismatched_number_of_arguments(decl_inputs, decl_output); + return self.sig_of_closure_no_expectation( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ); + } else if expected_sig.skip_binder().inputs_and_output.len() != decl_input_tys.len() + 1 { + return self.sig_of_closure_with_mismatched_number_of_arguments( + decl_input_tys, + decl_output_ty, + ); } // Create a `PolyFnSig`. Note the oddity that late bound @@ -728,9 +811,9 @@ impl<'db> InferenceContext<'_, 'db> { self.interner().mk_fn_sig( sig.inputs().iter().copied(), sig.output(), - sig.c_variadic, + sig.c_variadic(), Safety::Safe, - FnAbi::RustCall, + ExternAbi::RustCall, ) }); @@ -746,9 +829,21 @@ impl<'db> InferenceContext<'_, 'db> { // Along the way, it also writes out entries for types that the user // wrote into our typeck results, which are then later used by the privacy // check. - match self.merge_supplied_sig_with_expectation(decl_inputs, decl_output, closure_sigs) { + match self.merge_supplied_sig_with_expectation( + closure_expr, + decl_inputs, + decl_input_tys, + decl_output_ty, + closure_sigs, + closure_kind, + ) { Ok(infer_ok) => self.table.register_infer_ok(infer_ok), - Err(_) => self.sig_of_closure_no_expectation(decl_inputs, decl_output), + Err(_) => self.sig_of_closure_no_expectation( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ), } } @@ -767,15 +862,23 @@ impl<'db> InferenceContext<'_, 'db> { /// strategy. fn merge_supplied_sig_with_expectation( &mut self, - decl_inputs: &[Option<TypeRefId>], - decl_output: Option<TypeRefId>, + closure_expr: ExprId, + decl_inputs: &[PatId], + decl_input_tys: &[Option<TypeRefId>], + decl_output_ty: Option<TypeRefId>, mut expected_sigs: ClosureSignatures<'db>, + closure_kind: ClosureKind, ) -> InferResult<'db, ClosureSignatures<'db>> { // Get the signature S that the user gave. // // (See comment on `sig_of_closure_with_expectation` for the // meaning of these letters.) - let supplied_sig = self.supplied_sig_of_closure(decl_inputs, decl_output); + let supplied_sig = self.supplied_sig_of_closure( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ); debug!(?supplied_sig); @@ -796,25 +899,28 @@ impl<'db> InferenceContext<'_, 'db> { self.table.commit_if_ok(|table| { let mut all_obligations = PredicateObligations::new(); let supplied_sig = table.infer_ctxt.instantiate_binder_with_fresh_vars( + closure_expr.into(), BoundRegionConversionTime::FnCall, supplied_sig, ); // The liberated version of this signature should be a subtype // of the liberated form of the expectation. - for (supplied_ty, expected_ty) in iter::zip( - supplied_sig.inputs().iter().copied(), + for ((decl_input, supplied_ty), expected_ty) in iter::zip( + iter::zip(decl_inputs, supplied_sig.inputs().iter().copied()), expected_sigs.liberated_sig.inputs().iter().copied(), ) { // Check that E' = S'. - let cause = ObligationCause::new(); + let cause = ObligationCause::new(*decl_input); let InferOk { value: (), obligations } = table.infer_ctxt.at(&cause, table.param_env).eq(expected_ty, supplied_ty)?; all_obligations.extend(obligations); } let supplied_output_ty = supplied_sig.output(); - let cause = ObligationCause::new(); + let cause = ObligationCause::new( + decl_output_ty.map(Span::TypeRefId).unwrap_or(closure_expr.into()), + ); let InferOk { value: (), obligations } = table .infer_ctxt @@ -822,18 +928,15 @@ impl<'db> InferenceContext<'_, 'db> { .eq(expected_sigs.liberated_sig.output(), supplied_output_ty)?; all_obligations.extend(obligations); - let inputs = supplied_sig - .inputs() - .iter() - .copied() - .map(|ty| table.infer_ctxt.resolve_vars_if_possible(ty)); + let inputs = + supplied_sig.inputs().iter().copied().map(|ty| table.resolve_vars_if_possible(ty)); expected_sigs.liberated_sig = table.interner().mk_fn_sig( inputs, supplied_output_ty, - expected_sigs.liberated_sig.c_variadic, + expected_sigs.liberated_sig.c_variadic(), Safety::Safe, - FnAbi::RustCall, + ExternAbi::RustCall, ); Ok(InferOk { value: expected_sigs, obligations: all_obligations }) @@ -846,25 +949,54 @@ impl<'db> InferenceContext<'_, 'db> { /// Also, record this closure signature for later. fn supplied_sig_of_closure( &mut self, + closure_expr: ExprId, decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, + closure_kind: ClosureKind, ) -> PolyFnSig<'db> { let interner = self.interner(); let supplied_return = match decl_output { - Some(output) => { - let output = self.make_body_ty(output); - self.process_user_written_ty(output) - } - None => self.table.next_ty_var(), + Some(output) => self.make_body_ty(output), + None => match closure_kind { + // In the case of the async block that we create for a function body, + // we expect the return type of the block to match that of the enclosing + // function. + ClosureKind::Coroutine { + kind: CoroutineKind::Async, + source: CoroutineSource::Fn, + } => { + debug!("closure is async fn body"); + self.deduce_future_output_from_obligations(closure_expr).unwrap_or_else(|| { + // AFAIK, deducing the future output + // always succeeds *except* in error cases + // like #65159. I'd like to return Error + // here, but I can't because I can't + // easily (and locally) prove that we + // *have* reported an + // error. --nikomatsakis + self.table.next_ty_var(closure_expr.into()) + }) + } + // All `gen {}` and `async gen {}` must return unit. + ClosureKind::Coroutine { + kind: CoroutineKind::Gen | CoroutineKind::AsyncGen, + .. + } => self.types.types.unit, + + // For async blocks, we just fall back to `_` here. + // For closures/coroutines, we know nothing about the return + // type unless it was supplied. + ClosureKind::Coroutine { kind: CoroutineKind::Async, .. } + | ClosureKind::OldCoroutine(_) + | ClosureKind::Closure + | ClosureKind::CoroutineClosure(_) => self.table.next_ty_var(closure_expr.into()), + }, }; // First, convert the types that the user supplied (if any). let supplied_arguments = decl_inputs.iter().map(|&input| match input { - Some(input) => { - let input = self.make_body_ty(input); - self.process_user_written_ty(input) - } - None => self.table.next_ty_var(), + Some(input) => self.make_body_ty(input), + None => self.table.next_ty_var(closure_expr.into()), }); Binder::dummy(interner.mk_fn_sig( @@ -872,10 +1004,114 @@ impl<'db> InferenceContext<'_, 'db> { supplied_return, false, Safety::Safe, - FnAbi::RustCall, + ExternAbi::RustCall, )) } + /// Invoked when we are translating the coroutine that results + /// from desugaring an `async fn`. Returns the "sugared" return + /// type of the `async fn` -- that is, the return type that the + /// user specified. The "desugared" return type is an `impl + /// Future<Output = T>`, so we do this by searching through the + /// obligations to extract the `T`. + #[instrument(skip(self), level = "debug", ret)] + fn deduce_future_output_from_obligations(&mut self, body_def_id: ExprId) -> Option<Ty<'db>> { + let ret_coercion = self + .return_coercion + .as_ref() + .unwrap_or_else(|| panic!("async fn coroutine outside of a fn")); + + let ret_ty = ret_coercion.expected_ty(); + let ret_ty = self.table.resolve_vars_with_obligations(ret_ty); + + let get_future_output = |predicate: Predicate<'db>| { + // Search for a pending obligation like + // + // `<R as Future>::Output = T` + // + // where R is the return type we are expecting. This type `T` + // will be our output. + let bound_predicate = predicate.kind(); + if let PredicateKind::Clause(ClauseKind::Projection(proj_predicate)) = + bound_predicate.skip_binder() + { + self.deduce_future_output_from_projection(bound_predicate.rebind(proj_predicate)) + } else { + None + } + }; + + let output_ty = match ret_ty.kind() { + TyKind::Infer(InferTy::TyVar(ret_vid)) => self + .table + .obligations_for_self_ty(ret_vid) + .into_iter() + .find_map(|obligation| get_future_output(obligation.predicate))?, + TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { .. }, .. }) => { + return Some(self.types.types.error); + } + TyKind::Alias(AliasTy { kind: AliasTyKind::Opaque { def_id }, args, .. }) => def_id + .0 + .predicates(self.db) + .iter_instantiated_copied(self.interner(), &args) + .map(Unnormalized::skip_norm_wip) + .find_map(|p| get_future_output(p.as_predicate()))?, + TyKind::Error(_) => return Some(ret_ty), + _ => { + panic!("invalid async fn coroutine return type: {ret_ty:?}") + } + }; + + Some(output_ty) + } + + /// Given a projection like + /// + /// `<X as Future>::Output = T` + /// + /// where `X` is some type that has no late-bound regions, returns + /// `Some(T)`. If the projection is for some other trait, returns + /// `None`. + fn deduce_future_output_from_projection( + &self, + predicate: PolyProjectionPredicate<'db>, + ) -> Option<Ty<'db>> { + debug!("deduce_future_output_from_projection(predicate={:?})", predicate); + + // We do not expect any bound regions in our predicate, so + // skip past the bound vars. + let Some(predicate) = predicate.no_bound_vars() else { + debug!("deduce_future_output_from_projection: has late-bound regions"); + return None; + }; + + // Check that this is a projection from the `Future` trait. + let trait_def_id = predicate.projection_term.trait_def_id(self.interner()).0; + if Some(trait_def_id) != self.lang_items.Future { + debug!("deduce_future_output_from_projection: not a future"); + return None; + } + + // The `Future` trait has only one associated item, `Output`, + // so check that this is what we see. + let output_assoc_item = self.lang_items.FutureOutput; + if output_assoc_item.map(Into::into) != Some(predicate.def_id().0) { + panic!( + "projecting associated item `{:?}` from future, which is not Output `{:?}`", + predicate.projection_term.kind(self.interner()), + output_assoc_item, + ); + } + + // Extract the type from the projection. Note that there can + // be no bound variables in this type because the "self type" + // does not have any regions in it. + let output_ty = self.resolve_vars_if_possible(predicate.term); + debug!("deduce_future_output_from_projection: output_ty={:?}", output_ty); + // This is a projection on a Fn trait so will always be a type. + Some(output_ty.expect_type()) + } + /// Converts the types that the user supplied, in case that doing /// so should yield an error, but returns back a signature where /// all parameters are of type `ty::Error`. @@ -903,7 +1139,7 @@ impl<'db> InferenceContext<'_, 'db> { err_ty, false, Safety::Safe, - FnAbi::RustCall, + ExternAbi::RustCall, )); debug!("supplied_sig_of_closure: result={:?}", result); |