Unnamed repository; edit this file 'description' to name the repository.
fix: only emit "unnecessary else" diagnostic for expr stmts
davidsemakula 2024-02-19
parent 1205853 · commit ff70310
-rw-r--r--crates/hir-ty/src/diagnostics/expr.rs64
-rw-r--r--crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs14
2 files changed, 49 insertions, 29 deletions
diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs
index 718409e159..4fe75f24b8 100644
--- a/crates/hir-ty/src/diagnostics/expr.rs
+++ b/crates/hir-ty/src/diagnostics/expr.rs
@@ -109,7 +109,7 @@ impl ExprValidator {
self.check_for_trailing_return(*body_expr, &body);
}
Expr::If { .. } => {
- self.check_for_unnecessary_else(id, expr, db);
+ self.check_for_unnecessary_else(id, expr, &body, db);
}
Expr::Block { .. } => {
self.validate_block(db, expr);
@@ -337,35 +337,17 @@ impl ExprValidator {
}
}
- fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) {
+ fn check_for_unnecessary_else(
+ &mut self,
+ id: ExprId,
+ expr: &Expr,
+ body: &Body,
+ db: &dyn HirDatabase,
+ ) {
if let Expr::If { condition: _, then_branch, else_branch } = expr {
if else_branch.is_none() {
return;
}
- let (body, source_map) = db.body_with_source_map(self.owner);
- let Ok(source_ptr) = source_map.expr_syntax(id) else {
- return;
- };
- let root = source_ptr.file_syntax(db.upcast());
- let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
- return;
- };
- let mut top_if_expr = if_expr;
- loop {
- let parent = top_if_expr.syntax().parent();
- let has_parent_let_stmt =
- parent.as_ref().map_or(false, |node| ast::LetStmt::can_cast(node.kind()));
- if has_parent_let_stmt {
- // Bail if parent or direct ancestor is a let stmt.
- return;
- }
- let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
- // Parent is neither an if expr nor a let stmt.
- break;
- };
- // Check parent if expr.
- top_if_expr = parent_if_expr;
- }
if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] {
let last_then_expr = tail.or_else(|| match statements.last()? {
Statement::Expr { expr, .. } => Some(*expr),
@@ -374,6 +356,36 @@ impl ExprValidator {
if let Some(last_then_expr) = last_then_expr {
let last_then_expr_ty = &self.infer[last_then_expr];
if last_then_expr_ty.is_never() {
+ // Only look at sources if the then branch diverges and we have an else branch.
+ let (_, source_map) = db.body_with_source_map(self.owner);
+ let Ok(source_ptr) = source_map.expr_syntax(id) else {
+ return;
+ };
+ let root = source_ptr.file_syntax(db.upcast());
+ let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
+ return;
+ };
+ let mut top_if_expr = if_expr;
+ loop {
+ let parent = top_if_expr.syntax().parent();
+ let has_parent_expr_stmt_or_stmt_list =
+ parent.as_ref().map_or(false, |node| {
+ ast::ExprStmt::can_cast(node.kind())
+ | ast::StmtList::can_cast(node.kind())
+ });
+ if has_parent_expr_stmt_or_stmt_list {
+ // Only emit diagnostic if parent or direct ancestor is either
+ // an expr stmt or a stmt list.
+ break;
+ }
+ let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
+ // Bail if parent is neither an if expr, an expr stmt nor a stmt list.
+ return;
+ };
+ // Check parent if expr.
+ top_if_expr = parent_if_expr;
+ }
+
self.diagnostics
.push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id })
}
diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
index 9564807a33..7bfd64596e 100644
--- a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
+++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs
@@ -467,10 +467,10 @@ fn test() {
}
#[test]
- fn no_diagnostic_if_tail_exists_in_else_branch() {
+ fn no_diagnostic_if_not_expr_stmt() {
check_diagnostics_with_needless_return_disabled(
r#"
-fn test1(a: bool) {
+fn test1() {
let _x = if a {
return;
} else {
@@ -478,7 +478,7 @@ fn test1(a: bool) {
};
}
-fn test2(a: bool, b: bool, c: bool) {
+fn test2() {
let _x = if a {
return;
} else if b {
@@ -491,5 +491,13 @@ fn test2(a: bool, b: bool, c: bool) {
}
"#,
);
+ check_diagnostics_with_disabled(
+ r#"
+fn test3() {
+ foo(if a { return 1 } else { 0 })
+}
+"#,
+ std::iter::once("E0308".to_owned()),
+ );
}
}