Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #18458 from Giga-Bowser/master
feat: Add diagnostic fix to remove unnecessary wrapper in type mismatch
Lukas Wirth 2024-12-11
parent 087cb62 · parent 68b85ce · commit 41f3319
-rw-r--r--crates/ide-diagnostics/src/handlers/type_mismatch.rs458
-rw-r--r--crates/syntax/src/ast/syntax_factory/constructors.rs41
2 files changed, 432 insertions, 67 deletions
diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
index 93fe9374a3..bfdda53740 100644
--- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs
+++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs
@@ -1,12 +1,16 @@
use either::Either;
-use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type};
-use ide_db::text_edit::TextEdit;
-use ide_db::{famous_defs::FamousDefs, source_change::SourceChange};
+use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type};
+use ide_db::{
+ famous_defs::FamousDefs,
+ source_change::{SourceChange, SourceChangeBuilder},
+ text_edit::TextEdit,
+};
use syntax::{
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
- BlockExpr, Expr, ExprStmt,
+ syntax_factory::SyntaxFactory,
+ BlockExpr, Expr, ExprStmt, HasArgList,
},
AstNode, AstPtr, TextSize,
};
@@ -63,6 +67,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi
let expr_ptr = &InFile { file_id: d.expr_or_pat.file_id, value: expr_ptr };
add_reference(ctx, d, expr_ptr, &mut fixes);
add_missing_ok_or_some(ctx, d, expr_ptr, &mut fixes);
+ remove_unnecessary_wrapper(ctx, d, expr_ptr, &mut fixes);
remove_semicolon(ctx, d, expr_ptr, &mut fixes);
str_ref_to_owned(ctx, d, expr_ptr, &mut fixes);
}
@@ -184,6 +189,89 @@ fn add_missing_ok_or_some(
Some(())
}
+fn remove_unnecessary_wrapper(
+ ctx: &DiagnosticsContext<'_>,
+ d: &hir::TypeMismatch,
+ expr_ptr: &InFile<AstPtr<ast::Expr>>,
+ acc: &mut Vec<Assist>,
+) -> Option<()> {
+ let db = ctx.sema.db;
+ let root = db.parse_or_expand(expr_ptr.file_id);
+ let expr = expr_ptr.value.to_node(&root);
+ let expr = ctx.sema.original_ast_node(expr.clone())?;
+
+ let Expr::CallExpr(call_expr) = expr else {
+ return None;
+ };
+
+ let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?;
+ let CallableKind::TupleEnumVariant(variant) = callable.kind() else {
+ return None;
+ };
+
+ let actual_enum = d.actual.as_adt()?.as_enum()?;
+ let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate());
+ let core_option = famous_defs.core_option_Option();
+ let core_result = famous_defs.core_result_Result();
+ if Some(actual_enum) != core_option && Some(actual_enum) != core_result {
+ return None;
+ }
+
+ let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments());
+ if !d.expected.could_unify_with(db, &inner_type) {
+ return None;
+ }
+
+ let inner_arg = call_expr.arg_list()?.args().next()?;
+
+ let file_id = expr_ptr.file_id.original_file(db);
+ let mut builder = SourceChangeBuilder::new(file_id);
+ let mut editor;
+ match inner_arg {
+ // We're returning `()`
+ Expr::TupleExpr(tup) if tup.fields().next().is_none() => {
+ let parent = call_expr
+ .syntax()
+ .parent()
+ .and_then(Either::<ast::ReturnExpr, ast::StmtList>::cast)?;
+
+ editor = builder.make_editor(parent.syntax());
+ let make = SyntaxFactory::new();
+
+ match parent {
+ Either::Left(ret_expr) => {
+ editor.replace(ret_expr.syntax(), make.expr_return(None).syntax());
+ }
+ Either::Right(stmt_list) => {
+ let new_block = if stmt_list.statements().next().is_none() {
+ make.expr_empty_block()
+ } else {
+ make.block_expr(stmt_list.statements(), None)
+ };
+
+ editor.replace(stmt_list.syntax().parent()?, new_block.syntax());
+ }
+ }
+
+ editor.add_mappings(make.finish_with_mappings());
+ }
+ _ => {
+ editor = builder.make_editor(call_expr.syntax());
+ editor.replace(call_expr.syntax(), inner_arg.syntax());
+ }
+ }
+
+ builder.add_file_edits(file_id, editor);
+ let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str());
+ acc.push(fix(
+ "remove_unnecessary_wrapper",
+ &name,
+ builder.finish(),
+ call_expr.syntax().text_range(),
+ ));
+ Some(())
+}
+
fn remove_semicolon(
ctx: &DiagnosticsContext<'_>,
d: &hir::TypeMismatch,
@@ -243,7 +331,7 @@ fn str_ref_to_owned(
#[cfg(test)]
mod tests {
use crate::tests::{
- check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix,
+ check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix,
};
#[test]
@@ -260,7 +348,7 @@ fn test(_arg: &i32) {}
}
#[test]
- fn test_add_reference_to_int() {
+ fn add_reference_to_int() {
check_fix(
r#"
fn main() {
@@ -278,7 +366,7 @@ fn test(_arg: &i32) {}
}
#[test]
- fn test_add_mutable_reference_to_int() {
+ fn add_mutable_reference_to_int() {
check_fix(
r#"
fn main() {
@@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {}
}
#[test]
- fn test_add_reference_to_array() {
+ fn add_reference_to_array() {
check_fix(
r#"
//- minicore: coerce_unsized
@@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {}
}
#[test]
- fn test_add_reference_with_autoderef() {
+ fn add_reference_with_autoderef() {
check_fix(
r#"
//- minicore: coerce_unsized, deref
@@ -348,7 +436,7 @@ fn test(_arg: &Bar) {}
}
#[test]
- fn test_add_reference_to_method_call() {
+ fn add_reference_to_method_call() {
check_fix(
r#"
fn main() {
@@ -372,7 +460,7 @@ impl Test {
}
#[test]
- fn test_add_reference_to_let_stmt() {
+ fn add_reference_to_let_stmt() {
check_fix(
r#"
fn main() {
@@ -388,7 +476,7 @@ fn main() {
}
#[test]
- fn test_add_reference_to_macro_call() {
+ fn add_reference_to_macro_call() {
check_fix(
r#"
macro_rules! thousand {
@@ -416,7 +504,7 @@ fn main() {
}
#[test]
- fn test_add_mutable_reference_to_let_stmt() {
+ fn add_mutable_reference_to_let_stmt() {
check_fix(
r#"
fn main() {
@@ -432,29 +520,6 @@ fn main() {
}
#[test]
- fn test_wrap_return_type_option() {
- check_fix(
- r#"
-//- minicore: option, result
-fn div(x: i32, y: i32) -> Option<i32> {
- if y == 0 {
- return None;
- }
- x / y$0
-}
-"#,
- r#"
-fn div(x: i32, y: i32) -> Option<i32> {
- if y == 0 {
- return None;
- }
- Some(x / y)
-}
-"#,
- );
- }
-
- #[test]
fn const_generic_type_mismatch() {
check_diagnostics(
r#"
@@ -487,59 +552,82 @@ fn div(x: i32, y: i32) -> Option<i32> {
}
#[test]
- fn test_wrap_return_type_option_tails() {
+ fn wrap_return_type() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+ if y == 0 {
+ return Err(());
+ }
+ x / y$0
+}
+"#,
+ r#"
+fn div(x: i32, y: i32) -> Result<i32, ()> {
+ if y == 0 {
+ return Err(());
+ }
+ Ok(x / y)
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn wrap_return_type_option() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
- Some(0)
- } else if true {
- 100$0
- } else {
- None
+ return None;
}
+ x / y$0
}
"#,
r#"
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
- Some(0)
- } else if true {
- Some(100)
- } else {
- None
+ return None;
}
+ Some(x / y)
}
"#,
);
}
#[test]
- fn test_wrap_return_type() {
+ fn wrap_return_type_option_tails() {
check_fix(
r#"
//- minicore: option, result
-fn div(x: i32, y: i32) -> Result<i32, ()> {
+fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
- return Err(());
+ Some(0)
+ } else if true {
+ 100$0
+ } else {
+ None
}
- x / y$0
}
"#,
r#"
-fn div(x: i32, y: i32) -> Result<i32, ()> {
+fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
- return Err(());
+ Some(0)
+ } else if true {
+ Some(100)
+ } else {
+ None
}
- Ok(x / y)
}
"#,
);
}
#[test]
- fn test_wrap_return_type_handles_generic_functions() {
+ fn wrap_return_type_handles_generic_functions() {
check_fix(
r#"
//- minicore: option, result
@@ -562,7 +650,7 @@ fn div<T>(x: T) -> Result<T, i32> {
}
#[test]
- fn test_wrap_return_type_handles_type_aliases() {
+ fn wrap_return_type_handles_type_aliases() {
check_fix(
r#"
//- minicore: option, result
@@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult<i32> {
}
#[test]
- fn test_wrapped_unit_as_block_tail_expr() {
+ fn wrapped_unit_as_block_tail_expr() {
check_fix(
r#"
//- minicore: result
@@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> {
}
#[test]
- fn test_wrapped_unit_as_return_expr() {
+ fn wrapped_unit_as_return_expr() {
check_fix(
r#"
//- minicore: result
@@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> {
}
#[test]
- fn test_in_const_and_static() {
+ fn wrap_in_const_and_static() {
check_fix(
r#"
//- minicore: option, result
@@ -664,7 +752,7 @@ const _: Option<()> = {Some(())};
}
#[test]
- fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
+ fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
check_no_fix(
r#"
//- minicore: option, result
@@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 }
}
#[test]
- fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
+ fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
check_no_fix(
r#"
//- minicore: option, result
@@ -686,6 +774,254 @@ fn foo() -> SomeOtherEnum { 0$0 }
}
#[test]
+ fn unwrap_return_type() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ panic!();
+ }
+ Ok(x / y)$0
+}
+"#,
+ r#"
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ panic!();
+ }
+ x / y
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_option() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ panic!();
+ }
+ Some(x / y)$0
+}
+"#,
+ r#"
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ panic!();
+ }
+ x / y
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_option_tails() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ 42
+ } else if true {
+ Some(100)$0
+ } else {
+ 0
+ }
+}
+"#,
+ r#"
+fn div(x: i32, y: i32) -> i32 {
+ if y == 0 {
+ 42
+ } else if true {
+ 100
+ } else {
+ 0
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_option_tail_unit() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div(x: i32, y: i32) {
+ if y == 0 {
+ panic!();
+ }
+
+ Ok(())$0
+}
+"#,
+ r#"
+fn div(x: i32, y: i32) {
+ if y == 0 {
+ panic!();
+ }
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_handles_generic_functions() {
+ check_fix(
+ r#"
+//- minicore: option, result
+fn div<T>(x: T) -> T {
+ if x == 0 {
+ panic!();
+ }
+ $0Ok(x)
+}
+"#,
+ r#"
+fn div<T>(x: T) -> T {
+ if x == 0 {
+ panic!();
+ }
+ x
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_handles_type_aliases() {
+ check_fix(
+ r#"
+//- minicore: option, result
+type MyResult<T> = T;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+ if y == 0 {
+ panic!();
+ }
+ Ok(x $0/ y)
+}
+"#,
+ r#"
+type MyResult<T> = T;
+
+fn div(x: i32, y: i32) -> MyResult<i32> {
+ if y == 0 {
+ panic!();
+ }
+ x / y
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_tail_expr() {
+ check_fix(
+ r#"
+//- minicore: result
+fn foo() -> () {
+ println!("Hello, world!");
+ Ok(())$0
+}
+ "#,
+ r#"
+fn foo() -> () {
+ println!("Hello, world!");
+}
+ "#,
+ );
+ }
+
+ #[test]
+ fn unwrap_to_empty_block() {
+ check_fix(
+ r#"
+//- minicore: result
+fn foo() -> () {
+ Ok(())$0
+}
+ "#,
+ r#"
+fn foo() -> () {}
+ "#,
+ );
+ }
+
+ #[test]
+ fn unwrap_to_return_expr() {
+ check_has_fix(
+ r#"
+//- minicore: result
+fn foo(b: bool) -> () {
+ if b {
+ return $0Ok(());
+ }
+
+ panic!("oh dear");
+}"#,
+ r#"
+fn foo(b: bool) -> () {
+ if b {
+ return;
+ }
+
+ panic!("oh dear");
+}"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_in_const_and_static() {
+ check_fix(
+ r#"
+//- minicore: option, result
+static A: () = {Some(($0))};
+ "#,
+ r#"
+static A: () = {};
+ "#,
+ );
+ check_fix(
+ r#"
+//- minicore: option, result
+const _: () = {Some(($0))};
+ "#,
+ r#"
+const _: () = {};
+ "#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() {
+ check_no_fix(
+ r#"
+//- minicore: result
+fn foo() -> i32 { $0Ok(()) }
+"#,
+ );
+ }
+
+ #[test]
+ fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() {
+ check_no_fix(
+ r#"
+//- minicore: option, result
+enum SomeOtherEnum { Ok(i32), Err(String) }
+
+fn foo() -> i32 { SomeOtherEnum::Ok($042) }
+"#,
+ );
+ }
+
+ #[test]
fn remove_semicolon() {
check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#);
}
diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs
index aa894ef633..e86c291f76 100644
--- a/crates/syntax/src/ast/syntax_factory/constructors.rs
+++ b/crates/syntax/src/ast/syntax_factory/constructors.rs
@@ -66,28 +66,41 @@ impl SyntaxFactory {
tail_expr: Option<ast::Expr>,
) -> ast::BlockExpr {
let stmts = stmts.into_iter().collect_vec();
- let input = stmts.iter().map(|it| it.syntax().clone()).collect_vec();
+ let mut input = stmts.iter().map(|it| it.syntax().clone()).collect_vec();
let ast = make::block_expr(stmts, tail_expr.clone()).clone_for_update();
- if let Some((mut mapping, stmt_list)) = self.mappings().zip(ast.stmt_list()) {
+ if let Some(mut mapping) = self.mappings() {
+ let stmt_list = ast.stmt_list().unwrap();
let mut builder = SyntaxMappingBuilder::new(stmt_list.syntax().clone());
+ if let Some(input) = tail_expr {
+ builder.map_node(
+ input.syntax().clone(),
+ stmt_list.tail_expr().unwrap().syntax().clone(),
+ );
+ } else if let Some(ast_tail) = stmt_list.tail_expr() {
+ // The parser interpreted the last statement (probably a statement with a block) as an Expr
+ let last_stmt = input.pop().unwrap();
+
+ builder.map_node(last_stmt, ast_tail.syntax().clone());
+ }
+
builder.map_children(
input.into_iter(),
stmt_list.statements().map(|it| it.syntax().clone()),
);
- if let Some((input, output)) = tail_expr.zip(stmt_list.tail_expr()) {
- builder.map_node(input.syntax().clone(), output.syntax().clone());
- }
-
builder.finish(&mut mapping);
}
ast
}
+ pub fn expr_empty_block(&self) -> ast::BlockExpr {
+ ast::BlockExpr { syntax: make::expr_empty_block().syntax().clone_for_update() }
+ }
+
pub fn expr_bin(&self, lhs: ast::Expr, op: ast::BinaryOp, rhs: ast::Expr) -> ast::BinExpr {
let ast::Expr::BinExpr(ast) =
make::expr_bin_op(lhs.clone(), op, rhs.clone()).clone_for_update()
@@ -134,6 +147,22 @@ impl SyntaxFactory {
ast.into()
}
+ pub fn expr_return(&self, expr: Option<ast::Expr>) -> ast::ReturnExpr {
+ let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else {
+ unreachable!()
+ };
+
+ if let Some(mut mapping) = self.mappings() {
+ let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
+ if let Some(input) = expr {
+ builder.map_node(input.syntax().clone(), ast.expr().unwrap().syntax().clone());
+ }
+ builder.finish(&mut mapping);
+ }
+
+ ast
+ }
+
pub fn let_stmt(
&self,
pattern: ast::Pat,