Unnamed repository; edit this file 'description' to name the repository.
In "Wrap return type" assist, don't wrap exit points if they already have the right type
Chayim Refael Friedman 10 months ago
parent 0100bc7 · commit 78427be
-rw-r--r--crates/hir/src/lib.rs4
-rw-r--r--crates/ide-assists/src/handlers/wrap_return_type.rs163
2 files changed, 133 insertions, 34 deletions
diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs
index 3b39707cf6..46d2e88160 100644
--- a/crates/hir/src/lib.rs
+++ b/crates/hir/src/lib.rs
@@ -1727,10 +1727,10 @@ impl Adt {
pub fn ty_with_args<'db>(
self,
db: &'db dyn HirDatabase,
- args: impl Iterator<Item = Type<'db>>,
+ args: impl IntoIterator<Item = Type<'db>>,
) -> Type<'db> {
let id = AdtId::from(self);
- let mut it = args.map(|t| t.ty);
+ let mut it = args.into_iter().map(|t| t.ty);
let ty = TyBuilder::def_ty(db, id.into(), None)
.fill(|x| {
let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner));
diff --git a/crates/ide-assists/src/handlers/wrap_return_type.rs b/crates/ide-assists/src/handlers/wrap_return_type.rs
index 9ea78719b2..d7189aa5db 100644
--- a/crates/ide-assists/src/handlers/wrap_return_type.rs
+++ b/crates/ide-assists/src/handlers/wrap_return_type.rs
@@ -56,7 +56,8 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
};
let type_ref = &ret_type.ty()?;
- let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
+ let ty = ctx.sema.resolve_type(type_ref)?;
+ let ty_adt = ty.as_adt();
let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate());
for kind in WrapperKind::ALL {
@@ -64,7 +65,7 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
continue;
};
- if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
+ if matches!(ty_adt, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
// The return type is already wrapped
cov_mark::hit!(wrap_return_type_simple_return_type_already_wrapped);
continue;
@@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
|builder| {
let mut editor = builder.make_editor(&parent);
let make = SyntaxFactory::with_mappings();
- let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
- let new_return_ty = alias.unwrap_or_else(|| match kind {
- WrapperKind::Option => make.ty_option(type_ref.clone()),
- WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
+ let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol());
+ let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| {
+ let (ast_ty, ty_constructor) = match kind {
+ WrapperKind::Option => {
+ (make.ty_option(type_ref.clone()), famous_defs.core_option_Option())
+ }
+ WrapperKind::Result => (
+ make.ty_result(type_ref.clone(), make.ty_infer().into()),
+ famous_defs.core_result_Result(),
+ ),
+ };
+ let semantic_ty = ty_constructor
+ .map(|ty_constructor| {
+ hir::Adt::from(ty_constructor).ty_with_args(ctx.db(), [ty.clone()])
+ })
+ .unwrap_or_else(|| ty.clone());
+ (ast_ty, semantic_ty)
});
let mut exprs_to_wrap = Vec::new();
@@ -96,6 +110,17 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
for_each_tail_expr(&body_expr, tail_cb);
for ret_expr_arg in exprs_to_wrap {
+ if let Some(ty) = ctx.sema.type_of_expr(&ret_expr_arg) {
+ if ty.adjusted().could_unify_with(ctx.db(), &semantic_new_return_ty) {
+ // The type is already correct, don't wrap it.
+ // We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
+ // enum matches it's okay for us, as we don't trigger the assist if the return type
+ // is already `Option`/`Result`, so mismatched exact type is more likely a mistake
+ // than something intended.
+ continue;
+ }
+ }
+
let happy_wrapped = make.expr_call(
make.expr_path(make.ident_path(kind.happy_ident())),
make.arg_list(iter::once(ret_expr_arg.clone())),
@@ -103,12 +128,12 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
}
- editor.replace(type_ref.syntax(), new_return_ty.syntax());
+ editor.replace(type_ref.syntax(), ast_new_return_ty.syntax());
if let WrapperKind::Result = kind {
// Add a placeholder snippet at the first generic argument that doesn't equal the return type.
// This is normally the error type, but that may not be the case when we inserted a type alias.
- let args = new_return_ty
+ let args = ast_new_return_ty
.path()
.unwrap()
.segment()
@@ -188,27 +213,28 @@ impl WrapperKind {
}
// Try to find an wrapper type alias in the current scope (shadowing the default).
-fn wrapper_alias(
- ctx: &AssistContext<'_>,
+fn wrapper_alias<'db>(
+ ctx: &AssistContext<'db>,
make: &SyntaxFactory,
- core_wrapper: &hir::Enum,
- ret_type: &ast::Type,
+ core_wrapper: hir::Enum,
+ ast_ret_type: &ast::Type,
+ semantic_ret_type: &hir::Type<'db>,
wrapper: hir::Symbol,
-) -> Option<ast::PathType> {
+) -> Option<(ast::PathType, hir::Type<'db>)> {
let wrapper_path = hir::ModPath::from_segments(
hir::PathKind::Plain,
iter::once(hir::Name::new_symbol_root(wrapper)),
);
- ctx.sema.resolve_mod_path(ret_type.syntax(), &wrapper_path).and_then(|def| {
+ ctx.sema.resolve_mod_path(ast_ret_type.syntax(), &wrapper_path).and_then(|def| {
def.filter_map(|def| match def.into_module_def() {
hir::ModuleDef::TypeAlias(alias) => {
let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
- (&enum_ty == core_wrapper).then_some(alias)
+ (enum_ty == core_wrapper).then_some((alias, enum_ty))
}
_ => None,
})
- .find_map(|alias| {
+ .find_map(|(alias, enum_ty)| {
let mut inserted_ret_type = false;
let generic_args =
alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
@@ -216,7 +242,7 @@ fn wrapper_alias(
// Replace the very first type parameter with the function's return type.
ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
inserted_ret_type = true;
- make.type_arg(ret_type.clone()).into()
+ make.type_arg(ast_ret_type.clone()).into()
}
ast::GenericParam::LifetimeParam(_) => {
make.lifetime_arg(make.lifetime("'_")).into()
@@ -231,7 +257,10 @@ fn wrapper_alias(
make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
);
- Some(make.ty_path(path))
+ let new_ty =
+ hir::Adt::from(enum_ty).ty_with_args(ctx.db(), [semantic_ret_type.clone()]);
+
+ Some((make.ty_path(path), new_ty))
})
})
}
@@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
check_assist_by_label(
wrap_return_type,
r#"
-//- minicore: option
+//- minicore: option, future
+struct F(i32);
+impl core::future::Future for F {
+ type Output = i32;
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
async fn foo() -> i$032 {
if true {
if false {
- 1.await
+ F(1).await
} else {
- 2.await
+ F(2).await
}
} else {
- 24i32.await
+ F(24i32).await
}
}
"#,
r#"
+struct F(i32);
+impl core::future::Future for F {
+ type Output = i32;
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
async fn foo() -> Option<i32> {
if true {
if false {
- Some(1.await)
+ Some(F(1).await)
} else {
- Some(2.await)
+ Some(F(2).await)
}
} else {
- Some(24i32.await)
+ Some(F(24i32).await)
}
}
"#,
@@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
check_assist_by_label(
wrap_return_type,
r#"
-//- minicore: result
+//- minicore: result, future
+struct F(i32);
+impl core::future::Future for F {
+ type Output = i32;
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
async fn foo() -> i$032 {
if true {
if false {
- 1.await
+ F(1).await
} else {
- 2.await
+ F(2).await
}
} else {
- 24i32.await
+ F(24i32).await
}
}
"#,
r#"
+struct F(i32);
+impl core::future::Future for F {
+ type Output = i32;
+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
+}
async fn foo() -> Result<i32, ${0:_}> {
if true {
if false {
- Ok(1.await)
+ Ok(F(1).await)
} else {
- Ok(2.await)
+ Ok(F(2).await)
}
} else {
- Ok(24i32.await)
+ Ok(F(24i32).await)
}
}
"#,
@@ -2460,4 +2509,54 @@ fn foo() -> Result<i32, ${0:_}> {
WrapperKind::Result.label(),
);
}
+
+ #[test]
+ fn already_wrapped() {
+ check_assist_by_label(
+ wrap_return_type,
+ r#"
+//- minicore: option
+fn foo() -> i32$0 {
+ if false {
+ 0
+ } else {
+ Some(1)
+ }
+}
+ "#,
+ r#"
+fn foo() -> Option<i32> {
+ if false {
+ Some(0)
+ } else {
+ Some(1)
+ }
+}
+ "#,
+ WrapperKind::Option.label(),
+ );
+ check_assist_by_label(
+ wrap_return_type,
+ r#"
+//- minicore: result
+fn foo() -> i32$0 {
+ if false {
+ 0
+ } else {
+ Ok(1)
+ }
+}
+ "#,
+ r#"
+fn foo() -> Result<i32, ${0:_}> {
+ if false {
+ Ok(0)
+ } else {
+ Ok(1)
+ }
+}
+ "#,
+ WrapperKind::Result.label(),
+ );
+ }
}