Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #20125 from nicolas-guichard/push-pypzwzspzznu
Use inferred type in “extract type as type alias” assist and display inferred type placeholder `_` inlay hints
| -rw-r--r-- | crates/hir-def/src/hir/type_ref.rs | 14 | ||||
| -rw-r--r-- | crates/hir-ty/src/infer.rs | 43 | ||||
| -rw-r--r-- | crates/hir-ty/src/lib.rs | 29 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests.rs | 33 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/display_source_code.rs | 19 | ||||
| -rw-r--r-- | crates/hir/src/source_analyzer.rs | 32 | ||||
| -rw-r--r-- | crates/ide-assists/src/handlers/extract_type_alias.rs | 58 | ||||
| -rw-r--r-- | crates/ide/src/inlay_hints.rs | 5 | ||||
| -rw-r--r-- | crates/ide/src/inlay_hints/placeholders.rs | 76 |
9 files changed, 299 insertions, 10 deletions
diff --git a/crates/hir-def/src/hir/type_ref.rs b/crates/hir-def/src/hir/type_ref.rs index da0f058a9c..ad8535413d 100644 --- a/crates/hir-def/src/hir/type_ref.rs +++ b/crates/hir-def/src/hir/type_ref.rs @@ -195,12 +195,16 @@ impl TypeRef { TypeRef::Tuple(ThinVec::new()) } - pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(&TypeRef)) { + pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) { go(this, f, map); - fn go(type_ref: TypeRefId, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) { - let type_ref = &map[type_ref]; - f(type_ref); + fn go( + type_ref_id: TypeRefId, + f: &mut impl FnMut(TypeRefId, &TypeRef), + map: &ExpressionStore, + ) { + let type_ref = &map[type_ref_id]; + f(type_ref_id, type_ref); match type_ref { TypeRef::Fn(fn_) => { fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map)) @@ -224,7 +228,7 @@ impl TypeRef { }; } - fn go_path(path: &Path, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) { + fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) { if let Some(type_ref) = path.type_anchor() { go(type_ref, f, map); } diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 02b8ab8cdd..15eb355128 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -41,7 +41,7 @@ use hir_def::{ layout::Integer, resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs}, signatures::{ConstSignature, StaticSignature}, - type_ref::{ConstRef, LifetimeRefId, TypeRefId}, + type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId}, }; use hir_expand::{mod_path::ModPath, name::Name}; use indexmap::IndexSet; @@ -60,6 +60,7 @@ use triomphe::Arc; use crate::{ ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, TargetFeatures, + collect_type_inference_vars, db::{HirDatabase, InternedClosureId, InternedOpaqueTyId}, infer::{ coerce::{CoerceMany, DynamicCoerceMany}, @@ -497,6 +498,7 @@ pub struct InferenceResult<'db> { /// unresolved or missing subpatterns or subpatterns of mismatched types. pub(crate) type_of_pat: ArenaMap<PatId, Ty<'db>>, pub(crate) type_of_binding: ArenaMap<BindingId, Ty<'db>>, + pub(crate) type_of_type_placeholder: ArenaMap<TypeRefId, Ty<'db>>, pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, Ty<'db>>, pub(crate) type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch<'db>>, /// Whether there are any type-mismatching errors in the result. @@ -542,6 +544,7 @@ impl<'db> InferenceResult<'db> { type_of_expr: Default::default(), type_of_pat: Default::default(), type_of_binding: Default::default(), + type_of_type_placeholder: Default::default(), type_of_opaque: Default::default(), type_mismatches: Default::default(), has_errors: Default::default(), @@ -606,6 +609,12 @@ impl<'db> InferenceResult<'db> { _ => None, }) } + pub fn placeholder_types(&self) -> impl Iterator<Item = (TypeRefId, &Ty<'db>)> { + self.type_of_type_placeholder.iter() + } + pub fn type_of_type_placeholder(&self, type_ref: TypeRefId) -> Option<Ty<'db>> { + self.type_of_type_placeholder.get(type_ref).copied() + } pub fn closure_info(&self, closure: InternedClosureId) -> &(Vec<CapturedItem<'db>>, FnTrait) { self.closure_info.get(&closure).unwrap() } @@ -1014,6 +1023,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { type_of_expr, type_of_pat, type_of_binding, + type_of_type_placeholder, type_of_opaque, type_mismatches, has_errors, @@ -1046,6 +1056,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> { *has_errors = *has_errors || ty.references_non_lt_error(); } type_of_binding.shrink_to_fit(); + for ty in type_of_type_placeholder.values_mut() { + *ty = table.resolve_completely(*ty); + *has_errors = *has_errors || ty.references_non_lt_error(); + } + type_of_type_placeholder.shrink_to_fit(); type_of_opaque.shrink_to_fit(); *has_errors |= !type_mismatches.is_empty(); @@ -1285,6 +1300,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> { self.result.type_of_pat.insert(pat, ty); } + fn write_type_placeholder_ty(&mut self, type_ref: TypeRefId, ty: Ty<'db>) { + self.result.type_of_type_placeholder.insert(type_ref, ty); + } + fn write_binding_ty(&mut self, id: BindingId, ty: Ty<'db>) { self.result.type_of_binding.insert(id, ty); } @@ -1333,7 +1352,27 @@ impl<'body, 'db> InferenceContext<'body, 'db> { ) -> Ty<'db> { let ty = self .with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref)); - self.process_user_written_ty(ty) + let ty = self.process_user_written_ty(ty); + + // Record the association from placeholders' TypeRefId to type variables. + // We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order. + let type_variables = collect_type_inference_vars(&ty); + let mut placeholder_ids = vec![]; + TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| { + if matches!(type_ref, TypeRef::Placeholder) { + placeholder_ids.push(type_ref_id); + } + }); + + if placeholder_ids.len() == type_variables.len() { + for (placeholder_id, type_variable) in + placeholder_ids.into_iter().zip(type_variables.into_iter()) + { + self.write_type_placeholder_ty(placeholder_id, type_variable); + } + } + + ty } pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> { diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index b29c7d252b..8819307c53 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -569,6 +569,35 @@ where Vec::from_iter(collector.params) } +struct TypeInferenceVarCollector<'db> { + type_inference_vars: Vec<Ty<'db>>, +} + +impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for TypeInferenceVarCollector<'db> { + type Result = (); + + fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result { + use crate::rustc_type_ir::Flags; + if ty.is_ty_var() { + self.type_inference_vars.push(ty); + } else if ty.flags().intersects(rustc_type_ir::TypeFlags::HAS_TY_INFER) { + ty.super_visit_with(self); + } else { + // Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate + // that there are no placeholders. + } + } +} + +pub fn collect_type_inference_vars<'db, T>(value: &T) -> Vec<Ty<'db>> +where + T: ?Sized + rustc_type_ir::TypeVisitable<DbInterner<'db>>, +{ + let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] }; + value.visit_with(&mut collector); + collector.type_inference_vars +} + pub fn known_const_to_ast<'db>( konst: Const<'db>, db: &'db dyn HirDatabase, diff --git a/crates/hir-ty/src/tests.rs b/crates/hir-ty/src/tests.rs index 95a02d534b..002d58961d 100644 --- a/crates/hir-ty/src/tests.rs +++ b/crates/hir-ty/src/tests.rs @@ -23,6 +23,7 @@ use hir_def::{ item_scope::ItemScope, nameres::DefMap, src::HasSource, + type_ref::TypeRefId, }; use hir_expand::{FileRange, InFile, db::ExpandDatabase}; use itertools::Itertools; @@ -219,6 +220,24 @@ fn check_impl( } } } + + for (type_ref, ty) in inference_result.placeholder_types() { + let node = match type_node(&body_source_map, type_ref, &db) { + Some(value) => value, + None => continue, + }; + let range = node.as_ref().original_file_range_rooted(&db); + if let Some(expected) = types.remove(&range) { + let actual = salsa::attach(&db, || { + if display_source { + ty.display_source_code(&db, def.module(&db), true).unwrap() + } else { + ty.display_test(&db, display_target).to_string() + } + }); + assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range); + } + } } let mut buf = String::new(); @@ -275,6 +294,20 @@ fn pat_node( }) } +fn type_node( + body_source_map: &BodySourceMap, + type_ref: TypeRefId, + db: &TestDB, +) -> Option<InFile<SyntaxNode>> { + Some(match body_source_map.type_syntax(type_ref) { + Ok(sp) => { + let root = db.parse_or_expand(sp.file_id); + sp.map(|ptr| ptr.to_node(&root).syntax().clone()) + } + Err(SyntheticSyntax) => return None, + }) +} + fn infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String { infer_with_mismatches(ra_fixture, false) } diff --git a/crates/hir-ty/src/tests/display_source_code.rs b/crates/hir-ty/src/tests/display_source_code.rs index a986b54a7b..dc3869930d 100644 --- a/crates/hir-ty/src/tests/display_source_code.rs +++ b/crates/hir-ty/src/tests/display_source_code.rs @@ -246,3 +246,22 @@ fn test() { "#, ); } + +#[test] +fn type_placeholder_type() { + check_types_source_code( + r#" +struct S<T>(T); +fn test() { + let f: S<_> = S(3); + //^ i32 + let f: [_; _] = [4_u32, 5, 6]; + //^ u32 + let f: (_, _, _) = (1_u32, 1_i32, false); + //^ u32 + //^ i32 + //^ bool +} +"#, + ); +} diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index ae328a9680..858426ceab 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -21,7 +21,7 @@ use hir_def::{ lang_item::LangItem, nameres::MacroSubNs, resolver::{HasResolver, Resolver, TypeNs, ValueNs, resolver_for_scope}, - type_ref::{Mutability, TypeRefId}, + type_ref::{Mutability, TypeRef, TypeRefId}, }; use hir_expand::{ HirFileId, InFile, @@ -267,8 +267,11 @@ impl<'db> SourceAnalyzer<'db> { db: &'db dyn HirDatabase, ty: &ast::Type, ) -> Option<Type<'db>> { + let interner = DbInterner::new_with(db, None, None); + let type_ref = self.type_id(ty)?; - let ty = TyLoweringContext::new( + + let mut ty = TyLoweringContext::new( db, &self.resolver, self.store()?, @@ -279,6 +282,31 @@ impl<'db> SourceAnalyzer<'db> { LifetimeElisionKind::Infer, ) .lower_ty(type_ref); + + // Try and substitute unknown types using InferenceResult + if let Some(infer) = self.infer() + && let Some(store) = self.store() + { + let mut inferred_types = vec![]; + TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| { + if matches!(type_ref, TypeRef::Placeholder) { + inferred_types.push(infer.type_of_type_placeholder(type_ref_id)); + } + }); + let mut inferred_types = inferred_types.into_iter(); + + let substituted_ty = hir_ty::next_solver::fold::fold_tys(interner, ty, |ty| { + if ty.is_ty_error() { inferred_types.next().flatten().unwrap_or(ty) } else { ty } + }); + + // Only used the result if the placeholder and unknown type counts matched + let success = + inferred_types.next().is_none() && !substituted_ty.references_non_lt_error(); + if success { + ty = substituted_ty; + } + } + Some(Type::new_with_resolver(db, &self.resolver, ty)) } diff --git a/crates/ide-assists/src/handlers/extract_type_alias.rs b/crates/ide-assists/src/handlers/extract_type_alias.rs index 59522458af..769bbd976a 100644 --- a/crates/ide-assists/src/handlers/extract_type_alias.rs +++ b/crates/ide-assists/src/handlers/extract_type_alias.rs @@ -1,4 +1,5 @@ use either::Either; +use hir::HirDisplay; use ide_db::syntax_helpers::node_ext::walk_ty; use syntax::{ ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make}, @@ -39,6 +40,15 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> ); let target = ty.syntax().text_range(); + let resolved_ty = ctx.sema.resolve_type(&ty)?; + let resolved_ty = if !resolved_ty.contains_unknown() { + let module = ctx.sema.scope(ty.syntax())?.module(); + let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?; + make::ty(&resolved_ty) + } else { + ty.clone() + }; + acc.add( AssistId::refactor_extract("extract_type_alias"), "Extract type as type alias", @@ -72,7 +82,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> // Insert new alias let ty_alias = - make::ty_alias(None, "Type", generic_params, None, None, Some((ty, None))) + make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None))) .clone_for_update(); if let Some(cap) = ctx.config.snippet_cap @@ -391,4 +401,50 @@ where "#, ); } + + #[test] + fn inferred_generic_type_parameter() { + check_assist( + extract_type_alias, + r#" +struct Wrap<T>(T); + +fn main() { + let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32); +} + "#, + r#" +struct Wrap<T>(T); + +type $0Type = Wrap<i32>; + +fn main() { + let wrap: Type = Wrap::<_>(3i32); +} + "#, + ) + } + + #[test] + fn inferred_type() { + check_assist( + extract_type_alias, + r#" +struct Wrap<T>(T); + +fn main() { + let wrap: Wrap<$0_$0> = Wrap::<_>(3i32); +} + "#, + r#" +struct Wrap<T>(T); + +type $0Type = i32; + +fn main() { + let wrap: Wrap<Type> = Wrap::<_>(3i32); +} + "#, + ) + } } diff --git a/crates/ide/src/inlay_hints.rs b/crates/ide/src/inlay_hints.rs index 2b4fe54fc3..6dd9e84a57 100644 --- a/crates/ide/src/inlay_hints.rs +++ b/crates/ide/src/inlay_hints.rs @@ -40,6 +40,7 @@ mod implicit_static; mod implied_dyn_trait; mod lifetime; mod param_name; +mod placeholders; mod ra_fixture; mod range_exclusive; @@ -291,6 +292,10 @@ fn hints( implied_dyn_trait::hints(hints, famous_defs, config, Either::Right(dyn_)); Some(()) }, + ast::Type::InferType(placeholder) => { + placeholders::type_hints(hints, famous_defs, config, display_target, placeholder); + Some(()) + }, _ => Some(()), }, ast::GenericParamList(it) => bounds::hints(hints, famous_defs, config, it), diff --git a/crates/ide/src/inlay_hints/placeholders.rs b/crates/ide/src/inlay_hints/placeholders.rs new file mode 100644 index 0000000000..96d2c17c03 --- /dev/null +++ b/crates/ide/src/inlay_hints/placeholders.rs @@ -0,0 +1,76 @@ +//! Implementation of type placeholder inlay hints: +//! ```no_run +//! let a = Vec<_> = vec![4]; +//! //^ = i32 +//! ``` + +use hir::DisplayTarget; +use ide_db::famous_defs::FamousDefs; +use syntax::{ + AstNode, + ast::{InferType, Type}, +}; + +use crate::{InlayHint, InlayHintPosition, InlayHintsConfig, InlayKind, inlay_hints::label_of_ty}; + +pub(super) fn type_hints( + acc: &mut Vec<InlayHint>, + famous_defs @ FamousDefs(sema, _): &FamousDefs<'_, '_>, + config: &InlayHintsConfig<'_>, + display_target: DisplayTarget, + placeholder: InferType, +) -> Option<()> { + if !config.type_hints { + return None; + } + + let syntax = placeholder.syntax(); + let range = syntax.text_range(); + + let ty = sema.resolve_type(&Type::InferType(placeholder))?; + + let mut label = label_of_ty(famous_defs, config, &ty, display_target)?; + label.prepend_str("= "); + + acc.push(InlayHint { + range, + kind: InlayKind::Type, + label, + text_edit: None, + position: InlayHintPosition::After, + pad_left: true, + pad_right: false, + resolve_parent: None, + }); + Some(()) +} + +#[cfg(test)] +mod tests { + use crate::{ + InlayHintsConfig, + inlay_hints::tests::{DISABLED_CONFIG, check_with_config}, + }; + + #[track_caller] + fn check_type_infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) { + check_with_config(InlayHintsConfig { type_hints: true, ..DISABLED_CONFIG }, ra_fixture); + } + + #[test] + fn inferred_types() { + check_type_infer( + r#" +struct S<T>(T); + +fn foo() { + let t: (_, _, [_; _]) = (1_u32, S(2), [false] as _); + //^ = u32 + //^ = S<i32> + //^ = bool + //^ = [bool; 1] +} +"#, + ); + } +} |