Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs')
-rw-r--r--crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs232
1 files changed, 203 insertions, 29 deletions
diff --git a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
index 82e0970cc4..1cdd4187af 100644
--- a/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
+++ b/crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
@@ -11,15 +11,14 @@ use ide_db::{
search::FileReference,
RootDatabase,
};
-use itertools::Itertools;
+use itertools::{Itertools, Position};
use rustc_hash::FxHashSet;
use syntax::{
ast::{
self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, HasAttrs, HasGenericParams,
- HasName, HasTypeBounds, HasVisibility,
+ HasName, HasVisibility,
},
- match_ast,
- ted::{self, Position},
+ match_ast, ted, SyntaxElement,
SyntaxKind::*,
SyntaxNode, T,
};
@@ -106,7 +105,12 @@ pub(crate) fn extract_struct_from_enum_variant(
}
let indent = enum_ast.indent_level();
- let def = create_struct_def(variant_name.clone(), &variant, &field_list, &enum_ast);
+ let generic_params = enum_ast
+ .generic_param_list()
+ .and_then(|known_generics| extract_generic_params(&known_generics, &field_list));
+ let generics = generic_params.as_ref().map(|generics| generics.clone_for_update());
+ let def =
+ create_struct_def(variant_name.clone(), &variant, &field_list, generics, &enum_ast);
def.reindent_to(indent);
let start_offset = &variant.parent_enum().syntax().clone();
@@ -118,7 +122,7 @@ pub(crate) fn extract_struct_from_enum_variant(
],
);
- update_variant(&variant, enum_ast.generic_param_list());
+ update_variant(&variant, generic_params.map(|g| g.clone_for_update()));
},
)
}
@@ -159,10 +163,77 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
.any(|(name, _)| name.to_string() == variant_name.to_string())
}
+fn extract_generic_params(
+ known_generics: &ast::GenericParamList,
+ field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
+) -> Option<ast::GenericParamList> {
+ let mut generics = known_generics.generic_params().map(|param| (param, false)).collect_vec();
+
+ let tagged_one = match field_list {
+ Either::Left(field_list) => field_list
+ .fields()
+ .filter_map(|f| f.ty())
+ .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged),
+ Either::Right(field_list) => field_list
+ .fields()
+ .filter_map(|f| f.ty())
+ .fold(false, |tagged, ty| tag_generics_in_variant(&ty, &mut generics) || tagged),
+ };
+
+ let generics = generics.into_iter().filter_map(|(param, tag)| tag.then(|| param));
+ tagged_one.then(|| make::generic_param_list(generics))
+}
+
+fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, bool)]) -> bool {
+ let mut tagged_one = false;
+
+ for token in ty.syntax().descendants_with_tokens().filter_map(SyntaxElement::into_token) {
+ for (param, tag) in generics.iter_mut().filter(|(_, tag)| !tag) {
+ match param {
+ ast::GenericParam::LifetimeParam(lt)
+ if matches!(token.kind(), T![lifetime_ident]) =>
+ {
+ if let Some(lt) = lt.lifetime() {
+ if lt.text().as_str() == token.text() {
+ *tag = true;
+ tagged_one = true;
+ break;
+ }
+ }
+ }
+ param if matches!(token.kind(), T![ident]) => {
+ if match param {
+ ast::GenericParam::ConstParam(konst) => konst
+ .name()
+ .map(|name| name.text().as_str() == token.text())
+ .unwrap_or_default(),
+ ast::GenericParam::TypeParam(ty) => ty
+ .name()
+ .map(|name| name.text().as_str() == token.text())
+ .unwrap_or_default(),
+ ast::GenericParam::LifetimeParam(lt) => lt
+ .lifetime()
+ .map(|lt| lt.text().as_str() == token.text())
+ .unwrap_or_default(),
+ } {
+ *tag = true;
+ tagged_one = true;
+ break;
+ }
+ }
+ _ => (),
+ }
+ }
+ }
+
+ tagged_one
+}
+
fn create_struct_def(
variant_name: ast::Name,
variant: &ast::Variant,
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
+ generics: Option<ast::GenericParamList>,
enum_: &ast::Enum,
) -> ast::Struct {
let enum_vis = enum_.visibility();
@@ -204,9 +275,7 @@ fn create_struct_def(
field_list.reindent_to(IndentLevel::single());
- // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
- let strukt = make::struct_(enum_vis, variant_name, enum_.generic_param_list(), field_list)
- .clone_for_update();
+ let strukt = make::struct_(enum_vis, variant_name, generics, field_list).clone_for_update();
// FIXME: Consider making this an actual function somewhere (like in `AttrsOwnerEdit`) after some deliberation
let attrs_and_docs = |node: &SyntaxNode| {
@@ -233,36 +302,53 @@ fn create_struct_def(
_ => tok,
})
.collect();
- ted::insert_all(Position::first_child_of(strukt.syntax()), variant_attrs);
+ ted::insert_all(ted::Position::first_child_of(strukt.syntax()), variant_attrs);
// copy attributes from enum
ted::insert_all(
- Position::first_child_of(strukt.syntax()),
+ ted::Position::first_child_of(strukt.syntax()),
enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(),
);
strukt
}
-fn update_variant(variant: &ast::Variant, generic: Option<ast::GenericParamList>) -> Option<()> {
+fn update_variant(variant: &ast::Variant, generics: Option<ast::GenericParamList>) -> Option<()> {
let name = variant.name()?;
- let ty = match generic {
- // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
- Some(gpl) => {
- let gpl = gpl.clone_for_update();
- gpl.generic_params().for_each(|gp| {
- let tbl = match gp {
- ast::GenericParam::LifetimeParam(it) => it.type_bound_list(),
- ast::GenericParam::TypeParam(it) => it.type_bound_list(),
- ast::GenericParam::ConstParam(_) => return,
- };
- if let Some(tbl) = tbl {
- tbl.remove();
+ let ty = generics
+ .filter(|generics| generics.generic_params().count() > 0)
+ .map(|generics| {
+ let mut generic_str = String::with_capacity(8);
+
+ for (p, more) in generics.generic_params().with_position().map(|p| match p {
+ Position::First(p) | Position::Middle(p) => (p, true),
+ Position::Last(p) | Position::Only(p) => (p, false),
+ }) {
+ match p {
+ ast::GenericParam::ConstParam(konst) => {
+ if let Some(name) = konst.name() {
+ generic_str.push_str(name.text().as_str());
+ }
+ }
+ ast::GenericParam::LifetimeParam(lt) => {
+ if let Some(lt) = lt.lifetime() {
+ generic_str.push_str(lt.text().as_str());
+ }
+ }
+ ast::GenericParam::TypeParam(ty) => {
+ if let Some(name) = ty.name() {
+ generic_str.push_str(name.text().as_str());
+ }
+ }
}
- });
- make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", ")))
- }
- None => make::ty(&name.text()),
- };
+ if more {
+ generic_str.push_str(", ");
+ }
+ }
+
+ make::ty(&format!("{}<{}>", &name.text(), &generic_str))
+ })
+ .unwrap_or_else(|| make::ty(&name.text()));
+
let tuple_field = make::tuple_field(None, ty);
let replacement = make::variant(
name,
@@ -902,4 +988,92 @@ enum A { $0One(u8, u32) }
fn test_extract_not_applicable_no_field_named() {
check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None {} }");
}
+
+ #[test]
+ fn test_extract_struct_only_copies_needed_generics() {
+ check_assist(
+ extract_struct_from_enum_variant,
+ r#"
+enum X<'a, 'b, 'x> {
+ $0A { a: &'a &'x mut () },
+ B { b: &'b () },
+ C { c: () },
+}
+"#,
+ r#"
+struct A<'a, 'x>{ a: &'a &'x mut () }
+
+enum X<'a, 'b, 'x> {
+ A(A<'a, 'x>),
+ B { b: &'b () },
+ C { c: () },
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn test_extract_struct_with_liftime_type_const() {
+ check_assist(
+ extract_struct_from_enum_variant,
+ r#"
+enum X<'b, T, V, const C: usize> {
+ $0A { a: T, b: X<'b>, c: [u8; C] },
+ D { d: V },
+}
+"#,
+ r#"
+struct A<'b, T, const C: usize>{ a: T, b: X<'b>, c: [u8; C] }
+
+enum X<'b, T, V, const C: usize> {
+ A(A<'b, T, C>),
+ D { d: V },
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn test_extract_struct_without_generics() {
+ check_assist(
+ extract_struct_from_enum_variant,
+ r#"
+enum X<'a, 'b> {
+ A { a: &'a () },
+ B { b: &'b () },
+ $0C { c: () },
+}
+"#,
+ r#"
+struct C{ c: () }
+
+enum X<'a, 'b> {
+ A { a: &'a () },
+ B { b: &'b () },
+ C(C),
+}
+"#,
+ );
+ }
+
+ #[test]
+ fn test_extract_struct_keeps_trait_bounds() {
+ check_assist(
+ extract_struct_from_enum_variant,
+ r#"
+enum En<T: TraitT, V: TraitV> {
+ $0A { a: T },
+ B { b: V },
+}
+"#,
+ r#"
+struct A<T: TraitT>{ a: T }
+
+enum En<T: TraitT, V: TraitV> {
+ A(A<T>),
+ B { b: V },
+}
+"#,
+ );
+ }
}