Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/variance.rs')
-rw-r--r--crates/hir-ty/src/variance.rs1172
1 files changed, 1172 insertions, 0 deletions
diff --git a/crates/hir-ty/src/variance.rs b/crates/hir-ty/src/variance.rs
new file mode 100644
index 0000000000..2339e37fa1
--- /dev/null
+++ b/crates/hir-ty/src/variance.rs
@@ -0,0 +1,1172 @@
+//! Module for inferring the variance of type and lifetime parameters. See the [rustc dev guide]
+//! chapter for more info.
+//!
+//! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/variance.html
+
+use crate::db::HirDatabase;
+use crate::generics::{generics, Generics};
+use crate::{
+ AliasTy, Const, ConstScalar, DynTyExt, FnPointer, GenericArg, GenericArgData, Interner,
+ Lifetime, LifetimeData, Ty, TyKind,
+};
+use base_db::ra_salsa::Cycle;
+use chalk_ir::Mutability;
+use hir_def::data::adt::StructFlags;
+use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId};
+use std::fmt;
+use std::ops::Not;
+use triomphe::Arc;
+
+pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> {
+ tracing::debug!("variances_of(def={:?})", def);
+ match def {
+ GenericDefId::FunctionId(_) => (),
+ GenericDefId::AdtId(adt) => {
+ if let AdtId::StructId(id) = adt {
+ let flags = &db.struct_data(id).flags;
+ if flags.contains(StructFlags::IS_UNSAFE_CELL) {
+ return Some(Arc::from_iter(vec![Variance::Invariant; 1]));
+ } else if flags.contains(StructFlags::IS_PHANTOM_DATA) {
+ return Some(Arc::from_iter(vec![Variance::Covariant; 1]));
+ }
+ }
+ }
+ _ => return None,
+ }
+
+ let generics = generics(db.upcast(), def);
+ let count = generics.len();
+ if count == 0 {
+ return None;
+ }
+ let mut ctxt = Context {
+ def,
+ has_trait_self: generics.parent_generics().map_or(false, |it| it.has_trait_self()),
+ len_self: generics.len_self(),
+ len_self_lifetimes: generics.len_self_lifetimes(),
+ generics,
+ constraints: Vec::new(),
+ db,
+ };
+
+ ctxt.build_constraints_for_item();
+ let res = ctxt.solve();
+ res.is_empty().not().then(|| Arc::from_iter(res))
+}
+
+pub(crate) fn variances_of_cycle(
+ db: &dyn HirDatabase,
+ _cycle: &Cycle,
+ def: &GenericDefId,
+) -> Option<Arc<[Variance]>> {
+ let generics = generics(db.upcast(), *def);
+ let count = generics.len();
+
+ if count == 0 {
+ return None;
+ }
+ Some(Arc::from(vec![Variance::Bivariant; count]))
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+pub enum Variance {
+ Covariant, // T<A> <: T<B> iff A <: B -- e.g., function return type
+ Invariant, // T<A> <: T<B> iff B == A -- e.g., type of mutable cell
+ Contravariant, // T<A> <: T<B> iff B <: A -- e.g., function param type
+ Bivariant, // T<A> <: T<B> -- e.g., unused type parameter
+}
+
+impl fmt::Display for Variance {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Variance::Covariant => write!(f, "covariant"),
+ Variance::Invariant => write!(f, "invariant"),
+ Variance::Contravariant => write!(f, "contravariant"),
+ Variance::Bivariant => write!(f, "bivariant"),
+ }
+ }
+}
+
+impl Variance {
+ /// `a.xform(b)` combines the variance of a context with the
+ /// variance of a type with the following meaning. If we are in a
+ /// context with variance `a`, and we encounter a type argument in
+ /// a position with variance `b`, then `a.xform(b)` is the new
+ /// variance with which the argument appears.
+ ///
+ /// Example 1:
+ /// ```ignore (illustrative)
+ /// *mut Vec<i32>
+ /// ```
+ /// Here, the "ambient" variance starts as covariant. `*mut T` is
+ /// invariant with respect to `T`, so the variance in which the
+ /// `Vec<i32>` appears is `Covariant.xform(Invariant)`, which
+ /// yields `Invariant`. Now, the type `Vec<T>` is covariant with
+ /// respect to its type argument `T`, and hence the variance of
+ /// the `i32` here is `Invariant.xform(Covariant)`, which results
+ /// (again) in `Invariant`.
+ ///
+ /// Example 2:
+ /// ```ignore (illustrative)
+ /// fn(*const Vec<i32>, *mut Vec<i32)
+ /// ```
+ /// The ambient variance is covariant. A `fn` type is
+ /// contravariant with respect to its parameters, so the variance
+ /// within which both pointer types appear is
+ /// `Covariant.xform(Contravariant)`, or `Contravariant`. `*const
+ /// T` is covariant with respect to `T`, so the variance within
+ /// which the first `Vec<i32>` appears is
+ /// `Contravariant.xform(Covariant)` or `Contravariant`. The same
+ /// is true for its `i32` argument. In the `*mut T` case, the
+ /// variance of `Vec<i32>` is `Contravariant.xform(Invariant)`,
+ /// and hence the outermost type is `Invariant` with respect to
+ /// `Vec<i32>` (and its `i32` argument).
+ ///
+ /// Source: Figure 1 of "Taming the Wildcards:
+ /// Combining Definition- and Use-Site Variance" published in PLDI'11.
+ fn xform(self, v: Variance) -> Variance {
+ match (self, v) {
+ // Figure 1, column 1.
+ (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
+ (Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
+ (Variance::Covariant, Variance::Invariant) => Variance::Invariant,
+ (Variance::Covariant, Variance::Bivariant) => Variance::Bivariant,
+
+ // Figure 1, column 2.
+ (Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
+ (Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
+ (Variance::Contravariant, Variance::Invariant) => Variance::Invariant,
+ (Variance::Contravariant, Variance::Bivariant) => Variance::Bivariant,
+
+ // Figure 1, column 3.
+ (Variance::Invariant, _) => Variance::Invariant,
+
+ // Figure 1, column 4.
+ (Variance::Bivariant, _) => Variance::Bivariant,
+ }
+ }
+
+ fn glb(self, v: Variance) -> Variance {
+ // Greatest lower bound of the variance lattice as
+ // defined in The Paper:
+ //
+ // *
+ // - +
+ // o
+ match (self, v) {
+ (Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
+
+ (Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
+ (Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
+
+ (Variance::Covariant, Variance::Covariant) => Variance::Covariant,
+
+ (Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
+
+ (x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
+ }
+ }
+}
+#[derive(Copy, Clone, Debug)]
+struct InferredIndex(usize);
+
+#[derive(Clone)]
+enum VarianceTerm {
+ ConstantTerm(Variance),
+ TransformTerm(Box<VarianceTerm>, Box<VarianceTerm>),
+}
+
+impl fmt::Debug for VarianceTerm {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ VarianceTerm::ConstantTerm(c1) => write!(f, "{c1:?}"),
+ VarianceTerm::TransformTerm(v1, v2) => write!(f, "({v1:?} \u{00D7} {v2:?})"),
+ }
+ }
+}
+
+struct Context<'db> {
+ db: &'db dyn HirDatabase,
+ def: GenericDefId,
+ has_trait_self: bool,
+ len_self: usize,
+ len_self_lifetimes: usize,
+ generics: Generics,
+ constraints: Vec<Constraint>,
+}
+
+/// Declares that the variable `decl_id` appears in a location with
+/// variance `variance`.
+#[derive(Clone)]
+struct Constraint {
+ inferred: InferredIndex,
+ variance: VarianceTerm,
+}
+
+impl Context<'_> {
+ fn build_constraints_for_item(&mut self) {
+ match self.def {
+ GenericDefId::AdtId(adt) => {
+ let db = self.db;
+ let mut add_constraints_from_variant = |variant| {
+ let subst = self.generics.placeholder_subst(db);
+ for (_, field) in db.field_types(variant).iter() {
+ self.add_constraints_from_ty(
+ &field.clone().substitute(Interner, &subst),
+ &VarianceTerm::ConstantTerm(Variance::Covariant),
+ );
+ }
+ };
+ match adt {
+ AdtId::StructId(s) => add_constraints_from_variant(VariantId::StructId(s)),
+ AdtId::UnionId(u) => add_constraints_from_variant(VariantId::UnionId(u)),
+ AdtId::EnumId(e) => {
+ db.enum_data(e).variants.iter().for_each(|&(variant, _)| {
+ add_constraints_from_variant(VariantId::EnumVariantId(variant))
+ });
+ }
+ }
+ }
+ GenericDefId::FunctionId(f) => {
+ let subst = self.generics.placeholder_subst(self.db);
+ self.add_constraints_from_sig2(
+ &self
+ .db
+ .callable_item_signature(f.into())
+ .substitute(Interner, &subst)
+ .params_and_return,
+ &VarianceTerm::ConstantTerm(Variance::Covariant),
+ );
+ }
+ _ => {}
+ }
+ }
+
+ fn contravariant(&mut self, variance: &VarianceTerm) -> VarianceTerm {
+ self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Contravariant))
+ }
+
+ fn invariant(&mut self, variance: &VarianceTerm) -> VarianceTerm {
+ self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Invariant))
+ }
+
+ fn xform(&mut self, v1: &VarianceTerm, v2: &VarianceTerm) -> VarianceTerm {
+ match (v1, v2) {
+ // Applying a "covariant" transform is always a no-op
+ (_, VarianceTerm::ConstantTerm(Variance::Covariant)) => v1.clone(),
+ (VarianceTerm::ConstantTerm(c1), VarianceTerm::ConstantTerm(c2)) => {
+ VarianceTerm::ConstantTerm(c1.xform(*c2))
+ }
+ _ => VarianceTerm::TransformTerm(Box::new(v1.clone()), Box::new(v2.clone())),
+ }
+ }
+
+ fn add_constraints_from_invariant_args(
+ &mut self,
+ args: &[GenericArg],
+ variance: &VarianceTerm,
+ ) {
+ tracing::debug!(
+ "add_constraints_from_invariant_args(args={:?}, variance={:?})",
+ args,
+ variance
+ );
+ let variance_i = self.invariant(variance);
+
+ for k in args {
+ match k.data(Interner) {
+ GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, &variance_i),
+ GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i),
+ GenericArgData::Const(val) => self.add_constraints_from_const(val, &variance_i),
+ }
+ }
+ }
+
+ /// Adds constraints appropriate for an instance of `ty` appearing
+ /// in a context with the generics defined in `generics` and
+ /// ambient variance `variance`
+ fn add_constraints_from_ty(&mut self, ty: &Ty, variance: &VarianceTerm) {
+ tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance);
+ match ty.kind(Interner) {
+ TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => {
+ // leaf type -- noop
+ }
+
+ TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => {
+ panic!("Unexpected unnameable type in variance computation: {ty:?}");
+ }
+
+ TyKind::Ref(mutbl, lifetime, ty) => {
+ self.add_constraints_from_region(lifetime, variance);
+ self.add_constraints_from_mt(ty, *mutbl, variance);
+ }
+
+ TyKind::Array(typ, len) => {
+ self.add_constraints_from_const(len, variance);
+ self.add_constraints_from_ty(typ, variance);
+ }
+
+ TyKind::Slice(typ) => {
+ self.add_constraints_from_ty(typ, variance);
+ }
+
+ TyKind::Raw(mutbl, ty) => {
+ self.add_constraints_from_mt(ty, *mutbl, variance);
+ }
+
+ TyKind::Tuple(_, subtys) => {
+ for subty in subtys.type_parameters(Interner) {
+ self.add_constraints_from_ty(&subty, variance);
+ }
+ }
+
+ TyKind::Adt(def, args) => {
+ self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance);
+ }
+
+ TyKind::Alias(AliasTy::Opaque(opaque)) => {
+ self.add_constraints_from_invariant_args(
+ opaque.substitution.as_slice(Interner),
+ variance,
+ );
+ }
+ TyKind::Alias(AliasTy::Projection(proj)) => {
+ self.add_constraints_from_invariant_args(
+ proj.substitution.as_slice(Interner),
+ variance,
+ );
+ }
+ // FIXME: check this
+ TyKind::AssociatedType(_, subst) => {
+ self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
+ }
+ // FIXME: check this
+ TyKind::OpaqueType(_, subst) => {
+ self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
+ }
+
+ TyKind::Dyn(it) => {
+ // The type `dyn Trait<T> +'a` is covariant w/r/t `'a`:
+ self.add_constraints_from_region(&it.lifetime, variance);
+
+ if let Some(trait_ref) = it.principal() {
+ // Trait are always invariant so we can take advantage of that.
+ self.add_constraints_from_invariant_args(
+ trait_ref
+ .map(|it| it.map(|it| it.substitution.clone()))
+ .substitute(
+ Interner,
+ &[GenericArg::new(
+ Interner,
+ chalk_ir::GenericArgData::Ty(TyKind::Error.intern(Interner)),
+ )],
+ )
+ .skip_binders()
+ .as_slice(Interner),
+ variance,
+ );
+ }
+
+ // FIXME
+ // for projection in data.projection_bounds() {
+ // match projection.skip_binder().term.unpack() {
+ // TyKind::TermKind::Ty(ty) => {
+ // self.add_constraints_from_ty( ty, self.invariant);
+ // }
+ // TyKind::TermKind::Const(c) => {
+ // self.add_constraints_from_const( c, self.invariant)
+ // }
+ // }
+ // }
+ }
+
+ // Chalk has no params, so use placeholders for now?
+ TyKind::Placeholder(index) => {
+ let idx = crate::from_placeholder_idx(self.db, *index);
+ let index = idx.local_id.into_raw().into_u32() as usize + self.len_self_lifetimes;
+ let inferred = if idx.parent == self.def {
+ InferredIndex(self.has_trait_self as usize + index)
+ } else {
+ InferredIndex(self.len_self + index)
+ };
+ tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
+ self.constraints.push(Constraint { inferred, variance: variance.clone() });
+ }
+ TyKind::Function(f) => {
+ self.add_constraints_from_sig(f, variance);
+ }
+
+ TyKind::Error => {
+ // we encounter this when walking the trait references for object
+ // types, where we use Error as the Self type
+ }
+
+ TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => {
+ panic!("unexpected type encountered in variance inference: {:?}", ty);
+ }
+ }
+ }
+
+ /// Adds constraints appropriate for a nominal type (enum, struct,
+ /// object, etc) appearing in a context with ambient variance `variance`
+ fn add_constraints_from_args(
+ &mut self,
+ def_id: GenericDefId,
+ args: &[GenericArg],
+ variance: &VarianceTerm,
+ ) {
+ tracing::debug!(
+ "add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})",
+ def_id,
+ args,
+ variance
+ );
+
+ // We don't record `inferred_starts` entries for empty generics.
+ if args.is_empty() {
+ return;
+ }
+ if def_id == self.def {
+ // HACK: Workaround for the trivial cycle salsa case (see
+ // recursive_one_bivariant_more_non_bivariant_params test)
+ let variance_i = self.xform(variance, &VarianceTerm::ConstantTerm(Variance::Bivariant));
+ for k in args {
+ match k.data(Interner) {
+ GenericArgData::Lifetime(lt) => {
+ self.add_constraints_from_region(lt, &variance_i)
+ }
+ GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i),
+ GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
+ }
+ }
+ } else {
+ let Some(variances) = self.db.variances_of(def_id) else {
+ return;
+ };
+
+ for (i, k) in args.iter().enumerate() {
+ let variance_decl = &VarianceTerm::ConstantTerm(variances[i]);
+ let variance_i = self.xform(variance, variance_decl);
+ match k.data(Interner) {
+ GenericArgData::Lifetime(lt) => {
+ self.add_constraints_from_region(lt, &variance_i)
+ }
+ GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, &variance_i),
+ GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
+ }
+ }
+ }
+ }
+
+ /// Adds constraints appropriate for a const expression `val`
+ /// in a context with ambient variance `variance`
+ fn add_constraints_from_const(&mut self, c: &Const, variance: &VarianceTerm) {
+ match &c.data(Interner).value {
+ chalk_ir::ConstValue::Concrete(c) => {
+ if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned {
+ self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ /// Adds constraints appropriate for a function with signature
+ /// `sig` appearing in a context with ambient variance `variance`
+ fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: &VarianceTerm) {
+ let contra = self.contravariant(variance);
+ let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner));
+ self.add_constraints_from_ty(tys.next_back().unwrap(), variance);
+ for input in tys {
+ self.add_constraints_from_ty(input, &contra);
+ }
+ }
+
+ fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: &VarianceTerm) {
+ let contra = self.contravariant(variance);
+ let mut tys = sig.iter();
+ self.add_constraints_from_ty(tys.next_back().unwrap(), variance);
+ for input in tys {
+ self.add_constraints_from_ty(input, &contra);
+ }
+ }
+
+ /// Adds constraints appropriate for a region appearing in a
+ /// context with ambient variance `variance`
+ fn add_constraints_from_region(&mut self, region: &Lifetime, variance: &VarianceTerm) {
+ match region.data(Interner) {
+ // FIXME: chalk has no params?
+ LifetimeData::Placeholder(index) => {
+ let idx = crate::lt_from_placeholder_idx(self.db, *index);
+ let index = idx.local_id.into_raw().into_u32() as usize;
+ let inferred = if idx.parent == self.def {
+ InferredIndex(index)
+ } else {
+ InferredIndex(self.has_trait_self as usize + self.len_self + index)
+ };
+ tracing::debug!("add_constraint(index={:?}, variance={:?})", inferred, variance);
+ self.constraints.push(Constraint { inferred, variance: variance.clone() });
+ }
+ LifetimeData::Static => {}
+
+ LifetimeData::BoundVar(..) => {
+ // Either a higher-ranked region inside of a type or a
+ // late-bound function parameter.
+ //
+ // We do not compute constraints for either of these.
+ }
+
+ LifetimeData::Error => {}
+
+ LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => {
+ // We don't expect to see anything but 'static or bound
+ // regions when visiting member types or method types.
+ panic!(
+ "unexpected region encountered in variance \
+ inference: {:?}",
+ region
+ );
+ }
+ }
+ }
+
+ /// Adds constraints appropriate for a mutability-type pair
+ /// appearing in a context with ambient variance `variance`
+ fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: &VarianceTerm) {
+ match mt {
+ Mutability::Mut => {
+ let invar = self.invariant(variance);
+ self.add_constraints_from_ty(ty, &invar);
+ }
+
+ Mutability::Not => {
+ self.add_constraints_from_ty(ty, variance);
+ }
+ }
+ }
+}
+
+impl Context<'_> {
+ fn solve(self) -> Vec<Variance> {
+ let mut solutions = vec![Variance::Bivariant; self.generics.len()];
+ // Propagate constraints until a fixed point is reached. Note
+ // that the maximum number of iterations is 2C where C is the
+ // number of constraints (each variable can change values at most
+ // twice). Since number of constraints is linear in size of the
+ // input, so is the inference process.
+ let mut changed = true;
+ while changed {
+ changed = false;
+
+ for constraint in &self.constraints {
+ let Constraint { inferred, variance: term } = constraint;
+ let InferredIndex(inferred) = inferred;
+ let variance = Self::evaluate(term);
+ let old_value = solutions[*inferred];
+ let new_value = variance.glb(old_value);
+ if old_value != new_value {
+ solutions[*inferred] = new_value;
+ changed = true;
+ }
+ }
+ }
+
+ // Const parameters are always invariant.
+ // Make all const parameters invariant.
+ for (idx, param) in self.generics.iter_id().enumerate() {
+ if let GenericParamId::ConstParamId(_) = param {
+ solutions[idx] = Variance::Invariant;
+ }
+ }
+
+ // Functions are permitted to have unused generic parameters: make those invariant.
+ if let GenericDefId::FunctionId(_) = self.def {
+ for variance in &mut solutions {
+ if *variance == Variance::Bivariant {
+ *variance = Variance::Invariant;
+ }
+ }
+ }
+
+ solutions
+ }
+
+ fn evaluate(term: &VarianceTerm) -> Variance {
+ match term {
+ VarianceTerm::ConstantTerm(v) => *v,
+ VarianceTerm::TransformTerm(t1, t2) => {
+ let v1 = Self::evaluate(t1);
+ let v2 = Self::evaluate(t2);
+ v1.xform(v2)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use expect_test::{expect, Expect};
+ use hir_def::{
+ generics::GenericParamDataRef, src::HasSource, AdtId, GenericDefId, ModuleDefId,
+ };
+ use itertools::Itertools;
+ use stdx::format_to;
+ use syntax::{ast::HasName, AstNode};
+ use test_fixture::WithFixture;
+
+ use hir_def::Lookup;
+
+ use crate::{db::HirDatabase, test_db::TestDB, variance::generics};
+
+ #[test]
+ fn phantom_data() {
+ check(
+ r#"
+//- minicore: phantom_data
+
+struct Covariant<A> {
+ t: core::marker::PhantomData<A>
+}
+"#,
+ expect![[r#"
+ Covariant[A: covariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_types() {
+ check(
+ r#"
+//- minicore: cell
+
+use core::cell::UnsafeCell;
+
+struct InvariantMut<'a,A:'a,B:'a> { //~ ERROR ['a: +, A: o, B: o]
+ t: &'a mut (A,B)
+}
+
+struct InvariantCell<A> { //~ ERROR [A: o]
+ t: UnsafeCell<A>
+}
+
+struct InvariantIndirect<A> { //~ ERROR [A: o]
+ t: InvariantCell<A>
+}
+
+struct Covariant<A> { //~ ERROR [A: +]
+ t: A, u: fn() -> A
+}
+
+struct Contravariant<A> { //~ ERROR [A: -]
+ t: fn(A)
+}
+
+enum Enum<A,B,C> { //~ ERROR [A: +, B: -, C: o]
+ Foo(Covariant<A>),
+ Bar(Contravariant<B>),`
+ Zed(Covariant<C>,Contravariant<C>)
+}
+"#,
+ expect![[r#"
+ InvariantMut['a: covariant, A: invariant, B: invariant]
+ InvariantCell[A: invariant]
+ InvariantIndirect[A: invariant]
+ Covariant[A: covariant]
+ Contravariant[A: contravariant]
+ Enum[A: covariant, B: contravariant, C: invariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn type_resolve_error_two_structs_deep() {
+ check(
+ r#"
+struct Hello<'a> {
+ missing: Missing<'a>,
+}
+
+struct Other<'a> {
+ hello: Hello<'a>,
+}
+"#,
+ expect![[r#"
+ Hello['a: bivariant]
+ Other['a: bivariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_associated_consts() {
+ // FIXME: Should be invariant
+ check(
+ r#"
+trait Trait {
+ const Const: usize;
+}
+
+struct Foo<T: Trait> { //~ ERROR [T: o]
+ field: [u8; <T as Trait>::Const]
+}
+"#,
+ expect![[r#"
+ Foo[T: bivariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_associated_types() {
+ check(
+ r#"
+trait Trait<'a> {
+ type Type;
+
+ fn method(&'a self) { }
+}
+
+struct Foo<'a, T : Trait<'a>> { //~ ERROR ['a: +, T: +]
+ field: (T, &'a ())
+}
+
+struct Bar<'a, T : Trait<'a>> { //~ ERROR ['a: o, T: o]
+ field: <T as Trait<'a>>::Type
+}
+
+"#,
+ expect![[r#"
+ method[Self: contravariant, 'a: contravariant]
+ Foo['a: covariant, T: covariant]
+ Bar['a: invariant, T: invariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_associated_types2() {
+ // FIXME: RPITs have variance, but we can't treat them as their own thing right now
+ check(
+ r#"
+trait Foo {
+ type Bar;
+}
+
+fn make() -> *const dyn Foo<Bar = &'static u32> {}
+"#,
+ expect![""],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_trait_bounds() {
+ check(
+ r#"
+trait Getter<T> {
+ fn get(&self) -> T;
+}
+
+trait Setter<T> {
+ fn get(&self, _: T);
+}
+
+struct TestStruct<U,T:Setter<U>> { //~ ERROR [U: +, T: +]
+ t: T, u: U
+}
+
+enum TestEnum<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
+ //~^ ERROR: `U` is never used
+ Foo(T)
+}
+
+struct TestContraStruct<U,T:Setter<U>> { //~ ERROR [U: *, T: +]
+ //~^ ERROR: `U` is never used
+ t: T
+}
+
+struct TestBox<U,T:Getter<U>+Setter<U>> { //~ ERROR [U: *, T: +]
+ //~^ ERROR: `U` is never used
+ t: T
+}
+"#,
+ expect![[r#"
+ get[Self: contravariant, T: covariant]
+ get[Self: contravariant, T: contravariant]
+ TestStruct[U: covariant, T: covariant]
+ TestEnum[U: bivariant, T: covariant]
+ TestContraStruct[U: bivariant, T: covariant]
+ TestBox[U: bivariant, T: covariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_trait_matching() {
+ check(
+ r#"
+
+trait Get<T> {
+ fn get(&self) -> T;
+}
+
+struct Cloner<T:Clone> {
+ t: T
+}
+
+impl<T:Clone> Get<T> for Cloner<T> {
+ fn get(&self) -> T {}
+}
+
+fn get<'a, G>(get: &G) -> i32
+ where G : Get<&'a i32>
+{}
+
+fn pick<'b, G>(get: &'b G, if_odd: &'b i32) -> i32
+ where G : Get<&'b i32>
+{}
+"#,
+ expect![[r#"
+ get[Self: contravariant, T: covariant]
+ Cloner[T: covariant]
+ get[T: invariant]
+ get['a: invariant, G: contravariant]
+ pick['b: contravariant, G: contravariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_trait_object_bound() {
+ check(
+ r#"
+enum Option<T> {
+ Some(T),
+ None
+}
+trait T { fn foo(&self); }
+
+struct TOption<'a> { //~ ERROR ['a: +]
+ v: Option<*const (dyn T + 'a)>,
+}
+"#,
+ expect![[r#"
+ Option[T: covariant]
+ foo[Self: contravariant]
+ TOption['a: covariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_types_bounds() {
+ check(
+ r#"
+//- minicore: send
+struct TestImm<A, B> { //~ ERROR [A: +, B: +]
+ x: A,
+ y: B,
+}
+
+struct TestMut<A, B:'static> { //~ ERROR [A: +, B: o]
+ x: A,
+ y: &'static mut B,
+}
+
+struct TestIndirect<A:'static, B:'static> { //~ ERROR [A: +, B: o]
+ m: TestMut<A, B>
+}
+
+struct TestIndirect2<A:'static, B:'static> { //~ ERROR [A: o, B: o]
+ n: TestMut<A, B>,
+ m: TestMut<B, A>
+}
+
+trait Getter<A> {
+ fn get(&self) -> A;
+}
+
+trait Setter<A> {
+ fn set(&mut self, a: A);
+}
+
+struct TestObject<A, R> { //~ ERROR [A: o, R: o]
+ n: *const (dyn Setter<A> + Send),
+ m: *const (dyn Getter<R> + Send),
+}
+"#,
+ expect![[r#"
+ TestImm[A: covariant, B: covariant]
+ TestMut[A: covariant, B: invariant]
+ TestIndirect[A: covariant, B: invariant]
+ TestIndirect2[A: invariant, B: invariant]
+ get[Self: contravariant, A: covariant]
+ set[Self: invariant, A: contravariant]
+ TestObject[A: invariant, R: invariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_unused_region_param() {
+ check(
+ r#"
+struct SomeStruct<'a> { x: u32 } //~ ERROR parameter `'a` is never used
+enum SomeEnum<'a> { Nothing } //~ ERROR parameter `'a` is never used
+trait SomeTrait<'a> { fn foo(&self); } // OK on traits.
+"#,
+ expect![[r#"
+ SomeStruct['a: bivariant]
+ SomeEnum['a: bivariant]
+ foo[Self: contravariant, 'a: invariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_unused_type_param() {
+ check(
+ r#"
+//- minicore: sized
+struct SomeStruct<A> { x: u32 }
+enum SomeEnum<A> { Nothing }
+enum ListCell<T> {
+ Cons(*const ListCell<T>),
+ Nil
+}
+
+struct SelfTyAlias<T>(*const Self);
+struct WithBounds<T: Sized> {}
+struct WithWhereBounds<T> where T: Sized {}
+struct WithOutlivesBounds<T: 'static> {}
+struct DoubleNothing<T> {
+ s: SomeStruct<T>,
+}
+
+"#,
+ expect![[r#"
+ SomeStruct[A: bivariant]
+ SomeEnum[A: bivariant]
+ ListCell[T: bivariant]
+ SelfTyAlias[T: bivariant]
+ WithBounds[T: bivariant]
+ WithWhereBounds[T: bivariant]
+ WithOutlivesBounds[T: bivariant]
+ DoubleNothing[T: bivariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_use_contravariant_struct1() {
+ check(
+ r#"
+struct SomeStruct<T>(fn(T));
+
+fn foo<'min,'max>(v: SomeStruct<&'max ()>)
+ -> SomeStruct<&'min ()>
+ where 'max : 'min
+{}
+"#,
+ expect![[r#"
+ SomeStruct[T: contravariant]
+ foo['min: contravariant, 'max: covariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_use_contravariant_struct2() {
+ check(
+ r#"
+struct SomeStruct<T>(fn(T));
+
+fn bar<'min,'max>(v: SomeStruct<&'min ()>)
+ -> SomeStruct<&'max ()>
+ where 'max : 'min
+{}
+"#,
+ expect![[r#"
+ SomeStruct[T: contravariant]
+ bar['min: covariant, 'max: contravariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_use_covariant_struct1() {
+ check(
+ r#"
+struct SomeStruct<T>(T);
+
+fn foo<'min,'max>(v: SomeStruct<&'min ()>)
+ -> SomeStruct<&'max ()>
+ where 'max : 'min
+{}
+"#,
+ expect![[r#"
+ SomeStruct[T: covariant]
+ foo['min: contravariant, 'max: covariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_use_covariant_struct2() {
+ check(
+ r#"
+struct SomeStruct<T>(T);
+
+fn foo<'min,'max>(v: SomeStruct<&'max ()>)
+ -> SomeStruct<&'min ()>
+ where 'max : 'min
+{}
+"#,
+ expect![[r#"
+ SomeStruct[T: covariant]
+ foo['min: covariant, 'max: contravariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn rustc_test_variance_use_invariant_struct1() {
+ check(
+ r#"
+struct SomeStruct<T>(*mut T);
+
+fn foo<'min,'max>(v: SomeStruct<&'max ()>)
+ -> SomeStruct<&'min ()>
+ where 'max : 'min
+{}
+
+fn bar<'min,'max>(v: SomeStruct<&'min ()>)
+ -> SomeStruct<&'max ()>
+ where 'max : 'min
+{}
+"#,
+ expect![[r#"
+ SomeStruct[T: invariant]
+ foo['min: invariant, 'max: invariant]
+ bar['min: invariant, 'max: invariant]
+ "#]],
+ );
+ }
+
+ #[test]
+ fn recursive_one_bivariant_more_non_bivariant_params() {
+ // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper)
+ // This is a limitation of current salsa where a cycle may only set a fallback value to the
+ // query result which is not what we want! We want to treat the cycle call as fallback
+ // without setting the query result to the fallback.
+ // `BivariantPartial` works as we workaround for the trivial case of being self-referential
+ check(
+ r#"
+struct BivariantPartial<T, U>(*const BivariantPartial<T, U>, U);
+struct Wrapper<T, U>(BivariantPartialIndirect<T, U>);
+struct BivariantPartialIndirect<T, U>(*const Wrapper<T, U>, U);
+"#,
+ expect![[r#"
+ BivariantPartial[T: bivariant, U: covariant]
+ Wrapper[T: bivariant, U: bivariant]
+ BivariantPartialIndirect[T: bivariant, U: bivariant]
+ "#]],
+ );
+ }
+
+ #[track_caller]
+ fn check(ra_fixture: &str, expected: Expect) {
+ // use tracing_subscriber::{layer::SubscriberExt, Layer};
+ // let my_layer = tracing_subscriber::fmt::layer();
+ // let _g = tracing::subscriber::set_default(tracing_subscriber::registry().with(
+ // my_layer.with_filter(tracing_subscriber::filter::filter_fn(|metadata| {
+ // metadata.target().starts_with("hir_ty::variance")
+ // })),
+ // ));
+ let (db, file_id) = TestDB::with_single_file(ra_fixture);
+
+ let mut defs: Vec<GenericDefId> = Vec::new();
+ let module = db.module_for_file_opt(file_id).unwrap();
+ let def_map = module.def_map(&db);
+ crate::tests::visit_module(&db, &def_map, module.local_id, &mut |it| {
+ defs.push(match it {
+ ModuleDefId::FunctionId(it) => it.into(),
+ ModuleDefId::AdtId(it) => it.into(),
+ ModuleDefId::ConstId(it) => it.into(),
+ ModuleDefId::TraitId(it) => it.into(),
+ ModuleDefId::TraitAliasId(it) => it.into(),
+ ModuleDefId::TypeAliasId(it) => it.into(),
+ _ => return,
+ })
+ });
+ let defs = defs
+ .into_iter()
+ .filter_map(|def| {
+ Some((
+ def,
+ match def {
+ GenericDefId::FunctionId(it) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::AdtId(AdtId::EnumId(it)) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::AdtId(AdtId::StructId(it)) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::AdtId(AdtId::UnionId(it)) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::TraitId(it) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::TraitAliasId(it) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::TypeAliasId(it) => {
+ let loc = it.lookup(&db);
+ loc.source(&db).value.name().unwrap()
+ }
+ GenericDefId::ImplId(_) => return None,
+ GenericDefId::ConstId(_) => return None,
+ },
+ ))
+ })
+ .sorted_by_key(|(_, n)| n.syntax().text_range().start());
+ let mut res = String::new();
+ for (def, name) in defs {
+ let Some(variances) = db.variances_of(def) else {
+ continue;
+ };
+ format_to!(
+ res,
+ "{name}[{}]\n",
+ generics(&db, def)
+ .iter()
+ .map(|(_, param)| match param {
+ GenericParamDataRef::TypeParamData(type_param_data) => {
+ type_param_data.name.as_ref().unwrap()
+ }
+ GenericParamDataRef::ConstParamData(const_param_data) =>
+ &const_param_data.name,
+ GenericParamDataRef::LifetimeParamData(lifetime_param_data) => {
+ &lifetime_param_data.name
+ }
+ })
+ .zip_eq(&*variances)
+ .format_with(", ", |(name, var), f| f(&format_args!(
+ "{}: {var}",
+ name.as_str()
+ )))
+ );
+ }
+
+ expected.assert_eq(&res);
+ }
+}