Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs95
1 files changed, 88 insertions, 7 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 231df9b5b3..f2363c6f7b 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -25,7 +25,7 @@ use syntax::{
SyntaxKind::{self, COMMENT},
SyntaxNode, SyntaxToken, T, TextRange, TextSize, TokenAtOffset, WalkEvent,
ast::{
- self, AstNode, AstToken, HasGenericParams, HasName, edit::IndentLevel,
+ self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, edit::IndentLevel,
edit_in_place::Indent,
},
match_ast, ted,
@@ -120,7 +120,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let params = body.extracted_function_params(ctx, &container_info, locals_used);
- let name = make_function_name(&semantics_scope);
+ let name = make_function_name(&semantics_scope, &body);
let fun = Function {
name,
@@ -241,7 +241,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
)
}
-fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef {
+fn make_function_name(
+ semantics_scope: &hir::SemanticsScope<'_>,
+ body: &FunctionBody,
+) -> ast::NameRef {
let mut names_in_scope = vec![];
semantics_scope.process_all_names(&mut |name, _| {
names_in_scope.push(
@@ -252,7 +255,10 @@ fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef
let default_name = "fun_name";
- let mut name = default_name.to_owned();
+ let mut name = body
+ .suggest_name()
+ .filter(|name| name.len() > 2)
+ .unwrap_or_else(|| default_name.to_owned());
let mut counter = 0;
while names_in_scope.contains(&name) {
counter += 1;
@@ -375,6 +381,7 @@ struct ContainerInfo<'db> {
ret_type: Option<hir::Type<'db>>,
generic_param_lists: Vec<ast::GenericParamList>,
where_clauses: Vec<ast::WhereClause>,
+ attrs: Vec<ast::Attr>,
edition: Edition,
}
@@ -778,6 +785,16 @@ impl FunctionBody {
fn contains_node(&self, node: &SyntaxNode) -> bool {
self.contains_range(node.text_range())
}
+
+ fn suggest_name(&self) -> Option<String> {
+ if let Some(ast::Pat::IdentPat(pat)) = self.parent().and_then(ast::LetStmt::cast)?.pat()
+ && let Some(name) = pat.name().and_then(|it| it.ident_token())
+ {
+ Some(name.text().to_owned())
+ } else {
+ None
+ }
+ }
}
impl FunctionBody {
@@ -911,6 +928,7 @@ impl FunctionBody {
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
+ let attrs = parents.iter().flat_map(|it| it.attrs()).filter(is_inherit_attr).collect();
Some((
ContainerInfo {
@@ -919,6 +937,7 @@ impl FunctionBody {
ret_type: ty,
generic_param_lists,
where_clauses,
+ attrs,
edition,
},
contains_tail_expr,
@@ -1103,6 +1122,14 @@ impl GenericParent {
GenericParent::Trait(trait_) => trait_.where_clause(),
}
}
+
+ fn attrs(&self) -> impl Iterator<Item = ast::Attr> {
+ match self {
+ GenericParent::Fn(fn_) => fn_.attrs(),
+ GenericParent::Impl(impl_) => impl_.attrs(),
+ GenericParent::Trait(trait_) => trait_.attrs(),
+ }
+ }
}
/// Search `parent`'s ancestors for items with potentially applicable generic parameters
@@ -1578,7 +1605,7 @@ fn format_function(
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
make::fn_(
- None,
+ fun.mods.attrs.clone(),
None,
fun_name,
generic_params,
@@ -1958,6 +1985,11 @@ fn format_type(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module)
ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_owned())
}
+fn is_inherit_attr(attr: &ast::Attr) -> bool {
+ let Some(name) = attr.simple_name() else { return false };
+ matches!(name.as_str(), "track_caller" | "cfg")
+}
+
fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type {
let ty_str = format_type(ty, ctx, module);
make::ty(&ty_str)
@@ -5414,12 +5446,12 @@ impl Struct {
impl Trait for Struct {
fn bar(&self) -> i32 {
- let three_squared = fun_name();
+ let three_squared = three_squared();
self.0 + three_squared
}
}
-fn $0fun_name() -> i32 {
+fn $0three_squared() -> i32 {
3 * 3
}
"#,
@@ -6375,4 +6407,53 @@ fn $0fun_name(mut a: i32, mut b: i32) {
"#,
);
}
+
+ #[test]
+ fn with_cfg_attr() {
+ check_assist(
+ extract_function,
+ r#"
+//- /main.rs crate:main cfg:test
+#[cfg(test)]
+fn foo() {
+ foo($01 + 1$0);
+}
+"#,
+ r#"
+#[cfg(test)]
+fn foo() {
+ foo(fun_name());
+}
+
+#[cfg(test)]
+fn $0fun_name() -> i32 {
+ 1 + 1
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn with_track_caller() {
+ check_assist(
+ extract_function,
+ r#"
+#[track_caller]
+fn foo() {
+ foo($01 + 1$0);
+}
+"#,
+ r#"
+#[track_caller]
+fn foo() {
+ foo(fun_name());
+}
+
+#[track_caller]
+fn $0fun_name() -> i32 {
+ 1 + 1
+}
+"#,
+ );
+ }
}