Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/hir/src/semantics.rs17
-rw-r--r--crates/ide-assists/src/handlers/wrap_return_type_in_result.rs293
2 files changed, 294 insertions, 16 deletions
diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs
index 763f53031e..c78b59826c 100644
--- a/crates/hir/src/semantics.rs
+++ b/crates/hir/src/semantics.rs
@@ -14,6 +14,7 @@ use hir_def::{
hir::Expr,
lower::LowerCtx,
nameres::MacroSubNs,
+ path::ModPath,
resolver::{self, HasResolver, Resolver, TypeNs},
type_ref::Mutability,
AsMacroCall, DefWithBodyId, FunctionId, MacroId, TraitId, VariantId,
@@ -46,9 +47,9 @@ use crate::{
source_analyzer::{resolve_hir_path, SourceAnalyzer},
Access, Adjust, Adjustment, Adt, AutoBorrow, BindingMode, BuiltinAttr, Callable, Const,
ConstParam, Crate, DeriveHelper, Enum, Field, Function, HasSource, HirFileId, Impl, InFile,
- Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, ScopeDef,
- Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, TypeParam, Union,
- Variant, VariantDef,
+ ItemInNs, Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path,
+ ScopeDef, Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias,
+ TypeParam, Union, Variant, VariantDef,
};
const CONTINUE_NO_BREAKS: ControlFlow<Infallible, ()> = ControlFlow::Continue(());
@@ -1384,6 +1385,16 @@ impl<'db> SemanticsImpl<'db> {
self.analyze(path.syntax())?.resolve_path(self.db, path)
}
+ pub fn resolve_mod_path(
+ &self,
+ scope: &SyntaxNode,
+ path: &ModPath,
+ ) -> Option<impl Iterator<Item = ItemInNs>> {
+ let analyze = self.analyze(scope)?;
+ let items = analyze.resolver.resolve_module_path_in_items(self.db.upcast(), path);
+ Some(items.iter_items().map(|(item, _)| item.into()))
+ }
+
fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option<VariantId> {
self.analyze(record_lit.syntax())?.resolve_variant(self.db, record_lit)
}
diff --git a/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs b/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs
index b68ed00f77..8f0e9b4fe0 100644
--- a/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs
+++ b/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs
@@ -1,12 +1,14 @@
use std::iter;
+use hir::HasSource;
use ide_db::{
famous_defs::FamousDefs,
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
};
+use itertools::Itertools;
use syntax::{
- ast::{self, make, Expr},
- match_ast, ted, AstNode,
+ ast::{self, make, Expr, HasGenericParams},
+ match_ast, ted, AstNode, ToSmolStr,
};
use crate::{AssistContext, AssistId, AssistKind, Assists};
@@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
};
let type_ref = &ret_type.ty()?;
- let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
- let result_enum =
+ let core_result =
FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
- if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
+ let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
+ if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_result) {
+ // The return type is already wrapped in a Result
cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
return None;
}
- let new_result_ty =
- make::ext::ty_result(type_ref.clone(), make::ty_placeholder()).clone_for_update();
- let generic_args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast)?;
- let last_genarg = generic_args.generic_args().last()?;
-
acc.add(
AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
"Wrap return type in Result",
type_ref.syntax().text_range(),
|edit| {
+ let new_result_ty = result_type(ctx, &core_result, type_ref).clone_for_update();
let body = edit.make_mut(ast::Expr::BlockExpr(body));
let mut exprs_to_wrap = Vec::new();
@@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
}
let old_result_ty = edit.make_mut(type_ref.clone());
-
ted::replace(old_result_ty.syntax(), new_result_ty.syntax());
- if let Some(cap) = ctx.config.snippet_cap {
- edit.add_placeholder_snippet(cap, last_genarg);
+ // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
+ // This is normally the error type, but that may not be the case when we inserted a type alias.
+ let args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast);
+ let error_type_arg = args.and_then(|list| {
+ list.generic_args().find(|arg| match arg {
+ ast::GenericArg::TypeArg(_) => arg.syntax().text() != type_ref.syntax().text(),
+ ast::GenericArg::LifetimeArg(_) => false,
+ _ => true,
+ })
+ });
+ if let Some(error_type_arg) = error_type_arg {
+ if let Some(cap) = ctx.config.snippet_cap {
+ edit.add_placeholder_snippet(cap, error_type_arg);
+ }
}
},
)
}
+fn result_type(
+ ctx: &AssistContext<'_>,
+ core_result: &hir::Enum,
+ ret_type: &ast::Type,
+) -> ast::Type {
+ // Try to find a Result<T, ...> type alias in the current scope (shadowing the default).
+ let result_path = hir::ModPath::from_segments(
+ hir::PathKind::Plain,
+ iter::once(hir::Name::new_symbol_root(hir::sym::Result.clone())),
+ );
+ let alias = ctx.sema.resolve_mod_path(ret_type.syntax(), &result_path).and_then(|def| {
+ def.filter_map(|def| match def.as_module_def()? {
+ hir::ModuleDef::TypeAlias(alias) => {
+ let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
+ (&enum_ty == core_result).then_some(alias)
+ }
+ _ => None,
+ })
+ .find_map(|alias| {
+ let mut inserted_ret_type = false;
+ let generic_params = alias
+ .source(ctx.db())?
+ .value
+ .generic_param_list()?
+ .generic_params()
+ .map(|param| match param {
+ // Replace the very first type parameter with the functions return type.
+ ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
+ inserted_ret_type = true;
+ ret_type.to_smolstr()
+ }
+ ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(),
+ _ => make::ty_placeholder().to_smolstr(),
+ })
+ .join(", ");
+
+ let name = alias.name(ctx.db());
+ let name = name.as_str();
+ Some(make::ty(&format!("{name}<{generic_params}>")))
+ })
+ });
+ // If there is no applicable alias in scope use the default Result type.
+ alias.unwrap_or_else(|| make::ext::ty_result(ret_type.clone(), make::ty_placeholder()))
+}
+
fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
match e {
Expr::BreakExpr(break_expr) => {
@@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#,
);
}
+
+ #[test]
+ fn wrap_return_type_in_local_result_type() {
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result<T> = core::result::Result<T, ()>;
+
+fn foo() -> i3$02 {
+ return 42i32;
+}
+"#,
+ r#"
+type Result<T> = core::result::Result<T, ()>;
+
+fn foo() -> Result<i32> {
+ return Ok(42i32);
+}
+"#,
+ );
+
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result2<T> = core::result::Result<T, ()>;
+
+fn foo() -> i3$02 {
+ return 42i32;
+}
+"#,
+ r#"
+type Result2<T> = core::result::Result<T, ()>;
+
+fn foo() -> Result<i32, ${0:_}> {
+ return Ok(42i32);
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn wrap_return_type_in_imported_local_result_type() {
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+mod some_module {
+ pub type Result<T> = core::result::Result<T, ()>;
+}
+
+use some_module::Result;
+
+fn foo() -> i3$02 {
+ return 42i32;
+}
+"#,
+ r#"
+mod some_module {
+ pub type Result<T> = core::result::Result<T, ()>;
+}
+
+use some_module::Result;
+
+fn foo() -> Result<i32> {
+ return Ok(42i32);
+}
+"#,
+ );
+
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+mod some_module {
+ pub type Result<T> = core::result::Result<T, ()>;
+}
+
+use some_module::*;
+
+fn foo() -> i3$02 {
+ return 42i32;
+}
+"#,
+ r#"
+mod some_module {
+ pub type Result<T> = core::result::Result<T, ()>;
+}
+
+use some_module::*;
+
+fn foo() -> Result<i32> {
+ return Ok(42i32);
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn wrap_return_type_in_local_result_type_from_function_body() {
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+fn foo() -> i3$02 {
+ type Result<T> = core::result::Result<T, ()>;
+ 0
+}
+"#,
+ r#"
+fn foo() -> Result<i32, ${0:_}> {
+ type Result<T> = core::result::Result<T, ()>;
+ Ok(0)
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn wrap_return_type_in_local_result_type_already_using_alias() {
+ check_assist_not_applicable(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+pub type Result<T> = core::result::Result<T, ()>;
+
+fn foo() -> Result<i3$02> {
+ return Ok(42i32);
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn wrap_return_type_in_local_result_type_multiple_generics() {
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result<T, E> = core::result::Result<T, E>;
+
+fn foo() -> i3$02 {
+ 0
+}
+"#,
+ r#"
+type Result<T, E> = core::result::Result<T, E>;
+
+fn foo() -> Result<i32, ${0:_}> {
+ Ok(0)
+}
+"#,
+ );
+
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
+
+fn foo() -> i3$02 {
+ 0
+}
+ "#,
+ r#"
+type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
+
+fn foo() -> Result<i32, ${0:_}> {
+ Ok(0)
+}
+ "#,
+ );
+
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
+
+fn foo() -> i3$02 {
+ 0
+}
+ "#,
+ r#"
+type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
+
+fn foo() -> Result<'_, i32, ${0:_}> {
+ Ok(0)
+}
+ "#,
+ );
+
+ check_assist(
+ wrap_return_type_in_result,
+ r#"
+//- minicore: result
+type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
+
+fn foo() -> i3$02 {
+ 0
+}
+ "#,
+ r#"
+type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
+
+fn foo() -> Result<i32, ${0:_}> {
+ Ok(0)
+}
+ "#,
+ );
+ }
}