Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #15074 - oxalica:fix/inline-async-fn, r=lowr
Correctly handle inlining of async fn Fixes #10198
bors 2023-06-19
parent 67e1f52 · parent 2cd7b0c · commit f993485
-rw-r--r--crates/ide-assists/src/handlers/inline_call.rs141
-rw-r--r--crates/syntax/src/ast/edit_in_place.rs6
-rw-r--r--crates/syntax/src/ast/make.rs15
3 files changed, 146 insertions, 16 deletions
diff --git a/crates/ide-assists/src/handlers/inline_call.rs b/crates/ide-assists/src/handlers/inline_call.rs
index 797180fa18..67036029f5 100644
--- a/crates/ide-assists/src/handlers/inline_call.rs
+++ b/crates/ide-assists/src/handlers/inline_call.rs
@@ -15,7 +15,7 @@ use ide_db::{
};
use itertools::{izip, Itertools};
use syntax::{
- ast::{self, edit_in_place::Indent, HasArgList, PathExpr},
+ ast::{self, edit::IndentLevel, edit_in_place::Indent, HasArgList, PathExpr},
ted, AstNode, NodeOrToken, SyntaxKind,
};
@@ -306,7 +306,7 @@ fn inline(
params: &[(ast::Pat, Option<ast::Type>, hir::Param)],
CallInfo { node, arguments, generic_arg_list }: &CallInfo,
) -> ast::Expr {
- let body = if sema.hir_file_for(fn_body.syntax()).is_macro() {
+ let mut body = if sema.hir_file_for(fn_body.syntax()).is_macro() {
cov_mark::hit!(inline_call_defined_in_macro);
if let Some(body) = ast::BlockExpr::cast(insert_ws_into(fn_body.syntax().clone())) {
body
@@ -391,19 +391,19 @@ fn inline(
}
}
+ let mut let_stmts = Vec::new();
+
// Inline parameter expressions or generate `let` statements depending on whether inlining works or not.
- for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() {
+ for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments) {
// izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors
let usages: &[ast::PathExpr] = &usages;
let expr: &ast::Expr = expr;
- let insert_let_stmt = || {
+ let mut insert_let_stmt = || {
let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone());
- if let Some(stmt_list) = body.stmt_list() {
- stmt_list.push_front(
- make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(),
- )
- }
+ let_stmts.push(
+ make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(),
+ );
};
// check if there is a local var in the function that conflicts with parameter
@@ -457,6 +457,24 @@ fn inline(
}
}
+ let is_async_fn = function.is_async(sema.db);
+ if is_async_fn {
+ cov_mark::hit!(inline_call_async_fn);
+ body = make::async_move_block_expr(body.statements(), body.tail_expr()).clone_for_update();
+
+ // Arguments should be evaluated outside the async block, and then moved into it.
+ if !let_stmts.is_empty() {
+ cov_mark::hit!(inline_call_async_fn_with_let_stmts);
+ body.indent(IndentLevel(1));
+ body = make::block_expr(let_stmts, Some(body.into())).clone_for_update();
+ }
+ } else if let Some(stmt_list) = body.stmt_list() {
+ ted::insert_all(
+ ted::Position::after(stmt_list.l_curly_token().unwrap()),
+ let_stmts.into_iter().map(|stmt| stmt.syntax().clone().into()).collect(),
+ );
+ }
+
let original_indentation = match node {
ast::CallableExpr::Call(it) => it.indent_level(),
ast::CallableExpr::MethodCall(it) => it.indent_level(),
@@ -464,7 +482,7 @@ fn inline(
body.reindent_to(original_indentation);
match body.tail_expr() {
- Some(expr) if body.statements().next().is_none() => expr,
+ Some(expr) if !is_async_fn && body.statements().next().is_none() => expr,
_ => match node
.syntax()
.parent()
@@ -1353,4 +1371,107 @@ fn main() {
"#,
);
}
+
+ #[test]
+ fn async_fn_single_expression() {
+ cov_mark::check!(inline_call_async_fn);
+ check_assist(
+ inline_call,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(arg: u32) -> u32 {
+ bar(arg).await * 2
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ spawn(foo$0(42));
+}
+"#,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(arg: u32) -> u32 {
+ bar(arg).await * 2
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ spawn(async move {
+ bar(42).await * 2
+ });
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn async_fn_multiple_statements() {
+ cov_mark::check!(inline_call_async_fn);
+ check_assist(
+ inline_call,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(arg: u32) -> u32 {
+ bar(arg).await;
+ 42
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ spawn(foo$0(42));
+}
+"#,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(arg: u32) -> u32 {
+ bar(arg).await;
+ 42
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ spawn(async move {
+ bar(42).await;
+ 42
+ });
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn async_fn_with_let_statements() {
+ cov_mark::check!(inline_call_async_fn);
+ cov_mark::check!(inline_call_async_fn_with_let_stmts);
+ check_assist(
+ inline_call,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(x: u32, y: u32, z: &u32) -> u32 {
+ bar(x).await;
+ y + y + *z
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ let var = 42;
+ spawn(foo$0(var, var + 1, &var));
+}
+"#,
+ r#"
+async fn bar(x: u32) -> u32 { x + 1 }
+async fn foo(x: u32, y: u32, z: &u32) -> u32 {
+ bar(x).await;
+ y + y + *z
+}
+fn spawn<T>(_: T) {}
+fn main() {
+ let var = 42;
+ spawn({
+ let y = var + 1;
+ let z: &u32 = &var;
+ async move {
+ bar(var).await;
+ y + y + *z
+ }
+ });
+}
+"#,
+ );
+ }
}
diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs
index b3ea6ca8d4..5919609d91 100644
--- a/crates/syntax/src/ast/edit_in_place.rs
+++ b/crates/syntax/src/ast/edit_in_place.rs
@@ -676,12 +676,6 @@ fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken {
}
}
-impl ast::StmtList {
- pub fn push_front(&self, statement: ast::Stmt) {
- ted::insert(Position::after(self.l_curly_token().unwrap()), statement.syntax());
- }
-}
-
impl ast::VariantList {
pub fn add_variant(&self, variant: ast::Variant) {
let (indent, position) = match self.variants().last() {
diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs
index 3c2b7e56b0..e435766000 100644
--- a/crates/syntax/src/ast/make.rs
+++ b/crates/syntax/src/ast/make.rs
@@ -447,6 +447,21 @@ pub fn block_expr(
ast_from_text(&format!("fn f() {buf}"))
}
+pub fn async_move_block_expr(
+ stmts: impl IntoIterator<Item = ast::Stmt>,
+ tail_expr: Option<ast::Expr>,
+) -> ast::BlockExpr {
+ let mut buf = "async move {\n".to_string();
+ for stmt in stmts.into_iter() {
+ format_to!(buf, " {stmt}\n");
+ }
+ if let Some(tail_expr) = tail_expr {
+ format_to!(buf, " {tail_expr}\n");
+ }
+ buf += "}";
+ ast_from_text(&format!("const _: () = {buf};"))
+}
+
pub fn tail_only_block_expr(tail_expr: ast::Expr) -> ast::BlockExpr {
ast_from_text(&format!("fn f() {{ {tail_expr} }}"))
}