Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--crates/ide-assists/src/handlers/extract_function.rs69
1 files changed, 67 insertions, 2 deletions
diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs
index 231df9b5b3..294e5f7da8 100644
--- a/crates/ide-assists/src/handlers/extract_function.rs
+++ b/crates/ide-assists/src/handlers/extract_function.rs
@@ -25,7 +25,7 @@ use syntax::{
SyntaxKind::{self, COMMENT},
SyntaxNode, SyntaxToken, T, TextRange, TextSize, TokenAtOffset, WalkEvent,
ast::{
- self, AstNode, AstToken, HasGenericParams, HasName, edit::IndentLevel,
+ self, AstNode, AstToken, HasAttrs, HasGenericParams, HasName, edit::IndentLevel,
edit_in_place::Indent,
},
match_ast, ted,
@@ -375,6 +375,7 @@ struct ContainerInfo<'db> {
ret_type: Option<hir::Type<'db>>,
generic_param_lists: Vec<ast::GenericParamList>,
where_clauses: Vec<ast::WhereClause>,
+ attrs: Vec<ast::Attr>,
edition: Edition,
}
@@ -911,6 +912,7 @@ impl FunctionBody {
let parents = generic_parents(&parent);
let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect();
let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect();
+ let attrs = parents.iter().flat_map(|it| it.attrs()).filter(is_inherit_attr).collect();
Some((
ContainerInfo {
@@ -919,6 +921,7 @@ impl FunctionBody {
ret_type: ty,
generic_param_lists,
where_clauses,
+ attrs,
edition,
},
contains_tail_expr,
@@ -1103,6 +1106,14 @@ impl GenericParent {
GenericParent::Trait(trait_) => trait_.where_clause(),
}
}
+
+ fn attrs(&self) -> impl Iterator<Item = ast::Attr> {
+ match self {
+ GenericParent::Fn(fn_) => fn_.attrs(),
+ GenericParent::Impl(impl_) => impl_.attrs(),
+ GenericParent::Trait(trait_) => trait_.attrs(),
+ }
+ }
}
/// Search `parent`'s ancestors for items with potentially applicable generic parameters
@@ -1578,7 +1589,7 @@ fn format_function(
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
make::fn_(
- None,
+ fun.mods.attrs.clone(),
None,
fun_name,
generic_params,
@@ -1958,6 +1969,11 @@ fn format_type(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module)
ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_owned())
}
+fn is_inherit_attr(attr: &ast::Attr) -> bool {
+ let Some(name) = attr.simple_name() else { return false };
+ matches!(name.as_str(), "track_caller" | "cfg")
+}
+
fn make_ty(ty: &hir::Type<'_>, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type {
let ty_str = format_type(ty, ctx, module);
make::ty(&ty_str)
@@ -6375,4 +6391,53 @@ fn $0fun_name(mut a: i32, mut b: i32) {
"#,
);
}
+
+ #[test]
+ fn with_cfg_attr() {
+ check_assist(
+ extract_function,
+ r#"
+//- /main.rs crate:main cfg:test
+#[cfg(test)]
+fn foo() {
+ foo($01 + 1$0);
+}
+"#,
+ r#"
+#[cfg(test)]
+fn foo() {
+ foo(fun_name());
+}
+
+#[cfg(test)]
+fn $0fun_name() -> i32 {
+ 1 + 1
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn with_track_caller() {
+ check_assist(
+ extract_function,
+ r#"
+#[track_caller]
+fn foo() {
+ foo($01 + 1$0);
+}
+"#,
+ r#"
+#[track_caller]
+fn foo() {
+ foo(fun_name());
+}
+
+#[track_caller]
+fn $0fun_name() -> i32 {
+ 1 + 1
+}
+"#,
+ );
+ }
}