Unnamed repository; edit this file 'description' to name the repository.
| -rw-r--r-- | crates/hir-def/src/expr_store/lower.rs | 113 | ||||
| -rw-r--r-- | crates/hir-def/src/expr_store/tests/body.rs | 8 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/simple.rs | 15 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/traits.rs | 3 | ||||
| -rw-r--r-- | crates/ide-assists/src/handlers/inline_call.rs | 19 |
5 files changed, 132 insertions, 26 deletions
diff --git a/crates/hir-def/src/expr_store/lower.rs b/crates/hir-def/src/expr_store/lower.rs index 76c65361e5..a82046149a 100644 --- a/crates/hir-def/src/expr_store/lower.rs +++ b/crates/hir-def/src/expr_store/lower.rs @@ -961,44 +961,99 @@ impl<'db> ExprCollector<'db> { /// into the body. This is to make sure that the future actually owns the /// arguments that are passed to the function, and to ensure things like /// drop order are stable. - fn lower_coroutine_with_moved_arguments( + fn lower_coroutine_body_with_moved_arguments( &mut self, params: &mut [PatId], body: ExprId, kind: CoroutineKind, coroutine_source: CoroutineSource, ) -> ExprId { + // Async function parameters are lowered into the closure body so that they are + // captured and so that the drop order matches the equivalent non-async functions. + // + // from: + // + // async fn foo(<pattern>: <ty>, <pattern>: <ty>, <pattern>: <ty>) { + // <body> + // } + // + // into: + // + // fn foo(__arg0: <ty>, __arg1: <ty>, __arg2: <ty>) { + // async move { + // let __arg2 = __arg2; + // let <pattern> = __arg2; + // let __arg1 = __arg1; + // let <pattern> = __arg1; + // let __arg0 = __arg0; + // let <pattern> = __arg0; + // drop-temps { <body> } // see comments later in fn for details + // } + // } + // + // If `<pattern>` is a simple ident, then it is lowered to a single + // `let <pattern> = <pattern>;` statement as an optimization. + let mut statements = Vec::new(); for param in params { - let (name, hygiene) = match self.store.pats[*param] { - Pat::Bind { id, .. } + let (name, hygiene, is_simple_parameter) = match self.store.pats[*param] { + // Check if this is a binding pattern, if so, we can optimize and avoid adding a + // `let <pat> = __argN;` statement. In this case, we do not rename the parameter. + Pat::Bind { id, subpat: None, .. } if matches!( self.store.bindings[id].mode, BindingAnnotation::Unannotated | BindingAnnotation::Mutable ) => { - // If this is a direct binding, we can leave it as-is, as it'll always be captured anyway. - continue; + (self.store.bindings[id].name.clone(), self.store.bindings[id].hygiene, true) } Pat::Bind { id, .. } => { // If this is a `ref` binding, we can't leave it as is but we can at least reuse the name, for better display. - (self.store.bindings[id].name.clone(), self.store.bindings[id].hygiene) + (self.store.bindings[id].name.clone(), self.store.bindings[id].hygiene, false) } - _ => (self.generate_new_name(), HygieneId::ROOT), + _ => (self.generate_new_name(), HygieneId::ROOT, false), }; - let binding_id = self.alloc_binding(name.clone(), BindingAnnotation::Mutable, hygiene); - let pat_id = self.alloc_pat_desugared(Pat::Bind { id: binding_id, subpat: None }); - let expr = self.alloc_expr_desugared(Expr::Path(name.into())); + let pat_syntax = self.store.pat_map_back.get(*param).copied(); + let child_binding_id = + self.alloc_binding(name.clone(), BindingAnnotation::Mutable, hygiene); + let child_pat_id = + self.alloc_pat_desugared(Pat::Bind { id: child_binding_id, subpat: None }); + self.add_definition_to_binding(child_binding_id, child_pat_id); + if let Some(pat_syntax) = pat_syntax { + self.store.pat_map_back.insert(child_pat_id, pat_syntax); + } + let expr = self.alloc_expr_desugared(Expr::Path(name.clone().into())); if !hygiene.is_root() { self.store.ident_hygiene.insert(expr.into(), hygiene); } statements.push(Statement::Let { - pat: *param, + pat: child_pat_id, type_ref: None, initializer: Some(expr), else_branch: None, }); - *param = pat_id; + if !is_simple_parameter { + let expr = self.alloc_expr_desugared(Expr::Path(name.clone().into())); + if !hygiene.is_root() { + self.store.ident_hygiene.insert(expr.into(), hygiene); + } + statements.push(Statement::Let { + pat: *param, + type_ref: None, + initializer: Some(expr), + else_branch: None, + }); + + let parent_binding_id = + self.alloc_binding(name.clone(), BindingAnnotation::Mutable, hygiene); + let parent_pat_id = + self.alloc_pat_desugared(Pat::Bind { id: parent_binding_id, subpat: None }); + self.add_definition_to_binding(parent_binding_id, parent_pat_id); + if let Some(pat_syntax) = pat_syntax { + self.store.pat_map_back.insert(parent_pat_id, pat_syntax); + } + *param = parent_pat_id; + } } let coroutine = self.desugared_coroutine_expr( @@ -1055,7 +1110,12 @@ impl<'db> ExprCollector<'db> { (false, true) => CoroutineKind::Gen, (false, false) => unreachable!(), }; - this.lower_coroutine_with_moved_arguments(params, body, kind, CoroutineSource::Fn) + this.lower_coroutine_body_with_moved_arguments( + params, + body, + kind, + CoroutineSource::Fn, + ) } else { body } @@ -1479,11 +1539,11 @@ impl<'db> ExprCollector<'db> { } } ast::Expr::ClosureExpr(e) => self.with_label_rib(RibKind::Closure, |this| { - this.with_binding_owner_and_return(|this| { + let mut is_coroutine_closure = false; + let closure = this.with_binding_owner_and_return(|this| { let mut args = Vec::new(); let mut arg_types = Vec::new(); // For coroutine closures, the body, aka. the coroutine is the bindings owner, and not the closure. - let mut body_is_bindings_owner = false; if let Some(pl) = e.param_list() { let num_params = pl.params().count(); args.reserve_exact(num_params); @@ -1526,13 +1586,13 @@ impl<'db> ExprCollector<'db> { let closure_kind = if let Some(kind) = kind { // It's important that this expr is allocated immediately before the closure. // We rely on it for `coroutine_for_closure()`. - body = this.lower_coroutine_with_moved_arguments( + body = this.lower_coroutine_body_with_moved_arguments( &mut args, body, kind, CoroutineSource::Closure, ); - body_is_bindings_owner = true; + is_coroutine_closure = true; ClosureKind::CoroutineClosure(kind) } else if this.is_lowering_coroutine { @@ -1561,8 +1621,23 @@ impl<'db> ExprCollector<'db> { syntax_ptr, ); - (if body_is_bindings_owner { body } else { closure }, closure) - }) + (if is_coroutine_closure { body } else { closure }, closure) + }); + + if is_coroutine_closure { + let Expr::Closure { args, .. } = &this.store.exprs[closure] else { + unreachable!() + }; + for &arg in args { + let Pat::Bind { id, .. } = this.store.pats[arg] else { + never!("`lower_coroutine_body_with_moved_arguments()` should make sure the coroutine closure only have simple bind args"); + continue; + }; + this.store.binding_owners.insert(id, closure); + } + } + + closure }), ast::Expr::BinExpr(e) => { let op = e.op_kind(); diff --git a/crates/hir-def/src/expr_store/tests/body.rs b/crates/hir-def/src/expr_store/tests/body.rs index db12775df9..9727d87cf0 100644 --- a/crates/hir-def/src/expr_store/tests/body.rs +++ b/crates/hir-def/src/expr_store/tests/body.rs @@ -652,9 +652,15 @@ fn async_fn_weird_param_patterns() { async fn main(&self, param1: i32, ref mut param2: i32, _: i32, param4 @ _: i32, 123: i32) {} "#, expect![[r#" - fn main(self, param1, mut param2, mut <ra@gennew>0, param4 @ _, mut <ra@gennew>1) async { + fn main(self, param1, mut param2, mut <ra@gennew>0, mut param4, mut <ra@gennew>1) async { + let mut param1 = param1; + let mut param2 = param2; let ref mut param2 = param2; + let mut <ra@gennew>0 = <ra@gennew>0; let _ = <ra@gennew>0; + let mut param4 = param4; + let param4 @ _ = param4; + let mut <ra@gennew>1 = <ra@gennew>1; let 123 = <ra@gennew>1; {} }"#]], diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 1e75c31fa1..fbe7c3bd37 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -4032,6 +4032,7 @@ fn main() { 100..147 'async_... })': () 114..146 'async ... }': impl AsyncFnOnce(i32) 121..124 'arg': i32 + 121..124 'arg': i32 126..146 '{ ... }': () 136..139 'arg': i32 153..160 'closure': fn closure<impl FnOnce(i32)>(impl FnOnce(i32)) @@ -4279,3 +4280,17 @@ union U { "#]], ); } + +#[test] +fn async_closure_with_params() { + check_no_mismatches( + r#" +fn foo() { + let capture = false; + async move |param: i32| { + capture; + }; +} + "#, + ); +} diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index bcb5e5de16..18e4a5b41d 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4930,6 +4930,7 @@ async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) { "#, expect![[r#" 37..38 'a': T + 37..38 'a': T 43..83 '{ ...ait; }': () 53..57 'fut1': <T as AsyncFnMut<(u32,)>>::CallRefFuture<'?> 60..61 'a': T @@ -4938,6 +4939,7 @@ async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) { 70..74 'fut1': <T as AsyncFnMut<(u32,)>>::CallRefFuture<'?> 70..80 'fut1.await': i32 124..129 'mut b': T + 124..129 'mut b': T 134..174 '{ ...ait; }': () 144..148 'fut2': <T as AsyncFnMut<(u32,)>>::CallRefFuture<'?> 151..152 'b': T @@ -4946,6 +4948,7 @@ async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) { 161..165 'fut2': <T as AsyncFnMut<(u32,)>>::CallRefFuture<'?> 161..171 'fut2.await': i32 216..217 'c': T + 216..217 'c': T 222..262 '{ ...ait; }': () 232..236 'fut3': <T as AsyncFnOnce<(u32,)>>::CallOnceFuture 239..240 'c': T diff --git a/crates/ide-assists/src/handlers/inline_call.rs b/crates/ide-assists/src/handlers/inline_call.rs index 5299680980..af048c6ae0 100644 --- a/crates/ide-assists/src/handlers/inline_call.rs +++ b/crates/ide-assists/src/handlers/inline_call.rs @@ -1566,8 +1566,11 @@ async fn foo(arg: u32) -> u32 { } fn spawn<T>(_: T) {} fn main() { - spawn(async move { - bar(42).await * 2 + spawn({ + let arg = 42; + async move { + bar(arg).await * 2 + } }); } "#, @@ -1598,9 +1601,12 @@ async fn foo(arg: u32) -> u32 { } fn spawn<T>(_: T) {} fn main() { - spawn(async move { - bar(42).await; - 42 + spawn({ + let arg = 42; + async move { + bar(arg).await; + 42 + } }); } "#, @@ -1635,10 +1641,11 @@ fn spawn<T>(_: T) {} fn main() { let var = 42; spawn({ + let x = var; let y = var + 1; let z: &u32 = &var; async move { - bar(var).await; + bar(x).await; y + y + *z } }); |