Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #16460 - davidsemakula:trailing-return-diagnostic, r=Veykril
feat: Add diagnostic with fix to replace trailing `return <val>;` with `<val>` Works for functions and closures. Ignores desugared return expressions (e.g. from desugared try operators). Fixes: #10970 Completes: #11020
bors 2024-02-08
parent e071834 · parent 602acfc · commit e418c90
-rw-r--r--crates/hir-ty/src/diagnostics/expr.rs43
-rw-r--r--crates/hir/src/diagnostics.rs21
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_trailing_return.rs375
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs14
-rw-r--r--crates/ide-diagnostics/src/handlers/type_mismatch.rs7
-rw-r--r--crates/ide-diagnostics/src/lib.rs2
-rw-r--r--crates/ide-diagnostics/src/tests.rs28
7 files changed, 479 insertions, 11 deletions
diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs
index c09351390a..7f8fb7f4b5 100644
--- a/crates/hir-ty/src/diagnostics/expr.rs
+++ b/crates/hir-ty/src/diagnostics/expr.rs
@@ -44,6 +44,9 @@ pub enum BodyValidationDiagnostic {
match_expr: ExprId,
uncovered_patterns: String,
},
+ RemoveTrailingReturn {
+ return_expr: ExprId,
+ },
RemoveUnnecessaryElse {
if_expr: ExprId,
},
@@ -75,6 +78,10 @@ impl ExprValidator {
let body = db.body(self.owner);
let mut filter_map_next_checker = None;
+ if matches!(self.owner, DefWithBodyId::FunctionId(_)) {
+ self.check_for_trailing_return(body.body_expr, &body);
+ }
+
for (id, expr) in body.exprs.iter() {
if let Some((variant, missed_fields, true)) =
record_literal_missing_fields(db, &self.infer, id, expr)
@@ -93,12 +100,16 @@ impl ExprValidator {
Expr::Call { .. } | Expr::MethodCall { .. } => {
self.validate_call(db, id, expr, &mut filter_map_next_checker);
}
+ Expr::Closure { body: body_expr, .. } => {
+ self.check_for_trailing_return(*body_expr, &body);
+ }
Expr::If { .. } => {
self.check_for_unnecessary_else(id, expr, &body);
}
_ => {}
}
}
+
for (id, pat) in body.pats.iter() {
if let Some((variant, missed_fields, true)) =
record_pattern_missing_fields(db, &self.infer, id, pat)
@@ -244,6 +255,38 @@ impl ExprValidator {
pattern
}
+ fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) {
+ match &body.exprs[body_expr] {
+ Expr::Block { statements, tail, .. } => {
+ let last_stmt = tail.or_else(|| match statements.last()? {
+ Statement::Expr { expr, .. } => Some(*expr),
+ _ => None,
+ });
+ if let Some(last_stmt) = last_stmt {
+ self.check_for_trailing_return(last_stmt, body);
+ }
+ }
+ Expr::If { then_branch, else_branch, .. } => {
+ self.check_for_trailing_return(*then_branch, body);
+ if let Some(else_branch) = else_branch {
+ self.check_for_trailing_return(*else_branch, body);
+ }
+ }
+ Expr::Match { arms, .. } => {
+ for arm in arms.iter() {
+ let MatchArm { expr, .. } = arm;
+ self.check_for_trailing_return(*expr, body);
+ }
+ }
+ Expr::Return { .. } => {
+ self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn {
+ return_expr: body_expr,
+ });
+ }
+ _ => (),
+ }
+ }
+
fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) {
if let Expr::If { condition: _, then_branch, else_branch } = expr {
if else_branch.is_none() {
diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs
index 487e0c8f7a..08843a6c99 100644
--- a/crates/hir/src/diagnostics.rs
+++ b/crates/hir/src/diagnostics.rs
@@ -67,8 +67,9 @@ diagnostics![
NoSuchField,
PrivateAssocItem,
PrivateField,
- ReplaceFilterMapNextWithFindMap,
+ RemoveTrailingReturn,
RemoveUnnecessaryElse,
+ ReplaceFilterMapNextWithFindMap,
TraitImplIncorrectSafety,
TraitImplMissingAssocItems,
TraitImplOrphan,
@@ -344,6 +345,11 @@ pub struct TraitImplRedundantAssocItems {
}
#[derive(Debug)]
+pub struct RemoveTrailingReturn {
+ pub return_expr: InFile<AstPtr<ast::ReturnExpr>>,
+}
+
+#[derive(Debug)]
pub struct RemoveUnnecessaryElse {
pub if_expr: InFile<AstPtr<ast::IfExpr>>,
}
@@ -450,6 +456,19 @@ impl AnyDiagnostic {
Err(SyntheticSyntax) => (),
}
}
+ BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
+ if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
+ // Filters out desugared return expressions (e.g. desugared try operators).
+ if let Some(ptr) = source_ptr.value.cast::<ast::ReturnExpr>() {
+ return Some(
+ RemoveTrailingReturn {
+ return_expr: InFile::new(source_ptr.file_id, ptr),
+ }
+ .into(),
+ );
+ }
+ }
+ }
BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(if_expr) {
if let Some(ptr) = source_ptr.value.cast::<ast::IfExpr>() {
diff --git a/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
new file mode 100644
index 0000000000..605e8baba0
--- /dev/null
+++ b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs
@@ -0,0 +1,375 @@
+use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn};
+use ide_db::{assists::Assist, base_db::FileRange, source_change::SourceChange};
+use syntax::{ast, AstNode};
+use text_edit::TextEdit;
+
+use crate::{adjusted_display_range, fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
+
+// Diagnostic: remove-trailing-return
+//
+// This diagnostic is triggered when there is a redundant `return` at the end of a function
+// or closure.
+pub(crate) fn remove_trailing_return(
+ ctx: &DiagnosticsContext<'_>,
+ d: &RemoveTrailingReturn,
+) -> Diagnostic {
+ let display_range = adjusted_display_range(ctx, d.return_expr, &|return_expr| {
+ return_expr
+ .syntax()
+ .parent()
+ .and_then(ast::ExprStmt::cast)
+ .map(|stmt| stmt.syntax().text_range())
+ });
+ Diagnostic::new(
+ DiagnosticCode::Clippy("needless_return"),
+ "replace return <expr>; with <expr>",
+ display_range,
+ )
+ .with_fixes(fixes(ctx, d))
+}
+
+fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<Assist>> {
+ let root = ctx.sema.db.parse_or_expand(d.return_expr.file_id);
+ let return_expr = d.return_expr.value.to_node(&root);
+ let stmt = return_expr.syntax().parent().and_then(ast::ExprStmt::cast);
+
+ let FileRange { range, file_id } =
+ ctx.sema.original_range_opt(stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax))?;
+ if Some(file_id) != d.return_expr.file_id.file_id() {
+ return None;
+ }
+
+ let replacement =
+ return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text()));
+ let edit = TextEdit::replace(range, replacement);
+ let source_change = SourceChange::from_text_edit(file_id, edit);
+
+ Some(vec![fix(
+ "remove_trailing_return",
+ "Replace return <expr>; with <expr>",
+ source_change,
+ range,
+ )])
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::tests::{
+ check_diagnostics, check_diagnostics_with_disabled, check_fix, check_fix_with_disabled,
+ };
+
+ #[test]
+ fn remove_trailing_return() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ return 2;
+} //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_inner_function() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ fn bar() -> u8 {
+ return 2;
+ } //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+ bar()
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_closure() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ let bar = || return 2;
+ bar() //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+}
+"#,
+ );
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ let bar = || {
+ return 2;
+ };//^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+ bar()
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_unit() {
+ check_diagnostics(
+ r#"
+fn foo() {
+ return
+} //^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_no_semi() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ return 2
+} //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+"#,
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_in_if() {
+ check_diagnostics_with_disabled(
+ r#"
+fn foo(x: usize) -> u8 {
+ if x > 0 {
+ return 1;
+ //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+ } else {
+ return 0;
+ } //^^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+}
+"#,
+ std::iter::once("remove-unnecessary-else".to_string()),
+ );
+ }
+
+ #[test]
+ fn remove_trailing_return_in_match() {
+ check_diagnostics(
+ r#"
+fn foo<T, E>(x: Result<T, E>) -> u8 {
+ match x {
+ Ok(_) => return 1,
+ //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+ Err(_) => return 0,
+ } //^^^^^^^^ 💡 weak: replace return <expr>; with <expr>
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn no_diagnostic_if_no_return_keyword() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ 3
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn no_diagnostic_if_not_last_statement() {
+ check_diagnostics(
+ r#"
+fn foo() -> u8 {
+ if true { return 2; }
+ 3
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_with_expr() {
+ check_fix(
+ r#"
+fn foo() -> u8 {
+ return$0 2;
+}
+"#,
+ r#"
+fn foo() -> u8 {
+ 2
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_with_unit() {
+ check_fix(
+ r#"
+fn foo() {
+ return$0/*ensure tidy is happy*/
+}
+"#,
+ r#"
+fn foo() {
+ /*ensure tidy is happy*/
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_with_expr_no_semi() {
+ check_fix(
+ r#"
+fn foo() -> u8 {
+ return$0 2
+}
+"#,
+ r#"
+fn foo() -> u8 {
+ 2
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_in_inner_function() {
+ check_fix(
+ r#"
+fn foo() -> u8 {
+ fn bar() -> u8 {
+ return$0 2;
+ }
+ bar()
+}
+"#,
+ r#"
+fn foo() -> u8 {
+ fn bar() -> u8 {
+ 2
+ }
+ bar()
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_in_closure() {
+ check_fix(
+ r#"
+fn foo() -> u8 {
+ let bar = || return$0 2;
+ bar()
+}
+"#,
+ r#"
+fn foo() -> u8 {
+ let bar = || 2;
+ bar()
+}
+"#,
+ );
+ check_fix(
+ r#"
+fn foo() -> u8 {
+ let bar = || {
+ return$0 2;
+ };
+ bar()
+}
+"#,
+ r#"
+fn foo() -> u8 {
+ let bar = || {
+ 2
+ };
+ bar()
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_in_if() {
+ check_fix_with_disabled(
+ r#"
+fn foo(x: usize) -> u8 {
+ if x > 0 {
+ return$0 1;
+ } else {
+ 0
+ }
+}
+"#,
+ r#"
+fn foo(x: usize) -> u8 {
+ if x > 0 {
+ 1
+ } else {
+ 0
+ }
+}
+"#,
+ std::iter::once("remove-unnecessary-else".to_string()),
+ );
+ check_fix(
+ r#"
+fn foo(x: usize) -> u8 {
+ if x > 0 {
+ 1
+ } else {
+ return$0 0;
+ }
+}
+"#,
+ r#"
+fn foo(x: usize) -> u8 {
+ if x > 0 {
+ 1
+ } else {
+ 0
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn replace_in_match() {
+ check_fix(
+ r#"
+fn foo<T, E>(x: Result<T, E>) -> u8 {
+ match x {
+ Ok(_) => return$0 1,
+ Err(_) => 0,
+ }
+}
+"#,
+ r#"
+fn foo<T, E>(x: Result<T, E>) -> u8 {
+ match x {
+ Ok(_) => 1,
+ Err(_) => 0,
+ }
+}
+"#,
+ );
+ check_fix(
+ r#"
+fn foo<T, E>(x: Result<T, E>) -> u8 {
+ match x {
+ Ok(_) => 1,
+ Err(_) => return$0 0,
+ }
+}
+"#,
+ r#"
+fn foo<T, E>(x: Result<T, E>) -> u8 {
+ match x {
+ Ok(_) => 1,
+ Err(_) => 0,
+ }
+}
+"#,
+ );
+ }
+}
diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
index c6c85256f9..124086c8fa 100644
--- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
+++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
@@ -87,11 +87,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveUnnecessaryElse) -> Option<Vec<
#[cfg(test)]
mod tests {
- use crate::tests::{check_diagnostics, check_fix};
+ use crate::tests::{check_diagnostics, check_diagnostics_with_disabled, check_fix};
+
+ fn check_diagnostics_with_needless_return_disabled(ra_fixture: &str) {
+ check_diagnostics_with_disabled(ra_fixture, std::iter::once("needless_return".to_string()));
+ }
#[test]
fn remove_unnecessary_else_for_return() {
- check_diagnostics(
+ check_diagnostics_with_needless_return_disabled(
r#"
fn test() {
if foo {
@@ -126,7 +130,7 @@ fn test() {
#[test]
fn remove_unnecessary_else_for_return2() {
- check_diagnostics(
+ check_diagnostics_with_needless_return_disabled(
r#"
fn test() {
if foo {
@@ -169,7 +173,7 @@ fn test() {
#[test]
fn remove_unnecessary_else_for_return_in_child_if_expr() {
- check_diagnostics(
+ check_diagnostics_with_needless_return_disabled(
r#"
fn test() {
if foo {
@@ -371,7 +375,7 @@ fn test() {
#[test]
fn no_diagnostic_if_no_divergence_in_else_branch() {
- check_diagnostics(
+ check_diagnostics_with_needless_return_disabled(
r#"
fn test() {
if foo {
diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
index 750189beec..eec8efe785 100644
--- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs
+++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
@@ -186,7 +186,9 @@ fn str_ref_to_owned(
#[cfg(test)]
mod tests {
- use crate::tests::{check_diagnostics, check_fix, check_no_fix};
+ use crate::tests::{
+ check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix,
+ };
#[test]
fn missing_reference() {
@@ -718,7 +720,7 @@ struct Bar {
#[test]
fn return_no_value() {
- check_diagnostics(
+ check_diagnostics_with_disabled(
r#"
fn f() -> i32 {
return;
@@ -727,6 +729,7 @@ fn f() -> i32 {
}
fn g() { return; }
"#,
+ std::iter::once("needless_return".to_string()),
);
}
diff --git a/crates/ide-diagnostics/src/lib.rs b/crates/ide-diagnostics/src/lib.rs
index 7423de0be7..7c5cf67330 100644
--- a/crates/ide-diagnostics/src/lib.rs
+++ b/crates/ide-diagnostics/src/lib.rs
@@ -43,6 +43,7 @@ mod handlers {
pub(crate) mod no_such_field;
pub(crate) mod private_assoc_item;
pub(crate) mod private_field;
+ pub(crate) mod remove_trailing_return;
pub(crate) mod remove_unnecessary_else;
pub(crate) mod replace_filter_map_next_with_find_map;
pub(crate) mod trait_impl_incorrect_safety;
@@ -383,6 +384,7 @@ pub fn diagnostics(
AnyDiagnostic::UnusedVariable(d) => handlers::unused_variables::unused_variables(&ctx, &d),
AnyDiagnostic::BreakOutsideOfLoop(d) => handlers::break_outside_of_loop::break_outside_of_loop(&ctx, &d),
AnyDiagnostic::MismatchedTupleStructPatArgCount(d) => handlers::mismatched_arg_count::mismatched_tuple_struct_pat_arg_count(&ctx, &d),
+ AnyDiagnostic::RemoveTrailingReturn(d) => handlers::remove_trailing_return::remove_trailing_return(&ctx, &d),
AnyDiagnostic::RemoveUnnecessaryElse(d) => handlers::remove_unnecessary_else::remove_unnecessary_else(&ctx, &d),
};
res.push(d)
diff --git a/crates/ide-diagnostics/src/tests.rs b/crates/ide-diagnostics/src/tests.rs
index c6f4d6be76..da563b874b 100644
--- a/crates/ide-diagnostics/src/tests.rs
+++ b/crates/ide-diagnostics/src/tests.rs
@@ -34,13 +34,35 @@ pub(crate) fn check_fixes(ra_fixture_before: &str, ra_fixtures_after: Vec<&str>)
#[track_caller]
fn check_nth_fix(nth: usize, ra_fixture_before: &str, ra_fixture_after: &str) {
+ let mut config = DiagnosticsConfig::test_sample();
+ config.expr_fill_default = ExprFillDefaultMode::Default;
+ check_nth_fix_with_config(config, nth, ra_fixture_before, ra_fixture_after)
+}
+
+#[track_caller]
+pub(crate) fn check_fix_with_disabled(
+ ra_fixture_before: &str,
+ ra_fixture_after: &str,
+ disabled: impl Iterator<Item = String>,
+) {
+ let mut config = DiagnosticsConfig::test_sample();
+ config.expr_fill_default = ExprFillDefaultMode::Default;
+ config.disabled.extend(disabled);
+ check_nth_fix_with_config(config, 0, ra_fixture_before, ra_fixture_after)
+}
+
+#[track_caller]
+fn check_nth_fix_with_config(
+ config: DiagnosticsConfig,
+ nth: usize,
+ ra_fixture_before: &str,
+ ra_fixture_after: &str,
+) {
let after = trim_indent(ra_fixture_after);
let (db, file_position) = RootDatabase::with_position(ra_fixture_before);
- let mut conf = DiagnosticsConfig::test_sample();
- conf.expr_fill_default = ExprFillDefaultMode::Default;
let diagnostic =
- super::diagnostics(&db, &conf, &AssistResolveStrategy::All, file_position.file_id)
+ super::diagnostics(&db, &config, &AssistResolveStrategy::All, file_position.file_id)
.pop()
.expect("no diagnostics");
let fix = &diagnostic