Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/opaques.rs')
-rw-r--r--crates/hir-ty/src/opaques.rs199
1 files changed, 199 insertions, 0 deletions
diff --git a/crates/hir-ty/src/opaques.rs b/crates/hir-ty/src/opaques.rs
new file mode 100644
index 0000000000..8531f24377
--- /dev/null
+++ b/crates/hir-ty/src/opaques.rs
@@ -0,0 +1,199 @@
+//! Handling of opaque types, detection of defining scope and hidden type.
+
+use hir_def::{
+ AssocItemId, AssocItemLoc, DefWithBodyId, FunctionId, HasModule, ItemContainerId, TypeAliasId,
+};
+use hir_expand::name::Name;
+use la_arena::ArenaMap;
+use rustc_type_ir::inherent::Ty as _;
+use syntax::ast;
+use triomphe::Arc;
+
+use crate::{
+ ImplTraitId,
+ db::{HirDatabase, InternedOpaqueTyId},
+ lower::{ImplTraitIdx, ImplTraits},
+ next_solver::{
+ DbInterner, EarlyBinder, ErrorGuaranteed, SolverDefId, Ty, TypingMode,
+ infer::{DbInternerInferExt, traits::ObligationCause},
+ obligation_ctxt::ObligationCtxt,
+ },
+};
+
+pub(crate) fn opaque_types_defined_by(
+ db: &dyn HirDatabase,
+ def_id: DefWithBodyId,
+ result: &mut Vec<SolverDefId>,
+) {
+ if let DefWithBodyId::FunctionId(func) = def_id {
+ // A function may define its own RPITs.
+ extend_with_opaques(
+ db,
+ db.return_type_impl_traits(func),
+ |opaque_idx| ImplTraitId::ReturnTypeImplTrait(func, opaque_idx),
+ result,
+ );
+ }
+
+ let extend_with_taits = |type_alias| {
+ extend_with_opaques(
+ db,
+ db.type_alias_impl_traits(type_alias),
+ |opaque_idx| ImplTraitId::TypeAliasImplTrait(type_alias, opaque_idx),
+ result,
+ );
+ };
+
+ // Collect opaques from assoc items.
+ let extend_with_atpit_from_assoc_items = |assoc_items: &[(Name, AssocItemId)]| {
+ assoc_items
+ .iter()
+ .filter_map(|&(_, assoc_id)| match assoc_id {
+ AssocItemId::TypeAliasId(it) => Some(it),
+ AssocItemId::FunctionId(_) | AssocItemId::ConstId(_) => None,
+ })
+ .for_each(extend_with_taits);
+ };
+ let extend_with_atpit_from_container = |container| match container {
+ ItemContainerId::ImplId(impl_id) => {
+ if db.impl_signature(impl_id).target_trait.is_some() {
+ extend_with_atpit_from_assoc_items(&impl_id.impl_items(db).items);
+ }
+ }
+ ItemContainerId::TraitId(trait_id) => {
+ extend_with_atpit_from_assoc_items(&trait_id.trait_items(db).items);
+ }
+ _ => {}
+ };
+ match def_id {
+ DefWithBodyId::ConstId(id) => extend_with_atpit_from_container(id.loc(db).container),
+ DefWithBodyId::FunctionId(id) => extend_with_atpit_from_container(id.loc(db).container),
+ DefWithBodyId::StaticId(_) | DefWithBodyId::VariantId(_) => {}
+ }
+
+ // FIXME: Collect opaques from `#[define_opaque]`.
+
+ fn extend_with_opaques<'db>(
+ db: &'db dyn HirDatabase,
+ opaques: Option<Arc<EarlyBinder<'db, ImplTraits<'db>>>>,
+ mut make_impl_trait: impl FnMut(ImplTraitIdx<'db>) -> ImplTraitId<'db>,
+ result: &mut Vec<SolverDefId>,
+ ) {
+ if let Some(opaques) = opaques {
+ for (opaque_idx, _) in (*opaques).as_ref().skip_binder().impl_traits.iter() {
+ let opaque_id = InternedOpaqueTyId::new(db, make_impl_trait(opaque_idx));
+ result.push(opaque_id.into());
+ }
+ }
+ }
+}
+
+// These are firewall queries to prevent drawing dependencies between infers:
+
+#[salsa::tracked(returns(ref), unsafe(non_update_return_type))]
+pub(crate) fn rpit_hidden_types<'db>(
+ db: &'db dyn HirDatabase,
+ function: FunctionId,
+) -> ArenaMap<ImplTraitIdx<'db>, EarlyBinder<'db, Ty<'db>>> {
+ let infer = db.infer(function.into());
+ let mut result = ArenaMap::new();
+ for (opaque, hidden_type) in infer.return_position_impl_trait_types(db) {
+ result.insert(opaque, EarlyBinder::bind(hidden_type));
+ }
+ result.shrink_to_fit();
+ result
+}
+
+#[salsa::tracked(returns(ref), unsafe(non_update_return_type))]
+pub(crate) fn tait_hidden_types<'db>(
+ db: &'db dyn HirDatabase,
+ type_alias: TypeAliasId,
+) -> ArenaMap<ImplTraitIdx<'db>, EarlyBinder<'db, Ty<'db>>> {
+ let loc = type_alias.loc(db);
+ let module = loc.module(db);
+ let interner = DbInterner::new_with(db, Some(module.krate()), module.containing_block());
+ let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis());
+ let mut ocx = ObligationCtxt::new(&infcx);
+ let cause = ObligationCause::dummy();
+ let param_env = db.trait_environment(type_alias.into()).env;
+
+ let defining_bodies = tait_defining_bodies(db, &loc);
+
+ let taits_count = db
+ .type_alias_impl_traits(type_alias)
+ .map_or(0, |taits| (*taits).as_ref().skip_binder().impl_traits.len());
+
+ let mut result = ArenaMap::with_capacity(taits_count);
+ for defining_body in defining_bodies {
+ let infer = db.infer(defining_body);
+ for (&opaque, &hidden_type) in &infer.type_of_opaque {
+ let ImplTraitId::TypeAliasImplTrait(opaque_owner, opaque_idx) = opaque.loc(db) else {
+ continue;
+ };
+ if opaque_owner != type_alias {
+ continue;
+ }
+ // In the presence of errors, we attempt to create a unified type from all
+ // types. rustc doesn't do that, but this should improve the experience.
+ let hidden_type = infcx.insert_type_vars(hidden_type);
+ match result.entry(opaque_idx) {
+ la_arena::Entry::Vacant(entry) => {
+ entry.insert(EarlyBinder::bind(hidden_type));
+ }
+ la_arena::Entry::Occupied(entry) => {
+ _ = ocx.eq(&cause, param_env, entry.get().instantiate_identity(), hidden_type);
+ }
+ }
+ }
+ }
+
+ _ = ocx.try_evaluate_obligations();
+
+ // Fill missing entries.
+ for idx in 0..taits_count {
+ let idx = la_arena::Idx::from_raw(la_arena::RawIdx::from_u32(idx as u32));
+ match result.entry(idx) {
+ la_arena::Entry::Vacant(entry) => {
+ entry.insert(EarlyBinder::bind(Ty::new_error(interner, ErrorGuaranteed)));
+ }
+ la_arena::Entry::Occupied(mut entry) => {
+ *entry.get_mut() = entry.get().map_bound(|hidden_type| {
+ infcx.resolve_vars_if_possible(hidden_type).replace_infer_with_error(interner)
+ });
+ }
+ }
+ }
+
+ result
+}
+
+fn tait_defining_bodies(
+ db: &dyn HirDatabase,
+ loc: &AssocItemLoc<ast::TypeAlias>,
+) -> Vec<DefWithBodyId> {
+ let from_assoc_items = |assoc_items: &[(Name, AssocItemId)]| {
+ // Associated Type Position Impl Trait.
+ assoc_items
+ .iter()
+ .filter_map(|&(_, assoc_id)| match assoc_id {
+ AssocItemId::FunctionId(it) => Some(it.into()),
+ AssocItemId::ConstId(it) => Some(it.into()),
+ AssocItemId::TypeAliasId(_) => None,
+ })
+ .collect()
+ };
+ match loc.container {
+ ItemContainerId::ImplId(impl_id) => {
+ if db.impl_signature(impl_id).target_trait.is_some() {
+ return from_assoc_items(&impl_id.impl_items(db).items);
+ }
+ }
+ ItemContainerId::TraitId(trait_id) => {
+ return from_assoc_items(&trait_id.trait_items(db).items);
+ }
+ _ => {}
+ }
+
+ // FIXME: Support general TAITs, or decisively decide not to.
+ Vec::new()
+}