//! Interface with `rustc_pattern_analysis`.
use std::{cell::LazyCell, fmt};
use hir_def::{
EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId, attrs::AttrFlags,
};
use intern::sym;
use rustc_pattern_analysis::{
IndexVec, PatCx, PrivateUninhabitedField,
constructor::{Constructor, ConstructorSet, VariantVisibility},
usefulness::{PlaceValidity, UsefulnessReport, compute_match_usefulness},
};
use rustc_type_ir::inherent::{AdtDef, IntoKind};
use smallvec::{SmallVec, smallvec};
use stdx::never;
use crate::{
db::HirDatabase,
inhabitedness::{is_enum_variant_uninhabited_from, is_ty_uninhabited_from},
next_solver::{
ParamEnv, Ty, TyKind,
infer::{InferCtxt, traits::ObligationCause},
},
};
use super::{FieldPat, Pat, PatKind};
use Constructor::*;
// Re-export r-a-specific versions of all these types.
pub(crate) type DeconstructedPat<'a, 'db> =
rustc_pattern_analysis::pat::DeconstructedPat<MatchCheckCtx<'a, 'db>>;
pub(crate) type MatchArm<'a, 'b, 'db> =
rustc_pattern_analysis::MatchArm<'b, MatchCheckCtx<'a, 'db>>;
pub(crate) type WitnessPat<'a, 'db> =
rustc_pattern_analysis::pat::WitnessPat<MatchCheckCtx<'a, 'db>>;
/// [Constructor] uses this in unimplemented variants.
/// It allows porting match expressions from upstream algorithm without losing semantics.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum Void {}
/// An index type for enum variants. This ranges from 0 to `variants.len()`, whereas `EnumVariantId`
/// can take arbitrary large values (and hence mustn't be used with `IndexVec`/`BitSet`).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct EnumVariantContiguousIndex(usize);
impl EnumVariantContiguousIndex {
fn from_enum_variant_id(db: &dyn HirDatabase, target_evid: EnumVariantId) -> Self {
// Find the index of this variant in the list of variants.
use hir_def::Lookup;
let i = target_evid.lookup(db).index as usize;
EnumVariantContiguousIndex(i)
}
fn to_enum_variant_id(self, db: &dyn HirDatabase, eid: EnumId) -> EnumVariantId {
eid.enum_variants(db).variants[self.0].0
}
}
impl rustc_pattern_analysis::Idx for EnumVariantContiguousIndex {
fn new(idx: usize) -> Self {
EnumVariantContiguousIndex(idx)
}
fn index(self) -> usize {
self.0
}
}
#[derive(Clone)]
pub(crate) struct MatchCheckCtx<'a, 'db> {
module: ModuleId,
pub(crate) db: &'db dyn HirDatabase,
exhaustive_patterns: bool,
env: ParamEnv<'db>,
infcx: &'a InferCtxt<'db>,
}
impl<'a, 'db> MatchCheckCtx<'a, 'db> {
pub(crate) fn new(module: ModuleId, infcx: &'a InferCtxt<'db>, env: ParamEnv<'db>) -> Self {
let db = infcx.interner.db;
let def_map = module.crate_def_map(db);
let exhaustive_patterns = def_map.is_unstable_feature_enabled(&sym::exhaustive_patterns);
Self { module, db, exhaustive_patterns, env, infcx }
}
pub(crate) fn compute_match_usefulness<'b>(
&self,
arms: &[MatchArm<'a, 'b, 'db>],
scrut_ty: Ty<'db>,
known_valid_scrutinee: Option<bool>,
) -> Result<UsefulnessReport<'b, Self>, ()> {
if scrut_ty.references_non_lt_error() {
return Err(());
}
for arm in arms {
if arm.pat.ty().references_non_lt_error() {
return Err(());
}
}
let place_validity = PlaceValidity::from_bool(known_valid_scrutinee.unwrap_or(true));
// Measured to take ~100ms on modern hardware.
let complexity_limit = 500000;
compute_match_usefulness(self, arms, scrut_ty, place_validity, complexity_limit)
}
fn is_uninhabited(&self, ty: Ty<'db>) -> bool {
is_ty_uninhabited_from(self.infcx, ty, self.module, self.env)
}
/// Returns whether the given ADT is from another crate declared `#[non_exhaustive]`.
fn is_foreign_non_exhaustive(&self, adt: hir_def::AdtId) -> bool {
let is_local = adt.krate(self.db) == self.module.krate(self.db);
!is_local && AttrFlags::query(self.db, adt.into()).contains(AttrFlags::NON_EXHAUSTIVE)
}
fn variant_id_for_adt(
db: &'db dyn HirDatabase,
ctor: &Constructor<Self>,
adt: hir_def::AdtId,
) -> Option<VariantId> {
match ctor {
Variant(id) => {
let hir_def::AdtId::EnumId(eid) = adt else {
panic!("bad constructor {ctor:?} for adt {adt:?}")
};
Some(id.to_enum_variant_id(db, eid).into())
}
Struct | UnionField => match adt {
hir_def::AdtId::EnumId(_) => None,
hir_def::AdtId::StructId(id) => Some(id.into()),
hir_def::AdtId::UnionId(id) => Some(id.into()),
},
_ => panic!("bad constructor {ctor:?} for adt {adt:?}"),
}
}
// This lists the fields of a variant along with their types.
fn list_variant_fields(
&self,
ty: Ty<'db>,
variant: VariantId,
) -> impl Iterator<Item = (LocalFieldId, Ty<'db>)> {
let (_, substs) = ty.as_adt().unwrap();
let field_tys = self.db.field_types(variant);
let fields_len = variant.fields(self.db).fields().len() as u32;
(0..fields_len).map(|idx| LocalFieldId::from_raw(idx.into())).map(move |fid| {
let ty = field_tys[fid].get().instantiate(self.infcx.interner, substs);
let ty = self
.infcx
.at(&ObligationCause::dummy(), self.env)
.deeply_normalize(ty)
.unwrap_or(ty);
(fid, ty)
})
}
pub(crate) fn lower_pat(&self, pat: &Pat<'db>) -> DeconstructedPat<'a, 'db> {
let singleton = |pat: DeconstructedPat<'a, 'db>| vec![pat.at_index(0)];
let ctor;
let mut fields: Vec<_>;
let arity;
match pat.kind.as_ref() {
PatKind::Binding { subpattern: Some(subpat), .. } => return self.lower_pat(subpat),
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
ctor = Wildcard;
fields = Vec::new();
arity = 0;
}
PatKind::Deref { subpattern } => {
ctor = match pat.ty.kind() {
TyKind::Ref(..) => Ref,
_ => {
never!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, &pat.ty);
Wildcard
}
};
fields = singleton(self.lower_pat(subpattern));
arity = 1;
}
PatKind::Leaf { subpatterns } | PatKind::Variant { subpatterns, .. } => {
fields = subpatterns
.iter()
.map(|pat| {
let idx: u32 = pat.field.into_raw().into();
self.lower_pat(&pat.pattern).at_index(idx as usize)
})
.collect();
match pat.ty.kind() {
TyKind::Tuple(substs) => {
ctor = Struct;
arity = substs.len();
}
TyKind::Adt(adt_def, _) => {
let adt = adt_def.def_id().0;
ctor = match pat.kind.as_ref() {
PatKind::Leaf { .. } if matches!(adt, hir_def::AdtId::UnionId(_)) => {
UnionField
}
PatKind::Leaf { .. } => Struct,
PatKind::Variant { enum_variant, .. } => {
Variant(EnumVariantContiguousIndex::from_enum_variant_id(
self.db,
*enum_variant,
))
}
_ => {
never!();
Wildcard
}
};
let variant = Self::variant_id_for_adt(self.db, &ctor, adt).unwrap();
arity = variant.fields(self.db).fields().len();
}
_ => {
never!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, &pat.ty);
ctor = Wildcard;
fields.clear();
arity = 0;
}
}
}
&PatKind::LiteralBool { value } => {
ctor = Bool(value);
fields = Vec::new();
arity = 0;
}
PatKind::Never => {
ctor = Never;
fields = Vec::new();
arity = 0;
}
PatKind::Or { pats } => {
ctor = Or;
fields = pats
.iter()
.enumerate()
.map(|(i, pat)| self.lower_pat(pat).at_index(i))
.collect();
arity = pats.len();
}
}
DeconstructedPat::new(ctor, fields, arity, pat.ty, ())
}
pub(crate) fn hoist_witness_pat(&self, pat: &WitnessPat<'a, 'db>) -> Pat<'db> {
let mut subpatterns = pat.iter_fields().map(|p| self.hoist_witness_pat(p));
let kind = match pat.ctor() {
&Bool(value) => PatKind::LiteralBool { value },
IntRange(_) => unimplemented!(),
Struct | Variant(_) | UnionField => match pat.ty().kind() {
TyKind::Tuple(..) => PatKind::Leaf {
subpatterns: subpatterns
.zip(0u32..)
.map(|(p, i)| FieldPat {
field: LocalFieldId::from_raw(i.into()),
pattern: p,
})
.collect(),
},
TyKind::Adt(adt, substs) => {
let variant =
Self::variant_id_for_adt(self.db, pat.ctor(), adt.def_id().0).unwrap();
let subpatterns = self
.list_variant_fields(*pat.ty(), variant)
.zip(subpatterns)
.map(|((field, _ty), pattern)| FieldPat { field, pattern })
.collect();
if let VariantId::EnumVariantId(enum_variant) = variant {
PatKind::Variant { substs, enum_variant, subpatterns }
} else {
PatKind::Leaf { subpatterns }
}
}
_ => {
never!("unexpected ctor for type {:?} {:?}", pat.ctor(), pat.ty());
PatKind::Wild
}
},
// Note: given the expansion of `&str` patterns done in `expand_pattern`, we should
// be careful to reconstruct the correct constant pattern here. However a string
// literal pattern will never be reported as a non-exhaustiveness witness, so we
// ignore this issue.
Ref => PatKind::Deref { subpattern: subpatterns.next().unwrap() },
Slice(_) => unimplemented!(),
DerefPattern(_) => unimplemented!(),
&Str(void) => match void {},
Wildcard | NonExhaustive | Hidden | PrivateUninhabited => PatKind::Wild,
Never => PatKind::Never,
Missing | F16Range(..) | F32Range(..) | F64Range(..) | F128Range(..) | Opaque(..)
| Or => {
never!("can't convert to pattern: {:?}", pat.ctor());
PatKind::Wild
}
};
Pat { ty: *pat.ty(), kind: Box::new(kind) }
}
}
impl<'a, 'db> PatCx for MatchCheckCtx<'a, 'db> {
type Error = ();
type Ty = Ty<'db>;
type VariantIdx = EnumVariantContiguousIndex;
type StrLit = Void;
type ArmData = ();
type PatData = ();
fn is_exhaustive_patterns_feature_on(&self) -> bool {
self.exhaustive_patterns
}
fn ctor_arity(
&self,
ctor: &rustc_pattern_analysis::constructor::Constructor<Self>,
ty: &Self::Ty,
) -> usize {
match ctor {
Struct | Variant(_) | UnionField => match ty.kind() {
TyKind::Tuple(tys) => tys.len(),
TyKind::Adt(adt_def, ..) => {
let variant =
Self::variant_id_for_adt(self.db, ctor, adt_def.def_id().0).unwrap();
variant.fields(self.db).fields().len()
}
_ => {
never!("Unexpected type for `Single` constructor: {:?}", ty);
0
}
},
Ref => 1,
Slice(..) => unimplemented!(),
DerefPattern(..) => unimplemented!(),
Never | Bool(..) | IntRange(..) | F16Range(..) | F32Range(..) | F64Range(..)
| F128Range(..) | Str(..) | Opaque(..) | NonExhaustive | PrivateUninhabited
| Hidden | Missing | Wildcard => 0,
Or => {
never!("The `Or` constructor doesn't have a fixed arity");
0
}
}
}
fn ctor_sub_tys(
&self,
ctor: &rustc_pattern_analysis::constructor::Constructor<Self>,
ty: &Self::Ty,
) -> impl ExactSizeIterator<Item = (Self::Ty, PrivateUninhabitedField)> {
let single = |ty| smallvec![(ty, PrivateUninhabitedField(false))];
let tys: SmallVec<[_; 2]> = match ctor {
Struct | Variant(_) | UnionField => match ty.kind() {
TyKind::Tuple(substs) => {
substs.iter().map(|ty| (ty, PrivateUninhabitedField(false))).collect()
}
TyKind::Ref(_, rty, _) => single(rty),
TyKind::Adt(adt_def, ..) => {
let adt = adt_def.def_id().0;
let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
let visibilities = LazyCell::new(|| self.db.field_visibilities(variant));
self.list_variant_fields(*ty, variant)
.map(move |(fid, ty)| {
let is_visible = || {
matches!(adt, hir_def::AdtId::EnumId(..))
|| visibilities[fid].is_visible_from(self.db, self.module)
};
let is_uninhabited = self.is_uninhabited(ty);
let private_uninhabited = is_uninhabited && !is_visible();
(ty, PrivateUninhabitedField(private_uninhabited))
})
.collect()
}
ty_kind => {
never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind);
single(*ty)
}
},
Ref => match ty.kind() {
TyKind::Ref(_, rty, _) => single(rty),
ty_kind => {
never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind);
single(*ty)
}
},
Slice(_) => unreachable!("Found a `Slice` constructor in match checking"),
DerefPattern(_) => unreachable!("Found a `DerefPattern` constructor in match checking"),
Never | Bool(..) | IntRange(..) | F16Range(..) | F32Range(..) | F64Range(..)
| F128Range(..) | Str(..) | Opaque(..) | NonExhaustive | PrivateUninhabited
| Hidden | Missing | Wildcard => {
smallvec![]
}
Or => {
never!("called `Fields::wildcards` on an `Or` ctor");
smallvec![]
}
};
tys.into_iter()
}
fn ctors_for_ty(
&self,
ty: &Self::Ty,
) -> Result<rustc_pattern_analysis::constructor::ConstructorSet<Self>, Self::Error> {
let cx = self;
// Unhandled types are treated as non-exhaustive. Being explicit here instead of falling
// to catchall arm to ease further implementation.
let unhandled = || ConstructorSet::Unlistable;
// This determines the set of all possible constructors for the type `ty`. For numbers,
// arrays and slices we use ranges and variable-length slices when appropriate.
//
// If the `exhaustive_patterns` feature is enabled, we make sure to omit constructors that
// are statically impossible. E.g., for `Option<!>`, we do not include `Some(_)` in the
// returned list of constructors.
// Invariant: this is empty if and only if the type is uninhabited (as determined by
// `cx.is_uninhabited()`).
Ok(match ty.kind() {
TyKind::Bool => ConstructorSet::Bool,
TyKind::Char => unhandled(),
TyKind::Int(..) | TyKind::Uint(..) => unhandled(),
TyKind::Array(..) | TyKind::Slice(..) => unhandled(),
TyKind::Adt(adt_def, subst) => {
let adt = adt_def.def_id().0;
match adt {
hir_def::AdtId::EnumId(enum_id) => {
let enum_data = enum_id.enum_variants(cx.db);
let is_declared_nonexhaustive = cx.is_foreign_non_exhaustive(adt);
if enum_data.variants.is_empty() && !is_declared_nonexhaustive {
ConstructorSet::NoConstructors
} else {
let mut variants = IndexVec::with_capacity(enum_data.variants.len());
for &(variant, _, _) in enum_data.variants.iter() {
let is_uninhabited = is_enum_variant_uninhabited_from(
cx.infcx, variant, subst, cx.module, self.env,
);
let visibility = if is_uninhabited {
VariantVisibility::Empty
} else {
VariantVisibility::Visible
};
variants.push(visibility);
}
ConstructorSet::Variants {
variants,
non_exhaustive: is_declared_nonexhaustive,
}
}
}
hir_def::AdtId::UnionId(_) => ConstructorSet::Union,
hir_def::AdtId::StructId(_) => {
ConstructorSet::Struct { empty: cx.is_uninhabited(*ty) }
}
}
}
TyKind::Tuple(..) => ConstructorSet::Struct { empty: cx.is_uninhabited(*ty) },
TyKind::Ref(..) => ConstructorSet::Ref,
TyKind::Never => ConstructorSet::NoConstructors,
// This type is one for which we cannot list constructors, like `str` or `f64`.
_ => ConstructorSet::Unlistable,
})
}
fn write_variant_name(
f: &mut fmt::Formatter<'_>,
_ctor: &Constructor<Self>,
_ty: &Self::Ty,
) -> fmt::Result {
write!(f, "<write_variant_name unsupported>")
// We lack the database here ...
// let variant = ty.as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(db, ctor, adt));
// if let Some(variant) = variant {
// match variant {
// VariantId::EnumVariantId(v) => {
// write!(f, "{}", db.enum_variant_data(v).name.display(db))?;
// }
// VariantId::StructId(s) => {
// write!(f, "{}", db.struct_data(s).name.display(db))?
// }
// VariantId::UnionId(u) => {
// write!(f, "{}", db.union_data(u).name.display(db))?
// }
// }
// }
// Ok(())
}
fn bug(&self, fmt: fmt::Arguments<'_>) {
never!("{}", fmt)
}
fn complexity_exceeded(&self) -> Result<(), Self::Error> {
Err(())
}
fn report_mixed_deref_pat_ctors(
&self,
_deref_pat: &DeconstructedPat<'a, 'db>,
_normal_pat: &DeconstructedPat<'a, 'db>,
) {
// FIXME(deref_patterns): This could report an error comparable to the one in rustc.
}
}
impl fmt::Debug for MatchCheckCtx<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MatchCheckCtx").finish()
}
}