Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #22275 from ChayimFriedman2/future-output
fix: Infer the expected type as the return type for async blocks defined by async fns
Chayim Refael Friedman 2 weeks ago
parent b4ca747 · parent 9c58e93 · commit 217af23
-rw-r--r--crates/hir-ty/src/infer/closure.rs191
-rw-r--r--crates/hir-ty/src/tests/coercion.rs15
-rw-r--r--crates/hir-ty/src/tests/regression.rs10
-rw-r--r--crates/test-utils/src/minicore.rs31
4 files changed, 230 insertions, 17 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 218d8e2f3e..2679efca7d 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -10,12 +10,12 @@ use hir_def::{
type_ref::TypeRefId,
};
use rustc_type_ir::{
- ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs,
- CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
- TypeVisitor,
+ AliasTyKind, ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts,
+ CoroutineClosureArgs, CoroutineClosureArgsParts, InferTy, Interner, TypeSuperVisitable,
+ TypeVisitable, TypeVisitableExt, TypeVisitor,
inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, Ty as _},
};
-use tracing::debug;
+use tracing::{debug, instrument};
use crate::{
FnAbi, Span,
@@ -86,8 +86,14 @@ impl<'db> InferenceContext<'_, 'db> {
None => (None, None),
};
- let ClosureSignatures { bound_sig, mut liberated_sig } =
- self.sig_of_closure(closure_expr, args, arg_types, ret_type, expected_sig);
+ let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure(
+ closure_expr,
+ args,
+ arg_types,
+ ret_type,
+ expected_sig,
+ closure_kind,
+ );
debug!(?bound_sig, ?liberated_sig);
@@ -685,6 +691,7 @@ impl<'db> InferenceContext<'_, 'db> {
decl_input_tys: &[Option<TypeRefId>],
decl_output_ty: Option<TypeRefId>,
expected_sig: Option<PolyFnSig<'db>>,
+ closure_kind: ClosureKind,
) -> ClosureSignatures<'db> {
if let Some(e) = expected_sig {
self.sig_of_closure_with_expectation(
@@ -693,9 +700,15 @@ impl<'db> InferenceContext<'_, 'db> {
decl_input_tys,
decl_output_ty,
e,
+ closure_kind,
)
} else {
- self.sig_of_closure_no_expectation(closure_expr, decl_input_tys, decl_output_ty)
+ self.sig_of_closure_no_expectation(
+ closure_expr,
+ decl_input_tys,
+ decl_output_ty,
+ closure_kind,
+ )
}
}
@@ -706,8 +719,10 @@ impl<'db> InferenceContext<'_, 'db> {
closure_expr: ExprId,
decl_inputs: &[Option<TypeRefId>],
decl_output: Option<TypeRefId>,
+ closure_kind: ClosureKind,
) -> ClosureSignatures<'db> {
- let bound_sig = self.supplied_sig_of_closure(closure_expr, decl_inputs, decl_output);
+ let bound_sig =
+ self.supplied_sig_of_closure(closure_expr, decl_inputs, decl_output, closure_kind);
self.closure_sigs(bound_sig)
}
@@ -766,6 +781,7 @@ impl<'db> InferenceContext<'_, 'db> {
decl_input_tys: &[Option<TypeRefId>],
decl_output_ty: Option<TypeRefId>,
expected_sig: PolyFnSig<'db>,
+ closure_kind: ClosureKind,
) -> ClosureSignatures<'db> {
// Watch out for some surprises and just ignore the
// expectation if things don't see to match up with what we
@@ -775,6 +791,7 @@ impl<'db> InferenceContext<'_, 'db> {
closure_expr,
decl_input_tys,
decl_output_ty,
+ closure_kind,
);
} else if expected_sig.skip_binder().inputs_and_output.len() != decl_input_tys.len() + 1 {
return self.sig_of_closure_with_mismatched_number_of_arguments(
@@ -815,11 +832,15 @@ impl<'db> InferenceContext<'_, 'db> {
decl_input_tys,
decl_output_ty,
closure_sigs,
+ closure_kind,
) {
Ok(infer_ok) => self.table.register_infer_ok(infer_ok),
- Err(_) => {
- self.sig_of_closure_no_expectation(closure_expr, decl_input_tys, decl_output_ty)
- }
+ Err(_) => self.sig_of_closure_no_expectation(
+ closure_expr,
+ decl_input_tys,
+ decl_output_ty,
+ closure_kind,
+ ),
}
}
@@ -843,13 +864,18 @@ impl<'db> InferenceContext<'_, 'db> {
decl_input_tys: &[Option<TypeRefId>],
decl_output_ty: Option<TypeRefId>,
mut expected_sigs: ClosureSignatures<'db>,
+ closure_kind: ClosureKind,
) -> InferResult<'db, ClosureSignatures<'db>> {
// Get the signature S that the user gave.
//
// (See comment on `sig_of_closure_with_expectation` for the
// meaning of these letters.)
- let supplied_sig =
- self.supplied_sig_of_closure(closure_expr, decl_input_tys, decl_output_ty);
+ let supplied_sig = self.supplied_sig_of_closure(
+ closure_expr,
+ decl_input_tys,
+ decl_output_ty,
+ closure_kind,
+ );
debug!(?supplied_sig);
@@ -923,12 +949,46 @@ impl<'db> InferenceContext<'_, 'db> {
closure_expr: ExprId,
decl_inputs: &[Option<TypeRefId>],
decl_output: Option<TypeRefId>,
+ closure_kind: ClosureKind,
) -> PolyFnSig<'db> {
let interner = self.interner();
let supplied_return = match decl_output {
Some(output) => self.make_body_ty(output),
- None => self.table.next_ty_var(closure_expr.into()),
+ None => match closure_kind {
+ // In the case of the async block that we create for a function body,
+ // we expect the return type of the block to match that of the enclosing
+ // function.
+ ClosureKind::Coroutine {
+ kind: CoroutineKind::Async,
+ source: CoroutineSource::Fn,
+ } => {
+ debug!("closure is async fn body");
+ self.deduce_future_output_from_obligations(closure_expr).unwrap_or_else(|| {
+ // AFAIK, deducing the future output
+ // always succeeds *except* in error cases
+ // like #65159. I'd like to return Error
+ // here, but I can't because I can't
+ // easily (and locally) prove that we
+ // *have* reported an
+ // error. --nikomatsakis
+ self.table.next_ty_var(closure_expr.into())
+ })
+ }
+ // All `gen {}` and `async gen {}` must return unit.
+ ClosureKind::Coroutine {
+ kind: CoroutineKind::Gen | CoroutineKind::AsyncGen,
+ ..
+ } => self.types.types.unit,
+
+ // For async blocks, we just fall back to `_` here.
+ // For closures/coroutines, we know nothing about the return
+ // type unless it was supplied.
+ ClosureKind::Coroutine { kind: CoroutineKind::Async, .. }
+ | ClosureKind::OldCoroutine(_)
+ | ClosureKind::Closure
+ | ClosureKind::CoroutineClosure(_) => self.table.next_ty_var(closure_expr.into()),
+ },
};
// First, convert the types that the user supplied (if any).
let supplied_arguments = decl_inputs.iter().map(|&input| match input {
@@ -945,6 +1005,109 @@ impl<'db> InferenceContext<'_, 'db> {
))
}
+ /// Invoked when we are translating the coroutine that results
+ /// from desugaring an `async fn`. Returns the "sugared" return
+ /// type of the `async fn` -- that is, the return type that the
+ /// user specified. The "desugared" return type is an `impl
+ /// Future<Output = T>`, so we do this by searching through the
+ /// obligations to extract the `T`.
+ #[instrument(skip(self), level = "debug", ret)]
+ fn deduce_future_output_from_obligations(&mut self, body_def_id: ExprId) -> Option<Ty<'db>> {
+ let ret_coercion = self
+ .return_coercion
+ .as_ref()
+ .unwrap_or_else(|| panic!("async fn coroutine outside of a fn"));
+
+ let ret_ty = ret_coercion.expected_ty();
+ let ret_ty = self.table.resolve_vars_with_obligations(ret_ty);
+
+ let get_future_output = |predicate: Predicate<'db>| {
+ // Search for a pending obligation like
+ //
+ // `<R as Future>::Output = T`
+ //
+ // where R is the return type we are expecting. This type `T`
+ // will be our output.
+ let bound_predicate = predicate.kind();
+ if let PredicateKind::Clause(ClauseKind::Projection(proj_predicate)) =
+ bound_predicate.skip_binder()
+ {
+ self.deduce_future_output_from_projection(bound_predicate.rebind(proj_predicate))
+ } else {
+ None
+ }
+ };
+
+ let output_ty = match ret_ty.kind() {
+ TyKind::Infer(InferTy::TyVar(ret_vid)) => self
+ .table
+ .obligations_for_self_ty(ret_vid)
+ .into_iter()
+ .find_map(|obligation| get_future_output(obligation.predicate))?,
+ TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { .. }, .. }) => {
+ return Some(self.types.types.error);
+ }
+ TyKind::Alias(AliasTy { kind: AliasTyKind::Opaque { def_id }, args, .. }) => def_id
+ .expect_opaque_ty()
+ .predicates(self.db)
+ .iter_instantiated_copied(self.interner(), &args)
+ .find_map(|p| get_future_output(p.as_predicate()))?,
+ TyKind::Error(_) => return Some(ret_ty),
+ _ => {
+ panic!("invalid async fn coroutine return type: {ret_ty:?}")
+ }
+ };
+
+ Some(output_ty)
+ }
+
+ /// Given a projection like
+ ///
+ /// `<X as Future>::Output = T`
+ ///
+ /// where `X` is some type that has no late-bound regions, returns
+ /// `Some(T)`. If the projection is for some other trait, returns
+ /// `None`.
+ fn deduce_future_output_from_projection(
+ &self,
+ predicate: PolyProjectionPredicate<'db>,
+ ) -> Option<Ty<'db>> {
+ debug!("deduce_future_output_from_projection(predicate={:?})", predicate);
+
+ // We do not expect any bound regions in our predicate, so
+ // skip past the bound vars.
+ let Some(predicate) = predicate.no_bound_vars() else {
+ debug!("deduce_future_output_from_projection: has late-bound regions");
+ return None;
+ };
+
+ // Check that this is a projection from the `Future` trait.
+ let trait_def_id = predicate.projection_term.trait_def_id(self.interner()).0;
+ if Some(trait_def_id) != self.lang_items.Future {
+ debug!("deduce_future_output_from_projection: not a future");
+ return None;
+ }
+
+ // The `Future` trait has only one associated item, `Output`,
+ // so check that this is what we see.
+ let output_assoc_item = self.lang_items.FutureOutput;
+ if output_assoc_item != Some(predicate.projection_term.def_id.expect_type_alias()) {
+ panic!(
+ "projecting associated item `{:?}` from future, which is not Output `{:?}`",
+ predicate.projection_term.kind(self.interner()),
+ output_assoc_item,
+ );
+ }
+
+ // Extract the type from the projection. Note that there can
+ // be no bound variables in this type because the "self type"
+ // does not have any regions in it.
+ let output_ty = self.resolve_vars_if_possible(predicate.term);
+ debug!("deduce_future_output_from_projection: output_ty={:?}", output_ty);
+ // This is a projection on a Fn trait so will always be a type.
+ Some(output_ty.expect_type())
+ }
+
/// Converts the types that the user supplied, in case that doing
/// so should yield an error, but returns back a signature where
/// all parameters are of type `ty::Error`.
diff --git a/crates/hir-ty/src/tests/coercion.rs b/crates/hir-ty/src/tests/coercion.rs
index a7f65d4fe8..db06d55278 100644
--- a/crates/hir-ty/src/tests/coercion.rs
+++ b/crates/hir-ty/src/tests/coercion.rs
@@ -1054,3 +1054,18 @@ fn bar() {
"#,
);
}
+
+#[test]
+fn async_fn_ret() {
+ check_no_mismatches(
+ r#"
+//- minicore: coerce_unsized, unsize, future, index, slice, range
+async fn foo(a: &[i32]) -> &[i32] {
+ if true {
+ return &[];
+ }
+ &a[..0]
+}
+ "#,
+ );
+}
diff --git a/crates/hir-ty/src/tests/regression.rs b/crates/hir-ty/src/tests/regression.rs
index a5f349e593..87e6521e63 100644
--- a/crates/hir-ty/src/tests/regression.rs
+++ b/crates/hir-ty/src/tests/regression.rs
@@ -2,6 +2,8 @@ mod new_solver;
use expect_test::expect;
+use crate::tests::check;
+
use super::{check_infer, check_no_mismatches, check_types};
#[test]
@@ -2013,18 +2015,20 @@ where
#[test]
fn tait_async_stack_overflow_17199() {
- check_types(
+ // The error here is because we don't support TAITs.
+ check(
r#"
//- minicore: fmt, future
type Foo = impl core::fmt::Debug;
async fn foo() -> Foo {
()
+ // ^^ expected impl Debug, got ()
}
async fn test() {
let t = foo().await;
- // ^ impl Debug
+ // ^ type: impl Debug
}
"#,
);
@@ -2234,7 +2238,7 @@ type Bar = impl Foo;
async fn f<A, B, C>() -> Bar {}
"#,
expect![[r#"
- 64..66 '{}': ()
+ 64..66 '{}': impl Foo + ?Sized
"#]],
);
}
diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs
index 29775590ea..802e6ab8ce 100644
--- a/crates/test-utils/src/minicore.rs
+++ b/crates/test-utils/src/minicore.rs
@@ -701,6 +701,37 @@ pub mod ops {
unsafe impl<T> SliceIndex<[T]> for usize {
type Output = T;
}
+
+ macro_rules! impl_index_range {
+ ( $($range:ty,)* ) => {
+ $(
+ unsafe impl<T> SliceIndex<[T]> for $range {
+ type Output = [T];
+ }
+ )*
+ }
+ }
+
+ // region:range
+ impl_index_range!(
+ crate::ops::RangeFull,
+ crate::ops::Range<usize>,
+ crate::ops::RangeFrom<usize>,
+ crate::ops::RangeTo<usize>,
+ crate::ops::RangeInclusive<usize>,
+ crate::ops::RangeToInclusive<usize>,
+ );
+ // endregion:range
+
+ // region:new_range
+ impl_index_range!(
+ crate::range::Range<usize>,
+ crate::range::RangeFrom<usize>,
+ crate::range::RangeInclusive<usize>,
+ crate::range::RangeToInclusive<usize>,
+ );
+ // endregion:new_range
+
// endregion:slice
}
pub use self::index::{Index, IndexMut};