Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/cast.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/cast.rs | 454 |
1 files changed, 429 insertions, 25 deletions
diff --git a/crates/hir-ty/src/infer/cast.rs b/crates/hir-ty/src/infer/cast.rs index 060b5f36f2..caa3960a22 100644 --- a/crates/hir-ty/src/infer/cast.rs +++ b/crates/hir-ty/src/infer/cast.rs @@ -1,47 +1,451 @@ //! Type cast logic. Basically coercion + additional casts. -use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind}; +use chalk_ir::{Mutability, Scalar, TyVariableKind, UintTy}; +use hir_def::{hir::ExprId, AdtId}; +use stdx::never; + +use crate::{ + infer::unify::InferenceTable, Adjustment, Binders, DynTy, InferenceDiagnostic, Interner, + PlaceholderIndex, QuantifiedWhereClauses, Ty, TyExt, TyKind, TypeFlags, WhereClause, +}; + +#[derive(Debug)] +pub(crate) enum Int { + I, + U(UintTy), + Bool, + Char, + CEnum, + InferenceVar, +} + +#[derive(Debug)] +pub(crate) enum CastTy { + Int(Int), + Float, + FnPtr, + Ptr(Ty, Mutability), + // `DynStar` is Not supported yet in r-a +} + +impl CastTy { + pub(crate) fn from_ty(table: &mut InferenceTable<'_>, t: &Ty) -> Option<Self> { + match t.kind(Interner) { + TyKind::Scalar(Scalar::Bool) => Some(Self::Int(Int::Bool)), + TyKind::Scalar(Scalar::Char) => Some(Self::Int(Int::Char)), + TyKind::Scalar(Scalar::Int(_)) => Some(Self::Int(Int::I)), + TyKind::Scalar(Scalar::Uint(it)) => Some(Self::Int(Int::U(*it))), + TyKind::InferenceVar(_, TyVariableKind::Integer) => Some(Self::Int(Int::InferenceVar)), + TyKind::InferenceVar(_, TyVariableKind::Float) => Some(Self::Float), + TyKind::Scalar(Scalar::Float(_)) => Some(Self::Float), + TyKind::Adt(..) => { + let (AdtId::EnumId(id), _) = t.as_adt()? else { + return None; + }; + let enum_data = table.db.enum_data(id); + if enum_data.is_payload_free(table.db.upcast()) { + Some(Self::Int(Int::CEnum)) + } else { + None + } + } + TyKind::Raw(m, ty) => Some(Self::Ptr(table.resolve_ty_shallow(ty), *m)), + TyKind::Function(_) => Some(Self::FnPtr), + _ => None, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum CastError { + Unknown, + CastToBool, + CastToChar, + DifferingKinds, + SizedUnsizedCast, + IllegalCast, + IntToFatCast, + NeedDeref, + NeedViaPtr, + NeedViaThinPtr, + NeedViaInt, + NonScalar, + UnknownCastPtrKind, + UnknownExprPtrKind, +} + +impl CastError { + fn into_diagnostic(self, expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> InferenceDiagnostic { + InferenceDiagnostic::InvalidCast { expr, error: self, expr_ty, cast_ty } + } +} #[derive(Clone, Debug)] pub(super) struct CastCheck { + expr: ExprId, + source_expr: ExprId, expr_ty: Ty, cast_ty: Ty, } impl CastCheck { - pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self { - Self { expr_ty, cast_ty } + pub(super) fn new(expr: ExprId, source_expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> Self { + Self { expr, source_expr, expr_ty, cast_ty } } - pub(super) fn check(self, table: &mut InferenceTable<'_>) { - // FIXME: This function currently only implements the bits that influence the type - // inference. We should return the adjustments on success and report diagnostics on error. - let expr_ty = table.resolve_ty_shallow(&self.expr_ty); - let cast_ty = table.resolve_ty_shallow(&self.cast_ty); + pub(super) fn check<F, G>( + &mut self, + table: &mut InferenceTable<'_>, + apply_adjustments: &mut F, + set_coercion_cast: &mut G, + ) -> Result<(), InferenceDiagnostic> + where + F: FnMut(ExprId, Vec<Adjustment>), + G: FnMut(ExprId), + { + table.resolve_obligations_as_possible(); + self.expr_ty = table.resolve_ty_shallow(&self.expr_ty); + self.cast_ty = table.resolve_ty_shallow(&self.cast_ty); + + if self.expr_ty.contains_unknown() || self.cast_ty.contains_unknown() { + return Ok(()); + } + + if !self.cast_ty.data(Interner).flags.contains(TypeFlags::HAS_TY_INFER) + && !table.is_sized(&self.cast_ty) + { + return Err(InferenceDiagnostic::CastToUnsized { + expr: self.expr, + cast_ty: self.cast_ty.clone(), + }); + } - if table.coerce(&expr_ty, &cast_ty).is_ok() { - return; + // Chalk doesn't support trait upcasting and fails to solve some obvious goals + // when the trait environment contains some recursive traits (See issue #18047) + // We skip cast checks for such cases for now, until the next-gen solver. + if contains_dyn_trait(&self.cast_ty) { + return Ok(()); } - if check_ref_to_ptr_cast(expr_ty, cast_ty, table) { - // Note that this type of cast is actually split into a coercion to a - // pointer type and a cast: - // &[T; N] -> *[T; N] -> *T + if let Ok((adj, _)) = table.coerce(&self.expr_ty, &self.cast_ty) { + apply_adjustments(self.source_expr, adj); + set_coercion_cast(self.source_expr); + return Ok(()); } - // FIXME: Check other kinds of non-coercion casts and report error if any? + self.do_check(table, apply_adjustments) + .map_err(|e| e.into_diagnostic(self.expr, self.expr_ty.clone(), self.cast_ty.clone())) + } + + fn do_check<F>( + &self, + table: &mut InferenceTable<'_>, + apply_adjustments: &mut F, + ) -> Result<(), CastError> + where + F: FnMut(ExprId, Vec<Adjustment>), + { + let (t_from, t_cast) = + match (CastTy::from_ty(table, &self.expr_ty), CastTy::from_ty(table, &self.cast_ty)) { + (Some(t_from), Some(t_cast)) => (t_from, t_cast), + (None, Some(t_cast)) => match self.expr_ty.kind(Interner) { + TyKind::FnDef(..) => { + let sig = self.expr_ty.callable_sig(table.db).expect("FnDef had no sig"); + let sig = table.normalize_associated_types_in(sig); + let fn_ptr = TyKind::Function(sig.to_fn_ptr()).intern(Interner); + if let Ok((adj, _)) = table.coerce(&self.expr_ty, &fn_ptr) { + apply_adjustments(self.source_expr, adj); + } else { + return Err(CastError::IllegalCast); + } + + (CastTy::FnPtr, t_cast) + } + TyKind::Ref(mutbl, _, inner_ty) => { + let inner_ty = table.resolve_ty_shallow(inner_ty); + return match t_cast { + CastTy::Int(_) | CastTy::Float => match inner_ty.kind(Interner) { + TyKind::Scalar( + Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_), + ) + | TyKind::InferenceVar( + _, + TyVariableKind::Integer | TyVariableKind::Float, + ) => Err(CastError::NeedDeref), + + _ => Err(CastError::NeedViaPtr), + }, + // array-ptr-cast + CastTy::Ptr(t, m) => { + let t = table.resolve_ty_shallow(&t); + if !table.is_sized(&t) { + return Err(CastError::IllegalCast); + } + self.check_ref_cast( + table, + &inner_ty, + *mutbl, + &t, + m, + apply_adjustments, + ) + } + _ => Err(CastError::NonScalar), + }; + } + _ => return Err(CastError::NonScalar), + }, + _ => return Err(CastError::NonScalar), + }; + + // rustc checks whether the `expr_ty` is foreign adt with `non_exhaustive` sym + + match (t_from, t_cast) { + (_, CastTy::Int(Int::CEnum) | CastTy::FnPtr) => Err(CastError::NonScalar), + (_, CastTy::Int(Int::Bool)) => Err(CastError::CastToBool), + (CastTy::Int(Int::U(UintTy::U8)), CastTy::Int(Int::Char)) => Ok(()), + (_, CastTy::Int(Int::Char)) => Err(CastError::CastToChar), + (CastTy::Int(Int::Bool | Int::CEnum | Int::Char), CastTy::Float) => { + Err(CastError::NeedViaInt) + } + (CastTy::Int(Int::Bool | Int::CEnum | Int::Char) | CastTy::Float, CastTy::Ptr(..)) + | (CastTy::Ptr(..) | CastTy::FnPtr, CastTy::Float) => Err(CastError::IllegalCast), + (CastTy::Ptr(src, _), CastTy::Ptr(dst, _)) => { + self.check_ptr_ptr_cast(table, &src, &dst) + } + (CastTy::Ptr(src, _), CastTy::Int(_)) => self.check_ptr_addr_cast(table, &src), + (CastTy::Int(_), CastTy::Ptr(dst, _)) => self.check_addr_ptr_cast(table, &dst), + (CastTy::FnPtr, CastTy::Ptr(dst, _)) => self.check_fptr_ptr_cast(table, &dst), + (CastTy::Int(Int::CEnum), CastTy::Int(_)) => Ok(()), + (CastTy::Int(Int::Char | Int::Bool), CastTy::Int(_)) => Ok(()), + (CastTy::Int(_) | CastTy::Float, CastTy::Int(_) | CastTy::Float) => Ok(()), + (CastTy::FnPtr, CastTy::Int(_)) => Ok(()), + } + } + + fn check_ref_cast<F>( + &self, + table: &mut InferenceTable<'_>, + t_expr: &Ty, + m_expr: Mutability, + t_cast: &Ty, + m_cast: Mutability, + apply_adjustments: &mut F, + ) -> Result<(), CastError> + where + F: FnMut(ExprId, Vec<Adjustment>), + { + // Mutability order is opposite to rustc. `Mut < Not` + if m_expr <= m_cast { + if let TyKind::Array(ety, _) = t_expr.kind(Interner) { + // Coerce to a raw pointer so that we generate RawPtr in MIR. + let array_ptr_type = TyKind::Raw(m_expr, t_expr.clone()).intern(Interner); + if let Ok((adj, _)) = table.coerce(&self.expr_ty, &array_ptr_type) { + apply_adjustments(self.source_expr, adj); + } else { + never!( + "could not cast from reference to array to pointer to array ({:?} to {:?})", + self.expr_ty, + array_ptr_type + ); + } + + // This is a less strict condition than rustc's `demand_eqtype`, + // but false negative is better than false positive + if table.coerce(ety, t_cast).is_ok() { + return Ok(()); + } + } + } + + Err(CastError::IllegalCast) + } + + fn check_ptr_ptr_cast( + &self, + table: &mut InferenceTable<'_>, + src: &Ty, + dst: &Ty, + ) -> Result<(), CastError> { + let src_kind = pointer_kind(src, table).map_err(|_| CastError::Unknown)?; + let dst_kind = pointer_kind(dst, table).map_err(|_| CastError::Unknown)?; + + match (src_kind, dst_kind) { + (Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()), + (_, None) => Err(CastError::UnknownCastPtrKind), + (_, Some(PointerKind::Thin)) => Ok(()), + (None, _) => Err(CastError::UnknownExprPtrKind), + (Some(PointerKind::Thin), _) => Err(CastError::SizedUnsizedCast), + (Some(PointerKind::VTable(src_tty)), Some(PointerKind::VTable(dst_tty))) => { + let principal = |tty: &Binders<QuantifiedWhereClauses>| { + tty.skip_binders().as_slice(Interner).first().and_then(|pred| { + if let WhereClause::Implemented(tr) = pred.skip_binders() { + Some(tr.trait_id) + } else { + None + } + }) + }; + match (principal(&src_tty), principal(&dst_tty)) { + (Some(src_principal), Some(dst_principal)) => { + if src_principal == dst_principal { + return Ok(()); + } + let src_principal = + table.db.trait_datum(table.trait_env.krate, src_principal); + let dst_principal = + table.db.trait_datum(table.trait_env.krate, dst_principal); + if src_principal.is_auto_trait() && dst_principal.is_auto_trait() { + Ok(()) + } else { + Err(CastError::DifferingKinds) + } + } + _ => Err(CastError::Unknown), + } + } + (Some(src_kind), Some(dst_kind)) if src_kind == dst_kind => Ok(()), + (_, _) => Err(CastError::DifferingKinds), + } + } + + fn check_ptr_addr_cast( + &self, + table: &mut InferenceTable<'_>, + expr_ty: &Ty, + ) -> Result<(), CastError> { + match pointer_kind(expr_ty, table).map_err(|_| CastError::Unknown)? { + None => Err(CastError::UnknownExprPtrKind), + Some(PointerKind::Error) => Ok(()), + Some(PointerKind::Thin) => Ok(()), + _ => Err(CastError::NeedViaThinPtr), + } + } + + fn check_addr_ptr_cast( + &self, + table: &mut InferenceTable<'_>, + cast_ty: &Ty, + ) -> Result<(), CastError> { + match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? { + None => Err(CastError::UnknownCastPtrKind), + Some(PointerKind::Error) => Ok(()), + Some(PointerKind::Thin) => Ok(()), + Some(PointerKind::VTable(_)) => Err(CastError::IntToFatCast), + Some(PointerKind::Length) => Err(CastError::IntToFatCast), + Some(PointerKind::OfAlias | PointerKind::OfParam(_)) => Err(CastError::IntToFatCast), + } + } + + fn check_fptr_ptr_cast( + &self, + table: &mut InferenceTable<'_>, + cast_ty: &Ty, + ) -> Result<(), CastError> { + match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? { + None => Err(CastError::UnknownCastPtrKind), + Some(PointerKind::Error) => Ok(()), + Some(PointerKind::Thin) => Ok(()), + _ => Err(CastError::IllegalCast), + } } } -fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool { - let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { - return false; - }; - let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { - return false; - }; - let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { - return false; +#[derive(PartialEq, Eq)] +enum PointerKind { + // thin pointer + Thin, + // trait object + VTable(Binders<QuantifiedWhereClauses>), + // slice + Length, + OfAlias, + OfParam(PlaceholderIndex), + Error, +} + +fn pointer_kind(ty: &Ty, table: &mut InferenceTable<'_>) -> Result<Option<PointerKind>, ()> { + let ty = table.resolve_ty_shallow(ty); + + if table.is_sized(&ty) { + return Ok(Some(PointerKind::Thin)); + } + + match ty.kind(Interner) { + TyKind::Slice(_) | TyKind::Str => Ok(Some(PointerKind::Length)), + TyKind::Dyn(DynTy { bounds, .. }) => Ok(Some(PointerKind::VTable(bounds.clone()))), + TyKind::Adt(chalk_ir::AdtId(id), subst) => { + let AdtId::StructId(id) = *id else { + never!("`{:?}` should be sized but is not?", ty); + return Err(()); + }; + + let struct_data = table.db.struct_data(id); + if let Some((last_field, _)) = struct_data.variant_data.fields().iter().last() { + let last_field_ty = + table.db.field_types(id.into())[last_field].clone().substitute(Interner, subst); + pointer_kind(&last_field_ty, table) + } else { + Ok(Some(PointerKind::Thin)) + } + } + TyKind::Tuple(_, subst) => { + match subst.iter(Interner).last().and_then(|arg| arg.ty(Interner)) { + None => Ok(Some(PointerKind::Thin)), + Some(ty) => pointer_kind(ty, table), + } + } + TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)), + TyKind::Alias(_) | TyKind::AssociatedType(..) | TyKind::OpaqueType(..) => { + Ok(Some(PointerKind::OfAlias)) + } + TyKind::Error => Ok(Some(PointerKind::Error)), + TyKind::Placeholder(idx) => Ok(Some(PointerKind::OfParam(*idx))), + TyKind::BoundVar(_) | TyKind::InferenceVar(..) => Ok(None), + TyKind::Scalar(_) + | TyKind::Array(..) + | TyKind::CoroutineWitness(..) + | TyKind::Raw(..) + | TyKind::Ref(..) + | TyKind::FnDef(..) + | TyKind::Function(_) + | TyKind::Closure(..) + | TyKind::Coroutine(..) + | TyKind::Never => { + never!("`{:?}` should be sized but is not?", ty); + Err(()) + } + } +} + +fn contains_dyn_trait(ty: &Ty) -> bool { + use std::ops::ControlFlow; + + use chalk_ir::{ + visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, + DebruijnIndex, }; - table.coerce(expr_elt_ty, cast_inner_ty).is_ok() + + struct DynTraitVisitor; + + impl TypeVisitor<Interner> for DynTraitVisitor { + type BreakTy = (); + + fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> { + self + } + + fn interner(&self) -> Interner { + Interner + } + + fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> { + match ty.kind(Interner) { + TyKind::Dyn(_) => ControlFlow::Break(()), + _ => ty.super_visit_with(self.as_dyn(), outer_binder), + } + } + } + + ty.visit_with(DynTraitVisitor.as_dyn(), DebruijnIndex::INNERMOST).is_break() } |