Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #22050 from ChayimFriedman2/make-asyncs-closures
internal: Represent lowered coroutines with closures
26 files changed, 483 insertions, 308 deletions
diff --git a/crates/hir-def/src/expr_store.rs b/crates/hir-def/src/expr_store.rs index 62a17168b1..75278f778b 100644 --- a/crates/hir-def/src/expr_store.rs +++ b/crates/hir-def/src/expr_store.rs @@ -642,9 +642,7 @@ impl ExpressionStore { self.walk_exprs_in_pat(*pat, &mut f); f(*expr); } - Expr::Block { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Async { statements, tail, .. } => { + Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => { for stmt in statements.iter() { match stmt { Statement::Let { initializer, else_branch, pat, .. } => { @@ -777,9 +775,7 @@ impl ExpressionStore { Expr::Let { expr, .. } => { f(*expr); } - Expr::Block { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Async { statements, tail, .. } => { + Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => { for stmt in statements.iter() { match stmt { Statement::Let { initializer, else_branch, .. } => { @@ -923,6 +919,13 @@ impl ExpressionStore { None => const { &Arena::new() }.iter(), } } + + /// The coroutine associated with a coroutine closure. + #[inline] + pub fn coroutine_for_closure(coroutine_closure: ExprId) -> ExprId { + // We keep the async closure exactly one expr before. + ExprId::from_raw(la_arena::RawIdx::from_u32(coroutine_closure.into_raw().into_u32() - 1)) + } } impl Index<ExprId> for ExpressionStore { diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs index 7fe91a3d02..8a23ea69b8 100644 --- a/crates/hir-def/src/expr_store/lower.rs +++ b/crates/hir-def/src/expr_store/lower.rs @@ -46,8 +46,9 @@ use crate::{ }, hir::{ Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind, - Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, OffsetOf, Pat, PatId, - RecordFieldPat, RecordLitField, RecordSpread, Statement, generics::GenericParams, + CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, + OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement, + generics::GenericParams, }, item_scope::BuiltinShadowMode, item_tree::FieldsShape, @@ -978,11 +979,33 @@ impl<'db> ExprCollector<'db> { *param = pat_id; } - self.alloc_expr_desugared(Expr::Async { - id: None, - statements: statements.into_boxed_slice(), - tail: Some(body), - }) + let async_ = self.async_block( + CoroutineSource::Fn, + CaptureBy::Value, + None, + statements.into_boxed_slice(), + Some(body), + ); + self.alloc_expr_desugared(async_) + } + + fn async_block( + &mut self, + source: CoroutineSource, + capture_by: CaptureBy, + id: Option<BlockId>, + statements: Box<[Statement]>, + tail: Option<ExprId>, + ) -> Expr { + let block = self.alloc_expr_desugared(Expr::Block { label: None, id, statements, tail }); + Expr::Closure { + args: Box::default(), + arg_types: Box::default(), + ret_type: None, + body: block, + closure_kind: ClosureKind::AsyncBlock { source }, + capture_by, + } } fn collect( @@ -1126,7 +1149,7 @@ impl<'db> ExprCollector<'db> { self.desugar_try_block(e, result_type) } Some(ast::BlockModifier::Unsafe(_)) => { - self.collect_block_(e, |id, statements, tail| Expr::Unsafe { + self.collect_block_(e, |_, id, statements, tail| Expr::Unsafe { id, statements, tail, @@ -1136,7 +1159,7 @@ impl<'db> ExprCollector<'db> { let label_hygiene = self.hygiene_id_for(label.syntax().text_range()); let label_id = self.collect_label(label); self.with_labeled_rib(label_id, label_hygiene, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Block { + this.collect_block_(e, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -1145,12 +1168,18 @@ impl<'db> ExprCollector<'db> { }) } Some(ast::BlockModifier::Async(_)) => { + let capture_by = + if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; self.with_label_rib(RibKind::Closure, |this| { this.with_awaitable_block(Awaitable::Yes, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Async { - id, - statements, - tail, + this.collect_block_(e, |this, id, statements, tail| { + this.async_block( + CoroutineSource::Block, + capture_by, + id, + statements, + tail, + ) }) }) }) @@ -1406,7 +1435,7 @@ impl<'db> ExprCollector<'db> { } else { Awaitable::No("non-async closure") }; - let body = this + let mut body = this .with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body())); let closure_kind = if this.is_lowering_coroutine { @@ -1417,7 +1446,22 @@ impl<'db> ExprCollector<'db> { }; ClosureKind::Coroutine(movability) } else if e.async_token().is_some() { - ClosureKind::Async + // It's important that this expr is allocated immediately before the closure. + // We rely on it for `coroutine_for_closure()`. + body = this.alloc_expr_desugared(Expr::Closure { + args: Box::default(), + arg_types: Box::default(), + ret_type: None, + body, + closure_kind: ClosureKind::AsyncBlock { + source: CoroutineSource::Closure, + }, + // The block may need to capture by move, but we cannot know it now. + // It will be fixed in capture analysis. + capture_by: CaptureBy::Ref, + }); + + ClosureKind::AsyncClosure } else { ClosureKind::Closure }; @@ -1762,7 +1806,7 @@ impl<'db> ExprCollector<'db> { let ptr = AstPtr::new(&e).upcast(); let (btail, expr_id) = self.with_labeled_rib(label, HygieneId::ROOT, |this| { let mut btail = None; - let block = this.collect_block_(e, |id, statements, tail| { + let block = this.collect_block_(e, |_, id, statements, tail| { btail = tail; Expr::Block { id, statements, tail, label: Some(label) } }); @@ -2220,7 +2264,7 @@ impl<'db> ExprCollector<'db> { } fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { - self.collect_block_(block, |id, statements, tail| Expr::Block { + self.collect_block_(block, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -2231,7 +2275,7 @@ impl<'db> ExprCollector<'db> { fn collect_block_( &mut self, block: ast::BlockExpr, - mk_block: impl FnOnce(Option<BlockId>, Box<[Statement]>, Option<ExprId>) -> Expr, + mk_block: impl FnOnce(&mut Self, Option<BlockId>, Box<[Statement]>, Option<ExprId>) -> Expr, ) -> ExprId { let block_id = self.expander.ast_id_map().ast_id_for_block(&block).map(|file_local_id| { let ast_id = self.expander.in_file(file_local_id); @@ -2266,8 +2310,8 @@ impl<'db> ExprCollector<'db> { }); let syntax_node_ptr = AstPtr::new(&block.into()); - let expr_id = self - .alloc_expr(mk_block(block_id, statements.into_boxed_slice(), tail), syntax_node_ptr); + let expr = mk_block(self, block_id, statements.into_boxed_slice(), tail); + let expr_id = self.alloc_expr(expr, syntax_node_ptr); self.def_map = prev_def_map; self.module = prev_local_module; diff --git a/crates/hir-def/src/expr_store/pretty.rs b/crates/hir-def/src/expr_store/pretty.rs index 9c9c4db3b2..71d59c904d 100644 --- a/crates/hir-def/src/expr_store/pretty.rs +++ b/crates/hir-def/src/expr_store/pretty.rs @@ -9,6 +9,7 @@ use std::{ use hir_expand::{Lookup, mod_path::PathKind}; use itertools::Itertools; use span::Edition; +use stdx::never; use syntax::ast::{HasName, RangeOp}; use crate::{ @@ -760,14 +761,31 @@ impl Printer<'_> { w!(self, "]"); } Expr::Closure { args, arg_types, ret_type, body, closure_kind, capture_by } => { + let mut body = *body; + let mut print_pipes = true; match closure_kind { ClosureKind::Coroutine(Movability::Static) => { w!(self, "static "); } - ClosureKind::Async => { + ClosureKind::AsyncClosure => { + if let Expr::Closure { + body: inner_body, + closure_kind: ClosureKind::AsyncBlock { .. }, + .. + } = self.store[body] + { + body = inner_body; + } else { + never!("async closure should always have an async block body"); + } + w!(self, "async "); } - _ => (), + ClosureKind::AsyncBlock { .. } => { + w!(self, "async "); + print_pipes = false; + } + ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (), } match capture_by { CaptureBy::Value => { @@ -775,24 +793,26 @@ impl Printer<'_> { } CaptureBy::Ref => (), } - w!(self, "|"); - for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() { - if i != 0 { - w!(self, ", "); + if print_pipes { + w!(self, "|"); + for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() { + if i != 0 { + w!(self, ", "); + } + self.print_pat(*pat); + if let Some(ty) = ty { + w!(self, ": "); + self.print_type_ref(*ty); + } } - self.print_pat(*pat); - if let Some(ty) = ty { - w!(self, ": "); - self.print_type_ref(*ty); + w!(self, "|"); + if let Some(ret_ty) = ret_type { + w!(self, " -> "); + self.print_type_ref(*ret_ty); } + self.whitespace(); } - w!(self, "|"); - if let Some(ret_ty) = ret_type { - w!(self, " -> "); - self.print_type_ref(*ret_ty); - } - self.whitespace(); - self.print_expr(*body); + self.print_expr(body); } Expr::Tuple { exprs } => { w!(self, "("); @@ -832,9 +852,6 @@ impl Printer<'_> { Expr::Unsafe { id: _, statements, tail } => { self.print_block(Some("unsafe "), statements, tail); } - Expr::Async { id: _, statements, tail } => { - self.print_block(Some("async "), statements, tail); - } Expr::Const(id) => { w!(self, "const {{ /* {id:?} */ }}"); } diff --git a/crates/hir-def/src/expr_store/scope.rs b/crates/hir-def/src/expr_store/scope.rs index 9738ac5c44..c6ba0241b7 100644 --- a/crates/hir-def/src/expr_store/scope.rs +++ b/crates/hir-def/src/expr_store/scope.rs @@ -324,7 +324,7 @@ fn compute_expr_scopes( let mut scope = scopes.root_scope(); compute_expr_scopes(scopes, *id, &mut scope); } - Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } => { + Expr::Unsafe { id, statements, tail } => { let mut scope = scopes.new_block_scope(*scope, *id, None); // Overwrite the old scope for the block expr, so that every block scope can be found // via the block itself (important for blocks that only contain items, no expressions). diff --git a/crates/hir-def/src/expr_store/tests/body.rs b/crates/hir-def/src/expr_store/tests/body.rs index 4e5f2ca893..6e711e3a38 100644 --- a/crates/hir-def/src/expr_store/tests/body.rs +++ b/crates/hir-def/src/expr_store/tests/body.rs @@ -652,12 +652,12 @@ fn async_fn_weird_param_patterns() { async fn main(&self, param1: i32, ref mut param2: i32, _: i32, param4 @ _: i32, 123: i32) {} "#, expect![[r#" - fn main(self, param1, mut param2, mut <ra@gennew>0, param4 @ _, mut <ra@gennew>1) async { - let ref mut param2 = param2; - let _ = <ra@gennew>0; - let 123 = <ra@gennew>1; - {} - }"#]], + fn main(self, param1, mut param2, mut <ra@gennew>0, param4 @ _, mut <ra@gennew>1) async move { + let ref mut param2 = param2; + let _ = <ra@gennew>0; + let 123 = <ra@gennew>1; + {} + }"#]], ) } diff --git a/crates/hir-def/src/hir.rs b/crates/hir-def/src/hir.rs index 7781a8fe54..4dd113d419 100644 --- a/crates/hir-def/src/hir.rs +++ b/crates/hir-def/src/hir.rs @@ -214,11 +214,6 @@ pub enum Expr { tail: Option<ExprId>, label: Option<LabelId>, }, - Async { - id: Option<BlockId>, - statements: Box<[Statement]>, - tail: Option<ExprId>, - }, Const(ExprId), // FIXME: Fold this into Block with an unsafe flag? Unsafe { @@ -339,7 +334,6 @@ impl Expr { | Expr::Block { .. } | Expr::Unsafe { .. } | Expr::Const(_) - | Expr::Async { .. } | Expr::If { .. } | Expr::Literal(_) | Expr::Loop { .. } @@ -534,7 +528,25 @@ pub enum InlineAsmRegOrRegClass { pub enum ClosureKind { Closure, Coroutine(Movability), - Async, + AsyncBlock { source: CoroutineSource }, + AsyncClosure, +} + +/// In the case of a coroutine created as part of an async/gen construct, +/// which kind of async/gen construct caused it to be created? +/// +/// This helps error messages but is also used to drive coercions in +/// type-checking (see #60424). +#[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)] +pub enum CoroutineSource { + /// An explicit `async`/`gen` block written by the user. + Block, + + /// An explicit `async`/`gen` closure written by the user. + Closure, + + /// The `async`/`gen` block generated as the body of an async/gen function. + Fn, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index a0fb75397a..54c4b8d3ac 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -7,7 +7,7 @@ use hir_def::{ AdtId, BuiltinDeriveImplId, CallableDefId, ConstId, ConstParamId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, FunctionId, GenericDefId, ImplId, LifetimeParamId, LocalFieldId, StaticId, TraitId, TypeAliasId, VariantId, builtin_derive::BuiltinDeriveImplMethod, - db::DefDatabase, hir::ExprId, layout::TargetDataLayout, + db::DefDatabase, expr_store::ExpressionStore, hir::ExprId, layout::TargetDataLayout, }; use la_arena::ArenaMap; use salsa::plumbing::AsId; @@ -200,12 +200,6 @@ pub trait HirDatabase: DefDatabase + std::fmt::Debug { #[salsa::interned] fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId; - #[salsa::interned] - fn intern_closure(&self, id: InternedClosure) -> InternedClosureId; - - #[salsa::interned] - fn intern_coroutine(&self, id: InternedCoroutine) -> InternedCoroutineId; - #[salsa::invoke(crate::variance::variances_of)] #[salsa::transparent] fn variances_of<'db>(&'db self, def: GenericDefId) -> VariancesOf<'db>; @@ -238,17 +232,87 @@ pub struct InternedOpaqueTyId { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct InternedClosure(pub ExpressionStoreOwnerId, pub ExprId); -#[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] pub struct InternedClosureId { pub loc: InternedClosure, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct InternedCoroutine(pub ExpressionStoreOwnerId, pub ExprId); +impl InternedClosureId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Closure, + .. + } + ), + "expected a closure, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } +} -#[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] pub struct InternedCoroutineId { - pub loc: InternedCoroutine, + pub loc: InternedClosure, +} + +impl InternedCoroutineId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Coroutine(_) + | hir_def::hir::ClosureKind::AsyncBlock { .. }, + .. + } + ), + "expected a coroutine, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } +} + +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] +#[derive(PartialOrd, Ord)] +pub struct InternedCoroutineClosureId { + pub loc: InternedClosure, +} + +impl InternedCoroutineClosureId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncClosure, + .. + } + ), + "expected a coroutine closure, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } } diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index 33d9dd538d..6706e92fc1 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -146,7 +146,7 @@ impl<'db> ExprValidator<'db> { Expr::If { .. } => { self.check_for_unnecessary_else(id, expr); } - Expr::Block { .. } | Expr::Async { .. } | Expr::Unsafe { .. } => { + Expr::Block { .. } | Expr::Unsafe { .. } => { self.validate_block(expr); } _ => {} @@ -325,10 +325,7 @@ impl<'db> ExprValidator<'db> { } fn validate_block(&mut self, expr: &Expr) { - let (Expr::Block { statements, .. } - | Expr::Async { statements, .. } - | Expr::Unsafe { statements, .. }) = expr - else { + let (Expr::Block { statements, .. } | Expr::Unsafe { statements, .. }) = expr else { return; }; let pattern_arena = Arena::new(); diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs index 09c648139c..ee33f7d158 100644 --- a/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -406,7 +406,7 @@ impl<'db> UnsafeVisitor<'db> { }); return; } - Expr::Block { statements, .. } | Expr::Async { statements, .. } => { + Expr::Block { statements, .. } => { self.walk_pats_top( statements.iter().filter_map(|statement| match statement { &Statement::Let { pat, .. } => Some(pat), diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs index 0c4e34db7d..7f1b1ecbd2 100644 --- a/crates/hir-ty/src/display.rs +++ b/crates/hir-ty/src/display.rs @@ -39,8 +39,7 @@ use rustc_apfloat::{ use rustc_ast_ir::FloatTy; use rustc_hash::FxHashSet; use rustc_type_ir::{ - AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, CoroutineClosureArgsParts, RegionKind, - Upcast, + AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, RegionKind, Upcast, inherent::{AdtDef, GenericArgs as _, IntoKind, Term as _, Ty as _, Tys as _}, }; use smallvec::SmallVec; @@ -49,7 +48,7 @@ use stdx::never; use crate::{ CallableDefId, FnAbi, ImplTraitId, InferenceResult, MemoryMap, ParamEnvAndCrate, consteval, - db::{HirDatabase, InternedClosure, InternedCoroutine}, + db::{HirDatabase, InternedClosure}, generics::generics, layout::Layout, lower::GenericPredicates, @@ -1349,7 +1348,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } let sig = interner.signature_unclosure(substs.as_closure().sig(), Safety::Safe); let sig = sig.skip_binder(); - let InternedClosure(owner, _) = db.lookup_intern_closure(id); + let InternedClosure(owner, _) = id.loc(db); let infer = InferenceResult::of(db, owner); let (_, kind) = infer.closure_info(id); match f.closure_style { @@ -1403,26 +1402,16 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } _ => (), } - let CoroutineClosureArgsParts { closure_kind_ty, signature_parts_ty, .. } = - args.split_coroutine_closure_args(); - let kind = closure_kind_ty.to_opt_closure_kind().unwrap(); + let kind = args.as_coroutine_closure().kind(); let kind = match kind { rustc_type_ir::ClosureKind::Fn => "AsyncFn", rustc_type_ir::ClosureKind::FnMut => "AsyncFnMut", rustc_type_ir::ClosureKind::FnOnce => "AsyncFnOnce", }; - let TyKind::FnPtr(coroutine_sig, _) = signature_parts_ty.kind() else { - unreachable!("invalid coroutine closure signature"); - }; + let coroutine_sig = args.as_coroutine_closure().coroutine_closure_sig(); let coroutine_sig = coroutine_sig.skip_binder(); - let coroutine_inputs = coroutine_sig.inputs(); - let TyKind::Tuple(coroutine_inputs) = coroutine_inputs[1].kind() else { - unreachable!("invalid coroutine closure signature"); - }; - let TyKind::Tuple(coroutine_output) = coroutine_sig.output().kind() else { - unreachable!("invalid coroutine closure signature"); - }; - let coroutine_output = coroutine_output.as_slice()[1]; + let coroutine_inputs = coroutine_sig.tupled_inputs_ty.tuple_fields(); + let coroutine_output = coroutine_sig.return_ty; match f.closure_style { ClosureStyle::ImplFn => write!(f, "impl {kind}(")?, ClosureStyle::RANotation => write!(f, "async |")?, @@ -1536,17 +1525,16 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } TyKind::Infer(..) => write!(f, "_")?, TyKind::Coroutine(coroutine_id, subst) => { - let InternedCoroutine(owner, expr_id) = coroutine_id.0.loc(db); + let InternedClosure(owner, expr_id) = coroutine_id.0.loc(db); let CoroutineArgsParts { resume_ty, yield_ty, return_ty, .. } = subst.split_coroutine_args(); let body = ExpressionStore::of(db, owner); let expr = &body[expr_id]; match expr { hir_def::hir::Expr::Closure { - closure_kind: hir_def::hir::ClosureKind::Async, + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. }, .. - } - | hir_def::hir::Expr::Async { .. } => { + } => { let future_trait = f.lang_items().Future; let output = future_trait.and_then(|t| { t.trait_items(db) diff --git a/crates/hir-ty/src/drop.rs b/crates/hir-ty/src/drop.rs index ddc4e4ce85..d41a06c167 100644 --- a/crates/hir-ty/src/drop.rs +++ b/crates/hir-ty/src/drop.rs @@ -133,7 +133,7 @@ fn has_drop_glue_impl<'db>( } TyKind::Slice(ty) => has_drop_glue_impl(infcx, ty, env, visited), TyKind::Closure(closure_id, subst) => { - let owner = db.lookup_intern_closure(closure_id.0).0; + let owner = closure_id.0.loc(db).0; let infer = InferenceResult::of(db, owner); let (captures, _) = infer.closure_info(closure_id.0); let env = db.trait_environment(owner); diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index ce99016470..2cb936fec3 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -6,7 +6,7 @@ use std::{iter, mem, ops::ControlFlow}; use hir_def::{ TraitId, - hir::{ClosureKind, ExprId, PatId}, + hir::{ClosureKind, CoroutineSource, ExprId, PatId}, type_ref::TypeRefId, }; use rustc_type_ir::{ @@ -19,11 +19,11 @@ use tracing::debug; use crate::{ FnAbi, - db::{InternedClosure, InternedCoroutine}, + db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId}, infer::{BreakableKind, Diverges, coerce::CoerceMany}, next_solver::{ AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, - PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind, + PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind, Tys, abi::Safety, infer::{ BoundRegionConversionTime, InferOk, InferResult, @@ -54,52 +54,47 @@ impl<'db> InferenceContext<'_, 'db> { ret_type: Option<TypeRefId>, arg_types: &[Option<TypeRefId>], closure_kind: ClosureKind, - tgt_expr: ExprId, + closure_expr: ExprId, expected: &Expectation<'db>, ) -> Ty<'db> { assert_eq!(args.len(), arg_types.len()); let interner = self.interner(); + // It's always helpful for inference if we know the kind of + // closure sooner rather than later, so first examine the expected + // type, and see if can glean a closure kind from there. let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { - Some(expected_ty) => self.deduce_closure_signature(expected_ty, closure_kind), + Some(ty) => { + let ty = self.table.try_structurally_resolve_type(ty); + self.deduce_closure_signature(ty, closure_kind) + } None => (None, None), }; - let ClosureSignatures { bound_sig, liberated_sig } = + let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure(arg_types, ret_type, expected_sig); - let body_ret_ty = bound_sig.output().skip_binder(); - let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); - // FIXME: Make this an infer var and infer it later. - let tupled_upvars_ty = self.types.types.unit; - let (id, ty, resume_yield_tys) = match closure_kind { - ClosureKind::Coroutine(_) => { - let yield_ty = self.table.next_ty_var(); - let resume_ty = - liberated_sig.inputs().first().copied().unwrap_or(self.types.types.unit); + debug!(?bound_sig, ?liberated_sig); - // FIXME: Infer the upvars later. - let parts = CoroutineArgsParts { - parent_args: parent_args.as_slice(), - kind_ty: self.types.types.unit, - resume_ty, - yield_ty, - return_ty: body_ret_ty, - tupled_upvars_ty, - }; + let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); - let coroutine_id = - self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); - let coroutine_ty = Ty::new_coroutine( - interner, - coroutine_id, - CoroutineArgs::new(interner, parts).args, - ); + // FIXME: Do this when we infer closures correctly: + // let tupled_upvars_ty = self.table.next_ty_var(); + let tupled_upvars_ty = self.types.types.unit; - (None, coroutine_ty, Some((resume_ty, yield_ty))) - } + let mut current_closure_id = None; + // FIXME: We could probably actually just unify this further -- + // instead of having a `FnSig` and a `Option<CoroutineTypes>`, + // we can have a `ClosureSignature { Coroutine { .. }, Closure { .. } }`, + // similar to how `ty::GenSig` is a distinct data structure. + let (closure_ty, resume_yield_tys) = match closure_kind { ClosureKind::Closure => { - let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr)); + let closure_id = + InternedClosureId::new(self.db, InternedClosure(self.owner, closure_expr)); + current_closure_id = Some(closure_id); + self.deferred_closures.entry(closure_id).or_default(); + self.add_current_closure_dependency(closure_id); + match expected_kind { Some(kind) => { self.result.closure_info.insert( @@ -116,6 +111,9 @@ impl<'db> InferenceContext<'_, 'db> { } None => {} }; + + // Tuple up the arguments and insert the resulting function type into + // the `closures` table. let sig = bound_sig.map_bound(|sig| { interner.mk_fn_sig( [Ty::new_tup(interner, sig.inputs())], @@ -125,52 +123,101 @@ impl<'db> InferenceContext<'_, 'db> { sig.abi, ) }); - let sig_ty = Ty::new_fn_ptr(interner, sig); - // FIXME: Infer the kind later if needed. - let parts = ClosureArgsParts { - parent_args: parent_args.as_slice(), - closure_kind_ty: Ty::from_closure_kind( - interner, - expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), - ), - closure_sig_as_fn_ptr_ty: sig_ty, - tupled_upvars_ty, + + debug!(?sig, ?expected_kind); + + let closure_kind_ty = match expected_kind { + Some(kind) => Ty::from_closure_kind(interner, kind), + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + // FIXME: This too should be the next line: + // None => self.table.next_ty_var(), + None => self.types.types.i8, }; - let closure_ty = Ty::new_closure( + + let closure_args = ClosureArgs::new( interner, - closure_id.into(), - ClosureArgs::new(interner, parts).args, + ClosureArgsParts { + parent_args: parent_args.as_slice(), + closure_kind_ty, + closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(interner, sig), + tupled_upvars_ty, + }, ); - self.deferred_closures.entry(closure_id).or_default(); - self.add_current_closure_dependency(closure_id); - (Some(closure_id), closure_ty, None) + + (Ty::new_closure(interner, closure_id.into(), closure_args.args), None) } - ClosureKind::Async => { - // async closures always return the type ascribed after the `->` (if present), - // and yield `()`. - let bound_return_ty = bound_sig.skip_binder().output(); - let bound_yield_ty = self.types.types.unit; - // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. - let resume_ty = self.types.types.unit; + ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => { + let yield_ty = match closure_kind { + ClosureKind::Coroutine(_) => self.table.next_ty_var(), + ClosureKind::AsyncBlock { .. } => self.types.types.unit, + _ => unreachable!(), + }; + + // Resume type defaults to `()` if the coroutine has no argument. + let resume_ty = + liberated_sig.inputs().first().copied().unwrap_or(self.types.types.unit); - // FIXME: Infer the kind later if needed. - let closure_kind_ty = Ty::from_closure_kind( + // Coroutines that come from coroutine closures have not yet determined + // their kind ty, so make a fresh infer var which will be constrained + // later during upvar analysis. Regular coroutines always have the kind + // ty of `().` + let kind_ty = match closure_kind { + ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => { + self.table.next_ty_var() + } + _ => self.types.types.unit, + }; + + let coroutine_args = CoroutineArgs::new( interner, - expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), + CoroutineArgsParts { + parent_args: parent_args.as_slice(), + kind_ty, + resume_ty, + yield_ty, + return_ty: liberated_sig.output(), + tupled_upvars_ty, + }, ); - // FIXME: Infer captures later. - // `for<'env> fn() -> ()`, for no captures. + let coroutine_id = + InternedCoroutineId::new(self.db, InternedClosure(self.owner, closure_expr)); + + ( + Ty::new_coroutine(interner, coroutine_id.into(), coroutine_args.args), + Some((resume_ty, yield_ty)), + ) + } + ClosureKind::AsyncClosure => { + // async closures always return the type ascribed after the `->` (if present), + // and yield `()`. + let (bound_return_ty, bound_yield_ty) = + (bound_sig.skip_binder().output(), self.types.types.unit); + // Compute all of the variables that will be used to populate the coroutine. + let resume_ty = self.table.next_ty_var(); + + let closure_kind_ty = match expected_kind { + Some(kind) => Ty::from_closure_kind(interner, kind), + + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + // FIXME: Here again the next line should be active. + // None => self.table.next_ty_var(), + None => self.types.types.i8, + }; + + // FIXME: Another line that should be enabled. + // let coroutine_captures_by_ref_ty = self.table.next_ty_var(); let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( interner, Binder::bind_with_vars( - interner.mk_fn_sig( - [], - self.types.types.unit, - false, - Safety::Safe, - FnAbi::Rust, - ), + FnSig { + inputs_and_output: Tys::new_from_slice(&[self.types.types.unit]), + c_variadic: false, + safety: Safety::Safe, + abi: FnAbi::Rust, + }, self.types.coroutine_captures_by_ref_bound_var_kinds, ), ); @@ -183,7 +230,13 @@ impl<'db> InferenceContext<'_, 'db> { interner, bound_sig.map_bound(|sig| { interner.mk_fn_sig( - [resume_ty, Ty::new_tup(interner, sig.inputs())], + [ + resume_ty, + Ty::new_tup_from_iter( + interner, + sig.inputs().iter().copied(), + ), + ], Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), sig.c_variadic, sig.safety, @@ -196,9 +249,57 @@ impl<'db> InferenceContext<'_, 'db> { }, ); - let coroutine_id = - self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); - (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None) + let coroutine_kind_ty = match expected_kind { + Some(kind) => Ty::from_coroutine_closure_kind(interner, kind), + + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + // FIXME: And here again. + // None => self.table.next_ty_var(), + None => self.types.types.i16, + }; + + let coroutine_upvars_ty = self.table.next_ty_var(); + + let coroutine_closure_id = InternedCoroutineClosureId::new( + self.db, + InternedClosure(self.owner, closure_expr), + ); + + // We need to turn the liberated signature that we got from HIR, which + // looks something like `|Args...| -> T`, into a signature that is suitable + // for type checking the inner body of the closure, which always returns a + // coroutine. To do so, we use the `CoroutineClosureSignature` to compute + // the coroutine type, filling in the tupled_upvars_ty and kind_ty with infer + // vars which will get constrained during upvar analysis. + let coroutine_output_ty = closure_args + .coroutine_closure_sig() + .map_bound(|sig| { + sig.to_coroutine( + interner, + parent_args.as_slice(), + coroutine_kind_ty, + interner.coroutine_for_closure(coroutine_closure_id.into()), + coroutine_upvars_ty, + ) + }) + .skip_binder(); + liberated_sig = interner.mk_fn_sig( + liberated_sig.inputs().iter().copied(), + coroutine_output_ty, + liberated_sig.c_variadic, + liberated_sig.safety, + liberated_sig.abi, + ); + + ( + Ty::new_coroutine_closure( + interner, + coroutine_closure_id.into(), + closure_args.args, + ), + None, + ) } }; @@ -209,9 +310,10 @@ impl<'db> InferenceContext<'_, 'db> { // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); - let prev_closure = mem::replace(&mut self.current_closure, id); - let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty); - let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty)); + let prev_closure = mem::replace(&mut self.current_closure, current_closure_id); + let prev_ret_ty = mem::replace(&mut self.return_ty, liberated_sig.output()); + let prev_ret_coercion = + self.return_coercion.replace(CoerceMany::new(liberated_sig.output())); let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { @@ -224,7 +326,7 @@ impl<'db> InferenceContext<'_, 'db> { self.current_closure = prev_closure; self.resume_yield_tys = prev_resume_yield_tys; - ty + closure_ty } fn fn_trait_kind_from_def_id(&self, trait_id: TraitId) -> Option<rustc_type_ir::ClosureKind> { @@ -293,7 +395,9 @@ impl<'db> InferenceContext<'_, 'db> { let expected_sig = sig_tys.with(hdr); (Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn)) } - ClosureKind::Coroutine(_) | ClosureKind::Async => (None, None), + ClosureKind::Coroutine(_) + | ClosureKind::AsyncClosure + | ClosureKind::AsyncBlock { .. } => (None, None), }, _ => (None, None), } @@ -406,7 +510,7 @@ impl<'db> InferenceContext<'_, 'db> { if let Some(trait_def_id) = trait_def_id { let found_kind = match closure_kind { ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id), - ClosureKind::Async => self + ClosureKind::AsyncClosure => self .async_fn_trait_kind_from_def_id(trait_def_id) .or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)), _ => None, @@ -452,13 +556,13 @@ impl<'db> InferenceContext<'_, 'db> { ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection(projection) } - ClosureKind::Async if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { + ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { self.extract_sig_from_projection(projection) } // It's possible we've passed the closure to a (somewhat out-of-fashion) // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still // guide inference here, since it's beneficial for the user. - ClosureKind::Async if Some(def_id) == self.lang_items.FnOnceOutput => { + ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection_and_future_bound(projection) } _ => None, diff --git a/crates/hir-ty/src/infer/closure/analysis.rs b/crates/hir-ty/src/infer/closure/analysis.rs index ce0ccfe82f..2d999b596b 100644 --- a/crates/hir-ty/src/infer/closure/analysis.rs +++ b/crates/hir-ty/src/infer/closure/analysis.rs @@ -609,9 +609,7 @@ impl<'db> InferenceContext<'_, 'db> { self.consume_expr(expr); } } - Expr::Async { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Block { statements, tail, .. } => { + Expr::Unsafe { statements, tail, .. } | Expr::Block { statements, tail, .. } => { for s in statements.iter() { match s { Statement::Let { pat, type_ref: _, initializer, else_branch } => { @@ -755,7 +753,7 @@ impl<'db> InferenceContext<'_, 'db> { Expr::Closure { .. } => { let ty = self.expr_ty(tgt_expr); let TyKind::Closure(id, _) = ty.kind() else { - never!("closure type is always closure"); + // A coroutine or a coroutine closure. return; }; let (captures, _) = @@ -876,7 +874,7 @@ impl<'db> InferenceContext<'_, 'db> { fn is_upvar(&self, place: &HirPlace) -> bool { if let Some(c) = self.current_closure { - let InternedClosure(_, root) = self.db.lookup_intern_closure(c); + let InternedClosure(_, root) = c.loc(self.db); return self.store.is_binding_upvar(place.local, root); } false @@ -1139,7 +1137,7 @@ impl<'db> InferenceContext<'_, 'db> { } fn analyze_closure(&mut self, closure: InternedClosureId) -> FnTrait { - let InternedClosure(_, root) = self.db.lookup_intern_closure(closure); + let InternedClosure(_, root) = closure.loc(self.db); self.current_closure = Some(closure); let Expr::Closure { body, capture_by, .. } = &self.store[root] else { unreachable!("Closure expression id is always closure"); diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index ee34a30eba..06615cb691 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -17,7 +17,7 @@ use hir_def::{FunctionId, hir::ClosureKind}; use hir_expand::name::Name; use rustc_ast_ir::Mutability; use rustc_type_ir::{ - CoroutineArgs, CoroutineArgsParts, InferTy, Interner, + InferTy, Interner, inherent::{AdtDef, GenericArgs as _, IntoKind, Ty as _}, }; use syntax::ast::RangeOp; @@ -27,7 +27,6 @@ use crate::{ Adjust, Adjustment, CallableDefId, DeclContext, DeclOrigin, Rawness, autoderef::InferenceContextAutoderef, consteval, - db::InternedCoroutine, generics::generics, infer::{ AllowTwoPhase, BreakableKind, coerce::CoerceMany, find_continuable, @@ -244,7 +243,6 @@ impl<'db> InferenceContext<'_, 'db> { | Expr::Assignment { .. } | Expr::Yield { .. } | Expr::Cast { .. } - | Expr::Async { .. } | Expr::Unsafe { .. } | Expr::Await { .. } | Expr::Ref { .. } @@ -390,9 +388,6 @@ impl<'db> InferenceContext<'_, 'db> { }) .1 } - Expr::Async { id: _, statements, tail } => { - self.infer_async_block(tgt_expr, statements, tail) - } &Expr::Loop { body, label } => { // FIXME: should be: // let ty = expected.coercion_target_type(&mut self.table); @@ -1185,72 +1180,6 @@ impl<'db> InferenceContext<'_, 'db> { oprnd_t } - fn infer_async_block( - &mut self, - tgt_expr: ExprId, - statements: &[Statement], - tail: &Option<ExprId>, - ) -> Ty<'db> { - let ret_ty = self.table.next_ty_var(); - let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); - let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty); - let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(ret_ty)); - - // FIXME: We should handle async blocks like we handle closures - let expected = &Expectation::has_type(ret_ty); - let (_, inner_ty) = self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { - let ty = this.infer_block(tgt_expr, statements, *tail, None, expected); - if let Some(target) = expected.only_has_type(&mut this.table) { - match this.coerce(tgt_expr.into(), ty, target, AllowTwoPhase::No, ExprIsRead::Yes) { - Ok(res) => res, - Err(_) => { - this.result.type_mismatches.get_or_insert_default().insert( - tgt_expr.into(), - TypeMismatch { expected: target.store(), actual: ty.store() }, - ); - target - } - } - } else { - ty - } - }); - - self.diverges = prev_diverges; - self.return_ty = prev_ret_ty; - self.return_coercion = prev_ret_coercion; - - self.lower_async_block_type_impl_trait(inner_ty, tgt_expr) - } - - pub(crate) fn lower_async_block_type_impl_trait( - &mut self, - inner_ty: Ty<'db>, - tgt_expr: ExprId, - ) -> Ty<'db> { - let coroutine_id = InternedCoroutine(self.owner, tgt_expr); - let coroutine_id = self.db.intern_coroutine(coroutine_id).into(); - let parent_args = GenericArgs::identity_for_item(self.interner(), self.generic_def.into()); - Ty::new_coroutine( - self.interner(), - coroutine_id, - CoroutineArgs::new( - self.interner(), - CoroutineArgsParts { - parent_args: parent_args.as_slice(), - kind_ty: self.types.types.unit, - // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. - resume_ty: self.types.types.unit, - yield_ty: self.types.types.unit, - return_ty: inner_ty, - // FIXME: Infer upvars. - tupled_upvars_ty: self.types.types.unit, - }, - ) - .args, - ) - } - pub(crate) fn write_fn_trait_method_resolution( &mut self, fn_x: FnTrait, diff --git a/crates/hir-ty/src/infer/mutability.rs b/crates/hir-ty/src/infer/mutability.rs index bfe43fc928..5aba123435 100644 --- a/crates/hir-ty/src/infer/mutability.rs +++ b/crates/hir-ty/src/infer/mutability.rs @@ -86,7 +86,6 @@ impl<'db> InferenceContext<'_, 'db> { } Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)), Expr::Block { id: _, statements, tail, label: _ } - | Expr::Async { id: _, statements, tail } | Expr::Unsafe { id: _, statements, tail } => { for st in statements.iter() { match st { diff --git a/crates/hir-ty/src/layout.rs b/crates/hir-ty/src/layout.rs index 54332122d0..4ba39b1b45 100644 --- a/crates/hir-ty/src/layout.rs +++ b/crates/hir-ty/src/layout.rs @@ -332,7 +332,7 @@ pub fn layout_of_ty_query( Layout::scalar(dl, ptr) } TyKind::Closure(id, args) => { - let def = db.lookup_intern_closure(id.0); + let def = id.0.loc(db); let infer = InferenceResult::of(db, def.0); let (captures, _) = infer.closure_info(id.0); let fields = captures diff --git a/crates/hir-ty/src/mir/borrowck.rs b/crates/hir-ty/src/mir/borrowck.rs index 3ff2db15aa..d843359dcb 100644 --- a/crates/hir-ty/src/mir/borrowck.rs +++ b/crates/hir-ty/src/mir/borrowck.rs @@ -121,7 +121,7 @@ fn make_fetch_closure_field<'db>( db: &'db dyn HirDatabase, ) -> impl FnOnce(InternedClosureId, GenericArgs<'db>, usize) -> Ty<'db> + use<'db> { |c: InternedClosureId, subst: GenericArgs<'db>, f: usize| { - let InternedClosure(owner, _) = db.lookup_intern_closure(c); + let InternedClosure(owner, _) = c.loc(db); let interner = DbInterner::new_no_crate(db); let infer = InferenceResult::of(db, owner); let (captures, _) = infer.closure_info(c); diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 88376f14d1..79b1c5cb7c 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -736,7 +736,7 @@ impl<'db> Evaluator<'db> { self.param_env.param_env, ty, |c, subst, f| { - let InternedClosure(owner, _) = self.db.lookup_intern_closure(c); + let InternedClosure(owner, _) = c.loc(self.db); let infer = InferenceResult::of(self.db, owner); let (captures, _) = infer.closure_info(c); let parent_subst = subst.as_closure().parent_args(); diff --git a/crates/hir-ty/src/mir/eval/shim.rs b/crates/hir-ty/src/mir/eval/shim.rs index 2aed76ec90..1f3cee2a03 100644 --- a/crates/hir-ty/src/mir/eval/shim.rs +++ b/crates/hir-ty/src/mir/eval/shim.rs @@ -152,7 +152,7 @@ impl<'db> Evaluator<'db> { not_supported!("wrong arg count for clone"); }; let addr = Address::from_bytes(arg.get(self)?)?; - let InternedClosure(owner, _) = self.db.lookup_intern_closure(id.0); + let InternedClosure(owner, _) = id.0.loc(self.db); let infer = InferenceResult::of(self.db, owner); let (captures, _) = infer.closure_info(id.0); let layout = self.layout(self_ty)?; diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 44785d948a..d044019629 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -8,8 +8,9 @@ use hir_def::{ HasModule, ItemContainerId, LocalFieldId, Lookup, TraitId, TupleId, expr_store::{Body, ExpressionStore, HygieneId, path::Path}, hir::{ - ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, - Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, generics::GenericParams, + ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ClosureKind, ExprId, LabelId, + Literal, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, + generics::GenericParams, }, item_tree::FieldsShape, lang_item::LangItems, @@ -956,7 +957,6 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } Expr::Await { .. } => not_supported!("await"), Expr::Yeet { .. } => not_supported!("yeet"), - Expr::Async { .. } => not_supported!("async block"), &Expr::Const(_) => { // let subst = self.placeholder_subst(); // self.lower_const( @@ -1245,7 +1245,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { ); Ok(Some(current)) } - Expr::Closure { .. } => { + Expr::Closure { closure_kind: ClosureKind::Closure, .. } => { let ty = self.expr_ty_without_adjust(expr_id); let TyKind::Closure(id, _) = ty.kind() else { not_supported!("closure with non closure type"); @@ -1304,6 +1304,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { ); Ok(Some(current)) } + Expr::Closure { closure_kind, .. } => not_supported!("{closure_kind:?} closure"), Expr::Tuple { exprs } => { let Some(values) = exprs .iter() @@ -2110,7 +2111,7 @@ pub fn mir_body_for_closure_query<'db>( db: &'db dyn HirDatabase, closure: InternedClosureId, ) -> Result<'db, Arc<MirBody>> { - let InternedClosure(owner, expr) = db.lookup_intern_closure(closure); + let InternedClosure(owner, expr) = closure.loc(db); let body_owner = owner.as_def_with_body().expect("MIR lowering should only happen for body-owned closures"); let body = Body::of(db, body_owner); diff --git a/crates/hir-ty/src/next_solver/def_id.rs b/crates/hir-ty/src/next_solver/def_id.rs index 00161d6d08..542eca3ded 100644 --- a/crates/hir-ty/src/next_solver/def_id.rs +++ b/crates/hir-ty/src/next_solver/def_id.rs @@ -12,7 +12,9 @@ use hir_def::{ use rustc_type_ir::inherent; use stdx::impl_from; -use crate::db::{InternedClosureId, InternedCoroutineId, InternedOpaqueTyId}; +use crate::db::{ + InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId, InternedOpaqueTyId, +}; use super::DbInterner; @@ -35,6 +37,7 @@ pub enum SolverDefId { TypeAliasId(TypeAliasId), InternedClosureId(InternedClosureId), InternedCoroutineId(InternedCoroutineId), + InternedCoroutineClosureId(InternedCoroutineClosureId), InternedOpaqueTyId(InternedOpaqueTyId), EnumVariantId(EnumVariantId), Ctor(Ctor), @@ -80,6 +83,9 @@ impl std::fmt::Debug for SolverDefId { SolverDefId::InternedCoroutineId(id) => { f.debug_tuple("InternedCoroutineId").field(&id).finish() } + SolverDefId::InternedCoroutineClosureId(id) => { + f.debug_tuple("InternedCoroutineClosureId").field(&id).finish() + } SolverDefId::InternedOpaqueTyId(id) => { f.debug_tuple("InternedOpaqueTyId").field(&id).finish() } @@ -123,6 +129,7 @@ impl_from!( TypeAliasId, InternedClosureId, InternedCoroutineId, + InternedCoroutineClosureId, InternedOpaqueTyId, EnumVariantId, Ctor @@ -206,6 +213,7 @@ impl TryFrom<SolverDefId> for AttrDefId { SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::AnonConstId(_) => Err(()), } @@ -229,6 +237,7 @@ impl TryFrom<SolverDefId> for DefWithBodyId { | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::Ctor(Ctor::Struct(_)) | SolverDefId::AnonConstId(_) | SolverDefId::AdtId(_) => return Err(()), @@ -251,6 +260,7 @@ impl TryFrom<SolverDefId> for GenericDefId { SolverDefId::TypeAliasId(type_alias_id) => GenericDefId::TypeAliasId(type_alias_id), SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) | SolverDefId::BuiltinDeriveImplId(_) @@ -348,6 +358,7 @@ declare_id_wrapper!(TraitIdWrapper, TraitId); declare_id_wrapper!(TypeAliasIdWrapper, TypeAliasId); declare_id_wrapper!(ClosureIdWrapper, InternedClosureId); declare_id_wrapper!(CoroutineIdWrapper, InternedCoroutineId); +declare_id_wrapper!(CoroutineClosureIdWrapper, InternedCoroutineClosureId); declare_id_wrapper!(AdtIdWrapper, AdtId); #[derive(Clone, Copy, PartialEq, Eq, Hash)] diff --git a/crates/hir-ty/src/next_solver/interner.rs b/crates/hir-ty/src/next_solver/interner.rs index 622648bc8d..4f30fc7a89 100644 --- a/crates/hir-ty/src/next_solver/interner.rs +++ b/crates/hir-ty/src/next_solver/interner.rs @@ -38,14 +38,14 @@ use rustc_type_ir::{ use crate::{ FnAbi, - db::{HirDatabase, InternedCoroutine, InternedCoroutineId}, + db::{HirDatabase, InternedClosure, InternedCoroutineId}, lower::GenericPredicates, method_resolution::TraitImpls, next_solver::{ AdtIdWrapper, AnyImplId, BoundConst, CallableIdWrapper, CanonicalVarKind, ClosureIdWrapper, - CoroutineIdWrapper, Ctor, FnSig, FxIndexMap, GeneralConstIdWrapper, OpaqueTypeKey, - RegionAssumptions, SimplifiedType, SolverContext, SolverDefIds, TraitIdWrapper, - TypeAliasIdWrapper, UnevaluatedConst, + CoroutineClosureIdWrapper, CoroutineIdWrapper, Ctor, FnSig, FxIndexMap, + GeneralConstIdWrapper, OpaqueTypeKey, RegionAssumptions, SimplifiedType, SolverContext, + SolverDefIds, TraitIdWrapper, TypeAliasIdWrapper, UnevaluatedConst, util::{explicit_item_bounds, explicit_item_self_bounds}, }, }; @@ -1022,7 +1022,7 @@ impl<'db> Interner for DbInterner<'db> { type ForeignId = TypeAliasIdWrapper; type FunctionId = CallableIdWrapper; type ClosureId = ClosureIdWrapper; - type CoroutineClosureId = CoroutineIdWrapper; + type CoroutineClosureId = CoroutineClosureIdWrapper; type CoroutineId = CoroutineIdWrapper; type AdtId = AdtIdWrapper; type ImplId = AnyImplId; @@ -1198,6 +1198,7 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::AnonConstId(_) => { return VariancesOf::empty(self); } @@ -1315,10 +1316,13 @@ impl<'db> Interner for DbInterner<'db> { SolverDefId::TypeAliasId(it) => it.lookup(self.db()).container, SolverDefId::ConstId(it) => it.lookup(self.db()).container, SolverDefId::InternedClosureId(it) => { - return self.db().lookup_intern_closure(it).0.generic_def(self.db()).into(); + return it.loc(self.db).0.generic_def(self.db()).into(); } SolverDefId::InternedCoroutineId(it) => { - return self.db().lookup_intern_coroutine(it).0.generic_def(self.db()).into(); + return it.loc(self.db).0.generic_def(self.db()).into(); + } + SolverDefId::InternedCoroutineClosureId(it) => { + return it.loc(self.db).0.generic_def(self.db()).into(); } SolverDefId::StaticId(_) | SolverDefId::AdtId(_) @@ -1356,7 +1360,7 @@ impl<'db> Interner for DbInterner<'db> { fn coroutine_movability(self, def_id: Self::CoroutineId) -> rustc_ast_ir::Movability { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); let expr = &store[expr_id]; match *expr { @@ -1365,16 +1369,17 @@ impl<'db> Interner for DbInterner<'db> { hir_def::hir::Movability::Static => rustc_ast_ir::Movability::Static, hir_def::hir::Movability::Movable => rustc_ast_ir::Movability::Movable, }, - hir_def::hir::ClosureKind::Async => rustc_ast_ir::Movability::Static, + hir_def::hir::ClosureKind::AsyncBlock { .. } => rustc_ast_ir::Movability::Static, _ => panic!("unexpected expression for a coroutine: {expr:?}"), }, - hir_def::hir::Expr::Async { .. } => rustc_ast_ir::Movability::Static, _ => panic!("unexpected expression for a coroutine: {expr:?}"), } } fn coroutine_for_closure(self, def_id: Self::CoroutineClosureId) -> Self::CoroutineId { - def_id + let InternedClosure(owner, coroutine_closure_expr) = def_id.0.loc(self.db); + let coroutine_expr = ExpressionStore::coroutine_for_closure(coroutine_closure_expr); + InternedCoroutineId::new(self.db, InternedClosure(owner, coroutine_expr)).into() } fn generics_require_sized_self(self, def_id: Self::DefId) -> bool { @@ -1763,6 +1768,7 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::StaticId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) | SolverDefId::AnonConstId(_) @@ -1976,7 +1982,7 @@ impl<'db> Interner for DbInterner<'db> { fn is_general_coroutine(self, def_id: Self::CoroutineId) -> bool { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); matches!( store[expr_id], @@ -1990,12 +1996,14 @@ impl<'db> Interner for DbInterner<'db> { fn coroutine_is_async(self, def_id: Self::CoroutineId) -> bool { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); matches!( store[expr_id], - hir_def::hir::Expr::Closure { closure_kind: hir_def::hir::ClosureKind::Async, .. } - | hir_def::hir::Expr::Async { .. } + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. }, + .. + } ) } @@ -2118,16 +2126,15 @@ impl<'db> Interner for DbInterner<'db> { body.exprs().for_each(|(expr_id, expr)| { if matches!( expr, - hir_def::hir::Expr::Async { .. } - | hir_def::hir::Expr::Closure { - closure_kind: hir_def::hir::ClosureKind::Async - | hir_def::hir::ClosureKind::Coroutine(_), - .. - } + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. } + | hir_def::hir::ClosureKind::Coroutine(_), + .. + } ) { let coroutine = InternedCoroutineId::new( self.db, - InternedCoroutine(ExpressionStoreOwnerId::Body(def_id), expr_id), + InternedClosure(ExpressionStoreOwnerId::Body(def_id), expr_id), ); result.push(coroutine.into()); } @@ -2414,6 +2421,7 @@ TrivialTypeTraversalImpls! { CallableIdWrapper, ClosureIdWrapper, CoroutineIdWrapper, + CoroutineClosureIdWrapper, AdtIdWrapper, AnyImplId, GeneralConstIdWrapper, diff --git a/crates/hir-ty/src/next_solver/ty.rs b/crates/hir-ty/src/next_solver/ty.rs index 8e892b65ea..dccb8c7936 100644 --- a/crates/hir-ty/src/next_solver/ty.rs +++ b/crates/hir-ty/src/next_solver/ty.rs @@ -27,11 +27,12 @@ use rustc_type_ir::{ use crate::{ FnAbi, - db::{HirDatabase, InternedCoroutine}, + db::{HirDatabase, InternedClosure}, lower::GenericPredicates, next_solver::{ AdtDef, AliasTy, Binder, CallableIdWrapper, Clause, ClauseKind, ClosureIdWrapper, Const, - CoroutineIdWrapper, FnSig, GenericArgKind, PolyFnSig, Region, TraitRef, TypeAliasIdWrapper, + CoroutineClosureIdWrapper, CoroutineIdWrapper, FnSig, GenericArgKind, PolyFnSig, Region, + TraitRef, TypeAliasIdWrapper, abi::Safety, impl_foldable_for_interned_slice, impl_stored_interned, interned_slice, util::{CoroutineArgsExt, IntegerTypeExt}, @@ -527,7 +528,7 @@ impl<'db> Ty<'db> { let unit_ty = Ty::new_unit(interner); let return_ty = Ty::new_coroutine( interner, - coroutine_id, + interner.coroutine_for_closure(coroutine_id), CoroutineArgs::new( interner, CoroutineArgsParts { @@ -713,7 +714,7 @@ impl<'db> Ty<'db> { } } TyKind::Coroutine(coroutine_id, _args) => { - let InternedCoroutine(owner, _) = coroutine_id.0.loc(db); + let InternedClosure(owner, _) = coroutine_id.0.loc(db); let krate = owner.krate(db); if let Some(future_trait) = hir_def::lang_item::lang_items(db, krate).Future { // This is only used by type walking. @@ -1107,7 +1108,7 @@ impl<'db> rustc_type_ir::inherent::Ty<DbInterner<'db>> for Ty<'db> { fn new_coroutine_closure( interner: DbInterner<'db>, - def_id: CoroutineIdWrapper, + def_id: CoroutineClosureIdWrapper, args: <DbInterner<'db> as Interner>::GenericArgs, ) -> Self { Ty::new(interner, TyKind::CoroutineClosure(def_id, args)) diff --git a/crates/hir-ty/src/tests/closure_captures.rs b/crates/hir-ty/src/tests/closure_captures.rs index 9e68756821..9d1a1fbd11 100644 --- a/crates/hir-ty/src/tests/closure_captures.rs +++ b/crates/hir-ty/src/tests/closure_captures.rs @@ -11,7 +11,6 @@ use test_fixture::WithFixture; use crate::{ InferenceResult, - db::HirDatabase, display::{DisplayTarget, HirDisplay}, mir::MirSpan, test_db::TestDB, @@ -42,7 +41,7 @@ fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expec let db = &db; captures_info.extend(infer.closure_info.iter().flat_map( |(closure_id, (captures, _))| { - let closure = db.lookup_intern_closure(*closure_id); + let closure = closure_id.loc(db); let body_owner = closure.0; let source_map = ExpressionStore::with_source_map(db, body_owner).1; let closure_text_range = source_map diff --git a/crates/hir/src/has_source.rs b/crates/hir/src/has_source.rs index f9badc0b79..45c9811cc0 100644 --- a/crates/hir/src/has_source.rs +++ b/crates/hir/src/has_source.rs @@ -293,7 +293,7 @@ impl HasSource for Param<'_> { .map(|value| InFile { file_id, value }) } Callee::Closure(closure, _) => { - let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure); + let InternedClosure(owner, expr_id) = closure.loc(db); let (_, source_map) = ExpressionStore::with_source_map(db, owner); let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?; let root = db.parse_or_expand(file_id); diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 7a4085c474..53240259e0 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -85,7 +85,7 @@ use hir_ty::{ GenericPredicates, InferenceResult, ParamEnvAndCrate, TyDefId, TyLoweringDiagnostic, ValueTyDefId, all_super_traits, autoderef, check_orphan_rules, consteval::try_const_usize, - db::{InternedClosureId, InternedCoroutineId}, + db::{InternedClosureId, InternedCoroutineClosureId}, diagnostics::BodyValidationDiagnostic, direct_super_traits, known_const_to_ast, layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding}, @@ -2950,7 +2950,7 @@ impl<'db> Param<'db> { } } Callee::Closure(closure, _) => { - let c = db.lookup_intern_closure(closure); + let c = closure.loc(db); let body_owner = c.0; let store = ExpressionStore::of(db, c.0); @@ -5092,7 +5092,7 @@ impl<'db> TraitRef<'db> { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum AnyClosureId { ClosureId(InternedClosureId), - CoroutineClosureId(InternedCoroutineId), + CoroutineClosureId(InternedCoroutineClosureId), } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -5131,7 +5131,7 @@ impl<'db> Closure<'db> { // FIXME: Infer coroutine closures' captures. return Vec::new(); }; - let owner = db.lookup_intern_closure(id).0; + let owner = id.loc(db).0; let infer = InferenceResult::of(db, owner); let info = infer.closure_info(id); info.0 @@ -5151,7 +5151,7 @@ impl<'db> Closure<'db> { // FIXME: Infer coroutine closures' captures. return Vec::new(); }; - let owner = db.lookup_intern_closure(id).0; + let owner = id.loc(db).0; let Some(body_owner) = owner.as_def_with_body() else { return Vec::new(); }; @@ -5164,7 +5164,7 @@ impl<'db> Closure<'db> { pub fn fn_trait(&self, db: &dyn HirDatabase) -> FnTrait { match self.id { AnyClosureId::ClosureId(id) => { - let owner = db.lookup_intern_closure(id).0; + let owner = id.loc(db).0; let Some(body_owner) = owner.as_def_with_body() else { return FnTrait::FnOnce; }; @@ -6532,7 +6532,7 @@ pub struct Callable<'db> { enum Callee<'db> { Def(CallableDefId), Closure(InternedClosureId, GenericArgs<'db>), - CoroutineClosure(InternedCoroutineId, GenericArgs<'db>), + CoroutineClosure(InternedCoroutineClosureId, GenericArgs<'db>), FnPtr, FnImpl(traits::FnTrait), BuiltinDeriveImplMethod { method: BuiltinDeriveImplMethod, impl_: BuiltinDeriveImplId }, |