Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-def/src/expr_store/lower.rs')
| -rw-r--r-- | crates/hir-def/src/expr_store/lower.rs | 167 |
1 files changed, 119 insertions, 48 deletions
diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs index 7fe91a3d02..04437a59ac 100644 --- a/crates/hir-def/src/expr_store/lower.rs +++ b/crates/hir-def/src/expr_store/lower.rs @@ -46,8 +46,9 @@ use crate::{ }, hir::{ Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind, - Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, OffsetOf, Pat, PatId, - RecordFieldPat, RecordLitField, RecordSpread, Statement, generics::GenericParams, + CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, + OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement, + generics::GenericParams, }, item_scope::BuiltinShadowMode, item_tree::FieldsShape, @@ -944,12 +945,19 @@ impl<'db> ExprCollector<'db> { }) } - /// An `async fn` needs to capture all parameters in the generated `async` block, even if they have - /// non-captured patterns such as wildcards (to ensure consistent drop order). - fn lower_async_fn(&mut self, params: &mut Vec<PatId>, body: ExprId) -> ExprId { + /// Lowers a desugared coroutine body after moving all of the arguments + /// into the body. This is to make sure that the future actually owns the + /// arguments that are passed to the function, and to ensure things like + /// drop order are stable. + fn lower_async_block_with_moved_arguments( + &mut self, + params: &mut [PatId], + body: ExprId, + coroutine_source: CoroutineSource, + ) -> ExprId { let mut statements = Vec::new(); for param in params { - let name = match self.store.pats[*param] { + let (name, hygiene) = match self.store.pats[*param] { Pat::Bind { id, .. } if matches!( self.store.bindings[id].mode, @@ -961,14 +969,16 @@ impl<'db> ExprCollector<'db> { } Pat::Bind { id, .. } => { // If this is a `ref` binding, we can't leave it as is but we can at least reuse the name, for better display. - self.store.bindings[id].name.clone() + (self.store.bindings[id].name.clone(), self.store.bindings[id].hygiene) } - _ => self.generate_new_name(), + _ => (self.generate_new_name(), HygieneId::ROOT), }; - let binding_id = - self.alloc_binding(name.clone(), BindingAnnotation::Mutable, HygieneId::ROOT); + let binding_id = self.alloc_binding(name.clone(), BindingAnnotation::Mutable, hygiene); let pat_id = self.alloc_pat_desugared(Pat::Bind { id: binding_id, subpat: None }); let expr = self.alloc_expr_desugared(Expr::Path(name.into())); + if !hygiene.is_root() { + self.store.ident_hygiene.insert(expr.into(), hygiene); + } statements.push(Statement::Let { pat: *param, type_ref: None, @@ -978,23 +988,54 @@ impl<'db> ExprCollector<'db> { *param = pat_id; } - self.alloc_expr_desugared(Expr::Async { - id: None, - statements: statements.into_boxed_slice(), - tail: Some(body), - }) + let async_ = self.async_block( + coroutine_source, + // The default capture mode here is by-ref. Later on during upvar analysis, + // we will force the captured arguments to by-move, but for async closures, + // we want to make sure that we avoid unnecessarily moving captures, or else + // all async closures would default to `FnOnce` as their calling mode. + CaptureBy::Ref, + None, + statements.into_boxed_slice(), + Some(body), + ); + // It's important that this comes last, see the lowering of async closures for why. + self.alloc_expr_desugared(async_) + } + + fn async_block( + &mut self, + source: CoroutineSource, + capture_by: CaptureBy, + id: Option<BlockId>, + statements: Box<[Statement]>, + tail: Option<ExprId>, + ) -> Expr { + let block = self.alloc_expr_desugared(Expr::Block { label: None, id, statements, tail }); + Expr::Closure { + args: Box::default(), + arg_types: Box::default(), + ret_type: None, + body: block, + closure_kind: ClosureKind::AsyncBlock { source }, + capture_by, + } } fn collect( &mut self, - params: &mut Vec<PatId>, + params: &mut [PatId], expr: Option<ast::Expr>, awaitable: Awaitable, ) -> ExprId { self.awaitable_context.replace(awaitable); self.with_label_rib(RibKind::Closure, |this| { let body = this.collect_expr_opt(expr); - if awaitable == Awaitable::Yes { this.lower_async_fn(params, body) } else { body } + if awaitable == Awaitable::Yes { + this.lower_async_block_with_moved_arguments(params, body, CoroutineSource::Fn) + } else { + body + } }) } @@ -1126,7 +1167,7 @@ impl<'db> ExprCollector<'db> { self.desugar_try_block(e, result_type) } Some(ast::BlockModifier::Unsafe(_)) => { - self.collect_block_(e, |id, statements, tail| Expr::Unsafe { + self.collect_block_(e, |_, id, statements, tail| Expr::Unsafe { id, statements, tail, @@ -1136,7 +1177,7 @@ impl<'db> ExprCollector<'db> { let label_hygiene = self.hygiene_id_for(label.syntax().text_range()); let label_id = self.collect_label(label); self.with_labeled_rib(label_id, label_hygiene, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Block { + this.collect_block_(e, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -1145,12 +1186,18 @@ impl<'db> ExprCollector<'db> { }) } Some(ast::BlockModifier::Async(_)) => { + let capture_by = + if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; self.with_label_rib(RibKind::Closure, |this| { this.with_awaitable_block(Awaitable::Yes, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Async { - id, - statements, - tail, + this.collect_block_(e, |this, id, statements, tail| { + this.async_block( + CoroutineSource::Block, + capture_by, + id, + statements, + tail, + ) }) }) }) @@ -1378,9 +1425,11 @@ impl<'db> ExprCollector<'db> { } } ast::Expr::ClosureExpr(e) => self.with_label_rib(RibKind::Closure, |this| { - this.with_binding_owner(|this| { + this.with_binding_owner_and_return(|this| { let mut args = Vec::new(); let mut arg_types = Vec::new(); + // For coroutine closures, the body, aka. the coroutine is the bindings owner, and not the closure. + let mut body_is_bindings_owner = false; if let Some(pl) = e.param_list() { let num_params = pl.params().count(); args.reserve_exact(num_params); @@ -1406,7 +1455,7 @@ impl<'db> ExprCollector<'db> { } else { Awaitable::No("non-async closure") }; - let body = this + let mut body = this .with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body())); let closure_kind = if this.is_lowering_coroutine { @@ -1417,7 +1466,16 @@ impl<'db> ExprCollector<'db> { }; ClosureKind::Coroutine(movability) } else if e.async_token().is_some() { - ClosureKind::Async + // It's important that this expr is allocated immediately before the closure. + // We rely on it for `coroutine_for_closure()`. + body = this.lower_async_block_with_moved_arguments( + &mut args, + body, + CoroutineSource::Closure, + ); + body_is_bindings_owner = true; + + ClosureKind::AsyncClosure } else { ClosureKind::Closure }; @@ -1425,7 +1483,7 @@ impl<'db> ExprCollector<'db> { if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; this.is_lowering_coroutine = prev_is_lowering_coroutine; this.current_try_block = prev_try_block; - this.alloc_expr( + let closure = this.alloc_expr( Expr::Closure { args: args.into(), arg_types: arg_types.into(), @@ -1435,7 +1493,9 @@ impl<'db> ExprCollector<'db> { capture_by, }, syntax_ptr, - ) + ); + + (if body_is_bindings_owner { body } else { closure }, closure) }) }), ast::Expr::BinExpr(e) => { @@ -1737,13 +1797,24 @@ impl<'db> ExprCollector<'db> { } } - fn with_binding_owner(&mut self, create_expr: impl FnOnce(&mut Self) -> ExprId) -> ExprId { + /// The callback should return two exprs: the first is the bindings owner, the second is the expr to return. + fn with_binding_owner_and_return( + &mut self, + create_expr: impl FnOnce(&mut Self) -> (ExprId, ExprId), + ) -> ExprId { let prev_unowned_bindings_len = self.unowned_bindings.len(); - let expr_id = create_expr(self); + let (bindings_owner, expr_to_return) = create_expr(self); for binding in self.unowned_bindings.drain(prev_unowned_bindings_len..) { - self.store.binding_owners.insert(binding, expr_id); + self.store.binding_owners.insert(binding, bindings_owner); } - expr_id + expr_to_return + } + + fn with_binding_owner(&mut self, create_expr: impl FnOnce(&mut Self) -> ExprId) -> ExprId { + self.with_binding_owner_and_return(move |this| { + let expr = create_expr(this); + (expr, expr) + }) } /// Desugar `try { <stmts>; <expr> }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(<expr>) }`, @@ -1762,7 +1833,7 @@ impl<'db> ExprCollector<'db> { let ptr = AstPtr::new(&e).upcast(); let (btail, expr_id) = self.with_labeled_rib(label, HygieneId::ROOT, |this| { let mut btail = None; - let block = this.collect_block_(e, |id, statements, tail| { + let block = this.collect_block_(e, |_, id, statements, tail| { btail = tail; Expr::Block { id, statements, tail, label: Some(label) } }); @@ -2220,7 +2291,7 @@ impl<'db> ExprCollector<'db> { } fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { - self.collect_block_(block, |id, statements, tail| Expr::Block { + self.collect_block_(block, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -2231,7 +2302,7 @@ impl<'db> ExprCollector<'db> { fn collect_block_( &mut self, block: ast::BlockExpr, - mk_block: impl FnOnce(Option<BlockId>, Box<[Statement]>, Option<ExprId>) -> Expr, + mk_block: impl FnOnce(&mut Self, Option<BlockId>, Box<[Statement]>, Option<ExprId>) -> Expr, ) -> ExprId { let block_id = self.expander.ast_id_map().ast_id_for_block(&block).map(|file_local_id| { let ast_id = self.expander.in_file(file_local_id); @@ -2266,8 +2337,8 @@ impl<'db> ExprCollector<'db> { }); let syntax_node_ptr = AstPtr::new(&block.into()); - let expr_id = self - .alloc_expr(mk_block(block_id, statements.into_boxed_slice(), tail), syntax_node_ptr); + let expr = mk_block(self, block_id, statements.into_boxed_slice(), tail); + let expr_id = self.alloc_expr(expr, syntax_node_ptr); self.def_map = prev_def_map; self.module = prev_local_module; @@ -2693,17 +2764,17 @@ impl<'db> ExprCollector<'db> { for (rib_idx, rib) in self.label_ribs.iter().enumerate().rev() { match &rib.kind { - RibKind::Normal(label_name, id, label_hygiene) => { - if *label_name == name && *label_hygiene == hygiene_id { - return if self.is_label_valid_from_rib(rib_idx) { - Ok(Some(*id)) - } else { - Err(ExpressionStoreDiagnostics::UnreachableLabel { - name, - node: self.expander.in_file(AstPtr::new(&lifetime)), - }) - }; - } + RibKind::Normal(label_name, id, label_hygiene) + if *label_name == name && *label_hygiene == hygiene_id => + { + return if self.is_label_valid_from_rib(rib_idx) { + Ok(Some(*id)) + } else { + Err(ExpressionStoreDiagnostics::UnreachableLabel { + name, + node: self.expander.in_file(AstPtr::new(&lifetime)), + }) + }; } RibKind::MacroDef(macro_id) => { if let Some((parent_ctx, label_macro_id)) = hygiene_info |