Unnamed repository; edit this file 'description' to name the repository.
Infer the expected type as the return type for async blocks defined by async fns
Ported from rustc.
| -rw-r--r-- | crates/hir-ty/src/infer/closure.rs | 191 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/coercion.rs | 15 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/regression.rs | 10 | ||||
| -rw-r--r-- | crates/test-utils/src/minicore.rs | 31 |
4 files changed, 230 insertions, 17 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 218d8e2f3e..2679efca7d 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -10,12 +10,12 @@ use hir_def::{ type_ref::TypeRefId, }; use rustc_type_ir::{ - ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, - CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, - TypeVisitor, + AliasTyKind, ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, + CoroutineClosureArgs, CoroutineClosureArgsParts, InferTy, Interner, TypeSuperVisitable, + TypeVisitable, TypeVisitableExt, TypeVisitor, inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, Ty as _}, }; -use tracing::debug; +use tracing::{debug, instrument}; use crate::{ FnAbi, Span, @@ -86,8 +86,14 @@ impl<'db> InferenceContext<'_, 'db> { None => (None, None), }; - let ClosureSignatures { bound_sig, mut liberated_sig } = - self.sig_of_closure(closure_expr, args, arg_types, ret_type, expected_sig); + let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure( + closure_expr, + args, + arg_types, + ret_type, + expected_sig, + closure_kind, + ); debug!(?bound_sig, ?liberated_sig); @@ -685,6 +691,7 @@ impl<'db> InferenceContext<'_, 'db> { decl_input_tys: &[Option<TypeRefId>], decl_output_ty: Option<TypeRefId>, expected_sig: Option<PolyFnSig<'db>>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { if let Some(e) = expected_sig { self.sig_of_closure_with_expectation( @@ -693,9 +700,15 @@ impl<'db> InferenceContext<'_, 'db> { decl_input_tys, decl_output_ty, e, + closure_kind, ) } else { - self.sig_of_closure_no_expectation(closure_expr, decl_input_tys, decl_output_ty) + self.sig_of_closure_no_expectation( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ) } } @@ -706,8 +719,10 @@ impl<'db> InferenceContext<'_, 'db> { closure_expr: ExprId, decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { - let bound_sig = self.supplied_sig_of_closure(closure_expr, decl_inputs, decl_output); + let bound_sig = + self.supplied_sig_of_closure(closure_expr, decl_inputs, decl_output, closure_kind); self.closure_sigs(bound_sig) } @@ -766,6 +781,7 @@ impl<'db> InferenceContext<'_, 'db> { decl_input_tys: &[Option<TypeRefId>], decl_output_ty: Option<TypeRefId>, expected_sig: PolyFnSig<'db>, + closure_kind: ClosureKind, ) -> ClosureSignatures<'db> { // Watch out for some surprises and just ignore the // expectation if things don't see to match up with what we @@ -775,6 +791,7 @@ impl<'db> InferenceContext<'_, 'db> { closure_expr, decl_input_tys, decl_output_ty, + closure_kind, ); } else if expected_sig.skip_binder().inputs_and_output.len() != decl_input_tys.len() + 1 { return self.sig_of_closure_with_mismatched_number_of_arguments( @@ -815,11 +832,15 @@ impl<'db> InferenceContext<'_, 'db> { decl_input_tys, decl_output_ty, closure_sigs, + closure_kind, ) { Ok(infer_ok) => self.table.register_infer_ok(infer_ok), - Err(_) => { - self.sig_of_closure_no_expectation(closure_expr, decl_input_tys, decl_output_ty) - } + Err(_) => self.sig_of_closure_no_expectation( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ), } } @@ -843,13 +864,18 @@ impl<'db> InferenceContext<'_, 'db> { decl_input_tys: &[Option<TypeRefId>], decl_output_ty: Option<TypeRefId>, mut expected_sigs: ClosureSignatures<'db>, + closure_kind: ClosureKind, ) -> InferResult<'db, ClosureSignatures<'db>> { // Get the signature S that the user gave. // // (See comment on `sig_of_closure_with_expectation` for the // meaning of these letters.) - let supplied_sig = - self.supplied_sig_of_closure(closure_expr, decl_input_tys, decl_output_ty); + let supplied_sig = self.supplied_sig_of_closure( + closure_expr, + decl_input_tys, + decl_output_ty, + closure_kind, + ); debug!(?supplied_sig); @@ -923,12 +949,46 @@ impl<'db> InferenceContext<'_, 'db> { closure_expr: ExprId, decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, + closure_kind: ClosureKind, ) -> PolyFnSig<'db> { let interner = self.interner(); let supplied_return = match decl_output { Some(output) => self.make_body_ty(output), - None => self.table.next_ty_var(closure_expr.into()), + None => match closure_kind { + // In the case of the async block that we create for a function body, + // we expect the return type of the block to match that of the enclosing + // function. + ClosureKind::Coroutine { + kind: CoroutineKind::Async, + source: CoroutineSource::Fn, + } => { + debug!("closure is async fn body"); + self.deduce_future_output_from_obligations(closure_expr).unwrap_or_else(|| { + // AFAIK, deducing the future output + // always succeeds *except* in error cases + // like #65159. I'd like to return Error + // here, but I can't because I can't + // easily (and locally) prove that we + // *have* reported an + // error. --nikomatsakis + self.table.next_ty_var(closure_expr.into()) + }) + } + // All `gen {}` and `async gen {}` must return unit. + ClosureKind::Coroutine { + kind: CoroutineKind::Gen | CoroutineKind::AsyncGen, + .. + } => self.types.types.unit, + + // For async blocks, we just fall back to `_` here. + // For closures/coroutines, we know nothing about the return + // type unless it was supplied. + ClosureKind::Coroutine { kind: CoroutineKind::Async, .. } + | ClosureKind::OldCoroutine(_) + | ClosureKind::Closure + | ClosureKind::CoroutineClosure(_) => self.table.next_ty_var(closure_expr.into()), + }, }; // First, convert the types that the user supplied (if any). let supplied_arguments = decl_inputs.iter().map(|&input| match input { @@ -945,6 +1005,109 @@ impl<'db> InferenceContext<'_, 'db> { )) } + /// Invoked when we are translating the coroutine that results + /// from desugaring an `async fn`. Returns the "sugared" return + /// type of the `async fn` -- that is, the return type that the + /// user specified. The "desugared" return type is an `impl + /// Future<Output = T>`, so we do this by searching through the + /// obligations to extract the `T`. + #[instrument(skip(self), level = "debug", ret)] + fn deduce_future_output_from_obligations(&mut self, body_def_id: ExprId) -> Option<Ty<'db>> { + let ret_coercion = self + .return_coercion + .as_ref() + .unwrap_or_else(|| panic!("async fn coroutine outside of a fn")); + + let ret_ty = ret_coercion.expected_ty(); + let ret_ty = self.table.resolve_vars_with_obligations(ret_ty); + + let get_future_output = |predicate: Predicate<'db>| { + // Search for a pending obligation like + // + // `<R as Future>::Output = T` + // + // where R is the return type we are expecting. This type `T` + // will be our output. + let bound_predicate = predicate.kind(); + if let PredicateKind::Clause(ClauseKind::Projection(proj_predicate)) = + bound_predicate.skip_binder() + { + self.deduce_future_output_from_projection(bound_predicate.rebind(proj_predicate)) + } else { + None + } + }; + + let output_ty = match ret_ty.kind() { + TyKind::Infer(InferTy::TyVar(ret_vid)) => self + .table + .obligations_for_self_ty(ret_vid) + .into_iter() + .find_map(|obligation| get_future_output(obligation.predicate))?, + TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { .. }, .. }) => { + return Some(self.types.types.error); + } + TyKind::Alias(AliasTy { kind: AliasTyKind::Opaque { def_id }, args, .. }) => def_id + .expect_opaque_ty() + .predicates(self.db) + .iter_instantiated_copied(self.interner(), &args) + .find_map(|p| get_future_output(p.as_predicate()))?, + TyKind::Error(_) => return Some(ret_ty), + _ => { + panic!("invalid async fn coroutine return type: {ret_ty:?}") + } + }; + + Some(output_ty) + } + + /// Given a projection like + /// + /// `<X as Future>::Output = T` + /// + /// where `X` is some type that has no late-bound regions, returns + /// `Some(T)`. If the projection is for some other trait, returns + /// `None`. + fn deduce_future_output_from_projection( + &self, + predicate: PolyProjectionPredicate<'db>, + ) -> Option<Ty<'db>> { + debug!("deduce_future_output_from_projection(predicate={:?})", predicate); + + // We do not expect any bound regions in our predicate, so + // skip past the bound vars. + let Some(predicate) = predicate.no_bound_vars() else { + debug!("deduce_future_output_from_projection: has late-bound regions"); + return None; + }; + + // Check that this is a projection from the `Future` trait. + let trait_def_id = predicate.projection_term.trait_def_id(self.interner()).0; + if Some(trait_def_id) != self.lang_items.Future { + debug!("deduce_future_output_from_projection: not a future"); + return None; + } + + // The `Future` trait has only one associated item, `Output`, + // so check that this is what we see. + let output_assoc_item = self.lang_items.FutureOutput; + if output_assoc_item != Some(predicate.projection_term.def_id.expect_type_alias()) { + panic!( + "projecting associated item `{:?}` from future, which is not Output `{:?}`", + predicate.projection_term.kind(self.interner()), + output_assoc_item, + ); + } + + // Extract the type from the projection. Note that there can + // be no bound variables in this type because the "self type" + // does not have any regions in it. + let output_ty = self.resolve_vars_if_possible(predicate.term); + debug!("deduce_future_output_from_projection: output_ty={:?}", output_ty); + // This is a projection on a Fn trait so will always be a type. + Some(output_ty.expect_type()) + } + /// Converts the types that the user supplied, in case that doing /// so should yield an error, but returns back a signature where /// all parameters are of type `ty::Error`. diff --git a/crates/hir-ty/src/tests/coercion.rs b/crates/hir-ty/src/tests/coercion.rs index a7f65d4fe8..db06d55278 100644 --- a/crates/hir-ty/src/tests/coercion.rs +++ b/crates/hir-ty/src/tests/coercion.rs @@ -1054,3 +1054,18 @@ fn bar() { "#, ); } + +#[test] +fn async_fn_ret() { + check_no_mismatches( + r#" +//- minicore: coerce_unsized, unsize, future, index, slice, range +async fn foo(a: &[i32]) -> &[i32] { + if true { + return &[]; + } + &a[..0] +} + "#, + ); +} diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs index a5f349e593..87e6521e63 100644 --- a/crates/hir-ty/src/tests/regression.rs +++ b/crates/hir-ty/src/tests/regression.rs @@ -2,6 +2,8 @@ mod new_solver; use expect_test::expect; +use crate::tests::check; + use super::{check_infer, check_no_mismatches, check_types}; #[test] @@ -2013,18 +2015,20 @@ where #[test] fn tait_async_stack_overflow_17199() { - check_types( + // The error here is because we don't support TAITs. + check( r#" //- minicore: fmt, future type Foo = impl core::fmt::Debug; async fn foo() -> Foo { () + // ^^ expected impl Debug, got () } async fn test() { let t = foo().await; - // ^ impl Debug + // ^ type: impl Debug } "#, ); @@ -2234,7 +2238,7 @@ type Bar = impl Foo; async fn f<A, B, C>() -> Bar {} "#, expect![[r#" - 64..66 '{}': () + 64..66 '{}': impl Foo + ?Sized "#]], ); } diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 29775590ea..802e6ab8ce 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -701,6 +701,37 @@ pub mod ops { unsafe impl<T> SliceIndex<[T]> for usize { type Output = T; } + + macro_rules! impl_index_range { + ( $($range:ty,)* ) => { + $( + unsafe impl<T> SliceIndex<[T]> for $range { + type Output = [T]; + } + )* + } + } + + // region:range + impl_index_range!( + crate::ops::RangeFull, + crate::ops::Range<usize>, + crate::ops::RangeFrom<usize>, + crate::ops::RangeTo<usize>, + crate::ops::RangeInclusive<usize>, + crate::ops::RangeToInclusive<usize>, + ); + // endregion:range + + // region:new_range + impl_index_range!( + crate::range::Range<usize>, + crate::range::RangeFrom<usize>, + crate::range::RangeInclusive<usize>, + crate::range::RangeToInclusive<usize>, + ); + // endregion:new_range + // endregion:slice } pub use self::index::{Index, IndexMut}; |