Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/cast.rs')
-rw-r--r--crates/hir-ty/src/infer/cast.rs31
1 files changed, 17 insertions, 14 deletions
diff --git a/crates/hir-ty/src/infer/cast.rs b/crates/hir-ty/src/infer/cast.rs
index daf954c217..1dade4dfd7 100644
--- a/crates/hir-ty/src/infer/cast.rs
+++ b/crates/hir-ty/src/infer/cast.rs
@@ -125,8 +125,9 @@ impl<'db> CastCheck<'db> {
&mut self,
ctx: &mut InferenceContext<'_, 'db>,
) -> Result<(), InferenceDiagnostic> {
- self.expr_ty = ctx.table.try_structurally_resolve_type(self.expr_ty);
- self.cast_ty = ctx.table.try_structurally_resolve_type(self.cast_ty);
+ self.expr_ty =
+ ctx.table.try_structurally_resolve_type(self.source_expr.into(), self.expr_ty);
+ self.cast_ty = ctx.table.try_structurally_resolve_type(self.expr.into(), self.cast_ty);
// This should always come first so that we apply the coercion, which impacts infer vars.
if ctx
@@ -199,7 +200,8 @@ impl<'db> CastCheck<'db> {
},
// array-ptr-cast
CastTy::Ptr(t, m) => {
- let t = ctx.table.try_structurally_resolve_type(t);
+ let t =
+ ctx.table.try_structurally_resolve_type(self.expr.into(), t);
if !ctx.table.type_is_sized_modulo_regions(t) {
return Err(CastError::IllegalCast);
}
@@ -262,8 +264,8 @@ impl<'db> CastCheck<'db> {
t_cast: Ty<'db>,
m_cast: Mutability,
) -> Result<(), CastError> {
- let t_expr = ctx.table.try_structurally_resolve_type(t_expr);
- let t_cast = ctx.table.try_structurally_resolve_type(t_cast);
+ let t_expr = ctx.table.try_structurally_resolve_type(self.expr.into(), t_expr);
+ let t_cast = ctx.table.try_structurally_resolve_type(self.expr.into(), t_cast);
if m_expr >= m_cast
&& let TyKind::Array(ety, _) = t_expr.kind()
@@ -306,8 +308,8 @@ impl<'db> CastCheck<'db> {
src: Ty<'db>,
dst: Ty<'db>,
) -> Result<(), CastError> {
- let src_kind = pointer_kind(src, ctx).map_err(|_| CastError::Unknown)?;
- let dst_kind = pointer_kind(dst, ctx).map_err(|_| CastError::Unknown)?;
+ let src_kind = pointer_kind(self.expr, src, ctx).map_err(|_| CastError::Unknown)?;
+ let dst_kind = pointer_kind(self.expr, dst, ctx).map_err(|_| CastError::Unknown)?;
match (src_kind, dst_kind) {
(Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()),
@@ -372,7 +374,7 @@ impl<'db> CastCheck<'db> {
// This is `fcx.demand_eqtype`, but inlined to give a better error.
if ctx
.table
- .at(&ObligationCause::dummy())
+ .at(&ObligationCause::new(self.expr))
.eq(src_obj, dst_obj)
.map(|infer_ok| ctx.table.register_infer_ok(infer_ok))
.is_err()
@@ -457,7 +459,7 @@ impl<'db> CastCheck<'db> {
ctx: &mut InferenceContext<'_, 'db>,
expr_ty: Ty<'db>,
) -> Result<(), CastError> {
- match pointer_kind(expr_ty, ctx).map_err(|_| CastError::Unknown)? {
+ match pointer_kind(self.expr, expr_ty, ctx).map_err(|_| CastError::Unknown)? {
// None => Err(CastError::UnknownExprPtrKind),
None => Ok(()),
Some(PointerKind::Error) => Ok(()),
@@ -471,7 +473,7 @@ impl<'db> CastCheck<'db> {
ctx: &mut InferenceContext<'_, 'db>,
cast_ty: Ty<'db>,
) -> Result<(), CastError> {
- match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
+ match pointer_kind(self.expr, cast_ty, ctx).map_err(|_| CastError::Unknown)? {
// None => Err(CastError::UnknownCastPtrKind),
None => Ok(()),
Some(PointerKind::Error) => Ok(()),
@@ -487,7 +489,7 @@ impl<'db> CastCheck<'db> {
ctx: &mut InferenceContext<'_, 'db>,
cast_ty: Ty<'db>,
) -> Result<(), CastError> {
- match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
+ match pointer_kind(self.expr, cast_ty, ctx).map_err(|_| CastError::Unknown)? {
// None => Err(CastError::UnknownCastPtrKind),
None => Ok(()),
Some(PointerKind::Error) => Ok(()),
@@ -516,10 +518,11 @@ enum PointerKind<'db> {
}
fn pointer_kind<'db>(
+ expr: ExprId,
ty: Ty<'db>,
ctx: &mut InferenceContext<'_, 'db>,
) -> Result<Option<PointerKind<'db>>, ()> {
- let ty = ctx.table.try_structurally_resolve_type(ty);
+ let ty = ctx.table.try_structurally_resolve_type(expr.into(), ty);
if ctx.table.type_is_sized_modulo_regions(ty) {
return Ok(Some(PointerKind::Thin));
@@ -540,14 +543,14 @@ fn pointer_kind<'db>(
let last_field_ty = ctx.db.field_types(id.into())[last_field]
.get()
.instantiate(ctx.interner(), subst);
- pointer_kind(last_field_ty, ctx)
+ pointer_kind(expr, last_field_ty, ctx)
} else {
Ok(Some(PointerKind::Thin))
}
}
TyKind::Tuple(subst) => match subst.iter().next_back() {
None => Ok(Some(PointerKind::Thin)),
- Some(ty) => pointer_kind(ty, ctx),
+ Some(ty) => pointer_kind(expr, ty, ctx),
},
TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)),
TyKind::Alias(..) => Ok(Some(PointerKind::OfAlias)),