Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/tests.rs')
| -rw-r--r-- | crates/hir-ty/src/tests.rs | 109 |
1 files changed, 90 insertions, 19 deletions
diff --git a/crates/hir-ty/src/tests.rs b/crates/hir-ty/src/tests.rs index 69ec35f406..96e7130ecf 100644 --- a/crates/hir-ty/src/tests.rs +++ b/crates/hir-ty/src/tests.rs @@ -16,6 +16,7 @@ use std::env; use std::sync::LazyLock; use base_db::SourceDatabaseFileInputExt as _; +use either::Either; use expect_test::Expect; use hir_def::{ db::DefDatabase, @@ -23,12 +24,14 @@ use hir_def::{ hir::{ExprId, Pat, PatId}, item_scope::ItemScope, nameres::DefMap, - src::HasSource, - AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId, SyntheticSyntax, + src::{HasChildSource, HasSource}, + AdtId, AssocItemId, DefWithBodyId, FieldId, HasModule, LocalModuleId, Lookup, ModuleDefId, + SyntheticSyntax, }; use hir_expand::{db::ExpandDatabase, FileRange, InFile}; use itertools::Itertools; use rustc_hash::FxHashMap; +use span::TextSize; use stdx::format_to; use syntax::{ ast::{self, AstNode, HasName}, @@ -132,14 +135,40 @@ fn check_impl( None => continue, }; let def_map = module.def_map(&db); - visit_module(&db, &def_map, module.local_id, &mut |it| { - defs.push(match it { - ModuleDefId::FunctionId(it) => it.into(), - ModuleDefId::EnumVariantId(it) => it.into(), - ModuleDefId::ConstId(it) => it.into(), - ModuleDefId::StaticId(it) => it.into(), - _ => return, - }) + visit_module(&db, &def_map, module.local_id, &mut |it| match it { + ModuleDefId::FunctionId(it) => defs.push(it.into()), + ModuleDefId::EnumVariantId(it) => { + defs.push(it.into()); + let variant_id = it.into(); + let vd = db.variant_data(variant_id); + defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| { + if fd.has_default { + let field = FieldId { parent: variant_id, local_id }; + Some(DefWithBodyId::FieldId(field)) + } else { + None + } + })); + } + ModuleDefId::ConstId(it) => defs.push(it.into()), + ModuleDefId::StaticId(it) => defs.push(it.into()), + ModuleDefId::AdtId(it) => { + let variant_id = match it { + AdtId::StructId(it) => it.into(), + AdtId::UnionId(it) => it.into(), + AdtId::EnumId(_) => return, + }; + let vd = db.variant_data(variant_id); + defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| { + if fd.has_default { + let field = FieldId { parent: variant_id, local_id }; + Some(DefWithBodyId::FieldId(field)) + } else { + None + } + })); + } + _ => {} }); } defs.sort_by_key(|def| match def { @@ -160,6 +189,14 @@ fn check_impl( loc.source(&db).value.syntax().text_range().start() } DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(), + DefWithBodyId::FieldId(it) => { + let cs = it.parent.child_source(&db); + match cs.value.get(it.local_id) { + Some(Either::Left(it)) => it.syntax().text_range().start(), + Some(Either::Right(it)) => it.syntax().text_range().end(), + None => TextSize::new(u32::MAX), + } + } }); let mut unexpected_type_mismatches = String::new(); for def in defs { @@ -388,14 +425,40 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { let def_map = module.def_map(&db); let mut defs: Vec<DefWithBodyId> = Vec::new(); - visit_module(&db, &def_map, module.local_id, &mut |it| { - defs.push(match it { - ModuleDefId::FunctionId(it) => it.into(), - ModuleDefId::EnumVariantId(it) => it.into(), - ModuleDefId::ConstId(it) => it.into(), - ModuleDefId::StaticId(it) => it.into(), - _ => return, - }) + visit_module(&db, &def_map, module.local_id, &mut |it| match it { + ModuleDefId::FunctionId(it) => defs.push(it.into()), + ModuleDefId::EnumVariantId(it) => { + defs.push(it.into()); + let variant_id = it.into(); + let vd = db.variant_data(variant_id); + defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| { + if fd.has_default { + let field = FieldId { parent: variant_id, local_id }; + Some(DefWithBodyId::FieldId(field)) + } else { + None + } + })); + } + ModuleDefId::ConstId(it) => defs.push(it.into()), + ModuleDefId::StaticId(it) => defs.push(it.into()), + ModuleDefId::AdtId(it) => { + let variant_id = match it { + AdtId::StructId(it) => it.into(), + AdtId::UnionId(it) => it.into(), + AdtId::EnumId(_) => return, + }; + let vd = db.variant_data(variant_id); + defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| { + if fd.has_default { + let field = FieldId { parent: variant_id, local_id }; + Some(DefWithBodyId::FieldId(field)) + } else { + None + } + })); + } + _ => {} }); defs.sort_by_key(|def| match def { DefWithBodyId::FunctionId(it) => { @@ -415,6 +478,14 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { loc.source(&db).value.syntax().text_range().start() } DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(), + DefWithBodyId::FieldId(it) => { + let cs = it.parent.child_source(&db); + match cs.value.get(it.local_id) { + Some(Either::Left(it)) => it.syntax().text_range().start(), + Some(Either::Right(it)) => it.syntax().text_range().end(), + None => TextSize::new(u32::MAX), + } + } }); for def in defs { let (body, source_map) = db.body_with_source_map(def); @@ -475,7 +546,7 @@ pub(crate) fn visit_module( let body = db.body(it.into()); visit_body(db, &body, cb); } - ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => { + ModuleDefId::AdtId(AdtId::EnumId(it)) => { db.enum_data(it).variants.iter().for_each(|&(it, _)| { let body = db.body(it.into()); cb(it.into()); |