Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/closure.rs | 214 |
1 files changed, 130 insertions, 84 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 1d5d8dd13e..89ebd2b21d 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -2,54 +2,52 @@ pub(crate) mod analysis; -use std::ops::ControlFlow; -use std::{iter, mem}; +use std::{iter, mem, ops::ControlFlow}; use hir_def::{ TraitId, hir::{ClosureKind, ExprId, PatId}, - lang_item::LangItem, type_ref::TypeRefId, }; use rustc_type_ir::{ - ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, Interner, TypeSuperVisitable, - TypeVisitable, TypeVisitableExt, TypeVisitor, + ClosureArgs, ClosureArgsParts, CoroutineArgs, CoroutineArgsParts, CoroutineClosureArgs, + CoroutineClosureArgsParts, Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, + TypeVisitor, inherent::{BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Ty as _}, }; use tracing::debug; -use crate::traits::FnTrait; use crate::{ FnAbi, db::{InternedClosure, InternedCoroutine}, infer::{BreakableKind, Diverges, coerce::CoerceMany}, next_solver::{ - AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, - PolyProjectionPredicate, Predicate, PredicateKind, SolverDefId, Ty, TyKind, + AliasTy, Binder, BoundRegionKind, BoundVarKind, BoundVarKinds, ClauseKind, DbInterner, + ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, PolyProjectionPredicate, Predicate, + PredicateKind, SolverDefId, Ty, TyKind, abi::Safety, infer::{ - BoundRegionConversionTime, DefineOpaqueTypes, InferOk, InferResult, + BoundRegionConversionTime, InferOk, InferResult, traits::{ObligationCause, PredicateObligations}, }, - mapping::{ChalkToNextSolver, NextSolverToChalk}, - util::explicit_item_bounds, }, + traits::FnTrait, }; use super::{Expectation, InferenceContext}; #[derive(Debug)] -struct ClosureSignatures<'tcx> { +struct ClosureSignatures<'db> { /// The signature users of the closure see. - bound_sig: PolyFnSig<'tcx>, + bound_sig: PolyFnSig<'db>, /// The signature within the function body. /// This mostly differs in the sense that lifetimes are now early bound and any /// opaque types from the signature expectation are overridden in case there are /// explicit hidden types written by the user in the closure signature. - liberated_sig: FnSig<'tcx>, + liberated_sig: FnSig<'db>, } -impl<'db> InferenceContext<'db> { +impl<'db> InferenceContext<'_, 'db> { pub(super) fn infer_closure( &mut self, body: ExprId, @@ -58,15 +56,13 @@ impl<'db> InferenceContext<'db> { arg_types: &[Option<TypeRefId>], closure_kind: ClosureKind, tgt_expr: ExprId, - expected: &Expectation, - ) -> crate::Ty { + expected: &Expectation<'db>, + ) -> Ty<'db> { assert_eq!(args.len(), arg_types.len()); - let interner = self.table.interner; + let interner = self.interner(); let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { - Some(expected_ty) => { - self.deduce_closure_signature(expected_ty.to_nextsolver(interner), closure_kind) - } + Some(expected_ty) => self.deduce_closure_signature(expected_ty, closure_kind), None => (None, None), }; @@ -76,22 +72,21 @@ impl<'db> InferenceContext<'db> { let sig_ty = Ty::new_fn_ptr(interner, bound_sig); let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); + // FIXME: Make this an infer var and infer it later. + let tupled_upvars_ty = self.types.unit; let (id, ty, resume_yield_tys) = match closure_kind { ClosureKind::Coroutine(_) => { let yield_ty = self.table.next_ty_var(); - let resume_ty = liberated_sig - .inputs() - .get(0) - .unwrap_or(self.result.standard_types.unit.to_nextsolver(interner)); + let resume_ty = liberated_sig.inputs().get(0).unwrap_or(self.types.unit); // FIXME: Infer the upvars later. let parts = CoroutineArgsParts { parent_args, - kind_ty: Ty::new_unit(interner), + kind_ty: self.types.unit, resume_ty, yield_ty, return_ty: body_ret_ty, - tupled_upvars_ty: Ty::new_unit(interner), + tupled_upvars_ty, }; let coroutine_id = @@ -102,20 +97,14 @@ impl<'db> InferenceContext<'db> { CoroutineArgs::new(interner, parts).args, ); - ( - None, - coroutine_ty, - Some((resume_ty.to_chalk(interner), yield_ty.to_chalk(interner))), - ) + (None, coroutine_ty, Some((resume_ty, yield_ty))) } - // FIXME(next-solver): `ClosureKind::Async` should really be a separate arm that creates a `CoroutineClosure`. - // But for now we treat it as a closure. - ClosureKind::Closure | ClosureKind::Async => { + ClosureKind::Closure => { let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr)); match expected_kind { Some(kind) => { self.result.closure_info.insert( - closure_id.into(), + closure_id, ( Vec::new(), match kind { @@ -128,7 +117,7 @@ impl<'db> InferenceContext<'db> { } None => {} }; - // FIXME: Infer the kind and the upvars later when needed. + // FIXME: Infer the kind later if needed. let parts = ClosureArgsParts { parent_args, closure_kind_ty: Ty::from_closure_kind( @@ -136,7 +125,7 @@ impl<'db> InferenceContext<'db> { expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), ), closure_sig_as_fn_ptr_ty: sig_ty, - tupled_upvars_ty: Ty::new_unit(interner), + tupled_upvars_ty, }; let closure_ty = Ty::new_closure( interner, @@ -147,17 +136,72 @@ impl<'db> InferenceContext<'db> { self.add_current_closure_dependency(closure_id); (Some(closure_id), closure_ty, None) } + ClosureKind::Async => { + // async closures always return the type ascribed after the `->` (if present), + // and yield `()`. + let bound_return_ty = bound_sig.skip_binder().output(); + let bound_yield_ty = self.types.unit; + // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. + let resume_ty = self.types.unit; + + // FIXME: Infer the kind later if needed. + let closure_kind_ty = Ty::from_closure_kind( + interner, + expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), + ); + + // FIXME: Infer captures later. + // `for<'env> fn() -> ()`, for no captures. + let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + interner, + Binder::bind_with_vars( + interner.mk_fn_sig([], self.types.unit, false, Safety::Safe, FnAbi::Rust), + BoundVarKinds::new_from_iter( + interner, + [BoundVarKind::Region(BoundRegionKind::ClosureEnv)], + ), + ), + ); + let closure_args = CoroutineClosureArgs::new( + interner, + CoroutineClosureArgsParts { + parent_args, + closure_kind_ty, + signature_parts_ty: Ty::new_fn_ptr( + interner, + bound_sig.map_bound(|sig| { + interner.mk_fn_sig( + [ + resume_ty, + Ty::new_tup_from_iter(interner, sig.inputs().iter()), + ], + Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), + sig.c_variadic, + sig.safety, + sig.abi, + ) + }), + ), + tupled_upvars_ty, + coroutine_captures_by_ref_ty, + }, + ); + + let coroutine_id = + self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); + (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None) + } }; // Now go through the argument patterns for (arg_pat, arg_ty) in args.iter().zip(bound_sig.skip_binder().inputs()) { - self.infer_top_pat(*arg_pat, &arg_ty.to_chalk(interner), None); + self.infer_top_pat(*arg_pat, arg_ty, None); } // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); let prev_closure = mem::replace(&mut self.current_closure, id); - let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty.to_chalk(interner)); + let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty); let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty)); let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); @@ -171,15 +215,16 @@ impl<'db> InferenceContext<'db> { self.current_closure = prev_closure; self.resume_yield_tys = prev_resume_yield_tys; - ty.to_chalk(interner) + ty } fn fn_trait_kind_from_def_id(&self, trait_id: TraitId) -> Option<rustc_type_ir::ClosureKind> { - let lang_item = self.db.lang_attr(trait_id.into())?; - match lang_item { - LangItem::Fn => Some(rustc_type_ir::ClosureKind::Fn), - LangItem::FnMut => Some(rustc_type_ir::ClosureKind::FnMut), - LangItem::FnOnce => Some(rustc_type_ir::ClosureKind::FnOnce), + match trait_id { + _ if self.lang_items.Fn == Some(trait_id) => Some(rustc_type_ir::ClosureKind::Fn), + _ if self.lang_items.FnMut == Some(trait_id) => Some(rustc_type_ir::ClosureKind::FnMut), + _ if self.lang_items.FnOnce == Some(trait_id) => { + Some(rustc_type_ir::ClosureKind::FnOnce) + } _ => None, } } @@ -188,11 +233,14 @@ impl<'db> InferenceContext<'db> { &self, trait_id: TraitId, ) -> Option<rustc_type_ir::ClosureKind> { - let lang_item = self.db.lang_attr(trait_id.into())?; - match lang_item { - LangItem::AsyncFn => Some(rustc_type_ir::ClosureKind::Fn), - LangItem::AsyncFnMut => Some(rustc_type_ir::ClosureKind::FnMut), - LangItem::AsyncFnOnce => Some(rustc_type_ir::ClosureKind::FnOnce), + match trait_id { + _ if self.lang_items.AsyncFn == Some(trait_id) => Some(rustc_type_ir::ClosureKind::Fn), + _ if self.lang_items.AsyncFnMut == Some(trait_id) => { + Some(rustc_type_ir::ClosureKind::FnMut) + } + _ if self.lang_items.AsyncFnOnce == Some(trait_id) => { + Some(rustc_type_ir::ClosureKind::FnOnce) + } _ => None, } } @@ -209,14 +257,15 @@ impl<'db> InferenceContext<'db> { .deduce_closure_signature_from_predicates( expected_ty, closure_kind, - explicit_item_bounds(self.table.interner, def_id) - .iter_instantiated(self.table.interner, args) + def_id + .expect_opaque_ty() + .predicates(self.db) + .iter_instantiated_copied(self.interner(), args.as_slice()) .map(|clause| clause.as_predicate()), ), TyKind::Dynamic(object_type, ..) => { let sig = object_type.projection_bounds().into_iter().find_map(|pb| { - let pb = - pb.with_self_ty(self.table.interner, Ty::new_unit(self.table.interner)); + let pb = pb.with_self_ty(self.interner(), Ty::new_unit(self.interner())); self.deduce_sig_from_projection(closure_kind, pb) }); let kind = object_type @@ -226,7 +275,7 @@ impl<'db> InferenceContext<'db> { } TyKind::Infer(rustc_type_ir::TyVar(vid)) => self .deduce_closure_signature_from_predicates( - Ty::new_var(self.table.interner, self.table.infer_ctxt.root_var(vid)), + Ty::new_var(self.interner(), self.table.infer_ctxt.root_var(vid)), closure_kind, self.table.obligations_for_self_ty(vid).into_iter().map(|obl| obl.predicate), ), @@ -251,7 +300,7 @@ impl<'db> InferenceContext<'db> { let mut expected_kind = None; for pred in rustc_type_ir::elaborate::elaborate( - self.table.interner, + self.interner(), // Reverse the obligations here, since `elaborate_*` uses a stack, // and we want to keep inference generally in the same order of // the registered obligations. @@ -313,20 +362,20 @@ impl<'db> InferenceContext<'db> { // even though the normalized form may not name `expected_ty`. However, this matches the existing // behaviour of the old solver and would be technically a breaking change to fix. let generalized_fnptr_sig = self.table.next_ty_var(); - let inferred_fnptr_sig = Ty::new_fn_ptr(self.table.interner, inferred_sig); + let inferred_fnptr_sig = Ty::new_fn_ptr(self.interner(), inferred_sig); // FIXME: Report diagnostics. _ = self .table .infer_ctxt - .at(&ObligationCause::new(), self.table.param_env) - .eq(DefineOpaqueTypes::Yes, inferred_fnptr_sig, generalized_fnptr_sig) + .at(&ObligationCause::new(), self.table.trait_env.env) + .eq(inferred_fnptr_sig, generalized_fnptr_sig) .map(|infer_ok| self.table.register_infer_ok(infer_ok)); let resolved_sig = self.table.infer_ctxt.resolve_vars_if_possible(generalized_fnptr_sig); if resolved_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() { - expected_sig = Some(resolved_sig.fn_sig(self.table.interner)); + expected_sig = Some(resolved_sig.fn_sig(self.interner())); } } else if inferred_sig.visit_with(&mut MentionsTy { expected_ty }).is_continue() { expected_sig = inferred_sig; @@ -339,7 +388,7 @@ impl<'db> InferenceContext<'db> { // many viable options, so pick the most restrictive. let trait_def_id = match bound_predicate.skip_binder() { PredicateKind::Clause(ClauseKind::Projection(data)) => { - Some(data.projection_term.trait_def_id(self.table.interner).0) + Some(data.projection_term.trait_def_id(self.interner()).0) } PredicateKind::Clause(ClauseKind::Trait(data)) => Some(data.def_id().0), _ => None, @@ -387,21 +436,20 @@ impl<'db> InferenceContext<'db> { projection: PolyProjectionPredicate<'db>, ) -> Option<PolyFnSig<'db>> { let SolverDefId::TypeAliasId(def_id) = projection.item_def_id() else { unreachable!() }; - let lang_item = self.db.lang_attr(def_id.into()); // For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits, // for closures and async closures, respectively. match closure_kind { - ClosureKind::Closure if lang_item == Some(LangItem::FnOnceOutput) => { + ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection(projection) } - ClosureKind::Async if lang_item == Some(LangItem::AsyncFnOnceOutput) => { + ClosureKind::Async if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { self.extract_sig_from_projection(projection) } // It's possible we've passed the closure to a (somewhat out-of-fashion) // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still // guide inference here, since it's beneficial for the user. - ClosureKind::Async if lang_item == Some(LangItem::FnOnceOutput) => { + ClosureKind::Async if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection_and_future_bound(projection) } _ => None, @@ -427,7 +475,7 @@ impl<'db> InferenceContext<'db> { let ret_param_ty = projection.skip_binder().term.expect_type(); debug!(?ret_param_ty); - let sig = projection.rebind(self.table.interner.mk_fn_sig( + let sig = projection.rebind(self.interner().mk_fn_sig( input_tys, ret_param_ty, false, @@ -492,7 +540,7 @@ impl<'db> InferenceContext<'db> { && let ret_projection = bound.predicate.kind().rebind(ret_projection) && let Some(ret_projection) = ret_projection.no_bound_vars() && let SolverDefId::TypeAliasId(assoc_type) = ret_projection.def_id() - && self.db.lang_attr(assoc_type.into()) == Some(LangItem::FutureOutput) + && Some(assoc_type) == self.lang_items.FutureOutput { return_ty = Some(ret_projection.term.expect_type()); break; @@ -515,7 +563,7 @@ impl<'db> InferenceContext<'db> { // that does not misuse a `FnSig` type, but that can be done separately. let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var()); - let sig = projection.rebind(self.table.interner.mk_fn_sig( + let sig = projection.rebind(self.interner().mk_fn_sig( input_tys, return_ty, false, @@ -619,7 +667,7 @@ impl<'db> InferenceContext<'db> { // in this binder we are creating. assert!(!expected_sig.skip_binder().has_vars_bound_above(rustc_type_ir::INNERMOST)); let bound_sig = expected_sig.map_bound(|sig| { - self.table.interner.mk_fn_sig( + self.interner().mk_fn_sig( sig.inputs(), sig.output(), sig.c_variadic, @@ -631,7 +679,7 @@ impl<'db> InferenceContext<'db> { // `deduce_expectations_from_expected_type` introduces // late-bound lifetimes defined elsewhere, which we now // anonymize away, so as not to confuse the user. - let bound_sig = self.table.interner.anonymize_bound_vars(bound_sig); + let bound_sig = self.interner().anonymize_bound_vars(bound_sig); let closure_sigs = self.closure_sigs(bound_sig); @@ -703,19 +751,17 @@ impl<'db> InferenceContext<'db> { let cause = ObligationCause::new(); let InferOk { value: (), obligations } = table .infer_ctxt - .at(&cause, table.param_env) - .eq(DefineOpaqueTypes::Yes, expected_ty, supplied_ty)?; + .at(&cause, table.trait_env.env) + .eq(expected_ty, supplied_ty)?; all_obligations.extend(obligations); } let supplied_output_ty = supplied_sig.output(); let cause = ObligationCause::new(); - let InferOk { value: (), obligations } = - table.infer_ctxt.at(&cause, table.param_env).eq( - DefineOpaqueTypes::Yes, - expected_sigs.liberated_sig.output(), - supplied_output_ty, - )?; + let InferOk { value: (), obligations } = table + .infer_ctxt + .at(&cause, table.trait_env.env) + .eq(expected_sigs.liberated_sig.output(), supplied_output_ty)?; all_obligations.extend(obligations); let inputs = supplied_sig @@ -723,7 +769,7 @@ impl<'db> InferenceContext<'db> { .into_iter() .map(|ty| table.infer_ctxt.resolve_vars_if_possible(ty)); - expected_sigs.liberated_sig = table.interner.mk_fn_sig( + expected_sigs.liberated_sig = table.interner().mk_fn_sig( inputs, supplied_output_ty, expected_sigs.liberated_sig.c_variadic, @@ -744,12 +790,12 @@ impl<'db> InferenceContext<'db> { decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, ) -> PolyFnSig<'db> { - let interner = self.table.interner; + let interner = self.interner(); let supplied_return = match decl_output { Some(output) => { let output = self.make_body_ty(output); - self.process_user_written_ty(output).to_nextsolver(interner) + self.process_user_written_ty(output) } None => self.table.next_ty_var(), }; @@ -757,7 +803,7 @@ impl<'db> InferenceContext<'db> { let supplied_arguments = decl_inputs.iter().map(|&input| match input { Some(input) => { let input = self.make_body_ty(input); - self.process_user_written_ty(input).to_nextsolver(interner) + self.process_user_written_ty(input) } None => self.table.next_ty_var(), }); @@ -779,7 +825,7 @@ impl<'db> InferenceContext<'db> { decl_inputs: &[Option<TypeRefId>], decl_output: Option<TypeRefId>, ) -> PolyFnSig<'db> { - let interner = self.table.interner; + let interner = self.interner(); let err_ty = Ty::new_error(interner, ErrorGuaranteed); if let Some(output) = decl_output { |