Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/unnecessary_async.rs67
1 files changed, 65 insertions, 2 deletions
diff --git a/crates/ide-assists/src/handlers/unnecessary_async.rs b/crates/ide-assists/src/handlers/unnecessary_async.rs
index d5cd2d5513..44f8dbdef3 100644
--- a/crates/ide-assists/src/handlers/unnecessary_async.rs
+++ b/crates/ide-assists/src/handlers/unnecessary_async.rs
@@ -1,12 +1,14 @@
+use hir::AssocItem;
use ide_db::{
assists::{AssistId, AssistKind},
base_db::FileId,
defs::Definition,
search::FileReference,
syntax_helpers::node_ext::full_path_of_name_ref,
+ traits::resolve_target_trait,
};
use syntax::{
- ast::{self, NameLike, NameRef},
+ ast::{self, HasName, NameLike, NameRef},
AstNode, SyntaxKind, TextRange,
};
@@ -44,7 +46,16 @@ pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext<'_>) -> O
if function.body()?.syntax().descendants().find_map(ast::AwaitExpr::cast).is_some() {
return None;
}
-
+ // Do nothing if the method is an async member of trait.
+ if let Some(fname) = function.name() {
+ if let Some(trait_item) = find_corresponding_trait_member(ctx, fname.to_string()) {
+ if let AssocItem::Function(method) = trait_item {
+ if method.is_async(ctx.db()) {
+ return None;
+ }
+ }
+ }
+ }
// Remove the `async` keyword plus whitespace after it, if any.
let async_range = {
let async_token = function.async_token()?;
@@ -88,6 +99,23 @@ pub(crate) fn unnecessary_async(acc: &mut Assists, ctx: &AssistContext<'_>) -> O
)
}
+fn find_corresponding_trait_member(
+ ctx: &AssistContext<'_>,
+ function_name: String,
+) -> Option<AssocItem> {
+ let impl_ = ctx.find_node_at_offset::<ast::Impl>()?;
+ let trait_ = resolve_target_trait(&ctx.sema, &impl_)?;
+
+ trait_
+ .items(ctx.db())
+ .iter()
+ .find(|item| match item.name(ctx.db()) {
+ Some(method_name) => method_name.to_string() == function_name,
+ _ => false,
+ })
+ .cloned()
+}
+
fn find_all_references(
ctx: &AssistContext<'_>,
def: &Definition,
@@ -254,4 +282,39 @@ pub async fn f(s: &S) { s.f2() }"#,
fn does_not_apply_when_not_on_prototype() {
check_assist_not_applicable(unnecessary_async, "pub async fn f() { $0f2() }")
}
+
+ #[test]
+ fn applies_on_unnecessary_async_on_trait_method() {
+ check_assist(
+ unnecessary_async,
+ r#"
+trait Trait {
+ fn foo();
+}
+impl Trait for () {
+ $0async fn foo() {}
+}"#,
+ r#"
+trait Trait {
+ fn foo();
+}
+impl Trait for () {
+ fn foo() {}
+}"#,
+ );
+ }
+
+ #[test]
+ fn does_not_apply_on_async_trait_method() {
+ check_assist_not_applicable(
+ unnecessary_async,
+ r#"
+trait Trait {
+ async fn foo();
+}
+impl Trait for () {
+ $0async fn foo() {}
+}"#,
+ );
+ }
}