Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/object_safety.rs')
| -rw-r--r-- | crates/hir-ty/src/object_safety.rs | 612 |
1 files changed, 612 insertions, 0 deletions
diff --git a/crates/hir-ty/src/object_safety.rs b/crates/hir-ty/src/object_safety.rs new file mode 100644 index 0000000000..a4c6626855 --- /dev/null +++ b/crates/hir-ty/src/object_safety.rs @@ -0,0 +1,612 @@ +//! Compute the object-safety of a trait + +use std::ops::ControlFlow; + +use chalk_ir::{ + cast::Cast, + visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, + DebruijnIndex, +}; +use chalk_solve::rust_ir::InlineBound; +use hir_def::{ + lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId, HasModule, TraitId, + TypeAliasId, +}; +use rustc_hash::FxHashSet; +use smallvec::SmallVec; + +use crate::{ + all_super_traits, + db::HirDatabase, + from_assoc_type_id, from_chalk_trait_id, + generics::{generics, trait_self_param_idx}, + lower::callable_item_sig, + to_assoc_type_id, to_chalk_trait_id, + utils::elaborate_clause_supertraits, + AliasEq, AliasTy, Binders, BoundVar, CallableSig, GoalData, ImplTraitId, Interner, OpaqueTyId, + ProjectionTyExt, Solution, Substitution, TraitRef, Ty, TyKind, WhereClause, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ObjectSafetyViolation { + SizedSelf, + SelfReferential, + Method(FunctionId, MethodViolationCode), + AssocConst(ConstId), + GAT(TypeAliasId), + // This doesn't exist in rustc, but added for better visualization + HasNonSafeSuperTrait(TraitId), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MethodViolationCode { + StaticMethod, + ReferencesSelfInput, + ReferencesSelfOutput, + ReferencesImplTraitInTrait, + AsyncFn, + WhereClauseReferencesSelf, + Generic, + UndispatchableReceiver, +} + +pub fn object_safety(db: &dyn HirDatabase, trait_: TraitId) -> Option<ObjectSafetyViolation> { + for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() { + if db.object_safety_of_trait(super_trait).is_some() { + return Some(ObjectSafetyViolation::HasNonSafeSuperTrait(super_trait)); + } + } + + db.object_safety_of_trait(trait_) +} + +pub fn object_safety_with_callback<F>( + db: &dyn HirDatabase, + trait_: TraitId, + cb: &mut F, +) -> ControlFlow<()> +where + F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>, +{ + for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() { + if db.object_safety_of_trait(super_trait).is_some() { + cb(ObjectSafetyViolation::HasNonSafeSuperTrait(trait_))?; + } + } + + object_safety_of_trait_with_callback(db, trait_, cb) +} + +pub fn object_safety_of_trait_with_callback<F>( + db: &dyn HirDatabase, + trait_: TraitId, + cb: &mut F, +) -> ControlFlow<()> +where + F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>, +{ + // Check whether this has a `Sized` bound + if generics_require_sized_self(db, trait_.into()) { + cb(ObjectSafetyViolation::SizedSelf)?; + } + + // Check if there exist bounds that referencing self + if predicates_reference_self(db, trait_) { + cb(ObjectSafetyViolation::SelfReferential)?; + } + if bounds_reference_self(db, trait_) { + cb(ObjectSafetyViolation::SelfReferential)?; + } + + // rustc checks for non-lifetime binders here, but we don't support HRTB yet + + let trait_data = db.trait_data(trait_); + for (_, assoc_item) in &trait_data.items { + object_safety_violation_for_assoc_item(db, trait_, *assoc_item, cb)?; + } + + ControlFlow::Continue(()) +} + +pub fn object_safety_of_trait_query( + db: &dyn HirDatabase, + trait_: TraitId, +) -> Option<ObjectSafetyViolation> { + let mut res = None; + object_safety_of_trait_with_callback(db, trait_, &mut |osv| { + res = Some(osv); + ControlFlow::Break(()) + }); + + res +} + +fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool { + let krate = def.module(db.upcast()).krate(); + let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else { + return false; + }; + + let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else { + return false; + }; + + let predicates = &*db.generic_predicates(def); + let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone()); + elaborate_clause_supertraits(db, predicates).any(|pred| match pred { + WhereClause::Implemented(trait_ref) => { + if from_chalk_trait_id(trait_ref.trait_id) == sized { + if let TyKind::BoundVar(it) = + *trait_ref.self_type_parameter(Interner).kind(Interner) + { + // Since `generic_predicates` is `Binder<Binder<..>>`, the `DebrujinIndex` of + // self-parameter is `1` + return it + .index_if_bound_at(DebruijnIndex::ONE) + .is_some_and(|idx| idx == trait_self_param_idx); + } + } + false + } + _ => false, + }) +} + +// rustc gathers all the spans that references `Self` for error rendering, +// but we don't have good way to render such locations. +// So, just return single boolean value for existence of such `Self` reference +fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool { + db.generic_predicates(trait_.into()) + .iter() + .any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No)) +} + +// Same as the above, `predicates_reference_self` +fn bounds_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool { + let trait_data = db.trait_data(trait_); + trait_data + .items + .iter() + .filter_map(|(_, it)| match *it { + AssocItemId::TypeAliasId(id) => { + let assoc_ty_id = to_assoc_type_id(id); + let assoc_ty_data = db.associated_ty_data(assoc_ty_id); + Some(assoc_ty_data) + } + _ => None, + }) + .any(|assoc_ty_data| { + assoc_ty_data.binders.skip_binders().bounds.iter().any(|bound| { + let def = from_assoc_type_id(assoc_ty_data.id).into(); + match bound.skip_binders() { + InlineBound::TraitBound(it) => it.args_no_self.iter().any(|arg| { + contains_illegal_self_type_reference( + db, + def, + trait_, + arg, + DebruijnIndex::ONE, + AllowSelfProjection::Yes, + ) + }), + InlineBound::AliasEqBound(it) => it.parameters.iter().any(|arg| { + contains_illegal_self_type_reference( + db, + def, + trait_, + arg, + DebruijnIndex::ONE, + AllowSelfProjection::Yes, + ) + }), + } + }) + }) +} + +#[derive(Clone, Copy)] +enum AllowSelfProjection { + Yes, + No, +} + +fn predicate_references_self( + db: &dyn HirDatabase, + trait_: TraitId, + predicate: &Binders<Binders<WhereClause>>, + allow_self_projection: AllowSelfProjection, +) -> bool { + match predicate.skip_binders().skip_binders() { + WhereClause::Implemented(trait_ref) => { + trait_ref.substitution.iter(Interner).skip(1).any(|arg| { + contains_illegal_self_type_reference( + db, + trait_.into(), + trait_, + arg, + DebruijnIndex::ONE, + allow_self_projection, + ) + }) + } + WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(proj), .. }) => { + proj.substitution.iter(Interner).skip(1).any(|arg| { + contains_illegal_self_type_reference( + db, + trait_.into(), + trait_, + arg, + DebruijnIndex::ONE, + allow_self_projection, + ) + }) + } + _ => false, + } +} + +fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>( + db: &dyn HirDatabase, + def: GenericDefId, + trait_: TraitId, + t: &T, + outer_binder: DebruijnIndex, + allow_self_projection: AllowSelfProjection, +) -> bool { + let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else { + return false; + }; + struct IllegalSelfTypeVisitor<'a> { + db: &'a dyn HirDatabase, + trait_: TraitId, + super_traits: Option<SmallVec<[TraitId; 4]>>, + trait_self_param_idx: usize, + allow_self_projection: AllowSelfProjection, + } + impl<'a> TypeVisitor<Interner> for IllegalSelfTypeVisitor<'a> { + type BreakTy = (); + + fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> { + self + } + + fn interner(&self) -> Interner { + Interner + } + + fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> { + match ty.kind(Interner) { + TyKind::BoundVar(BoundVar { debruijn, index }) => { + if *debruijn == outer_binder && *index == self.trait_self_param_idx { + ControlFlow::Break(()) + } else { + ty.super_visit_with(self.as_dyn(), outer_binder) + } + } + TyKind::Alias(AliasTy::Projection(proj)) => match self.allow_self_projection { + AllowSelfProjection::Yes => { + let trait_ = proj.trait_(self.db); + if self.super_traits.is_none() { + self.super_traits = + Some(all_super_traits(self.db.upcast(), self.trait_)); + } + if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) { + ControlFlow::Continue(()) + } else { + ty.super_visit_with(self.as_dyn(), outer_binder) + } + } + AllowSelfProjection::No => ty.super_visit_with(self.as_dyn(), outer_binder), + }, + _ => ty.super_visit_with(self.as_dyn(), outer_binder), + } + } + + fn visit_const( + &mut self, + constant: &chalk_ir::Const<Interner>, + outer_binder: DebruijnIndex, + ) -> std::ops::ControlFlow<Self::BreakTy> { + constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder) + } + } + + let mut visitor = IllegalSelfTypeVisitor { + db, + trait_, + super_traits: None, + trait_self_param_idx, + allow_self_projection, + }; + t.visit_with(visitor.as_dyn(), outer_binder).is_break() +} + +fn object_safety_violation_for_assoc_item<F>( + db: &dyn HirDatabase, + trait_: TraitId, + item: AssocItemId, + cb: &mut F, +) -> ControlFlow<()> +where + F: FnMut(ObjectSafetyViolation) -> ControlFlow<()>, +{ + // Any item that has a `Self : Sized` requisite is otherwise + // exempt from the regulations. + if generics_require_sized_self(db, item.into()) { + return ControlFlow::Continue(()); + } + + match item { + AssocItemId::ConstId(it) => cb(ObjectSafetyViolation::AssocConst(it)), + AssocItemId::FunctionId(it) => { + virtual_call_violations_for_method(db, trait_, it, &mut |mvc| { + cb(ObjectSafetyViolation::Method(it, mvc)) + }) + } + AssocItemId::TypeAliasId(it) => { + let def_map = db.crate_def_map(trait_.krate(db.upcast())); + if def_map.is_unstable_feature_enabled(&intern::sym::generic_associated_type_extended) { + ControlFlow::Continue(()) + } else { + let generic_params = db.generic_params(item.into()); + if !generic_params.is_empty() { + cb(ObjectSafetyViolation::GAT(it)) + } else { + ControlFlow::Continue(()) + } + } + } + } +} + +fn virtual_call_violations_for_method<F>( + db: &dyn HirDatabase, + trait_: TraitId, + func: FunctionId, + cb: &mut F, +) -> ControlFlow<()> +where + F: FnMut(MethodViolationCode) -> ControlFlow<()>, +{ + let func_data = db.function_data(func); + if !func_data.has_self_param() { + cb(MethodViolationCode::StaticMethod)?; + } + + if func_data.is_async() { + cb(MethodViolationCode::AsyncFn)?; + } + + let sig = callable_item_sig(db, func.into()); + if sig.skip_binders().params().iter().skip(1).any(|ty| { + contains_illegal_self_type_reference( + db, + func.into(), + trait_, + ty, + DebruijnIndex::INNERMOST, + AllowSelfProjection::Yes, + ) + }) { + cb(MethodViolationCode::ReferencesSelfInput)?; + } + + if contains_illegal_self_type_reference( + db, + func.into(), + trait_, + sig.skip_binders().ret(), + DebruijnIndex::INNERMOST, + AllowSelfProjection::Yes, + ) { + cb(MethodViolationCode::ReferencesSelfOutput)?; + } + + if !func_data.is_async() { + if let Some(mvc) = contains_illegal_impl_trait_in_trait(db, &sig) { + cb(mvc)?; + } + } + + let generic_params = db.generic_params(func.into()); + if generic_params.len_type_or_consts() > 0 { + cb(MethodViolationCode::Generic)?; + } + + if func_data.has_self_param() && !receiver_is_dispatchable(db, trait_, func, &sig) { + cb(MethodViolationCode::UndispatchableReceiver)?; + } + + let predicates = &*db.generic_predicates_without_parent(func.into()); + let trait_self_idx = trait_self_param_idx(db.upcast(), func.into()); + for pred in predicates { + let pred = pred.skip_binders().skip_binders(); + + if matches!(pred, WhereClause::TypeOutlives(_)) { + continue; + } + + // Allow `impl AutoTrait` predicates + if let WhereClause::Implemented(TraitRef { trait_id, substitution }) = pred { + let trait_data = db.trait_data(from_chalk_trait_id(*trait_id)); + if trait_data.is_auto + && substitution + .as_slice(Interner) + .first() + .and_then(|arg| arg.ty(Interner)) + .and_then(|ty| ty.bound_var(Interner)) + .is_some_and(|b| { + b.debruijn == DebruijnIndex::ONE && Some(b.index) == trait_self_idx + }) + { + continue; + } + } + + if contains_illegal_self_type_reference( + db, + func.into(), + trait_, + pred, + DebruijnIndex::ONE, + AllowSelfProjection::Yes, + ) { + cb(MethodViolationCode::WhereClauseReferencesSelf)?; + break; + } + } + + ControlFlow::Continue(()) +} + +fn receiver_is_dispatchable( + db: &dyn HirDatabase, + trait_: TraitId, + func: FunctionId, + sig: &Binders<CallableSig>, +) -> bool { + let Some(trait_self_idx) = trait_self_param_idx(db.upcast(), func.into()) else { + return false; + }; + + // `self: Self` can't be dispatched on, but this is already considered object safe. + // See rustc's comment on https://github.com/rust-lang/rust/blob/3f121b9461cce02a703a0e7e450568849dfaa074/compiler/rustc_trait_selection/src/traits/object_safety.rs#L433-L437 + if sig + .skip_binders() + .params() + .first() + .and_then(|receiver| receiver.bound_var(Interner)) + .is_some_and(|b| { + b == BoundVar { debruijn: DebruijnIndex::INNERMOST, index: trait_self_idx } + }) + { + return true; + } + + let placeholder_subst = generics(db.upcast(), func.into()).placeholder_subst(db); + + let substituted_sig = sig.clone().substitute(Interner, &placeholder_subst); + let Some(receiver_ty) = substituted_sig.params().first() else { + return false; + }; + + let krate = func.module(db.upcast()).krate(); + let traits = ( + db.lang_item(krate, LangItem::Unsize).and_then(|it| it.as_trait()), + db.lang_item(krate, LangItem::DispatchFromDyn).and_then(|it| it.as_trait()), + ); + let (Some(unsize_did), Some(dispatch_from_dyn_did)) = traits else { + return false; + }; + + // Type `U` + let unsized_self_ty = + TyKind::Scalar(chalk_ir::Scalar::Uint(chalk_ir::UintTy::U32)).intern(Interner); + // `Receiver[Self => U]` + let Some(unsized_receiver_ty) = receiver_for_self_ty(db, func, unsized_self_ty.clone()) else { + return false; + }; + + let self_ty = placeholder_subst.as_slice(Interner)[trait_self_idx].assert_ty_ref(Interner); + let unsized_predicate = WhereClause::Implemented(TraitRef { + trait_id: to_chalk_trait_id(unsize_did), + substitution: Substitution::from_iter(Interner, [self_ty.clone(), unsized_self_ty.clone()]), + }); + let trait_predicate = WhereClause::Implemented(TraitRef { + trait_id: to_chalk_trait_id(trait_), + substitution: Substitution::from_iter( + Interner, + std::iter::once(unsized_self_ty.clone().cast(Interner)) + .chain(placeholder_subst.iter(Interner).skip(1).cloned()), + ), + }); + + let generic_predicates = &*db.generic_predicates(func.into()); + + let clauses = std::iter::once(unsized_predicate) + .chain(std::iter::once(trait_predicate)) + .chain(generic_predicates.iter().map(|pred| { + pred.clone().substitute(Interner, &placeholder_subst).into_value_and_skipped_binders().0 + })) + .map(|pred| { + pred.cast::<chalk_ir::ProgramClause<Interner>>(Interner).into_from_env_clause(Interner) + }); + let env = chalk_ir::Environment::new(Interner).add_clauses(Interner, clauses); + + let obligation = WhereClause::Implemented(TraitRef { + trait_id: to_chalk_trait_id(dispatch_from_dyn_did), + substitution: Substitution::from_iter(Interner, [receiver_ty.clone(), unsized_receiver_ty]), + }); + let goal = GoalData::DomainGoal(chalk_ir::DomainGoal::Holds(obligation)).intern(Interner); + + let in_env = chalk_ir::InEnvironment::new(&env, goal); + + let mut table = chalk_solve::infer::InferenceTable::<Interner>::new(); + let canonicalized = table.canonicalize(Interner, in_env); + let solution = db.trait_solve(krate, None, canonicalized.quantified); + + matches!(solution, Some(Solution::Unique(_))) +} + +fn receiver_for_self_ty(db: &dyn HirDatabase, func: FunctionId, ty: Ty) -> Option<Ty> { + let generics = generics(db.upcast(), func.into()); + let trait_self_idx = trait_self_param_idx(db.upcast(), func.into())?; + let subst = generics.placeholder_subst(db); + let subst = Substitution::from_iter( + Interner, + subst.iter(Interner).enumerate().map(|(idx, arg)| { + if idx == trait_self_idx { + ty.clone().cast(Interner) + } else { + arg.clone() + } + }), + ); + let sig = callable_item_sig(db, func.into()); + let sig = sig.substitute(Interner, &subst); + sig.params_and_return.first().cloned() +} + +fn contains_illegal_impl_trait_in_trait( + db: &dyn HirDatabase, + sig: &Binders<CallableSig>, +) -> Option<MethodViolationCode> { + struct OpaqueTypeCollector(FxHashSet<OpaqueTyId>); + + impl TypeVisitor<Interner> for OpaqueTypeCollector { + type BreakTy = (); + + fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> { + self + } + + fn interner(&self) -> Interner { + Interner + } + + fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> { + if let TyKind::OpaqueType(opaque_ty_id, _) = ty.kind(Interner) { + self.0.insert(*opaque_ty_id); + } + ty.super_visit_with(self.as_dyn(), outer_binder) + } + } + + let ret = sig.skip_binders().ret(); + let mut visitor = OpaqueTypeCollector(FxHashSet::default()); + ret.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST); + + // Since we haven't implemented RPITIT in proper way like rustc yet, + // just check whether `ret` contains RPIT for now + for opaque_ty in visitor.0 { + let impl_trait_id = db.lookup_intern_impl_trait_id(opaque_ty.into()); + if matches!(impl_trait_id, ImplTraitId::ReturnTypeImplTrait(..)) { + return Some(MethodViolationCode::ReferencesImplTraitInTrait); + } + } + + None +} + +#[cfg(test)] +mod tests; |