Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide-assists/src/handlers/extract_function.rs')
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_function.rs | 95 |
1 files changed, 88 insertions, 7 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 231df9b5b3..f2363c6f7b 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -25,7 +25,7 @@ use syntax::{ SyntaxKind::{self, COMMENT}, SyntaxNode, SyntaxToken, T, TextRange, TextSize, TokenAtOffset, WalkEvent, ast::{ - self, AstNode, AstToken, HasGenericParams, HasName, edit::IndentLevel, + self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, edit::IndentLevel, edit_in_place::Indent, }, match_ast, ted, @@ -120,7 +120,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let params = body.extracted_function_params(ctx, &container_info, locals_used); - let name = make_function_name(&semantics_scope); + let name = make_function_name(&semantics_scope, &body); let fun = Function { name, @@ -241,7 +241,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ) } -fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef { +fn make_function_name( + semantics_scope: &hir::SemanticsScope<'_>, + body: &FunctionBody, +) -> ast::NameRef { let mut names_in_scope = vec![]; semantics_scope.process_all_names(&mut |name, _| { names_in_scope.push( @@ -252,7 +255,10 @@ fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef let default_name = "fun_name"; - let mut name = default_name.to_owned(); + let mut name = body + .suggest_name() + .filter(|name| name.len() > 2) + .unwrap_or_else(|| default_name.to_owned()); let mut counter = 0; while names_in_scope.contains(&name) { counter += 1; @@ -375,6 +381,7 @@ struct ContainerInfo<'db> { ret_type: Option<hir::Type<'db>>, generic_param_lists: Vec<ast::GenericParamList>, where_clauses: Vec<ast::WhereClause>, + attrs: Vec<ast::Attr>, edition: Edition, } @@ -778,6 +785,16 @@ impl FunctionBody { fn contains_node(&self, node: &SyntaxNode) -> bool { self.contains_range(node.text_range()) } + + fn suggest_name(&self) -> Option<String> { + if let Some(ast::Pat::IdentPat(pat)) = self.parent().and_then(ast::LetStmt::cast)?.pat() + && let Some(name) = pat.name().and_then(|it| it.ident_token()) + { + Some(name.text().to_owned()) + } else { + None + } + } } impl FunctionBody { @@ -911,6 +928,7 @@ impl FunctionBody { let parents = generic_parents(&parent); let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect(); + let attrs = parents.iter().flat_map(|it| it.attrs()).filter(is_inherit_attr).collect(); Some(( ContainerInfo { @@ -919,6 +937,7 @@ impl FunctionBody { ret_type: ty, generic_param_lists, where_clauses, + attrs, edition, }, contains_tail_expr, @@ -1103,6 +1122,14 @@ impl GenericParent { GenericParent::Trait(trait_) => trait_.where_clause(), } } + + fn attrs(&self) -> impl Iterator<Item = ast::Attr> { + match self { + GenericParent::Fn(fn_) => fn_.attrs(), + GenericParent::Impl(impl_) => impl_.attrs(), + GenericParent::Trait(trait_) => trait_.attrs(), + } + } } /// Search `parent`'s ancestors for items with potentially applicable generic parameters @@ -1578,7 +1605,7 @@ fn format_function( let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun); make::fn_( - None, + fun.mods.attrs.clone(), None, fun_name, generic_params, @@ -1958,6 +1985,11 @@ fn format_type(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_owned()) } +fn is_inherit_attr(attr: &ast::Attr) -> bool { + let Some(name) = attr.simple_name() else { return false }; + matches!(name.as_str(), "track_caller" | "cfg") +} + fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type { let ty_str = format_type(ty, ctx, module); make::ty(&ty_str) @@ -5414,12 +5446,12 @@ impl Struct { impl Trait for Struct { fn bar(&self) -> i32 { - let three_squared = fun_name(); + let three_squared = three_squared(); self.0 + three_squared } } -fn $0fun_name() -> i32 { +fn $0three_squared() -> i32 { 3 * 3 } "#, @@ -6375,4 +6407,53 @@ fn $0fun_name(mut a: i32, mut b: i32) { "#, ); } + + #[test] + fn with_cfg_attr() { + check_assist( + extract_function, + r#" +//- /main.rs crate:main cfg:test +#[cfg(test)] +fn foo() { + foo($01 + 1$0); +} +"#, + r#" +#[cfg(test)] +fn foo() { + foo(fun_name()); +} + +#[cfg(test)] +fn $0fun_name() -> i32 { + 1 + 1 +} +"#, + ); + } + + #[test] + fn with_track_caller() { + check_assist( + extract_function, + r#" +#[track_caller] +fn foo() { + foo($01 + 1$0); +} +"#, + r#" +#[track_caller] +fn foo() { + foo(fun_name()); +} + +#[track_caller] +fn $0fun_name() -> i32 { + 1 + 1 +} +"#, + ); + } } |