Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
| -rw-r--r-- | crates/hir-ty/src/infer.rs | 199 |
1 files changed, 111 insertions, 88 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 30b420b6d5..4aeb5ec71c 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -29,7 +29,7 @@ mod path; mod place_op; pub(crate) mod unify; -use std::{cell::OnceCell, convert::identity, fmt, iter, ops::Deref}; +use std::{cell::OnceCell, convert::identity, fmt, ops::Deref}; use base_db::{Crate, FxIndexMap}; use either::Either; @@ -50,6 +50,7 @@ use hir_def::{ use hir_expand::{mod_path::ModPath, name::Name}; use indexmap::IndexSet; use la_arena::ArenaMap; +use macros::{TypeFoldable, TypeVisitable}; use rustc_ast_ir::Mutability; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_type_ir::{ @@ -76,14 +77,15 @@ use crate::{ diagnostics::{Diagnostics, InferenceTyLoweringContext as TyLoweringContext}, expr::ExprIsRead, pat::PatOrigin, + unify::resolve_completely::WriteBackCtxt, }, lower::{ ImplTraitIdx, ImplTraitLoweringMode, LifetimeElisionKind, diagnostics::TyLoweringDiagnostic, }, method_resolution::CandidateId, next_solver::{ - AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredGenericArgs, - StoredTy, StoredTys, Ty, TyKind, Tys, + AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredGenericArg, + StoredGenericArgs, StoredTy, StoredTys, Term, Ty, TyKind, Tys, abi::Safety, infer::{InferCtxt, ObligationInspector, traits::ObligationCause}, }, @@ -188,7 +190,7 @@ fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceRe // Array lengths are always `usize`. RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize), // Const parameter default: look up the param's declared type. - RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns( + RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty( ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }), )), // Path const generic args: determining the expected type requires @@ -307,107 +309,152 @@ pub enum InferenceTyDiagnosticSource { Signature, } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, TypeVisitable, TypeFoldable)] pub enum InferenceDiagnostic { NoSuchField { + #[type_visitable(ignore)] field: ExprOrPatId, + #[type_visitable(ignore)] private: Option<LocalFieldId>, + #[type_visitable(ignore)] variant: VariantId, }, PrivateField { + #[type_visitable(ignore)] expr: ExprId, + #[type_visitable(ignore)] field: FieldId, }, PrivateAssocItem { + #[type_visitable(ignore)] id: ExprOrPatId, + #[type_visitable(ignore)] item: AssocItemId, }, UnresolvedField { + #[type_visitable(ignore)] expr: ExprId, receiver: StoredTy, + #[type_visitable(ignore)] name: Name, + #[type_visitable(ignore)] method_with_same_name_exists: bool, }, UnresolvedMethodCall { + #[type_visitable(ignore)] expr: ExprId, receiver: StoredTy, + #[type_visitable(ignore)] name: Name, /// Contains the type the field resolves to field_with_same_name: Option<StoredTy>, + #[type_visitable(ignore)] assoc_func_with_same_name: Option<FunctionId>, }, UnresolvedAssocItem { + #[type_visitable(ignore)] id: ExprOrPatId, }, UnresolvedIdent { + #[type_visitable(ignore)] id: ExprOrPatId, }, // FIXME: This should be emitted in body lowering BreakOutsideOfLoop { + #[type_visitable(ignore)] expr: ExprId, + #[type_visitable(ignore)] is_break: bool, + #[type_visitable(ignore)] bad_value_break: bool, }, MismatchedArgCount { + #[type_visitable(ignore)] call_expr: ExprId, + #[type_visitable(ignore)] expected: usize, + #[type_visitable(ignore)] found: usize, }, MismatchedTupleStructPatArgCount { + #[type_visitable(ignore)] pat: PatId, + #[type_visitable(ignore)] expected: usize, + #[type_visitable(ignore)] found: usize, }, ExpectedFunction { + #[type_visitable(ignore)] call_expr: ExprId, found: StoredTy, }, TypedHole { + #[type_visitable(ignore)] expr: ExprId, expected: StoredTy, }, CastToUnsized { + #[type_visitable(ignore)] expr: ExprId, cast_ty: StoredTy, }, InvalidCast { + #[type_visitable(ignore)] expr: ExprId, + #[type_visitable(ignore)] error: CastError, expr_ty: StoredTy, cast_ty: StoredTy, }, TyDiagnostic { + #[type_visitable(ignore)] source: InferenceTyDiagnosticSource, + #[type_visitable(ignore)] diag: TyLoweringDiagnostic, }, PathDiagnostic { + #[type_visitable(ignore)] node: ExprOrPatId, + #[type_visitable(ignore)] diag: PathLoweringDiagnostic, }, MethodCallIncorrectGenericsLen { + #[type_visitable(ignore)] expr: ExprId, + #[type_visitable(ignore)] provided_count: u32, + #[type_visitable(ignore)] expected_count: u32, + #[type_visitable(ignore)] kind: IncorrectGenericsLenKind, + #[type_visitable(ignore)] def: GenericDefId, }, MethodCallIncorrectGenericsOrder { + #[type_visitable(ignore)] expr: ExprId, + #[type_visitable(ignore)] param_id: GenericParamId, + #[type_visitable(ignore)] arg_idx: u32, /// Whether the `GenericArgs` contains a `Self` arg. + #[type_visitable(ignore)] has_self_arg: bool, }, InvalidLhsOfAssignment { + #[type_visitable(ignore)] lhs: ExprId, }, TypeMustBeKnown { - at_point: ExprOrPatId, + #[type_visitable(ignore)] + at_point: Span, + top_term: Option<StoredGenericArg>, }, } /// A mismatch between an expected and an inferred type. -#[derive(Clone, PartialEq, Eq, Debug, Hash)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, TypeVisitable, TypeFoldable)] pub struct TypeMismatch { pub expected: StoredTy, pub actual: StoredTy, @@ -1181,7 +1228,7 @@ pub(crate) struct InferenceContext<'body, 'db> { deferred_call_resolutions: FxHashMap<ExprId, Vec<DeferredCallResolution<'db>>>, diagnostics: Diagnostics, - vars_emitted_type_must_be_known_for: FxHashSet<Ty<'db>>, + vars_emitted_type_must_be_known_for: FxHashSet<Term<'db>>, } #[derive(Clone, Debug)] @@ -1331,14 +1378,15 @@ impl<'body, 'db> InferenceContext<'body, 'db> { // there is no problem in it being `pub(crate)`, remove this comment. fn resolve_all(self) -> InferenceResult { let InferenceContext { - mut table, + table, mut result, tuple_field_accesses_rev, diagnostics, types, + vars_emitted_type_must_be_known_for, .. } = self; - let mut diagnostics = diagnostics.finish(); + let diagnostics = diagnostics.finish(); // Destructure every single field so whenever new fields are added to `InferenceResult` we // don't forget to handle them here. let InferenceResult { @@ -1359,30 +1407,28 @@ impl<'body, 'db> InferenceContext<'body, 'db> { pat_adjustments, binding_modes: _, expr_adjustments, - tuple_field_access_types: _, + tuple_field_access_types, coercion_casts: _, - diagnostics: _, + diagnostics: result_diagnostics, } = &mut result; + let mut resolver = + WriteBackCtxt::new(table, diagnostics, vars_emitted_type_must_be_known_for); skipped_ref_pats.shrink_to_fit(); for ty in type_of_expr.values_mut() { - *ty = table.resolve_completely(ty.as_ref()).store(); - *has_errors = *has_errors || ty.as_ref().references_non_lt_error(); + resolver.resolve_completely(ty); } type_of_expr.shrink_to_fit(); for ty in type_of_pat.values_mut() { - *ty = table.resolve_completely(ty.as_ref()).store(); - *has_errors = *has_errors || ty.as_ref().references_non_lt_error(); + resolver.resolve_completely(ty); } type_of_pat.shrink_to_fit(); for ty in type_of_binding.values_mut() { - *ty = table.resolve_completely(ty.as_ref()).store(); - *has_errors = *has_errors || ty.as_ref().references_non_lt_error(); + resolver.resolve_completely(ty); } type_of_binding.shrink_to_fit(); for ty in type_of_type_placeholder.values_mut() { - *ty = table.resolve_completely(ty.as_ref()).store(); - *has_errors = *has_errors || ty.as_ref().references_non_lt_error(); + resolver.resolve_completely(ty); } type_of_type_placeholder.shrink_to_fit(); type_of_opaque.shrink_to_fit(); @@ -1390,61 +1436,25 @@ impl<'body, 'db> InferenceContext<'body, 'db> { if let Some(type_mismatches) = type_mismatches { *has_errors = true; for mismatch in type_mismatches.values_mut() { - mismatch.expected = table.resolve_completely(mismatch.expected.as_ref()).store(); - mismatch.actual = table.resolve_completely(mismatch.actual.as_ref()).store(); + resolver.resolve_type_mismatch(mismatch); } type_mismatches.shrink_to_fit(); } - diagnostics.retain_mut(|diagnostic| { - use InferenceDiagnostic::*; - match diagnostic { - ExpectedFunction { found: ty, .. } - | UnresolvedField { receiver: ty, .. } - | UnresolvedMethodCall { receiver: ty, .. } => { - *ty = table.resolve_completely(ty.as_ref()).store(); - // FIXME: Remove this when we are on par with rustc in terms of inference - if ty.as_ref().references_non_lt_error() { - return false; - } - - if let UnresolvedMethodCall { field_with_same_name, .. } = diagnostic - && let Some(ty) = field_with_same_name - { - *ty = table.resolve_completely(ty.as_ref()).store(); - if ty.as_ref().references_non_lt_error() { - *field_with_same_name = None; - } - } - } - TypedHole { expected: ty, .. } => { - *ty = table.resolve_completely(ty.as_ref()).store(); - } - _ => (), - } - true - }); - diagnostics.shrink_to_fit(); for (_, subst) in method_resolutions.values_mut() { - *subst = table.resolve_completely(subst.as_ref()).store(); - *has_errors = - *has_errors || subst.as_ref().types().any(|ty| ty.references_non_lt_error()); + resolver.resolve_completely(subst); } method_resolutions.shrink_to_fit(); for (_, subst) in assoc_resolutions.values_mut() { - *subst = table.resolve_completely(subst.as_ref()).store(); - *has_errors = - *has_errors || subst.as_ref().types().any(|ty| ty.references_non_lt_error()); + resolver.resolve_completely(subst); } assoc_resolutions.shrink_to_fit(); for adjustment in expr_adjustments.values_mut().flatten() { - adjustment.target = table.resolve_completely(adjustment.target.as_ref()).store(); - *has_errors = *has_errors || adjustment.target.as_ref().references_non_lt_error(); + resolver.resolve_completely(&mut adjustment.target); } expr_adjustments.shrink_to_fit(); for adjustments in pat_adjustments.values_mut() { for adjustment in &mut *adjustments { - adjustment.source = table.resolve_completely(adjustment.source.as_ref()).store(); - *has_errors = *has_errors || adjustment.source.as_ref().references_non_lt_error(); + resolver.resolve_completely(&mut adjustment.source); } adjustments.shrink_to_fit(); } @@ -1458,7 +1468,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { }; for (place, _, sources) in fake_reads { - *place = table.resolve_completely(std::mem::replace(place, dummy_place())); + resolver.resolve_completely_with_default(place, dummy_place()); place.projections.shrink_to_fit(); for source in &mut *sources { source.shrink_to_fit(); @@ -1469,7 +1479,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { for min_capture in min_captures.values_mut() { for captured in &mut *min_capture { let CapturedPlace { place, info, mutability: _ } = captured; - *place = table.resolve_completely(std::mem::replace(place, dummy_place())); + resolver.resolve_completely_with_default(place, dummy_place()); let CaptureInfo { sources, capture_kind: _ } = info; for source in &mut *sources { source.shrink_to_fit(); @@ -1481,17 +1491,18 @@ impl<'body, 'db> InferenceContext<'body, 'db> { min_captures.shrink_to_fit(); } closures_data.shrink_to_fit(); - result.tuple_field_access_types = tuple_field_accesses_rev + *tuple_field_access_types = tuple_field_accesses_rev .into_iter() - .map(|subst| table.resolve_completely(subst).store()) - .inspect(|subst| { - *has_errors = - *has_errors || subst.as_ref().iter().any(|ty| ty.references_non_lt_error()); + .map(|mut subst| { + resolver.resolve_completely(&mut subst); + subst.store() }) .collect(); - result.tuple_field_access_types.shrink_to_fit(); + tuple_field_access_types.shrink_to_fit(); - result.diagnostics = diagnostics; + let (diagnostics, resolver_has_errors) = resolver.resolve_diagnostics(); + *result_diagnostics = diagnostics; + *has_errors |= resolver_has_errors; result } @@ -1502,6 +1513,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { &data.store, InferenceTyDiagnosticSource::Signature, LifetimeElisionKind::for_const(self.interner(), id.loc(self.db).container), + Span::Dummy, ); self.return_ty = return_ty; @@ -1513,6 +1525,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { &data.store, InferenceTyDiagnosticSource::Signature, LifetimeElisionKind::Elided(self.types.regions.statik), + Span::Dummy, ); self.return_ty = return_ty; @@ -1545,16 +1558,16 @@ impl<'body, 'db> InferenceContext<'body, 'db> { param_tys.push(va_list_ty); } - let mut param_tys = - param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var(Span::Dummy))); + let mut param_tys = param_tys.into_iter(); if let Some(self_param) = self_param && let Some(ty) = param_tys.next() { - let ty = self.process_user_written_ty(ty); + let ty = self.process_user_written_ty(Span::Dummy, ty); self.write_binding_ty(self_param, ty); } - for (ty, pat) in param_tys.zip(params) { - let ty = self.process_user_written_ty(ty); + for pat in params { + let ty = param_tys.next().unwrap_or_else(|| self.table.next_ty_var(Span::Dummy)); + let ty = self.process_user_written_ty(Span::Dummy, ty); self.infer_top_pat(*pat, ty, PatOrigin::Param); } @@ -1569,7 +1582,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { ctx.lower_ty(return_ty) }, ); - self.process_user_written_ty(return_ty) + self.process_user_written_ty(Span::Dummy, return_ty) } None => self.types.types.unit, }; @@ -1606,7 +1619,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { let var = self.table.next_ty_var(Span::Dummy); // Suppress future errors on this var. Add more things here when we add more diagnostics. - self.vars_emitted_type_must_be_known_for.insert(var); + self.vars_emitted_type_must_be_known_for.insert(var.into()); var } else { @@ -1751,10 +1764,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> { store: &ExpressionStore, type_source: InferenceTyDiagnosticSource, lifetime_elision: LifetimeElisionKind<'db>, + span: Span, ) -> Ty<'db> { let ty = self .with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref)); - let ty = self.process_user_written_ty(ty); + let ty = self.process_user_written_ty(span, ty); // Record the association from placeholders' TypeRefId to type variables. // We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order. @@ -1781,6 +1795,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, + type_ref.into(), ) } @@ -1791,17 +1806,22 @@ impl<'body, 'db> InferenceContext<'body, 'db> { LifetimeElisionKind::Infer, |ctx| ctx.lower_const(const_ref, ty), ); - self.insert_type_vars(const_, Span::Dummy) + self.insert_type_vars(const_, const_ref.expr.into()) } - pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> { + pub(crate) fn make_path_as_body_const( + &mut self, + type_ref: TypeRefId, + path: &Path, + ty: Ty<'db>, + ) -> Const<'db> { let const_ = self.with_ty_lowering( self.store, InferenceTyDiagnosticSource::Body, LifetimeElisionKind::Infer, |ctx| ctx.lower_path_as_const(path, ty), ); - self.insert_type_vars(const_, Span::Dummy) + self.insert_type_vars(const_, type_ref.into()) } fn err_ty(&self) -> Ty<'db> { @@ -1887,8 +1907,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> { } /// Whenever you lower a user-written type, you should call this. - fn process_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> { - self.table.process_user_written_ty(ty) + fn process_user_written_ty(&mut self, span: Span, ty: Ty<'db>) -> Ty<'db> { + self.table.process_user_written_ty(span, ty) } /// The difference of this method from `process_user_written_ty()` is that this method doesn't register a well-formed obligation, @@ -1979,8 +1999,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> { node: ExprOrPatId, ty: Ty<'db>, ) -> Ty<'db> { - if self.vars_emitted_type_must_be_known_for.insert(ty) { - self.push_diagnostic(InferenceDiagnostic::TypeMustBeKnown { at_point: node }); + if self.vars_emitted_type_must_be_known_for.insert(ty.into()) { + self.push_diagnostic(InferenceDiagnostic::TypeMustBeKnown { + at_point: node.into(), + top_term: None, + }); } self.types.types.error } @@ -2029,7 +2052,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { return (self.err_ty(), None); } let (mut ty, type_ns) = ctx.lower_ty_ext(type_anchor); - ty = self.table.process_user_written_ty(ty); + ty = self.table.process_user_written_ty(type_anchor.into(), ty); if let Some(TypeNs::SelfType(impl_)) = type_ns && let Some(trait_ref) = self.db.impl_trait(impl_) @@ -2197,7 +2220,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { (ty, _) = path_ctx.lower_partly_resolved_path(resolution, true); tried_resolving_once = true; - ty = self.table.process_user_written_ty(ty); + ty = self.table.process_user_written_ty(node.into(), ty); if ty.is_ty_error() { return (self.err_ty(), None); } @@ -2228,7 +2251,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { } let (mut ty, _) = path_ctx.lower_partly_resolved_path(resolution, true); - ty = self.table.process_user_written_ty(ty); + ty = self.table.process_user_written_ty(node.into(), ty); if let Some(segment) = remaining_segments.get(1) && let Some((AdtId::EnumId(id), _)) = ty.as_adt() |