Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/object_safety.rs')
-rw-r--r--crates/hir-ty/src/object_safety.rs612
1 files changed, 612 insertions, 0 deletions
diff --git a/crates/hir-ty/src/object_safety.rs b/crates/hir-ty/src/object_safety.rs
new file mode 100644
index 0000000000..a4c6626855
--- /dev/null
+++ b/crates/hir-ty/src/object_safety.rs
@@ -0,0 +1,612 @@
+//! Compute the object-safety of a trait
+
+use std::ops::ControlFlow;
+
+use chalk_ir::{
+ cast::Cast,
+ visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
+ DebruijnIndex,
+};
+use chalk_solve::rust_ir::InlineBound;
+use hir_def::{
+ lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId,
+ TypeAliasId,
+};
+use rustc_hash::FxHashSet;
+use smallvec::SmallVec;
+
+use crate::{
+ all_super_traits,
+ db::HirDatabase,
+ from_assoc_type_id, from_chalk_trait_id,
+ generics::{generics, trait_self_param_idx},
+ lower::callable_item_sig,
+ to_assoc_type_id, to_chalk_trait_id,
+ utils::elaborate_clause_supertraits,
+ AliasEq, AliasTy, Binders, BoundVar, CallableSig, GoalData, ImplTraitId, Interner, OpaqueTyId,
+ ProjectionTyExt, Solution, Substitution, TraitRef, Ty, TyKind, WhereClause,
+};
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum ObjectSafetyViolation {
+ SizedSelf,
+ SelfReferential,
+ Method(FunctionId, MethodViolationCode),
+ AssocConst(ConstId),
+ GAT(TypeAliasId),
+ // This doesn't exist in rustc, but added for better visualization
+ HasNonSafeSuperTrait(TraitId),
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum MethodViolationCode {
+ StaticMethod,
+ ReferencesSelfInput,
+ ReferencesSelfOutput,
+ ReferencesImplTraitInTrait,
+ AsyncFn,
+ WhereClauseReferencesSelf,
+ Generic,
+ UndispatchableReceiver,
+}
+
+pub fn object_safety(db: &dyn HirDatabase, trait_: TraitId) -> Option<ObjectSafetyViolation> {
+ for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() {
+ if db.object_safety_of_trait(super_trait).is_some() {
+ return Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait));
+ }
+ }
+
+ db.object_safety_of_trait(trait_)
+}
+
+pub fn object_safety_with_callback<F>(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ cb: &mut F,
+) -> ControlFlow<()>
+where
+ F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>,
+{
+ for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() {
+ if db.object_safety_of_trait(super_trait).is_some() {
+ cb(ObjectSafetyViolation::HasNonSafeSuperTrait(trait_))?;
+ }
+ }
+
+ object_safety_of_trait_with_callback(db, trait_, cb)
+}
+
+pub fn object_safety_of_trait_with_callback<F>(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ cb: &mut F,
+) -> ControlFlow<()>
+where
+ F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>,
+{
+ // Check whether this has a `Sized` bound
+ if generics_require_sized_self(db, trait_.into()) {
+ cb(ObjectSafetyViolation::SizedSelf)?;
+ }
+
+ // Check if there exist bounds that referencing self
+ if predicates_reference_self(db, trait_) {
+ cb(ObjectSafetyViolation::SelfReferential)?;
+ }
+ if bounds_reference_self(db, trait_) {
+ cb(ObjectSafetyViolation::SelfReferential)?;
+ }
+
+ // rustc checks for non-lifetime binders here, but we don't support HRTB yet
+
+ let trait_data = db.trait_data(trait_);
+ for (_, assoc_item) in &trait_data.items {
+ object_safety_violation_for_assoc_item(db, trait_, *assoc_item, cb)?;
+ }
+
+ ControlFlow::Continue(())
+}
+
+pub fn object_safety_of_trait_query(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+) -> Option<ObjectSafetyViolation> {
+ let mut res = None;
+ object_safety_of_trait_with_callback(db, trait_, &mut |osv| {
+ res = Some(osv);
+ ControlFlow::Break(())
+ });
+
+ res
+}
+
+fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool {
+ let krate = def.module(db.upcast()).krate();
+ let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else {
+ return false;
+ };
+
+ let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else {
+ return false;
+ };
+
+ let predicates = &*db.generic_predicates(def);
+ let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone());
+ elaborate_clause_supertraits(db, predicates).any(|pred| match pred {
+ WhereClause::Implemented(trait_ref) => {
+ if from_chalk_trait_id(trait_ref.trait_id) == sized {
+ if let TyKind::BoundVar(it) =
+ *trait_ref.self_type_parameter(Interner).kind(Interner)
+ {
+ // Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of
+ // self-parameter is `1`
+ return it
+ .index_if_bound_at(DebruijnIndex::ONE)
+ .is_some_and(|idx| idx == trait_self_param_idx);
+ }
+ }
+ false
+ }
+ _ => false,
+ })
+}
+
+// rustc gathers all the spans that references `Self` for error rendering,
+// but we don't have good way to render such locations.
+// So, just return single boolean value for existence of such `Self` reference
+fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
+ db.generic_predicates(trait_.into())
+ .iter()
+ .any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No))
+}
+
+// Same as the above, `predicates_reference_self`
+fn bounds_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
+ let trait_data = db.trait_data(trait_);
+ trait_data
+ .items
+ .iter()
+ .filter_map(|(_, it)| match *it {
+ AssocItemId::TypeAliasId(id) => {
+ let assoc_ty_id = to_assoc_type_id(id);
+ let assoc_ty_data = db.associated_ty_data(assoc_ty_id);
+ Some(assoc_ty_data)
+ }
+ _ => None,
+ })
+ .any(|assoc_ty_data| {
+ assoc_ty_data.binders.skip_binders().bounds.iter().any(|bound| {
+ let def = from_assoc_type_id(assoc_ty_data.id).into();
+ match bound.skip_binders() {
+ InlineBound::TraitBound(it) => it.args_no_self.iter().any(|arg| {
+ contains_illegal_self_type_reference(
+ db,
+ def,
+ trait_,
+ arg,
+ DebruijnIndex::ONE,
+ AllowSelfProjection::Yes,
+ )
+ }),
+ InlineBound::AliasEqBound(it) => it.parameters.iter().any(|arg| {
+ contains_illegal_self_type_reference(
+ db,
+ def,
+ trait_,
+ arg,
+ DebruijnIndex::ONE,
+ AllowSelfProjection::Yes,
+ )
+ }),
+ }
+ })
+ })
+}
+
+#[derive(Clone, Copy)]
+enum AllowSelfProjection {
+ Yes,
+ No,
+}
+
+fn predicate_references_self(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ predicate: &Binders<Binders<WhereClause>>,
+ allow_self_projection: AllowSelfProjection,
+) -> bool {
+ match predicate.skip_binders().skip_binders() {
+ WhereClause::Implemented(trait_ref) => {
+ trait_ref.substitution.iter(Interner).skip(1).any(|arg| {
+ contains_illegal_self_type_reference(
+ db,
+ trait_.into(),
+ trait_,
+ arg,
+ DebruijnIndex::ONE,
+ allow_self_projection,
+ )
+ })
+ }
+ WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(proj), .. }) => {
+ proj.substitution.iter(Interner).skip(1).any(|arg| {
+ contains_illegal_self_type_reference(
+ db,
+ trait_.into(),
+ trait_,
+ arg,
+ DebruijnIndex::ONE,
+ allow_self_projection,
+ )
+ })
+ }
+ _ => false,
+ }
+}
+
+fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>(
+ db: &dyn HirDatabase,
+ def: GenericDefId,
+ trait_: TraitId,
+ t: &T,
+ outer_binder: DebruijnIndex,
+ allow_self_projection: AllowSelfProjection,
+) -> bool {
+ let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else {
+ return false;
+ };
+ struct IllegalSelfTypeVisitor<'a> {
+ db: &'a dyn HirDatabase,
+ trait_: TraitId,
+ super_traits: Option<SmallVec<[TraitId; 4]>>,
+ trait_self_param_idx: usize,
+ allow_self_projection: AllowSelfProjection,
+ }
+ impl<'a> TypeVisitor<Interner> for IllegalSelfTypeVisitor<'a> {
+ type BreakTy = ();
+
+ fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
+ self
+ }
+
+ fn interner(&self) -> Interner {
+ Interner
+ }
+
+ fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
+ match ty.kind(Interner) {
+ TyKind::BoundVar(BoundVar { debruijn, index }) => {
+ if *debruijn == outer_binder && *index == self.trait_self_param_idx {
+ ControlFlow::Break(())
+ } else {
+ ty.super_visit_with(self.as_dyn(), outer_binder)
+ }
+ }
+ TyKind::Alias(AliasTy::Projection(proj)) => match self.allow_self_projection {
+ AllowSelfProjection::Yes => {
+ let trait_ = proj.trait_(self.db);
+ if self.super_traits.is_none() {
+ self.super_traits =
+ Some(all_super_traits(self.db.upcast(), self.trait_));
+ }
+ if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) {
+ ControlFlow::Continue(())
+ } else {
+ ty.super_visit_with(self.as_dyn(), outer_binder)
+ }
+ }
+ AllowSelfProjection::No => ty.super_visit_with(self.as_dyn(), outer_binder),
+ },
+ _ => ty.super_visit_with(self.as_dyn(), outer_binder),
+ }
+ }
+
+ fn visit_const(
+ &mut self,
+ constant: &chalk_ir::Const<Interner>,
+ outer_binder: DebruijnIndex,
+ ) -> std::ops::ControlFlow<Self::BreakTy> {
+ constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder)
+ }
+ }
+
+ let mut visitor = IllegalSelfTypeVisitor {
+ db,
+ trait_,
+ super_traits: None,
+ trait_self_param_idx,
+ allow_self_projection,
+ };
+ t.visit_with(visitor.as_dyn(), outer_binder).is_break()
+}
+
+fn object_safety_violation_for_assoc_item<F>(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ item: AssocItemId,
+ cb: &mut F,
+) -> ControlFlow<()>
+where
+ F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>,
+{
+ // Any item that has a `Self : Sized` requisite is otherwise
+ // exempt from the regulations.
+ if generics_require_sized_self(db, item.into()) {
+ return ControlFlow::Continue(());
+ }
+
+ match item {
+ AssocItemId::ConstId(it) => cb(ObjectSafetyViolation::AssocConst(it)),
+ AssocItemId::FunctionId(it) => {
+ virtual_call_violations_for_method(db, trait_, it, &mut |mvc| {
+ cb(ObjectSafetyViolation::Method(it, mvc))
+ })
+ }
+ AssocItemId::TypeAliasId(it) => {
+ let def_map = db.crate_def_map(trait_.krate(db.upcast()));
+ if def_map.is_unstable_feature_enabled(&intern::sym::generic_associated_type_extended) {
+ ControlFlow::Continue(())
+ } else {
+ let generic_params = db.generic_params(item.into());
+ if !generic_params.is_empty() {
+ cb(ObjectSafetyViolation::GAT(it))
+ } else {
+ ControlFlow::Continue(())
+ }
+ }
+ }
+ }
+}
+
+fn virtual_call_violations_for_method<F>(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ func: FunctionId,
+ cb: &mut F,
+) -> ControlFlow<()>
+where
+ F: FnMut(MethodViolationCode) -> ControlFlow<()>,
+{
+ let func_data = db.function_data(func);
+ if !func_data.has_self_param() {
+ cb(MethodViolationCode::StaticMethod)?;
+ }
+
+ if func_data.is_async() {
+ cb(MethodViolationCode::AsyncFn)?;
+ }
+
+ let sig = callable_item_sig(db, func.into());
+ if sig.skip_binders().params().iter().skip(1).any(|ty| {
+ contains_illegal_self_type_reference(
+ db,
+ func.into(),
+ trait_,
+ ty,
+ DebruijnIndex::INNERMOST,
+ AllowSelfProjection::Yes,
+ )
+ }) {
+ cb(MethodViolationCode::ReferencesSelfInput)?;
+ }
+
+ if contains_illegal_self_type_reference(
+ db,
+ func.into(),
+ trait_,
+ sig.skip_binders().ret(),
+ DebruijnIndex::INNERMOST,
+ AllowSelfProjection::Yes,
+ ) {
+ cb(MethodViolationCode::ReferencesSelfOutput)?;
+ }
+
+ if !func_data.is_async() {
+ if let Some(mvc) = contains_illegal_impl_trait_in_trait(db, &sig) {
+ cb(mvc)?;
+ }
+ }
+
+ let generic_params = db.generic_params(func.into());
+ if generic_params.len_type_or_consts() > 0 {
+ cb(MethodViolationCode::Generic)?;
+ }
+
+ if func_data.has_self_param() && !receiver_is_dispatchable(db, trait_, func, &sig) {
+ cb(MethodViolationCode::UndispatchableReceiver)?;
+ }
+
+ let predicates = &*db.generic_predicates_without_parent(func.into());
+ let trait_self_idx = trait_self_param_idx(db.upcast(), func.into());
+ for pred in predicates {
+ let pred = pred.skip_binders().skip_binders();
+
+ if matches!(pred, WhereClause::TypeOutlives(_)) {
+ continue;
+ }
+
+ // Allow `impl AutoTrait` predicates
+ if let WhereClause::Implemented(TraitRef { trait_id, substitution }) = pred {
+ let trait_data = db.trait_data(from_chalk_trait_id(*trait_id));
+ if trait_data.is_auto
+ && substitution
+ .as_slice(Interner)
+ .first()
+ .and_then(|arg| arg.ty(Interner))
+ .and_then(|ty| ty.bound_var(Interner))
+ .is_some_and(|b| {
+ b.debruijn == DebruijnIndex::ONE && Some(b.index) == trait_self_idx
+ })
+ {
+ continue;
+ }
+ }
+
+ if contains_illegal_self_type_reference(
+ db,
+ func.into(),
+ trait_,
+ pred,
+ DebruijnIndex::ONE,
+ AllowSelfProjection::Yes,
+ ) {
+ cb(MethodViolationCode::WhereClauseReferencesSelf)?;
+ break;
+ }
+ }
+
+ ControlFlow::Continue(())
+}
+
+fn receiver_is_dispatchable(
+ db: &dyn HirDatabase,
+ trait_: TraitId,
+ func: FunctionId,
+ sig: &Binders<CallableSig>,
+) -> bool {
+ let Some(trait_self_idx) = trait_self_param_idx(db.upcast(), func.into()) else {
+ return false;
+ };
+
+ // `self: Self` can't be dispatched on, but this is already considered object safe.
+ // See rustc's comment on https://github.com/rust-lang/rust/blob/3f121b9461cce02a703a0e7e450568849dfaa074/compiler/rustc_trait_selection/src/traits/object_safety.rs#L433-L437
+ if sig
+ .skip_binders()
+ .params()
+ .first()
+ .and_then(|receiver| receiver.bound_var(Interner))
+ .is_some_and(|b| {
+ b == BoundVar { debruijn: DebruijnIndex::INNERMOST, index: trait_self_idx }
+ })
+ {
+ return true;
+ }
+
+ let placeholder_subst = generics(db.upcast(), func.into()).placeholder_subst(db);
+
+ let substituted_sig = sig.clone().substitute(Interner, &placeholder_subst);
+ let Some(receiver_ty) = substituted_sig.params().first() else {
+ return false;
+ };
+
+ let krate = func.module(db.upcast()).krate();
+ let traits = (
+ db.lang_item(krate, LangItem::Unsize).and_then(|it| it.as_trait()),
+ db.lang_item(krate, LangItem::DispatchFromDyn).and_then(|it| it.as_trait()),
+ );
+ let (Some(unsize_did), Some(dispatch_from_dyn_did)) = traits else {
+ return false;
+ };
+
+ // Type `U`
+ let unsized_self_ty =
+ TyKind::Scalar(chalk_ir::Scalar::Uint(chalk_ir::UintTy::U32)).intern(Interner);
+ // `Receiver[Self => U]`
+ let Some(unsized_receiver_ty) = receiver_for_self_ty(db, func, unsized_self_ty.clone()) else {
+ return false;
+ };
+
+ let self_ty = placeholder_subst.as_slice(Interner)[trait_self_idx].assert_ty_ref(Interner);
+ let unsized_predicate = WhereClause::Implemented(TraitRef {
+ trait_id: to_chalk_trait_id(unsize_did),
+ substitution: Substitution::from_iter(Interner, [self_ty.clone(), unsized_self_ty.clone()]),
+ });
+ let trait_predicate = WhereClause::Implemented(TraitRef {
+ trait_id: to_chalk_trait_id(trait_),
+ substitution: Substitution::from_iter(
+ Interner,
+ std::iter::once(unsized_self_ty.clone().cast(Interner))
+ .chain(placeholder_subst.iter(Interner).skip(1).cloned()),
+ ),
+ });
+
+ let generic_predicates = &*db.generic_predicates(func.into());
+
+ let clauses = std::iter::once(unsized_predicate)
+ .chain(std::iter::once(trait_predicate))
+ .chain(generic_predicates.iter().map(|pred| {
+ pred.clone().substitute(Interner, &placeholder_subst).into_value_and_skipped_binders().0
+ }))
+ .map(|pred| {
+ pred.cast::<chalk_ir::ProgramClause<Interner>>(Interner).into_from_env_clause(Interner)
+ });
+ let env = chalk_ir::Environment::new(Interner).add_clauses(Interner, clauses);
+
+ let obligation = WhereClause::Implemented(TraitRef {
+ trait_id: to_chalk_trait_id(dispatch_from_dyn_did),
+ substitution: Substitution::from_iter(Interner, [receiver_ty.clone(), unsized_receiver_ty]),
+ });
+ let goal = GoalData::DomainGoal(chalk_ir::DomainGoal::Holds(obligation)).intern(Interner);
+
+ let in_env = chalk_ir::InEnvironment::new(&env, goal);
+
+ let mut table = chalk_solve::infer::InferenceTable::<Interner>::new();
+ let canonicalized = table.canonicalize(Interner, in_env);
+ let solution = db.trait_solve(krate, None, canonicalized.quantified);
+
+ matches!(solution, Some(Solution::Unique(_)))
+}
+
+fn receiver_for_self_ty(db: &dyn HirDatabase, func: FunctionId, ty: Ty) -> Option<Ty> {
+ let generics = generics(db.upcast(), func.into());
+ let trait_self_idx = trait_self_param_idx(db.upcast(), func.into())?;
+ let subst = generics.placeholder_subst(db);
+ let subst = Substitution::from_iter(
+ Interner,
+ subst.iter(Interner).enumerate().map(|(idx, arg)| {
+ if idx == trait_self_idx {
+ ty.clone().cast(Interner)
+ } else {
+ arg.clone()
+ }
+ }),
+ );
+ let sig = callable_item_sig(db, func.into());
+ let sig = sig.substitute(Interner, &subst);
+ sig.params_and_return.first().cloned()
+}
+
+fn contains_illegal_impl_trait_in_trait(
+ db: &dyn HirDatabase,
+ sig: &Binders<CallableSig>,
+) -> Option<MethodViolationCode> {
+ struct OpaqueTypeCollector(FxHashSet<OpaqueTyId>);
+
+ impl TypeVisitor<Interner> for OpaqueTypeCollector {
+ type BreakTy = ();
+
+ fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
+ self
+ }
+
+ fn interner(&self) -> Interner {
+ Interner
+ }
+
+ fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
+ if let TyKind::OpaqueType(opaque_ty_id, _) = ty.kind(Interner) {
+ self.0.insert(*opaque_ty_id);
+ }
+ ty.super_visit_with(self.as_dyn(), outer_binder)
+ }
+ }
+
+ let ret = sig.skip_binders().ret();
+ let mut visitor = OpaqueTypeCollector(FxHashSet::default());
+ ret.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST);
+
+ // Since we haven't implemented RPITIT in proper way like rustc yet,
+ // just check whether `ret` contains RPIT for now
+ for opaque_ty in visitor.0 {
+ let impl_trait_id = db.lookup_intern_impl_trait_id(opaque_ty.into());
+ if matches!(impl_trait_id, ImplTraitId::ReturnTypeImplTrait(..)) {
+ return Some(MethodViolationCode::ReferencesImplTraitInTrait);
+ }
+ }
+
+ None
+}
+
+#[cfg(test)]
+mod tests;