Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer.rs')
-rw-r--r--crates/hir-ty/src/infer.rs199
1 files changed, 111 insertions, 88 deletions
diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs
index 30b420b6d5..4aeb5ec71c 100644
--- a/crates/hir-ty/src/infer.rs
+++ b/crates/hir-ty/src/infer.rs
@@ -29,7 +29,7 @@ mod path;
mod place_op;
pub(crate) mod unify;
-use std::{cell::OnceCell, convert::identity, fmt, iter, ops::Deref};
+use std::{cell::OnceCell, convert::identity, fmt, ops::Deref};
use base_db::{Crate, FxIndexMap};
use either::Either;
@@ -50,6 +50,7 @@ use hir_def::{
use hir_expand::{mod_path::ModPath, name::Name};
use indexmap::IndexSet;
use la_arena::ArenaMap;
+use macros::{TypeFoldable, TypeVisitable};
use rustc_ast_ir::Mutability;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_type_ir::{
@@ -76,14 +77,15 @@ use crate::{
diagnostics::{Diagnostics, InferenceTyLoweringContext as TyLoweringContext},
expr::ExprIsRead,
pat::PatOrigin,
+ unify::resolve_completely::WriteBackCtxt,
},
lower::{
ImplTraitIdx, ImplTraitLoweringMode, LifetimeElisionKind, diagnostics::TyLoweringDiagnostic,
},
method_resolution::CandidateId,
next_solver::{
- AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredGenericArgs,
- StoredTy, StoredTys, Ty, TyKind, Tys,
+ AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArgs, Region, StoredGenericArg,
+ StoredGenericArgs, StoredTy, StoredTys, Term, Ty, TyKind, Tys,
abi::Safety,
infer::{InferCtxt, ObligationInspector, traits::ObligationCause},
},
@@ -188,7 +190,7 @@ fn infer_signature_query(db: &dyn HirDatabase, def: GenericDefId) -> InferenceRe
// Array lengths are always `usize`.
RootExprOrigin::ArrayLength => Expectation::has_type(ctx.types.types.usize),
// Const parameter default: look up the param's declared type.
- RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty_ns(
+ RootExprOrigin::ConstParam(local_id) => Expectation::has_type(db.const_param_ty(
ConstParamId::from_unchecked(TypeOrConstParamId { parent: def, local_id }),
)),
// Path const generic args: determining the expected type requires
@@ -307,107 +309,152 @@ pub enum InferenceTyDiagnosticSource {
Signature,
}
-#[derive(Debug, PartialEq, Eq, Clone)]
+#[derive(Debug, PartialEq, Eq, Clone, TypeVisitable, TypeFoldable)]
pub enum InferenceDiagnostic {
NoSuchField {
+ #[type_visitable(ignore)]
field: ExprOrPatId,
+ #[type_visitable(ignore)]
private: Option<LocalFieldId>,
+ #[type_visitable(ignore)]
variant: VariantId,
},
PrivateField {
+ #[type_visitable(ignore)]
expr: ExprId,
+ #[type_visitable(ignore)]
field: FieldId,
},
PrivateAssocItem {
+ #[type_visitable(ignore)]
id: ExprOrPatId,
+ #[type_visitable(ignore)]
item: AssocItemId,
},
UnresolvedField {
+ #[type_visitable(ignore)]
expr: ExprId,
receiver: StoredTy,
+ #[type_visitable(ignore)]
name: Name,
+ #[type_visitable(ignore)]
method_with_same_name_exists: bool,
},
UnresolvedMethodCall {
+ #[type_visitable(ignore)]
expr: ExprId,
receiver: StoredTy,
+ #[type_visitable(ignore)]
name: Name,
/// Contains the type the field resolves to
field_with_same_name: Option<StoredTy>,
+ #[type_visitable(ignore)]
assoc_func_with_same_name: Option<FunctionId>,
},
UnresolvedAssocItem {
+ #[type_visitable(ignore)]
id: ExprOrPatId,
},
UnresolvedIdent {
+ #[type_visitable(ignore)]
id: ExprOrPatId,
},
// FIXME: This should be emitted in body lowering
BreakOutsideOfLoop {
+ #[type_visitable(ignore)]
expr: ExprId,
+ #[type_visitable(ignore)]
is_break: bool,
+ #[type_visitable(ignore)]
bad_value_break: bool,
},
MismatchedArgCount {
+ #[type_visitable(ignore)]
call_expr: ExprId,
+ #[type_visitable(ignore)]
expected: usize,
+ #[type_visitable(ignore)]
found: usize,
},
MismatchedTupleStructPatArgCount {
+ #[type_visitable(ignore)]
pat: PatId,
+ #[type_visitable(ignore)]
expected: usize,
+ #[type_visitable(ignore)]
found: usize,
},
ExpectedFunction {
+ #[type_visitable(ignore)]
call_expr: ExprId,
found: StoredTy,
},
TypedHole {
+ #[type_visitable(ignore)]
expr: ExprId,
expected: StoredTy,
},
CastToUnsized {
+ #[type_visitable(ignore)]
expr: ExprId,
cast_ty: StoredTy,
},
InvalidCast {
+ #[type_visitable(ignore)]
expr: ExprId,
+ #[type_visitable(ignore)]
error: CastError,
expr_ty: StoredTy,
cast_ty: StoredTy,
},
TyDiagnostic {
+ #[type_visitable(ignore)]
source: InferenceTyDiagnosticSource,
+ #[type_visitable(ignore)]
diag: TyLoweringDiagnostic,
},
PathDiagnostic {
+ #[type_visitable(ignore)]
node: ExprOrPatId,
+ #[type_visitable(ignore)]
diag: PathLoweringDiagnostic,
},
MethodCallIncorrectGenericsLen {
+ #[type_visitable(ignore)]
expr: ExprId,
+ #[type_visitable(ignore)]
provided_count: u32,
+ #[type_visitable(ignore)]
expected_count: u32,
+ #[type_visitable(ignore)]
kind: IncorrectGenericsLenKind,
+ #[type_visitable(ignore)]
def: GenericDefId,
},
MethodCallIncorrectGenericsOrder {
+ #[type_visitable(ignore)]
expr: ExprId,
+ #[type_visitable(ignore)]
param_id: GenericParamId,
+ #[type_visitable(ignore)]
arg_idx: u32,
/// Whether the `GenericArgs` contains a `Self` arg.
+ #[type_visitable(ignore)]
has_self_arg: bool,
},
InvalidLhsOfAssignment {
+ #[type_visitable(ignore)]
lhs: ExprId,
},
TypeMustBeKnown {
- at_point: ExprOrPatId,
+ #[type_visitable(ignore)]
+ at_point: Span,
+ top_term: Option<StoredGenericArg>,
},
}
/// A mismatch between an expected and an inferred type.
-#[derive(Clone, PartialEq, Eq, Debug, Hash)]
+#[derive(Clone, PartialEq, Eq, Debug, Hash, TypeVisitable, TypeFoldable)]
pub struct TypeMismatch {
pub expected: StoredTy,
pub actual: StoredTy,
@@ -1181,7 +1228,7 @@ pub(crate) struct InferenceContext<'body, 'db> {
deferred_call_resolutions: FxHashMap<ExprId, Vec<DeferredCallResolution<'db>>>,
diagnostics: Diagnostics,
- vars_emitted_type_must_be_known_for: FxHashSet<Ty<'db>>,
+ vars_emitted_type_must_be_known_for: FxHashSet<Term<'db>>,
}
#[derive(Clone, Debug)]
@@ -1331,14 +1378,15 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
// there is no problem in it being `pub(crate)`, remove this comment.
fn resolve_all(self) -> InferenceResult {
let InferenceContext {
- mut table,
+ table,
mut result,
tuple_field_accesses_rev,
diagnostics,
types,
+ vars_emitted_type_must_be_known_for,
..
} = self;
- let mut diagnostics = diagnostics.finish();
+ let diagnostics = diagnostics.finish();
// Destructure every single field so whenever new fields are added to `InferenceResult` we
// don't forget to handle them here.
let InferenceResult {
@@ -1359,30 +1407,28 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
pat_adjustments,
binding_modes: _,
expr_adjustments,
- tuple_field_access_types: _,
+ tuple_field_access_types,
coercion_casts: _,
- diagnostics: _,
+ diagnostics: result_diagnostics,
} = &mut result;
+ let mut resolver =
+ WriteBackCtxt::new(table, diagnostics, vars_emitted_type_must_be_known_for);
skipped_ref_pats.shrink_to_fit();
for ty in type_of_expr.values_mut() {
- *ty = table.resolve_completely(ty.as_ref()).store();
- *has_errors = *has_errors || ty.as_ref().references_non_lt_error();
+ resolver.resolve_completely(ty);
}
type_of_expr.shrink_to_fit();
for ty in type_of_pat.values_mut() {
- *ty = table.resolve_completely(ty.as_ref()).store();
- *has_errors = *has_errors || ty.as_ref().references_non_lt_error();
+ resolver.resolve_completely(ty);
}
type_of_pat.shrink_to_fit();
for ty in type_of_binding.values_mut() {
- *ty = table.resolve_completely(ty.as_ref()).store();
- *has_errors = *has_errors || ty.as_ref().references_non_lt_error();
+ resolver.resolve_completely(ty);
}
type_of_binding.shrink_to_fit();
for ty in type_of_type_placeholder.values_mut() {
- *ty = table.resolve_completely(ty.as_ref()).store();
- *has_errors = *has_errors || ty.as_ref().references_non_lt_error();
+ resolver.resolve_completely(ty);
}
type_of_type_placeholder.shrink_to_fit();
type_of_opaque.shrink_to_fit();
@@ -1390,61 +1436,25 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
if let Some(type_mismatches) = type_mismatches {
*has_errors = true;
for mismatch in type_mismatches.values_mut() {
- mismatch.expected = table.resolve_completely(mismatch.expected.as_ref()).store();
- mismatch.actual = table.resolve_completely(mismatch.actual.as_ref()).store();
+ resolver.resolve_type_mismatch(mismatch);
}
type_mismatches.shrink_to_fit();
}
- diagnostics.retain_mut(|diagnostic| {
- use InferenceDiagnostic::*;
- match diagnostic {
- ExpectedFunction { found: ty, .. }
- | UnresolvedField { receiver: ty, .. }
- | UnresolvedMethodCall { receiver: ty, .. } => {
- *ty = table.resolve_completely(ty.as_ref()).store();
- // FIXME: Remove this when we are on par with rustc in terms of inference
- if ty.as_ref().references_non_lt_error() {
- return false;
- }
-
- if let UnresolvedMethodCall { field_with_same_name, .. } = diagnostic
- && let Some(ty) = field_with_same_name
- {
- *ty = table.resolve_completely(ty.as_ref()).store();
- if ty.as_ref().references_non_lt_error() {
- *field_with_same_name = None;
- }
- }
- }
- TypedHole { expected: ty, .. } => {
- *ty = table.resolve_completely(ty.as_ref()).store();
- }
- _ => (),
- }
- true
- });
- diagnostics.shrink_to_fit();
for (_, subst) in method_resolutions.values_mut() {
- *subst = table.resolve_completely(subst.as_ref()).store();
- *has_errors =
- *has_errors || subst.as_ref().types().any(|ty| ty.references_non_lt_error());
+ resolver.resolve_completely(subst);
}
method_resolutions.shrink_to_fit();
for (_, subst) in assoc_resolutions.values_mut() {
- *subst = table.resolve_completely(subst.as_ref()).store();
- *has_errors =
- *has_errors || subst.as_ref().types().any(|ty| ty.references_non_lt_error());
+ resolver.resolve_completely(subst);
}
assoc_resolutions.shrink_to_fit();
for adjustment in expr_adjustments.values_mut().flatten() {
- adjustment.target = table.resolve_completely(adjustment.target.as_ref()).store();
- *has_errors = *has_errors || adjustment.target.as_ref().references_non_lt_error();
+ resolver.resolve_completely(&mut adjustment.target);
}
expr_adjustments.shrink_to_fit();
for adjustments in pat_adjustments.values_mut() {
for adjustment in &mut *adjustments {
- adjustment.source = table.resolve_completely(adjustment.source.as_ref()).store();
- *has_errors = *has_errors || adjustment.source.as_ref().references_non_lt_error();
+ resolver.resolve_completely(&mut adjustment.source);
}
adjustments.shrink_to_fit();
}
@@ -1458,7 +1468,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
};
for (place, _, sources) in fake_reads {
- *place = table.resolve_completely(std::mem::replace(place, dummy_place()));
+ resolver.resolve_completely_with_default(place, dummy_place());
place.projections.shrink_to_fit();
for source in &mut *sources {
source.shrink_to_fit();
@@ -1469,7 +1479,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
for min_capture in min_captures.values_mut() {
for captured in &mut *min_capture {
let CapturedPlace { place, info, mutability: _ } = captured;
- *place = table.resolve_completely(std::mem::replace(place, dummy_place()));
+ resolver.resolve_completely_with_default(place, dummy_place());
let CaptureInfo { sources, capture_kind: _ } = info;
for source in &mut *sources {
source.shrink_to_fit();
@@ -1481,17 +1491,18 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
min_captures.shrink_to_fit();
}
closures_data.shrink_to_fit();
- result.tuple_field_access_types = tuple_field_accesses_rev
+ *tuple_field_access_types = tuple_field_accesses_rev
.into_iter()
- .map(|subst| table.resolve_completely(subst).store())
- .inspect(|subst| {
- *has_errors =
- *has_errors || subst.as_ref().iter().any(|ty| ty.references_non_lt_error());
+ .map(|mut subst| {
+ resolver.resolve_completely(&mut subst);
+ subst.store()
})
.collect();
- result.tuple_field_access_types.shrink_to_fit();
+ tuple_field_access_types.shrink_to_fit();
- result.diagnostics = diagnostics;
+ let (diagnostics, resolver_has_errors) = resolver.resolve_diagnostics();
+ *result_diagnostics = diagnostics;
+ *has_errors |= resolver_has_errors;
result
}
@@ -1502,6 +1513,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
&data.store,
InferenceTyDiagnosticSource::Signature,
LifetimeElisionKind::for_const(self.interner(), id.loc(self.db).container),
+ Span::Dummy,
);
self.return_ty = return_ty;
@@ -1513,6 +1525,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
&data.store,
InferenceTyDiagnosticSource::Signature,
LifetimeElisionKind::Elided(self.types.regions.statik),
+ Span::Dummy,
);
self.return_ty = return_ty;
@@ -1545,16 +1558,16 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
param_tys.push(va_list_ty);
}
- let mut param_tys =
- param_tys.into_iter().chain(iter::repeat(self.table.next_ty_var(Span::Dummy)));
+ let mut param_tys = param_tys.into_iter();
if let Some(self_param) = self_param
&& let Some(ty) = param_tys.next()
{
- let ty = self.process_user_written_ty(ty);
+ let ty = self.process_user_written_ty(Span::Dummy, ty);
self.write_binding_ty(self_param, ty);
}
- for (ty, pat) in param_tys.zip(params) {
- let ty = self.process_user_written_ty(ty);
+ for pat in params {
+ let ty = param_tys.next().unwrap_or_else(|| self.table.next_ty_var(Span::Dummy));
+ let ty = self.process_user_written_ty(Span::Dummy, ty);
self.infer_top_pat(*pat, ty, PatOrigin::Param);
}
@@ -1569,7 +1582,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
ctx.lower_ty(return_ty)
},
);
- self.process_user_written_ty(return_ty)
+ self.process_user_written_ty(Span::Dummy, return_ty)
}
None => self.types.types.unit,
};
@@ -1606,7 +1619,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
let var = self.table.next_ty_var(Span::Dummy);
// Suppress future errors on this var. Add more things here when we add more diagnostics.
- self.vars_emitted_type_must_be_known_for.insert(var);
+ self.vars_emitted_type_must_be_known_for.insert(var.into());
var
} else {
@@ -1751,10 +1764,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
store: &ExpressionStore,
type_source: InferenceTyDiagnosticSource,
lifetime_elision: LifetimeElisionKind<'db>,
+ span: Span,
) -> Ty<'db> {
let ty = self
.with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref));
- let ty = self.process_user_written_ty(ty);
+ let ty = self.process_user_written_ty(span, ty);
// Record the association from placeholders' TypeRefId to type variables.
// We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order.
@@ -1781,6 +1795,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
+ type_ref.into(),
)
}
@@ -1791,17 +1806,22 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
LifetimeElisionKind::Infer,
|ctx| ctx.lower_const(const_ref, ty),
);
- self.insert_type_vars(const_, Span::Dummy)
+ self.insert_type_vars(const_, const_ref.expr.into())
}
- pub(crate) fn make_path_as_body_const(&mut self, path: &Path, ty: Ty<'db>) -> Const<'db> {
+ pub(crate) fn make_path_as_body_const(
+ &mut self,
+ type_ref: TypeRefId,
+ path: &Path,
+ ty: Ty<'db>,
+ ) -> Const<'db> {
let const_ = self.with_ty_lowering(
self.store,
InferenceTyDiagnosticSource::Body,
LifetimeElisionKind::Infer,
|ctx| ctx.lower_path_as_const(path, ty),
);
- self.insert_type_vars(const_, Span::Dummy)
+ self.insert_type_vars(const_, type_ref.into())
}
fn err_ty(&self) -> Ty<'db> {
@@ -1887,8 +1907,8 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
}
/// Whenever you lower a user-written type, you should call this.
- fn process_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
- self.table.process_user_written_ty(ty)
+ fn process_user_written_ty(&mut self, span: Span, ty: Ty<'db>) -> Ty<'db> {
+ self.table.process_user_written_ty(span, ty)
}
/// The difference of this method from `process_user_written_ty()` is that this method doesn't register a well-formed obligation,
@@ -1979,8 +1999,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
node: ExprOrPatId,
ty: Ty<'db>,
) -> Ty<'db> {
- if self.vars_emitted_type_must_be_known_for.insert(ty) {
- self.push_diagnostic(InferenceDiagnostic::TypeMustBeKnown { at_point: node });
+ if self.vars_emitted_type_must_be_known_for.insert(ty.into()) {
+ self.push_diagnostic(InferenceDiagnostic::TypeMustBeKnown {
+ at_point: node.into(),
+ top_term: None,
+ });
}
self.types.types.error
}
@@ -2029,7 +2052,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
return (self.err_ty(), None);
}
let (mut ty, type_ns) = ctx.lower_ty_ext(type_anchor);
- ty = self.table.process_user_written_ty(ty);
+ ty = self.table.process_user_written_ty(type_anchor.into(), ty);
if let Some(TypeNs::SelfType(impl_)) = type_ns
&& let Some(trait_ref) = self.db.impl_trait(impl_)
@@ -2197,7 +2220,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
(ty, _) = path_ctx.lower_partly_resolved_path(resolution, true);
tried_resolving_once = true;
- ty = self.table.process_user_written_ty(ty);
+ ty = self.table.process_user_written_ty(node.into(), ty);
if ty.is_ty_error() {
return (self.err_ty(), None);
}
@@ -2228,7 +2251,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
}
let (mut ty, _) = path_ctx.lower_partly_resolved_path(resolution, true);
- ty = self.table.process_user_written_ty(ty);
+ ty = self.table.process_user_written_ty(node.into(), ty);
if let Some(segment) = remaining_segments.get(1)
&& let Some((AdtId::EnumId(id), _)) = ty.as_adt()