Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
-rw-r--r--crates/hir-ty/src/infer/closure.rs83
1 files changed, 63 insertions, 20 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 93c98f2542..f5b974b1db 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -5,8 +5,8 @@ pub(crate) mod analysis;
use std::{iter, mem, ops::ControlFlow};
use hir_def::{
- TraitId,
- hir::{ClosureKind, CoroutineSource, ExprId, PatId},
+ AdtId, TraitId,
+ hir::{ClosureKind, CoroutineKind, CoroutineSource, ExprId, PatId},
type_ref::TypeRefId,
};
use rustc_type_ir::{
@@ -22,8 +22,8 @@ use crate::{
db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId},
infer::{BreakableKind, Diverges, coerce::CoerceMany, pat::PatOrigin},
next_solver::{
- AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig,
- PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind,
+ AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArg, GenericArgs,
+ PolyFnSig, PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind,
abi::Safety,
infer::{
BoundRegionConversionTime, InferOk, InferResult,
@@ -46,6 +46,22 @@ struct ClosureSignatures<'db> {
}
impl<'db> InferenceContext<'_, 'db> {
+ fn poll_option_ty(&mut self, item_ty: Ty<'db>) -> Ty<'db> {
+ let interner = self.interner();
+
+ let (Some(option), Some(poll)) = (self.lang_items.Option, self.lang_items.Poll) else {
+ return self.types.types.error;
+ };
+
+ let option_ty = Ty::new_adt(
+ interner,
+ AdtId::EnumId(option),
+ interner.mk_args(&[GenericArg::from(item_ty)]),
+ );
+
+ Ty::new_adt(interner, AdtId::EnumId(poll), interner.mk_args(&[GenericArg::from(option_ty)]))
+ }
+
pub(super) fn infer_closure(
&mut self,
body: ExprId,
@@ -121,10 +137,22 @@ impl<'db> InferenceContext<'_, 'db> {
(Ty::new_closure(interner, closure_id.into(), closure_args.args), None)
}
- ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => {
+ ClosureKind::OldCoroutine(_) | ClosureKind::Coroutine { .. } => {
let yield_ty = match closure_kind {
- ClosureKind::Coroutine(_) => self.table.next_ty_var(closure_expr.into()),
- ClosureKind::AsyncBlock { .. } => self.types.types.unit,
+ ClosureKind::OldCoroutine(_)
+ | ClosureKind::Coroutine { kind: CoroutineKind::Gen, .. } => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ self.require_type_is_sized(yield_ty);
+ yield_ty
+ }
+ ClosureKind::Coroutine { kind: CoroutineKind::Async, .. } => {
+ self.types.types.unit
+ }
+ ClosureKind::Coroutine { kind: CoroutineKind::AsyncGen, .. } => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ self.require_type_is_sized(yield_ty);
+ self.poll_option_ty(yield_ty)
+ }
_ => unreachable!(),
};
@@ -137,7 +165,7 @@ impl<'db> InferenceContext<'_, 'db> {
// later during upvar analysis. Regular coroutines always have the kind
// ty of `().`
let kind_ty = match closure_kind {
- ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => {
+ ClosureKind::Coroutine { source: CoroutineSource::Closure, .. } => {
self.table.next_ty_var(closure_expr.into())
}
_ => self.types.types.unit,
@@ -163,11 +191,20 @@ impl<'db> InferenceContext<'_, 'db> {
Some((resume_ty, yield_ty)),
)
}
- ClosureKind::AsyncClosure => {
- // async closures always return the type ascribed after the `->` (if present),
- // and yield `()`.
- let (bound_return_ty, bound_yield_ty) =
- (bound_sig.skip_binder().output(), self.types.types.unit);
+ ClosureKind::CoroutineClosure(coroutine_kind) => {
+ let (bound_return_ty, bound_yield_ty) = match coroutine_kind {
+ CoroutineKind::Gen => {
+ (self.types.types.unit, self.table.next_ty_var(closure_expr.into()))
+ }
+ CoroutineKind::Async => {
+ (bound_sig.skip_binder().output(), self.types.types.unit)
+ }
+ CoroutineKind::AsyncGen => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ (self.types.types.unit, self.poll_option_ty(yield_ty))
+ }
+ };
+
// Compute all of the variables that will be used to populate the coroutine.
let resume_ty = self.table.next_ty_var(closure_expr.into());
@@ -354,9 +391,9 @@ impl<'db> InferenceContext<'_, 'db> {
let expected_sig = sig_tys.with(hdr);
(Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn))
}
- ClosureKind::Coroutine(_)
- | ClosureKind::AsyncClosure
- | ClosureKind::AsyncBlock { .. } => (None, None),
+ ClosureKind::OldCoroutine(_)
+ | ClosureKind::Coroutine { .. }
+ | ClosureKind::CoroutineClosure(_) => (None, None),
},
_ => (None, None),
}
@@ -469,8 +506,10 @@ impl<'db> InferenceContext<'_, 'db> {
if let Some(trait_def_id) = trait_def_id {
let found_kind = match closure_kind {
- ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id),
- ClosureKind::AsyncClosure => self
+ ClosureKind::Closure | ClosureKind::CoroutineClosure(CoroutineKind::Gen) => {
+ self.fn_trait_kind_from_def_id(trait_def_id)
+ }
+ ClosureKind::CoroutineClosure(CoroutineKind::Async) => self
.async_fn_trait_kind_from_def_id(trait_def_id)
.or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)),
_ => None,
@@ -517,13 +556,17 @@ impl<'db> InferenceContext<'_, 'db> {
ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => {
self.extract_sig_from_projection(projection)
}
- ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.AsyncFnOnceOutput => {
+ ClosureKind::CoroutineClosure(CoroutineKind::Async)
+ if Some(def_id) == self.lang_items.AsyncFnOnceOutput =>
+ {
self.extract_sig_from_projection(projection)
}
// It's possible we've passed the closure to a (somewhat out-of-fashion)
// `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
// guide inference here, since it's beneficial for the user.
- ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.FnOnceOutput => {
+ ClosureKind::CoroutineClosure(CoroutineKind::Async)
+ if Some(def_id) == self.lang_items.FnOnceOutput =>
+ {
self.extract_sig_from_projection_and_future_bound(closure_expr, projection)
}
_ => None,