Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/hir-def/src/expr_store/body.rs5
-rw-r--r--crates/hir-def/src/expr_store/lower.rs125
-rw-r--r--crates/hir-def/src/expr_store/pretty.rs28
-rw-r--r--crates/hir-def/src/hir.rs13
-rw-r--r--crates/hir-def/src/lang_item.rs2
-rw-r--r--crates/hir-def/src/signatures.rs26
-rw-r--r--crates/hir-expand/src/mod_path.rs1
-rw-r--r--crates/hir-ty/src/db.rs6
-rw-r--r--crates/hir-ty/src/display.rs139
-rw-r--r--crates/hir-ty/src/infer/callee.rs1
-rw-r--r--crates/hir-ty/src/infer/closure.rs83
-rw-r--r--crates/hir-ty/src/infer/closure/analysis.rs5
-rw-r--r--crates/hir-ty/src/infer/expr.rs2
-rw-r--r--crates/hir-ty/src/next_solver/interner.rs51
-rw-r--r--crates/hir-ty/src/next_solver/ty.rs28
-rw-r--r--crates/hir-ty/src/tests/coercion.rs28
-rw-r--r--crates/hir-ty/src/tests/simple.rs132
-rw-r--r--crates/hir-ty/src/tests/traits.rs19
-rw-r--r--crates/hir/src/display.rs5
-rw-r--r--crates/ide/src/references.rs2
-rw-r--r--crates/intern/src/symbol/symbols.rs4
-rw-r--r--crates/test-utils/src/minicore.rs19
22 files changed, 596 insertions, 128 deletions
diff --git a/crates/hir-def/src/expr_store/body.rs b/crates/hir-def/src/expr_store/body.rs
index 6be3e49a70..2764677226 100644
--- a/crates/hir-def/src/expr_store/body.rs
+++ b/crates/hir-def/src/expr_store/body.rs
@@ -74,6 +74,7 @@ impl Body {
let mut params = None;
let mut is_async_fn = false;
+ let mut is_gen_fn = false;
let InFile { file_id, value: body } = {
match def {
DefWithBodyId::FunctionId(f) => {
@@ -81,6 +82,7 @@ impl Body {
let src = f.source(db);
params = src.value.param_list();
is_async_fn = src.value.async_token().is_some();
+ is_gen_fn = src.value.gen_token().is_some();
src.map(|it| it.body().map(ast::Expr::from))
}
DefWithBodyId::ConstId(c) => {
@@ -101,7 +103,8 @@ impl Body {
}
};
let module = def.module(db);
- let (body, source_map) = lower_body(db, def, file_id, module, params, body, is_async_fn);
+ let (body, source_map) =
+ lower_body(db, def, file_id, module, params, body, is_async_fn, is_gen_fn);
(Arc::new(body), source_map)
}
diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs
index fd8b50d714..5c8e87c0e3 100644
--- a/crates/hir-def/src/expr_store/lower.rs
+++ b/crates/hir-def/src/expr_store/lower.rs
@@ -47,8 +47,8 @@ use crate::{
},
hir::{
Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind,
- CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability,
- OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
+ CoroutineKind, CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm,
+ Movability, OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
generics::GenericParams,
},
item_scope::BuiltinShadowMode,
@@ -72,6 +72,7 @@ pub(super) fn lower_body(
parameters: Option<ast::ParamList>,
body: Option<ast::Expr>,
is_async_fn: bool,
+ is_gen_fn: bool,
) -> (Body, BodySourceMap) {
// We cannot leave the root span map empty and let any identifier from it be treated as root,
// because when inside nested macros `SyntaxContextId`s from the outer macro will be interleaved
@@ -176,6 +177,8 @@ pub(super) fn lower_body(
DefWithBodyId::VariantId(..) => Awaitable::No("enum variant"),
}
},
+ is_async_fn,
+ is_gen_fn,
);
collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]);
@@ -376,12 +379,20 @@ pub(crate) fn lower_function(
expr_collector.lower_type_ref_opt(ret_type.ty(), &mut ExprCollector::impl_trait_allocator)
});
- let return_type = if fn_.value.async_token().is_some() {
- let path = hir_expand::mod_path::path![core::future::Future];
+ let return_type = if fn_.value.async_token().is_some() || fn_.value.gen_token().is_some() {
+ let (path, assoc_name) =
+ match (fn_.value.async_token().is_some(), fn_.value.gen_token().is_some()) {
+ (true, true) => {
+ (hir_expand::mod_path::path![core::async_iter::AsyncIterator], sym::Item)
+ }
+ (true, false) => (hir_expand::mod_path::path![core::future::Future], sym::Output),
+ (false, true) => (hir_expand::mod_path::path![core::iter::Iterator], sym::Item),
+ (false, false) => unreachable!(),
+ };
let mut generic_args: Vec<_> =
std::iter::repeat_n(None, path.segments().len() - 1).collect();
let binding = AssociatedTypeBinding {
- name: Name::new_symbol_root(sym::Output),
+ name: Name::new_symbol_root(assoc_name),
args: None,
type_ref: Some(
return_type
@@ -950,10 +961,11 @@ impl<'db> ExprCollector<'db> {
/// into the body. This is to make sure that the future actually owns the
/// arguments that are passed to the function, and to ensure things like
/// drop order are stable.
- fn lower_async_block_with_moved_arguments(
+ fn lower_coroutine_with_moved_arguments(
&mut self,
params: &mut [PatId],
body: ExprId,
+ kind: CoroutineKind,
coroutine_source: CoroutineSource,
) -> ExprId {
let mut statements = Vec::new();
@@ -989,7 +1001,8 @@ impl<'db> ExprCollector<'db> {
*param = pat_id;
}
- let async_ = self.async_block(
+ let coroutine = self.desugared_coroutine_expr(
+ kind,
coroutine_source,
// The default capture mode here is by-ref. Later on during upvar analysis,
// we will force the captured arguments to by-move, but for async closures,
@@ -1001,11 +1014,12 @@ impl<'db> ExprCollector<'db> {
Some(body),
);
// It's important that this comes last, see the lowering of async closures for why.
- self.alloc_expr_desugared(async_)
+ self.alloc_expr_desugared(coroutine)
}
- fn async_block(
+ fn desugared_coroutine_expr(
&mut self,
+ kind: CoroutineKind,
source: CoroutineSource,
capture_by: CaptureBy,
id: Option<BlockId>,
@@ -1018,7 +1032,7 @@ impl<'db> ExprCollector<'db> {
arg_types: Box::default(),
ret_type: None,
body: block,
- closure_kind: ClosureKind::AsyncBlock { source },
+ closure_kind: ClosureKind::Coroutine { kind, source },
capture_by,
}
}
@@ -1028,12 +1042,20 @@ impl<'db> ExprCollector<'db> {
params: &mut [PatId],
expr: Option<ast::Expr>,
awaitable: Awaitable,
+ is_async_fn: bool,
+ is_gen_fn: bool,
) -> ExprId {
self.awaitable_context.replace(awaitable);
self.with_label_rib(RibKind::Closure, |this| {
let body = this.collect_expr_opt(expr);
- if awaitable == Awaitable::Yes {
- this.lower_async_block_with_moved_arguments(params, body, CoroutineSource::Fn)
+ if is_async_fn || is_gen_fn {
+ let kind = match (is_async_fn, is_gen_fn) {
+ (true, true) => CoroutineKind::AsyncGen,
+ (true, false) => CoroutineKind::Async,
+ (false, true) => CoroutineKind::Gen,
+ (false, false) => unreachable!(),
+ };
+ this.lower_coroutine_with_moved_arguments(params, body, kind, CoroutineSource::Fn)
} else {
body
}
@@ -1192,7 +1214,44 @@ impl<'db> ExprCollector<'db> {
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::Yes, |this| {
this.collect_block_(e, |this, id, statements, tail| {
- this.async_block(
+ this.desugared_coroutine_expr(
+ CoroutineKind::Async,
+ CoroutineSource::Block,
+ capture_by,
+ id,
+ statements,
+ tail,
+ )
+ })
+ })
+ })
+ }
+ Some(ast::BlockModifier::Gen(_)) => {
+ let capture_by =
+ if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
+ self.with_label_rib(RibKind::Closure, |this| {
+ this.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
+ this.collect_block_(e, |this, id, statements, tail| {
+ this.desugared_coroutine_expr(
+ CoroutineKind::Gen,
+ CoroutineSource::Block,
+ capture_by,
+ id,
+ statements,
+ tail,
+ )
+ })
+ })
+ })
+ }
+ Some(ast::BlockModifier::AsyncGen(_)) => {
+ let capture_by =
+ if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
+ self.with_label_rib(RibKind::Closure, |this| {
+ this.with_awaitable_block(Awaitable::Yes, |this| {
+ this.collect_block_(e, |this, id, statements, tail| {
+ this.desugared_coroutine_expr(
+ CoroutineKind::AsyncGen,
CoroutineSource::Block,
capture_by,
id,
@@ -1213,14 +1272,6 @@ impl<'db> ExprCollector<'db> {
})
})
}
- // FIXME
- Some(ast::BlockModifier::AsyncGen(_)) => {
- self.with_awaitable_block(Awaitable::Yes, |this| this.collect_block(e))
- }
- Some(ast::BlockModifier::Gen(_)) => self
- .with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
- this.collect_block(e)
- }),
None => self.collect_block(e),
},
ast::Expr::LoopExpr(e) => {
@@ -1460,25 +1511,37 @@ impl<'db> ExprCollector<'db> {
};
let mut body = this
.with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body()));
-
- let closure_kind = if this.is_lowering_coroutine {
- let movability = if e.static_token().is_some() {
- Movability::Static
+ let kind = {
+ if e.async_token().is_some() && e.gen_token().is_some() {
+ Some(CoroutineKind::AsyncGen)
+ } else if e.async_token().is_some() {
+ Some(CoroutineKind::Async)
+ } else if e.gen_token().is_some() {
+ Some(CoroutineKind::Gen)
} else {
- Movability::Movable
- };
- ClosureKind::Coroutine(movability)
- } else if e.async_token().is_some() {
+ None
+ }
+ };
+
+ let closure_kind = if let Some(kind) = kind {
// It's important that this expr is allocated immediately before the closure.
// We rely on it for `coroutine_for_closure()`.
- body = this.lower_async_block_with_moved_arguments(
+ body = this.lower_coroutine_with_moved_arguments(
&mut args,
body,
+ kind,
CoroutineSource::Closure,
);
body_is_bindings_owner = true;
- ClosureKind::AsyncClosure
+ ClosureKind::CoroutineClosure(kind)
+ } else if this.is_lowering_coroutine {
+ let movability = if e.static_token().is_some() {
+ Movability::Static
+ } else {
+ Movability::Movable
+ };
+ ClosureKind::OldCoroutine(movability)
} else {
ClosureKind::Closure
};
diff --git a/crates/hir-def/src/expr_store/pretty.rs b/crates/hir-def/src/expr_store/pretty.rs
index bb35009f36..5d90191503 100644
--- a/crates/hir-def/src/expr_store/pretty.rs
+++ b/crates/hir-def/src/expr_store/pretty.rs
@@ -17,8 +17,8 @@ use crate::{
attrs::AttrFlags,
expr_store::path::{GenericArg, GenericArgs},
hir::{
- Array, BindingAnnotation, CaptureBy, ClosureKind, Literal, Movability, RecordSpread,
- Statement,
+ Array, BindingAnnotation, CaptureBy, ClosureKind, CoroutineKind, Literal, Movability,
+ RecordSpread, Statement,
generics::{GenericParams, WherePredicate},
},
lang_item::LangItemTarget,
@@ -761,28 +761,36 @@ impl Printer<'_> {
let mut body = *body;
let mut print_pipes = true;
match closure_kind {
- ClosureKind::Coroutine(Movability::Static) => {
+ ClosureKind::OldCoroutine(Movability::Static) => {
w!(self, "static ");
}
- ClosureKind::AsyncClosure => {
+ ClosureKind::CoroutineClosure(kind) => {
if let Expr::Closure {
body: inner_body,
- closure_kind: ClosureKind::AsyncBlock { .. },
+ closure_kind: ClosureKind::Coroutine { .. },
..
} = self.store[body]
{
body = inner_body;
} else {
- never!("async closure should always have an async block body");
+ never!("coroutine closure should always have a coroutine body");
}
- w!(self, "async ");
+ match kind {
+ CoroutineKind::Async => w!(self, "async "),
+ CoroutineKind::Gen => w!(self, "gen "),
+ CoroutineKind::AsyncGen => w!(self, "async gen "),
+ }
}
- ClosureKind::AsyncBlock { .. } => {
- w!(self, "async ");
+ ClosureKind::Coroutine { kind, .. } => {
+ match kind {
+ CoroutineKind::Async => w!(self, "async "),
+ CoroutineKind::Gen => w!(self, "gen "),
+ CoroutineKind::AsyncGen => w!(self, "async gen "),
+ }
print_pipes = false;
}
- ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (),
+ ClosureKind::Closure | ClosureKind::OldCoroutine(Movability::Movable) => (),
}
match capture_by {
CaptureBy::Value => {
diff --git a/crates/hir-def/src/hir.rs b/crates/hir-def/src/hir.rs
index 93fa7ff961..6bea505757 100644
--- a/crates/hir-def/src/hir.rs
+++ b/crates/hir-def/src/hir.rs
@@ -525,11 +525,18 @@ pub enum InlineAsmRegOrRegClass {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum CoroutineKind {
+ Async,
+ Gen,
+ AsyncGen,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClosureKind {
Closure,
- Coroutine(Movability),
- AsyncBlock { source: CoroutineSource },
- AsyncClosure,
+ OldCoroutine(Movability),
+ Coroutine { kind: CoroutineKind, source: CoroutineSource },
+ CoroutineClosure(CoroutineKind),
}
/// In the case of a coroutine created as part of an async/gen construct,
diff --git a/crates/hir-def/src/lang_item.rs b/crates/hir-def/src/lang_item.rs
index adc445c2a8..e5a9b5d46c 100644
--- a/crates/hir-def/src/lang_item.rs
+++ b/crates/hir-def/src/lang_item.rs
@@ -413,6 +413,7 @@ language_item_table! { LangItems =>
FnOnceOutput, sym::fn_once_output, TypeAliasId;
Future, sym::future_trait, TraitId;
+ AsyncIterator, sym::async_iterator, TraitId;
CoroutineState, sym::coroutine_state, EnumId;
Coroutine, sym::coroutine, TraitId;
CoroutineReturn, sym::coroutine_return, TypeAliasId;
@@ -522,7 +523,6 @@ language_item_table! { LangItems =>
IteratorNext, sym::next, FunctionId;
Iterator, sym::iterator, TraitId;
FusedIterator, sym::fused_iterator, TraitId;
- AsyncIterator, sym::async_iterator, TraitId;
PinNewUnchecked, sym::new_unchecked, FunctionId;
diff --git a/crates/hir-def/src/signatures.rs b/crates/hir-def/src/signatures.rs
index f03ad5dae8..474b238add 100644
--- a/crates/hir-def/src/signatures.rs
+++ b/crates/hir-def/src/signatures.rs
@@ -567,19 +567,20 @@ bitflags! {
const DEFAULT = 1 << 2;
const CONST = 1 << 3;
const ASYNC = 1 << 4;
- const UNSAFE = 1 << 5;
- const HAS_VARARGS = 1 << 6;
- const RUSTC_ALLOW_INCOHERENT_IMPL = 1 << 7;
- const HAS_SELF_PARAM = 1 << 8;
+ const GEN = 1 << 5;
+ const UNSAFE = 1 << 6;
+ const HAS_VARARGS = 1 << 7;
+ const RUSTC_ALLOW_INCOHERENT_IMPL = 1 << 8;
+ const HAS_SELF_PARAM = 1 << 9;
/// The `#[target_feature]` attribute is necessary to check safety (with RFC 2396),
/// but keeping it for all functions will consume a lot of memory when there are
/// only very few functions with it. So we only encode its existence here, and lookup
/// it if needed.
- const HAS_TARGET_FEATURE = 1 << 9;
- const DEPRECATED_SAFE_2024 = 1 << 10;
- const EXPLICIT_SAFE = 1 << 11;
- const HAS_LEGACY_CONST_GENERICS = 1 << 12;
- const RUSTC_INTRINSIC = 1 << 13;
+ const HAS_TARGET_FEATURE = 1 << 10;
+ const DEPRECATED_SAFE_2024 = 1 << 11;
+ const EXPLICIT_SAFE = 1 << 12;
+ const HAS_LEGACY_CONST_GENERICS = 1 << 13;
+ const RUSTC_INTRINSIC = 1 << 14;
}
}
@@ -638,6 +639,9 @@ impl FunctionSignature {
if source.value.async_token().is_some() {
flags.insert(FnFlags::ASYNC);
}
+ if source.value.gen_token().is_some() {
+ flags.insert(FnFlags::GEN);
+ }
if source.value.const_token().is_some() {
flags.insert(FnFlags::CONST);
}
@@ -711,6 +715,10 @@ impl FunctionSignature {
self.flags.contains(FnFlags::ASYNC)
}
+ pub fn is_gen(&self) -> bool {
+ self.flags.contains(FnFlags::GEN)
+ }
+
pub fn is_unsafe(&self) -> bool {
self.flags.contains(FnFlags::UNSAFE)
}
diff --git a/crates/hir-expand/src/mod_path.rs b/crates/hir-expand/src/mod_path.rs
index 78228cf82e..3e108f72be 100644
--- a/crates/hir-expand/src/mod_path.rs
+++ b/crates/hir-expand/src/mod_path.rs
@@ -427,6 +427,7 @@ macro_rules! __known_path {
(core::range::RangeFrom) => {};
(core::range::RangeInclusive) => {};
(core::range::RangeToInclusive) => {};
+ (core::async_iter::AsyncIterator) => {};
(core::future::Future) => {};
(core::future::IntoFuture) => {};
(core::fmt::Debug) => {};
diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs
index 99bad2682b..821ab5fc04 100644
--- a/crates/hir-ty/src/db.rs
+++ b/crates/hir-ty/src/db.rs
@@ -276,8 +276,8 @@ impl InternedCoroutineId {
matches!(
expr,
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::Coroutine(_)
- | hir_def::hir::ClosureKind::AsyncBlock { .. },
+ closure_kind: hir_def::hir::ClosureKind::OldCoroutine(_)
+ | hir_def::hir::ClosureKind::Coroutine { .. },
..
}
),
@@ -305,7 +305,7 @@ impl InternedCoroutineClosureId {
matches!(
expr,
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::AsyncClosure,
+ closure_kind: hir_def::hir::ClosureKind::CoroutineClosure(_),
..
}
),
diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs
index e4a8def442..8b00bdf391 100644
--- a/crates/hir-ty/src/display.rs
+++ b/crates/hir-ty/src/display.rs
@@ -14,7 +14,10 @@ use hir_def::{
Lookup, ModuleDefId, ModuleId, TraitId,
expr_store::{ExpressionStore, path::Path},
find_path::{self, PrefixKind},
- hir::generics::{GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate},
+ hir::{
+ ClosureKind as HirClosureKind, CoroutineKind,
+ generics::{GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate},
+ },
item_scope::ItemInNs,
item_tree::FieldsShape,
lang_item::LangItems,
@@ -65,6 +68,36 @@ use crate::{
utils::{detect_variant_from_bytes, fn_traits},
};
+fn async_gen_item_ty_from_yield_ty<'db>(
+ lang_items: &LangItems,
+ yield_ty: Ty<'db>,
+) -> Option<Ty<'db>> {
+ let poll_id = lang_items.Poll.map(hir_def::AdtId::EnumId)?;
+ let option_id = lang_items.Option.map(hir_def::AdtId::EnumId)?;
+
+ let TyKind::Adt(poll_def, poll_args) = yield_ty.kind() else {
+ return None;
+ };
+ if poll_def.inner().id != poll_id {
+ return None;
+ }
+ let [poll_inner] = poll_args.as_slice() else {
+ return None;
+ };
+ let poll_inner = poll_inner.ty()?;
+
+ let TyKind::Adt(option_def, option_args) = poll_inner.kind() else {
+ return None;
+ };
+ if option_def.inner().id != option_id {
+ return None;
+ }
+ let [item] = option_args.as_slice() else {
+ return None;
+ };
+ item.ty()
+}
+
pub type Result<T = (), E = HirDisplayError> = std::result::Result<T, E>;
pub trait HirWrite: fmt::Write {
@@ -1519,6 +1552,22 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
}
TyKind::CoroutineClosure(id, args) => {
let id = id.0;
+ let closure_kind = match id.loc(db) {
+ InternedClosure(owner, expr_id) => {
+ match &ExpressionStore::of(db, owner)[expr_id] {
+ hir_def::hir::Expr::Closure {
+ closure_kind: HirClosureKind::CoroutineClosure(kind),
+ ..
+ } => *kind,
+ expr => panic!("invalid expr for coroutine closure: {expr:?}"),
+ }
+ }
+ };
+ let closure_label = match closure_kind {
+ CoroutineKind::Async => "async closure",
+ CoroutineKind::Gen => "gen closure",
+ CoroutineKind::AsyncGen => "async gen closure",
+ };
if f.display_kind.is_source_code() {
if !f.display_kind.allows_opaque() {
return Err(HirDisplayError::DisplaySourceCodeError(
@@ -1533,25 +1582,28 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
ClosureStyle::ClosureWithId => {
return write!(
f,
- "{{async closure#{:?}}}",
+ "{{{closure_label}#{:?}}}",
salsa::plumbing::AsId::as_id(&id).index()
);
}
ClosureStyle::ClosureWithSubst => {
write!(
f,
- "{{async closure#{:?}}}",
+ "{{{closure_label}#{:?}}}",
salsa::plumbing::AsId::as_id(&id).index()
)?;
return hir_fmt_generics(f, args.as_slice(), None, None);
}
_ => (),
}
- let kind = args.as_coroutine_closure().kind();
- let kind = match kind {
- rustc_type_ir::ClosureKind::Fn => "AsyncFn",
- rustc_type_ir::ClosureKind::FnMut => "AsyncFnMut",
- rustc_type_ir::ClosureKind::FnOnce => "AsyncFnOnce",
+ let callable_kind = args.as_coroutine_closure().kind();
+ let kind = match (closure_kind, callable_kind) {
+ (CoroutineKind::Async, rustc_type_ir::ClosureKind::Fn) => "AsyncFn",
+ (CoroutineKind::Async, rustc_type_ir::ClosureKind::FnMut) => "AsyncFnMut",
+ (CoroutineKind::Async, rustc_type_ir::ClosureKind::FnOnce) => "AsyncFnOnce",
+ (_, rustc_type_ir::ClosureKind::Fn) => "Fn",
+ (_, rustc_type_ir::ClosureKind::FnMut) => "FnMut",
+ (_, rustc_type_ir::ClosureKind::FnOnce) => "FnOnce",
};
let coroutine_sig = args.as_coroutine_closure().coroutine_closure_sig();
let coroutine_sig = coroutine_sig.skip_binder();
@@ -1559,7 +1611,11 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
let coroutine_output = coroutine_sig.return_ty;
match f.closure_style {
ClosureStyle::ImplFn => write!(f, "impl {kind}(")?,
- ClosureStyle::RANotation => write!(f, "async |")?,
+ ClosureStyle::RANotation => match closure_kind {
+ CoroutineKind::Async => write!(f, "async |")?,
+ CoroutineKind::Gen => write!(f, "gen |")?,
+ CoroutineKind::AsyncGen => write!(f, "async gen |")?,
+ },
_ => unreachable!(),
}
if coroutine_inputs.is_empty() {
@@ -1677,7 +1733,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
let expr = &body[expr_id];
match expr {
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. },
+ closure_kind: HirClosureKind::Coroutine { kind: CoroutineKind::Async, .. },
..
} => {
let future_trait = f.lang_items().Future;
@@ -1706,7 +1762,68 @@ impl<'db> HirDisplay<'db> for Ty<'db> {
write!(f, ">")?;
}
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::Coroutine(..),
+ closure_kind: HirClosureKind::Coroutine { kind: CoroutineKind::Gen, .. },
+ ..
+ } => {
+ let iterator_trait = f.lang_items().Iterator;
+ let item = iterator_trait.and_then(|t| {
+ t.trait_items(db)
+ .associated_type_by_name(&Name::new_symbol_root(sym::Item))
+ });
+ write!(f, "impl ")?;
+ if let Some(t) = iterator_trait {
+ f.start_location_link(t.into());
+ }
+ write!(f, "Iterator")?;
+ if iterator_trait.is_some() {
+ f.end_location_link();
+ }
+ write!(f, "<")?;
+ if let Some(t) = item {
+ f.start_location_link(t.into());
+ }
+ write!(f, "Item")?;
+ if item.is_some() {
+ f.end_location_link();
+ }
+ write!(f, " = ")?;
+ yield_ty.hir_fmt(f)?;
+ write!(f, ">")?;
+ }
+ hir_def::hir::Expr::Closure {
+ closure_kind:
+ HirClosureKind::Coroutine { kind: CoroutineKind::AsyncGen, .. },
+ ..
+ } => {
+ let async_iterator_trait = f.lang_items().AsyncIterator;
+ let item = async_iterator_trait.and_then(|t| {
+ t.trait_items(db)
+ .associated_type_by_name(&Name::new_symbol_root(sym::Item))
+ });
+ write!(f, "impl ")?;
+ if let Some(t) = async_iterator_trait {
+ f.start_location_link(t.into());
+ }
+ write!(f, "AsyncIterator")?;
+ if async_iterator_trait.is_some() {
+ f.end_location_link();
+ }
+ write!(f, "<")?;
+ if let Some(t) = item {
+ f.start_location_link(t.into());
+ }
+ write!(f, "Item")?;
+ if item.is_some() {
+ f.end_location_link();
+ }
+ write!(f, " = ")?;
+ let item_ty = async_gen_item_ty_from_yield_ty(f.lang_items(), yield_ty)
+ .unwrap_or(yield_ty);
+ item_ty.hir_fmt(f)?;
+ write!(f, ">")?;
+ }
+ hir_def::hir::Expr::Closure {
+ closure_kind: HirClosureKind::OldCoroutine(..),
..
} => {
if f.display_kind.is_source_code() {
diff --git a/crates/hir-ty/src/infer/callee.rs b/crates/hir-ty/src/infer/callee.rs
index ffdde58c48..6c86b6720f 100644
--- a/crates/hir-ty/src/infer/callee.rs
+++ b/crates/hir-ty/src/infer/callee.rs
@@ -272,6 +272,7 @@ impl<'db> InferenceContext<'_, 'db> {
// ...or *ideally*, we just have `LendingFn`/`LendingFnMut`, which
// would naturally unify these two trait hierarchies in the most
// general way.
+
let call_trait_choices = if self.shallow_resolve(adjusted_ty).is_coroutine_closure() {
[
(self.lang_items.AsyncFn, sym::async_call, true),
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 93c98f2542..f5b974b1db 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -5,8 +5,8 @@ pub(crate) mod analysis;
use std::{iter, mem, ops::ControlFlow};
use hir_def::{
- TraitId,
- hir::{ClosureKind, CoroutineSource, ExprId, PatId},
+ AdtId, TraitId,
+ hir::{ClosureKind, CoroutineKind, CoroutineSource, ExprId, PatId},
type_ref::TypeRefId,
};
use rustc_type_ir::{
@@ -22,8 +22,8 @@ use crate::{
db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId},
infer::{BreakableKind, Diverges, coerce::CoerceMany, pat::PatOrigin},
next_solver::{
- AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig,
- PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind,
+ AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArg, GenericArgs,
+ PolyFnSig, PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind,
abi::Safety,
infer::{
BoundRegionConversionTime, InferOk, InferResult,
@@ -46,6 +46,22 @@ struct ClosureSignatures<'db> {
}
impl<'db> InferenceContext<'_, 'db> {
+ fn poll_option_ty(&mut self, item_ty: Ty<'db>) -> Ty<'db> {
+ let interner = self.interner();
+
+ let (Some(option), Some(poll)) = (self.lang_items.Option, self.lang_items.Poll) else {
+ return self.types.types.error;
+ };
+
+ let option_ty = Ty::new_adt(
+ interner,
+ AdtId::EnumId(option),
+ interner.mk_args(&[GenericArg::from(item_ty)]),
+ );
+
+ Ty::new_adt(interner, AdtId::EnumId(poll), interner.mk_args(&[GenericArg::from(option_ty)]))
+ }
+
pub(super) fn infer_closure(
&mut self,
body: ExprId,
@@ -121,10 +137,22 @@ impl<'db> InferenceContext<'_, 'db> {
(Ty::new_closure(interner, closure_id.into(), closure_args.args), None)
}
- ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => {
+ ClosureKind::OldCoroutine(_) | ClosureKind::Coroutine { .. } => {
let yield_ty = match closure_kind {
- ClosureKind::Coroutine(_) => self.table.next_ty_var(closure_expr.into()),
- ClosureKind::AsyncBlock { .. } => self.types.types.unit,
+ ClosureKind::OldCoroutine(_)
+ | ClosureKind::Coroutine { kind: CoroutineKind::Gen, .. } => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ self.require_type_is_sized(yield_ty);
+ yield_ty
+ }
+ ClosureKind::Coroutine { kind: CoroutineKind::Async, .. } => {
+ self.types.types.unit
+ }
+ ClosureKind::Coroutine { kind: CoroutineKind::AsyncGen, .. } => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ self.require_type_is_sized(yield_ty);
+ self.poll_option_ty(yield_ty)
+ }
_ => unreachable!(),
};
@@ -137,7 +165,7 @@ impl<'db> InferenceContext<'_, 'db> {
// later during upvar analysis. Regular coroutines always have the kind
// ty of `().`
let kind_ty = match closure_kind {
- ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => {
+ ClosureKind::Coroutine { source: CoroutineSource::Closure, .. } => {
self.table.next_ty_var(closure_expr.into())
}
_ => self.types.types.unit,
@@ -163,11 +191,20 @@ impl<'db> InferenceContext<'_, 'db> {
Some((resume_ty, yield_ty)),
)
}
- ClosureKind::AsyncClosure => {
- // async closures always return the type ascribed after the `->` (if present),
- // and yield `()`.
- let (bound_return_ty, bound_yield_ty) =
- (bound_sig.skip_binder().output(), self.types.types.unit);
+ ClosureKind::CoroutineClosure(coroutine_kind) => {
+ let (bound_return_ty, bound_yield_ty) = match coroutine_kind {
+ CoroutineKind::Gen => {
+ (self.types.types.unit, self.table.next_ty_var(closure_expr.into()))
+ }
+ CoroutineKind::Async => {
+ (bound_sig.skip_binder().output(), self.types.types.unit)
+ }
+ CoroutineKind::AsyncGen => {
+ let yield_ty = self.table.next_ty_var(closure_expr.into());
+ (self.types.types.unit, self.poll_option_ty(yield_ty))
+ }
+ };
+
// Compute all of the variables that will be used to populate the coroutine.
let resume_ty = self.table.next_ty_var(closure_expr.into());
@@ -354,9 +391,9 @@ impl<'db> InferenceContext<'_, 'db> {
let expected_sig = sig_tys.with(hdr);
(Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn))
}
- ClosureKind::Coroutine(_)
- | ClosureKind::AsyncClosure
- | ClosureKind::AsyncBlock { .. } => (None, None),
+ ClosureKind::OldCoroutine(_)
+ | ClosureKind::Coroutine { .. }
+ | ClosureKind::CoroutineClosure(_) => (None, None),
},
_ => (None, None),
}
@@ -469,8 +506,10 @@ impl<'db> InferenceContext<'_, 'db> {
if let Some(trait_def_id) = trait_def_id {
let found_kind = match closure_kind {
- ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id),
- ClosureKind::AsyncClosure => self
+ ClosureKind::Closure | ClosureKind::CoroutineClosure(CoroutineKind::Gen) => {
+ self.fn_trait_kind_from_def_id(trait_def_id)
+ }
+ ClosureKind::CoroutineClosure(CoroutineKind::Async) => self
.async_fn_trait_kind_from_def_id(trait_def_id)
.or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)),
_ => None,
@@ -517,13 +556,17 @@ impl<'db> InferenceContext<'_, 'db> {
ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => {
self.extract_sig_from_projection(projection)
}
- ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.AsyncFnOnceOutput => {
+ ClosureKind::CoroutineClosure(CoroutineKind::Async)
+ if Some(def_id) == self.lang_items.AsyncFnOnceOutput =>
+ {
self.extract_sig_from_projection(projection)
}
// It's possible we've passed the closure to a (somewhat out-of-fashion)
// `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
// guide inference here, since it's beneficial for the user.
- ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.FnOnceOutput => {
+ ClosureKind::CoroutineClosure(CoroutineKind::Async)
+ if Some(def_id) == self.lang_items.FnOnceOutput =>
+ {
self.extract_sig_from_projection_and_future_bound(closure_expr, projection)
}
_ => None,
diff --git a/crates/hir-ty/src/infer/closure/analysis.rs b/crates/hir-ty/src/infer/closure/analysis.rs
index 31b6252475..79bdc6cea1 100644
--- a/crates/hir-ty/src/infer/closure/analysis.rs
+++ b/crates/hir-ty/src/infer/closure/analysis.rs
@@ -284,7 +284,7 @@ impl<'a, 'db> InferenceContext<'a, 'db> {
// coroutine-closures that are `move` since otherwise they themselves will
// be borrowing from the outer environment, so there's no self-borrows occurring.
if let UpvarArgs::Coroutine(..) = args
- && let hir_def::hir::ClosureKind::AsyncBlock { source: CoroutineSource::Closure } =
+ && let hir_def::hir::ClosureKind::Coroutine { source: CoroutineSource::Closure, .. } =
closure_kind
&& let parent_hir_id = ExpressionStore::closure_for_coroutine(closure_expr_id)
&& let parent_ty = self.result.expr_ty(parent_hir_id)
@@ -310,8 +310,9 @@ impl<'a, 'db> InferenceContext<'a, 'db> {
//
// FIXME(async_closures): This could be cleaned up. It's a bit janky that we're just
// moving all of the `LocalSource::AsyncFn` locals here.
- if let hir_def::hir::ClosureKind::AsyncBlock {
+ if let hir_def::hir::ClosureKind::Coroutine {
source: CoroutineSource::Fn | CoroutineSource::Closure,
+ ..
} = closure_kind
{
let Expr::Block { statements, .. } = &self.store[body] else {
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs
index a6c8cda404..2b19c445ed 100644
--- a/crates/hir-ty/src/infer/expr.rs
+++ b/crates/hir-ty/src/infer/expr.rs
@@ -2129,7 +2129,7 @@ impl<'db> InferenceContext<'_, 'db> {
// closure wrapped in a block.
// See <https://github.com/rust-lang/rust/issues/112225>.
let is_closure = if let Expr::Closure { closure_kind, .. } = self.store[*arg] {
- !matches!(closure_kind, ClosureKind::Coroutine(_))
+ !matches!(closure_kind, ClosureKind::OldCoroutine(_))
} else {
false
};
diff --git a/crates/hir-ty/src/next_solver/interner.rs b/crates/hir-ty/src/next_solver/interner.rs
index af3798f1da..14e20dfe80 100644
--- a/crates/hir-ty/src/next_solver/interner.rs
+++ b/crates/hir-ty/src/next_solver/interner.rs
@@ -1318,11 +1318,11 @@ impl<'db> Interner for DbInterner<'db> {
let expr = &store[expr_id];
match *expr {
hir_def::hir::Expr::Closure { closure_kind, .. } => match closure_kind {
- hir_def::hir::ClosureKind::Coroutine(movability) => match movability {
+ hir_def::hir::ClosureKind::OldCoroutine(movability) => match movability {
hir_def::hir::Movability::Static => rustc_ast_ir::Movability::Static,
hir_def::hir::Movability::Movable => rustc_ast_ir::Movability::Movable,
},
- hir_def::hir::ClosureKind::AsyncBlock { .. } => rustc_ast_ir::Movability::Static,
+ hir_def::hir::ClosureKind::Coroutine { .. } => rustc_ast_ir::Movability::Static,
_ => panic!("unexpected expression for a coroutine: {expr:?}"),
},
_ => panic!("unexpected expression for a coroutine: {expr:?}"),
@@ -1560,7 +1560,6 @@ impl<'db> Interner for DbInterner<'db> {
ignore = {
AsyncFnKindHelper,
- AsyncIterator,
BikeshedGuaranteedNoDrop,
FusedIterator,
Field,
@@ -1587,6 +1586,7 @@ impl<'db> Interner for DbInterner<'db> {
Unpin,
Tuple,
Iterator,
+ AsyncIterator,
AsyncFn,
AsyncFnMut,
AsyncFnOnce,
@@ -1652,7 +1652,6 @@ impl<'db> Interner for DbInterner<'db> {
ignore = {
AsyncFnKindHelper,
- AsyncIterator,
BikeshedGuaranteedNoDrop,
FusedIterator,
Field,
@@ -1679,6 +1678,7 @@ impl<'db> Interner for DbInterner<'db> {
Unpin,
Tuple,
Iterator,
+ AsyncIterator,
AsyncFn,
AsyncFnMut,
AsyncFnOnce,
@@ -1943,7 +1943,7 @@ impl<'db> Interner for DbInterner<'db> {
matches!(
store[expr_id],
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::Coroutine(_),
+ closure_kind: hir_def::hir::ClosureKind::OldCoroutine(_),
..
}
)
@@ -1957,20 +1957,43 @@ impl<'db> Interner for DbInterner<'db> {
matches!(
store[expr_id],
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. },
+ closure_kind: hir_def::hir::ClosureKind::Coroutine {
+ kind: hir_def::hir::CoroutineKind::Async,
+ ..
+ },
..
}
)
}
- fn coroutine_is_gen(self, _coroutine_def_id: Self::CoroutineId) -> bool {
- // We don't handle gen coroutines yet.
- false
+ fn coroutine_is_gen(self, def_id: Self::CoroutineId) -> bool {
+ let InternedClosure(owner, expr_id) = def_id.0.loc(self.db);
+ let store = ExpressionStore::of(self.db, owner);
+ matches!(
+ store[expr_id],
+ hir_def::hir::Expr::Closure {
+ closure_kind: hir_def::hir::ClosureKind::Coroutine {
+ kind: hir_def::hir::CoroutineKind::Gen,
+ ..
+ },
+ ..
+ }
+ )
}
- fn coroutine_is_async_gen(self, _coroutine_def_id: Self::CoroutineId) -> bool {
- // We don't handle gen coroutines yet.
- false
+ fn coroutine_is_async_gen(self, def_id: Self::CoroutineId) -> bool {
+ let InternedClosure(owner, expr_id) = def_id.0.loc(self.db);
+ let store = ExpressionStore::of(self.db, owner);
+ matches!(
+ store[expr_id],
+ hir_def::hir::Expr::Closure {
+ closure_kind: hir_def::hir::ClosureKind::Coroutine {
+ kind: hir_def::hir::CoroutineKind::AsyncGen,
+ ..
+ },
+ ..
+ }
+ )
}
fn unsizing_params_for_adt(self, id: Self::AdtId) -> Self::UnsizingParams {
@@ -2084,8 +2107,8 @@ impl<'db> Interner for DbInterner<'db> {
if matches!(
expr,
hir_def::hir::Expr::Closure {
- closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. }
- | hir_def::hir::ClosureKind::Coroutine(_),
+ closure_kind: hir_def::hir::ClosureKind::Coroutine { .. }
+ | hir_def::hir::ClosureKind::OldCoroutine(_),
..
}
) {
diff --git a/crates/hir-ty/src/next_solver/ty.rs b/crates/hir-ty/src/next_solver/ty.rs
index 511259ecd8..3bd20e9064 100644
--- a/crates/hir-ty/src/next_solver/ty.rs
+++ b/crates/hir-ty/src/next_solver/ty.rs
@@ -12,10 +12,10 @@ use macros::GenericTypeVisitable;
use rustc_abi::{Float, Integer, Size};
use rustc_ast_ir::{Mutability, try_visit, visit::VisitorResult};
use rustc_type_ir::{
- AliasTyKind, BoundVar, BoundVarIndexKind, ClosureKind, CoroutineArgs, CoroutineArgsParts,
- DebruijnIndex, FlagComputation, Flags, FloatTy, FloatVid, GenericTypeVisitable, InferTy, IntTy,
- IntVid, Interner, TyVid, TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable,
- TypeVisitableExt, TypeVisitor, UintTy, Upcast, WithCachedTypeInfo,
+ AliasTyKind, BoundVar, BoundVarIndexKind, ClosureKind, DebruijnIndex, FlagComputation, Flags,
+ FloatTy, FloatVid, GenericTypeVisitable, InferTy, IntTy, IntVid, Interner, TyVid, TypeFoldable,
+ TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, UintTy,
+ Upcast, WithCachedTypeInfo,
inherent::{
AdtDef as _, BoundExistentialPredicates, GenericArgs as _, IntoKind, ParamLike,
Safety as _, SliceLike, Ty as _,
@@ -575,23 +575,13 @@ impl<'db> Ty<'db> {
}
TyKind::CoroutineClosure(coroutine_id, args) => {
Some(args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
- let unit_ty = Ty::new_unit(interner);
- let return_ty = Ty::new_coroutine(
+ let closure_args = args.as_coroutine_closure();
+ let return_ty = sig.to_coroutine(
interner,
+ closure_args.parent_args(),
+ closure_args.kind_ty(),
interner.coroutine_for_closure(coroutine_id),
- CoroutineArgs::new(
- interner,
- CoroutineArgsParts {
- parent_args: args.as_coroutine_closure().parent_args(),
- kind_ty: unit_ty,
- resume_ty: unit_ty,
- yield_ty: unit_ty,
- return_ty: sig.return_ty,
- // FIXME: Deduce this from the coroutine closure's upvars.
- tupled_upvars_ty: unit_ty,
- },
- )
- .args,
+ closure_args.tupled_upvars_ty(),
);
FnSig {
inputs_and_output: Tys::new_from_iter(
diff --git a/crates/hir-ty/src/tests/coercion.rs b/crates/hir-ty/src/tests/coercion.rs
index cc3464cd7f..05dbb1a8ac 100644
--- a/crates/hir-ty/src/tests/coercion.rs
+++ b/crates/hir-ty/src/tests/coercion.rs
@@ -392,6 +392,34 @@ fn test() {
}
#[test]
+fn gen_yield_coerce() {
+ check_no_mismatches(
+ r#"
+fn test() {
+ let g = gen {
+ yield &1u32;
+ yield &&1u32;
+ };
+}
+ "#,
+ );
+}
+
+#[test]
+fn async_gen_yield_coerce() {
+ check_no_mismatches(
+ r#"
+fn test() {
+ let g = async gen {
+ yield &1u32;
+ yield &&1u32;
+ };
+}
+ "#,
+ );
+}
+
+#[test]
fn assign_coerce() {
check_no_mismatches(
r"
diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs
index fb46e4b58b..1e75c31fa1 100644
--- a/crates/hir-ty/src/tests/simple.rs
+++ b/crates/hir-ty/src/tests/simple.rs
@@ -2066,6 +2066,138 @@ fn test() {
}
#[test]
+fn gen_block_types_inferred() {
+ check_infer(
+ r#"
+//- minicore: iterator, deref
+use core::iter::Iterator;
+
+fn test() {
+ let mut generator = gen {
+ yield 1i8;
+ };
+ let result = generator.next();
+}
+ "#,
+ expect![[r#"
+ 37..131 '{ ...t(); }': ()
+ 47..60 'mut generator': impl Iterator<Item = i8>
+ 63..93 'gen { ... }': impl Iterator<Item = i8>
+ 77..86 'yield 1i8': ()
+ 83..86 '1i8': i8
+ 103..109 'result': Option<i8>
+ 112..121 'generator': impl Iterator<Item = i8>
+ 112..128 'genera...next()': Option<i8>
+ "#]],
+ );
+}
+
+#[test]
+fn async_gen_block_types_inferred() {
+ check_infer(
+ r#"
+//- minicore: async_iterator, option, future, deref, pin
+use core::async_iter::AsyncIterator;
+use core::pin::Pin;
+use core::task::Context;
+
+fn test(mut cx: Context<'_>) {
+ let mut generator = async gen {
+ yield 1i8;
+ };
+ let result = Pin::new(&mut generator).poll_next(&mut cx);
+}
+ "#,
+ expect![[r#"
+ 91..97 'mut cx': Context<'?>
+ 112..239 '{ ...cx); }': ()
+ 122..135 'mut generator': impl AsyncIterator<Item = {unknown}>
+ 138..174 'async ... }': impl AsyncIterator<Item = {unknown}>
+ 158..167 'yield 1i8': ()
+ 164..167 '1i8': i8
+ 184..190 'result': Poll<Option<{unknown}>>
+ 193..201 'Pin::new': fn new<&'? mut impl AsyncIterator<Item = {unknown}>>(&'? mut impl AsyncIterator<Item = {unknown}>) -> Pin<&'? mut impl AsyncIterator<Item = {unknown}>>
+ 193..217 'Pin::n...rator)': Pin<&'? mut impl AsyncIterator<Item = {unknown}>>
+ 193..236 'Pin::n...ut cx)': Poll<Option<{unknown}>>
+ 202..216 '&mut generator': &'? mut impl AsyncIterator<Item = {unknown}>
+ 207..216 'generator': impl AsyncIterator<Item = {unknown}>
+ 228..235 '&mut cx': &'? mut Context<'?>
+ 233..235 'cx': Context<'?>
+ "#]],
+ );
+}
+
+#[test]
+fn gen_fn_types_inferred() {
+ check_infer(
+ r#"
+//- minicore: iterator, deref
+use core::iter::Iterator;
+
+gen fn html() {
+ yield ();
+}
+
+fn test() {
+ let mut generator = html();
+ let result = generator.next();
+}
+ "#,
+ expect![[r#"
+ 41..58 '{ ... (); }': ()
+ 47..55 'yield ()': ()
+ 53..55 '()': ()
+ 70..140 '{ ...t(); }': ()
+ 80..93 'mut generator': impl Iterator<Item = ()>
+ 96..100 'html': fn html() -> impl Iterator<Item = ()>
+ 96..102 'html()': impl Iterator<Item = ()>
+ 112..118 'result': Option<()>
+ 121..130 'generator': impl Iterator<Item = ()>
+ 121..137 'genera...next()': Option<()>
+ "#]],
+ );
+}
+
+#[test]
+fn async_gen_fn_types_inferred() {
+ check_infer(
+ r#"
+//- minicore: async_iterator, option, future, deref, pin
+use core::async_iter::AsyncIterator;
+use core::pin::Pin;
+use core::task::Context;
+
+async gen fn html() {
+ yield ();
+}
+
+fn test(mut cx: Context<'_>) {
+ let mut generator = html();
+ let result = Pin::new(&mut generator).poll_next(&mut cx);
+}
+ "#,
+ expect![[r#"
+ 103..120 '{ ... (); }': ()
+ 109..117 'yield ()': ()
+ 115..117 '()': ()
+ 130..136 'mut cx': Context<'?>
+ 151..248 '{ ...cx); }': ()
+ 161..174 'mut generator': impl AsyncIterator<Item = ()>
+ 177..181 'html': fn html() -> impl AsyncIterator<Item = ()>
+ 177..183 'html()': impl AsyncIterator<Item = ()>
+ 193..199 'result': Poll<Option<()>>
+ 202..210 'Pin::new': fn new<&'? mut impl AsyncIterator<Item = ()>>(&'? mut impl AsyncIterator<Item = ()>) -> Pin<&'? mut impl AsyncIterator<Item = ()>>
+ 202..226 'Pin::n...rator)': Pin<&'? mut impl AsyncIterator<Item = ()>>
+ 202..245 'Pin::n...ut cx)': Poll<Option<()>>
+ 211..225 '&mut generator': &'? mut impl AsyncIterator<Item = ()>
+ 216..225 'generator': impl AsyncIterator<Item = ()>
+ 237..244 '&mut cx': &'? mut Context<'?>
+ 242..244 'cx': Context<'?>
+ "#]],
+ );
+}
+
+#[test]
fn tuple_pattern_nested_match_ergonomics() {
check_no_mismatches(
r#"
diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs
index 278666ef35..bcb5e5de16 100644
--- a/crates/hir-ty/src/tests/traits.rs
+++ b/crates/hir-ty/src/tests/traits.rs
@@ -124,6 +124,25 @@ async fn test() {
}
#[test]
+fn infer_async_gen_closure() {
+ check(
+ r#"
+//- minicore: async_iterator, fn, add, builtin_impls
+//- /main.rs edition:2024
+fn test() {
+ let f = async gen move |x: i32| {
+ yield x + 42;
+ //^^^^^^ expected Poll<Option<{unknown}>>, got i32
+ };
+ let a = f(4);
+ a;
+// ^ type: impl AsyncIterator<Item = {unknown}>
+}
+"#,
+ );
+}
+
+#[test]
fn auto_sized_async_block() {
check_no_mismatches(
r#"
diff --git a/crates/hir/src/display.rs b/crates/hir/src/display.rs
index 53f24713cd..139f078eef 100644
--- a/crates/hir/src/display.rs
+++ b/crates/hir/src/display.rs
@@ -163,6 +163,9 @@ fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Re
if data.is_async() {
f.write_str("async ")?;
}
+ if data.is_gen() {
+ f.write_str("gen ")?;
+ }
// FIXME: This will show `unsafe` for functions that are `#[target_feature]` but not unsafe
// (they are conditionally unsafe to call). We probably should show something else.
if func.is_unsafe_to_call(db, None, f.edition()) {
@@ -223,7 +226,7 @@ fn write_function<'db>(f: &mut HirFormatter<'_, 'db>, func_id: FunctionId) -> Re
// `FunctionData::ret_type` will be `::core::future::Future<Output = ...>` for async fns.
// Use ugly pattern match to strip the Future trait.
// Better way?
- let ret_type = if !data.is_async() {
+ let ret_type = if !data.is_async() && !data.is_gen() {
data.ret_type
} else if let Some(ret_type) = data.ret_type {
match &data.store[ret_type] {
diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs
index 4ed3d1c7d7..a2b317be58 100644
--- a/crates/ide/src/references.rs
+++ b/crates/ide/src/references.rs
@@ -599,7 +599,7 @@ fn main() {
false,
false,
expect![[r#"
- Some Variant FileId(1) 5999..6031 6024..6028
+ Some Variant FileId(1) 6022..6054 6047..6051
FileId(0) 46..50
"#]],
diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs
index 25c2e3f733..c0053a3f21 100644
--- a/crates/intern/src/symbol/symbols.rs
+++ b/crates/intern/src/symbol/symbols.rs
@@ -128,6 +128,9 @@ define_symbols! {
as_str,
asm,
assert,
+ async_iter,
+ async_iterator,
+ AsyncIterator,
attr,
attributes,
begin_panic,
@@ -304,7 +307,6 @@ define_symbols! {
Iterator,
iterator,
fused_iterator,
- async_iterator,
keyword,
lang,
lang_items,
diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs
index 4b3e7c4673..e9ab066160 100644
--- a/crates/test-utils/src/minicore.rs
+++ b/crates/test-utils/src/minicore.rs
@@ -11,6 +11,7 @@
//! add:
//! asm:
//! assert:
+//! async_iterator: option, future, pin
//! as_mut: sized
//! as_ref: sized
//! async_fn: fn, tuple, future, copy
@@ -1531,6 +1532,7 @@ pub mod slice {
// region:option
pub mod option {
+ #[lang = "Option"]
pub enum Option<T> {
#[lang = "None"]
None,
@@ -1680,6 +1682,7 @@ pub mod future {
}
}
pub mod task {
+ #[lang = "Poll"]
pub enum Poll<T> {
#[lang = "Ready"]
Ready(T),
@@ -1693,6 +1696,22 @@ pub mod task {
}
// endregion:future
+// region:async_iterator
+pub mod async_iter {
+ use crate::{
+ pin::Pin,
+ task::{Context, Poll},
+ };
+
+ #[lang = "async_iterator"]
+ pub trait AsyncIterator {
+ type Item;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
+ }
+}
+// endregion:async_iterator
+
// region:iterator
pub mod iter {
// region:iterators