Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/callee.rs')
-rw-r--r--crates/hir-ty/src/infer/callee.rs543
1 files changed, 543 insertions, 0 deletions
diff --git a/crates/hir-ty/src/infer/callee.rs b/crates/hir-ty/src/infer/callee.rs
new file mode 100644
index 0000000000..3d478912a3
--- /dev/null
+++ b/crates/hir-ty/src/infer/callee.rs
@@ -0,0 +1,543 @@
+//! Inference of calls.
+
+use std::iter;
+
+use intern::sym;
+use tracing::debug;
+
+use hir_def::{CallableDefId, hir::ExprId, signatures::FunctionSignature};
+use rustc_type_ir::{
+ InferTy, Interner,
+ inherent::{GenericArgs as _, IntoKind, Ty as _},
+};
+
+use crate::{
+ Adjust, Adjustment, AutoBorrow, FnAbi,
+ autoderef::{GeneralAutoderef, InferenceContextAutoderef},
+ infer::{
+ AllowTwoPhase, AutoBorrowMutability, Expectation, InferenceContext, InferenceDiagnostic,
+ expr::{ExprIsRead, TupleArgumentsFlag},
+ },
+ method_resolution::{MethodCallee, TreatNotYetDefinedOpaques},
+ next_solver::{
+ FnSig, Ty, TyKind,
+ infer::{BoundRegionConversionTime, traits::ObligationCause},
+ },
+};
+
+#[derive(Debug)]
+enum CallStep<'db> {
+ Builtin(Ty<'db>),
+ DeferredClosure(ExprId, FnSig<'db>),
+ /// Call overloading when callee implements one of the Fn* traits.
+ Overloaded(MethodCallee<'db>),
+}
+
+impl<'db> InferenceContext<'_, 'db> {
+ pub(crate) fn infer_call(
+ &mut self,
+ call_expr: ExprId,
+ callee_expr: ExprId,
+ arg_exprs: &[ExprId],
+ expected: &Expectation<'db>,
+ ) -> Ty<'db> {
+ let original_callee_ty = self.infer_expr_no_expect(callee_expr, ExprIsRead::Yes);
+
+ let expr_ty = self.table.try_structurally_resolve_type(original_callee_ty);
+
+ let mut autoderef = GeneralAutoderef::new_from_inference_context(self, expr_ty);
+ let mut result = None;
+ while result.is_none() && autoderef.next().is_some() {
+ result =
+ Self::try_overloaded_call_step(call_expr, callee_expr, arg_exprs, &mut autoderef);
+ }
+
+ // FIXME: rustc does some ABI checks here, but the ABI mapping is in rustc_target and we don't have access to that crate.
+
+ let obligations = autoderef.take_obligations();
+ self.table.register_predicates(obligations);
+
+ let output = match result {
+ None => {
+ // Check all of the arg expressions, but with no expectations
+ // since we don't have a signature to compare them to.
+ for &arg in arg_exprs {
+ self.infer_expr_no_expect(arg, ExprIsRead::Yes);
+ }
+
+ self.push_diagnostic(InferenceDiagnostic::ExpectedFunction {
+ call_expr,
+ found: original_callee_ty.store(),
+ });
+
+ self.types.types.error
+ }
+
+ Some(CallStep::Builtin(callee_ty)) => {
+ self.confirm_builtin_call(call_expr, callee_ty, arg_exprs, expected)
+ }
+
+ Some(CallStep::DeferredClosure(_def_id, fn_sig)) => {
+ self.confirm_deferred_closure_call(call_expr, arg_exprs, expected, fn_sig)
+ }
+
+ Some(CallStep::Overloaded(method_callee)) => {
+ self.confirm_overloaded_call(call_expr, arg_exprs, expected, method_callee)
+ }
+ };
+
+ // we must check that return type of called functions is WF:
+ self.table.register_wf_obligation(output.into(), ObligationCause::new());
+
+ output
+ }
+
+ fn try_overloaded_call_step(
+ call_expr: ExprId,
+ callee_expr: ExprId,
+ arg_exprs: &[ExprId],
+ autoderef: &mut InferenceContextAutoderef<'_, '_, 'db>,
+ ) -> Option<CallStep<'db>> {
+ let final_ty = autoderef.final_ty();
+ let adjusted_ty = autoderef.ctx().table.try_structurally_resolve_type(final_ty);
+
+ // If the callee is a function pointer or a closure, then we're all set.
+ match adjusted_ty.kind() {
+ TyKind::FnDef(..) | TyKind::FnPtr(..) => {
+ let adjust_steps = autoderef.adjust_steps_as_infer_ok();
+ let adjustments =
+ autoderef.ctx().table.register_infer_ok(adjust_steps).into_boxed_slice();
+ autoderef.ctx().write_expr_adj(callee_expr, adjustments);
+ return Some(CallStep::Builtin(adjusted_ty));
+ }
+
+ // Check whether this is a call to a closure where we
+ // haven't yet decided on whether the closure is fn vs
+ // fnmut vs fnonce. If so, we have to defer further processing.
+ TyKind::Closure(def_id, args)
+ if autoderef.ctx().infcx().closure_kind(adjusted_ty).is_none() =>
+ {
+ let closure_sig = args.as_closure().sig();
+ let closure_sig = autoderef.ctx().infcx().instantiate_binder_with_fresh_vars(
+ BoundRegionConversionTime::FnCall,
+ closure_sig,
+ );
+ let adjust_steps = autoderef.adjust_steps_as_infer_ok();
+ let adjustments = autoderef.ctx().table.register_infer_ok(adjust_steps);
+ let def_id = def_id.0.loc(autoderef.ctx().db).1;
+ autoderef.ctx().record_deferred_call_resolution(
+ def_id,
+ DeferredCallResolution {
+ call_expr,
+ callee_expr,
+ closure_ty: adjusted_ty,
+ adjustments,
+ fn_sig: closure_sig,
+ },
+ );
+ return Some(CallStep::DeferredClosure(def_id, closure_sig));
+ }
+
+ // When calling a `CoroutineClosure` that is local to the body, we will
+ // not know what its `closure_kind` is yet. Instead, just fill in the
+ // signature with an infer var for the `tupled_upvars_ty` of the coroutine,
+ // and record a deferred call resolution which will constrain that var
+ // as part of `AsyncFn*` trait confirmation.
+ TyKind::CoroutineClosure(def_id, args)
+ if autoderef.ctx().infcx().closure_kind(adjusted_ty).is_none() =>
+ {
+ let closure_args = args.as_coroutine_closure();
+ let coroutine_closure_sig =
+ autoderef.ctx().infcx().instantiate_binder_with_fresh_vars(
+ BoundRegionConversionTime::FnCall,
+ closure_args.coroutine_closure_sig(),
+ );
+ let tupled_upvars_ty = autoderef.ctx().table.next_ty_var();
+ // We may actually receive a coroutine back whose kind is different
+ // from the closure that this dispatched from. This is because when
+ // we have no captures, we automatically implement `FnOnce`. This
+ // impl forces the closure kind to `FnOnce` i.e. `u8`.
+ let kind_ty = autoderef.ctx().table.next_ty_var();
+ let interner = autoderef.ctx().interner();
+ let call_sig = interner.mk_fn_sig(
+ [coroutine_closure_sig.tupled_inputs_ty],
+ coroutine_closure_sig.to_coroutine(
+ interner,
+ closure_args.parent_args(),
+ kind_ty,
+ interner.coroutine_for_closure(def_id),
+ tupled_upvars_ty,
+ ),
+ coroutine_closure_sig.c_variadic,
+ coroutine_closure_sig.safety,
+ coroutine_closure_sig.abi,
+ );
+ let adjust_steps = autoderef.adjust_steps_as_infer_ok();
+ let adjustments = autoderef.ctx().table.register_infer_ok(adjust_steps);
+ let def_id = def_id.0.loc(autoderef.ctx().db).1;
+ autoderef.ctx().record_deferred_call_resolution(
+ def_id,
+ DeferredCallResolution {
+ call_expr,
+ callee_expr,
+ closure_ty: adjusted_ty,
+ adjustments,
+ fn_sig: call_sig,
+ },
+ );
+ return Some(CallStep::DeferredClosure(def_id, call_sig));
+ }
+
+ // Hack: we know that there are traits implementing Fn for &F
+ // where F:Fn and so forth. In the particular case of types
+ // like `f: &mut FnMut()`, if there is a call `f()`, we would
+ // normally translate to `FnMut::call_mut(&mut f, ())`, but
+ // that winds up potentially requiring the user to mark their
+ // variable as `mut` which feels unnecessary and unexpected.
+ //
+ // fn foo(f: &mut impl FnMut()) { f() }
+ // ^ without this hack `f` would have to be declared as mutable
+ //
+ // The simplest fix by far is to just ignore this case and deref again,
+ // so we wind up with `FnMut::call_mut(&mut *f, ())`.
+ TyKind::Ref(..) if autoderef.step_count() == 0 => {
+ return None;
+ }
+
+ TyKind::Infer(InferTy::TyVar(vid))
+ // If we end up with an inference variable which is not the hidden type of
+ // an opaque, emit an error.
+ if !autoderef.ctx().infcx().has_opaques_with_sub_unified_hidden_type(vid) => {
+ autoderef
+ .ctx()
+ .type_must_be_known_at_this_point(callee_expr.into(), adjusted_ty);
+ return None;
+ }
+
+ TyKind::Error(_) => {
+ return None;
+ }
+
+ _ => {}
+ }
+
+ // Now, we look for the implementation of a Fn trait on the object's type.
+ // We first do it with the explicit instruction to look for an impl of
+ // `Fn<Tuple>`, with the tuple `Tuple` having an arity corresponding
+ // to the number of call parameters.
+ // If that fails (or_else branch), we try again without specifying the
+ // shape of the tuple (hence the None). This allows to detect an Fn trait
+ // is implemented, and use this information for diagnostic.
+ autoderef
+ .ctx()
+ .try_overloaded_call_traits(adjusted_ty, Some(arg_exprs))
+ .or_else(|| autoderef.ctx().try_overloaded_call_traits(adjusted_ty, None))
+ .map(|(autoref, method)| {
+ let adjustments = autoderef.adjust_steps_as_infer_ok();
+ let mut adjustments = autoderef.ctx().table.register_infer_ok(adjustments);
+ adjustments.extend(autoref);
+ autoderef.ctx().write_expr_adj(callee_expr, adjustments.into_boxed_slice());
+ CallStep::Overloaded(method)
+ })
+ }
+
+ fn try_overloaded_call_traits(
+ &mut self,
+ adjusted_ty: Ty<'db>,
+ opt_arg_exprs: Option<&[ExprId]>,
+ ) -> Option<(Option<Adjustment>, MethodCallee<'db>)> {
+ // HACK(async_closures): For async closures, prefer `AsyncFn*`
+ // over `Fn*`, since all async closures implement `FnOnce`, but
+ // choosing that over `AsyncFn`/`AsyncFnMut` would be more restrictive.
+ // For other callables, just prefer `Fn*` for perf reasons.
+ //
+ // The order of trait choices here is not that big of a deal,
+ // since it just guides inference (and our choice of autoref).
+ // Though in the future, I'd like typeck to choose:
+ // `Fn > AsyncFn > FnMut > AsyncFnMut > FnOnce > AsyncFnOnce`
+ // ...or *ideally*, we just have `LendingFn`/`LendingFnMut`, which
+ // would naturally unify these two trait hierarchies in the most
+ // general way.
+ let call_trait_choices = if self.shallow_resolve(adjusted_ty).is_coroutine_closure() {
+ [
+ (self.lang_items.AsyncFn, sym::async_call, true),
+ (self.lang_items.AsyncFnMut, sym::async_call_mut, true),
+ (self.lang_items.AsyncFnOnce, sym::async_call_once, false),
+ (self.lang_items.Fn, sym::call, true),
+ (self.lang_items.FnMut, sym::call_mut, true),
+ (self.lang_items.FnOnce, sym::call_once, false),
+ ]
+ } else {
+ [
+ (self.lang_items.Fn, sym::call, true),
+ (self.lang_items.FnMut, sym::call_mut, true),
+ (self.lang_items.FnOnce, sym::call_once, false),
+ (self.lang_items.AsyncFn, sym::async_call, true),
+ (self.lang_items.AsyncFnMut, sym::async_call_mut, true),
+ (self.lang_items.AsyncFnOnce, sym::async_call_once, false),
+ ]
+ };
+
+ // Try the options that are least restrictive on the caller first.
+ for (opt_trait_def_id, method_name, borrow) in call_trait_choices {
+ let Some(trait_def_id) = opt_trait_def_id else {
+ continue;
+ };
+
+ let opt_input_type = opt_arg_exprs.map(|arg_exprs| {
+ Ty::new_tup_from_iter(
+ self.interner(),
+ arg_exprs.iter().map(|_| self.table.next_ty_var()),
+ )
+ });
+
+ // We use `TreatNotYetDefinedOpaques::AsRigid` here so that if the `adjusted_ty`
+ // is `Box<impl FnOnce()>` we choose `FnOnce` instead of `Fn`.
+ //
+ // We try all the different call traits in order and choose the first
+ // one which may apply. So if we treat opaques as inference variables
+ // `Box<impl FnOnce()>: Fn` is considered ambiguous and chosen.
+ if let Some(ok) = self.table.lookup_method_for_operator(
+ ObligationCause::new(),
+ method_name,
+ trait_def_id,
+ adjusted_ty,
+ opt_input_type,
+ TreatNotYetDefinedOpaques::AsRigid,
+ ) {
+ let method = self.table.register_infer_ok(ok);
+ let mut autoref = None;
+ if borrow {
+ // Check for &self vs &mut self in the method signature. Since this is either
+ // the Fn or FnMut trait, it should be one of those.
+ let TyKind::Ref(_, _, mutbl) = method.sig.inputs_and_output.inputs()[0].kind()
+ else {
+ panic!("Expected `FnMut`/`Fn` to take receiver by-ref/by-mut")
+ };
+
+ // For initial two-phase borrow
+ // deployment, conservatively omit
+ // overloaded function call ops.
+ let mutbl = AutoBorrowMutability::new(mutbl, AllowTwoPhase::No);
+
+ autoref = Some(Adjustment {
+ kind: Adjust::Borrow(AutoBorrow::Ref(mutbl)),
+ target: method.sig.inputs_and_output.inputs()[0].store(),
+ });
+ }
+
+ return Some((autoref, method));
+ }
+ }
+
+ None
+ }
+
+ /// Returns the argument indices to skip.
+ fn check_legacy_const_generics(
+ &mut self,
+ callee: Option<CallableDefId>,
+ args: &[ExprId],
+ ) -> Box<[u32]> {
+ let func = match callee {
+ Some(CallableDefId::FunctionId(func)) => func,
+ _ => return Default::default(),
+ };
+
+ let data = FunctionSignature::of(self.db, func);
+ let Some(legacy_const_generics_indices) = data.legacy_const_generics_indices(self.db, func)
+ else {
+ return Default::default();
+ };
+ let mut legacy_const_generics_indices = Box::<[u32]>::from(legacy_const_generics_indices);
+
+ // only use legacy const generics if the param count matches with them
+ if data.params.len() + legacy_const_generics_indices.len() != args.len() {
+ if args.len() <= data.params.len() {
+ return Default::default();
+ } else {
+ // there are more parameters than there should be without legacy
+ // const params; use them
+ legacy_const_generics_indices.sort_unstable();
+ return legacy_const_generics_indices;
+ }
+ }
+
+ // check legacy const parameters
+ for arg_idx in legacy_const_generics_indices.iter().copied() {
+ if arg_idx >= args.len() as u32 {
+ continue;
+ }
+ let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly
+ self.infer_expr(args[arg_idx as usize], &expected, ExprIsRead::Yes);
+ // FIXME: evaluate and unify with the const
+ }
+ legacy_const_generics_indices.sort_unstable();
+ legacy_const_generics_indices
+ }
+
+ fn confirm_builtin_call(
+ &mut self,
+ call_expr: ExprId,
+ callee_ty: Ty<'db>,
+ arg_exprs: &[ExprId],
+ expected: &Expectation<'db>,
+ ) -> Ty<'db> {
+ let (fn_sig, def_id) = match callee_ty.kind() {
+ TyKind::FnDef(def_id, args) => {
+ let fn_sig =
+ self.db.callable_item_signature(def_id.0).instantiate(self.interner(), args);
+ (fn_sig, Some(def_id.0))
+ }
+
+ // FIXME(const_trait_impl): these arms should error because we can't enforce them
+ TyKind::FnPtr(sig_tys, hdr) => (sig_tys.with(hdr), None),
+
+ _ => unreachable!(),
+ };
+
+ // Replace any late-bound regions that appear in the function
+ // signature with region variables. We also have to
+ // renormalize the associated types at this point, since they
+ // previously appeared within a `Binder<>` and hence would not
+ // have been normalized before.
+ let fn_sig = self
+ .infcx()
+ .instantiate_binder_with_fresh_vars(BoundRegionConversionTime::FnCall, fn_sig);
+
+ let indices_to_skip = self.check_legacy_const_generics(def_id, arg_exprs);
+ self.check_call_arguments(
+ call_expr,
+ fn_sig.inputs(),
+ fn_sig.output(),
+ expected,
+ arg_exprs,
+ &indices_to_skip,
+ fn_sig.c_variadic,
+ TupleArgumentsFlag::DontTupleArguments,
+ );
+
+ if fn_sig.abi == FnAbi::RustCall
+ && let Some(ty) = fn_sig.inputs().last().copied()
+ && let Some(tuple_trait) = self.lang_items.Tuple
+ {
+ self.table.register_bound(ty, tuple_trait, ObligationCause::new());
+ self.require_type_is_sized(ty);
+ }
+
+ fn_sig.output()
+ }
+
+ fn confirm_deferred_closure_call(
+ &mut self,
+ call_expr: ExprId,
+ arg_exprs: &[ExprId],
+ expected: &Expectation<'db>,
+ fn_sig: FnSig<'db>,
+ ) -> Ty<'db> {
+ // `fn_sig` is the *signature* of the closure being called. We
+ // don't know the full details yet (`Fn` vs `FnMut` etc), but we
+ // do know the types expected for each argument and the return
+ // type.
+ self.check_call_arguments(
+ call_expr,
+ fn_sig.inputs(),
+ fn_sig.output(),
+ expected,
+ arg_exprs,
+ &[],
+ fn_sig.c_variadic,
+ TupleArgumentsFlag::TupleArguments,
+ );
+
+ fn_sig.output()
+ }
+
+ fn confirm_overloaded_call(
+ &mut self,
+ call_expr: ExprId,
+ arg_exprs: &[ExprId],
+ expected: &Expectation<'db>,
+ method: MethodCallee<'db>,
+ ) -> Ty<'db> {
+ self.check_call_arguments(
+ call_expr,
+ &method.sig.inputs()[1..],
+ method.sig.output(),
+ expected,
+ arg_exprs,
+ &[],
+ method.sig.c_variadic,
+ TupleArgumentsFlag::TupleArguments,
+ );
+
+ self.write_method_resolution(call_expr, method.def_id, method.args);
+
+ method.sig.output()
+ }
+}
+
+#[derive(Debug, Clone)]
+pub(crate) struct DeferredCallResolution<'db> {
+ call_expr: ExprId,
+ callee_expr: ExprId,
+ closure_ty: Ty<'db>,
+ adjustments: Vec<Adjustment>,
+ fn_sig: FnSig<'db>,
+}
+
+impl<'a, 'db> DeferredCallResolution<'db> {
+ pub(crate) fn resolve(self, ctx: &mut InferenceContext<'a, 'db>) {
+ debug!("DeferredCallResolution::resolve() {:?}", self);
+
+ // we should not be invoked until the closure kind has been
+ // determined by upvar inference
+ assert!(ctx.infcx().closure_kind(self.closure_ty).is_some());
+
+ // We may now know enough to figure out fn vs fnmut etc.
+ match ctx.try_overloaded_call_traits(self.closure_ty, None) {
+ Some((autoref, method_callee)) => {
+ // One problem is that when we get here, we are going
+ // to have a newly instantiated function signature
+ // from the call trait. This has to be reconciled with
+ // the older function signature we had before. In
+ // principle we *should* be able to fn_sigs(), but we
+ // can't because of the annoying need for a TypeTrace.
+ // (This always bites me, should find a way to
+ // refactor it.)
+ let method_sig = method_callee.sig;
+
+ debug!("attempt_resolution: method_callee={:?}", method_callee);
+
+ for (method_arg_ty, self_arg_ty) in
+ iter::zip(method_sig.inputs().iter().skip(1), self.fn_sig.inputs())
+ {
+ _ = ctx.demand_eqtype(self.call_expr.into(), *self_arg_ty, *method_arg_ty);
+ }
+
+ _ = ctx.demand_eqtype(
+ self.call_expr.into(),
+ method_sig.output(),
+ self.fn_sig.output(),
+ );
+
+ let mut adjustments = self.adjustments;
+ adjustments.extend(autoref);
+ ctx.write_expr_adj(self.callee_expr, adjustments.into_boxed_slice());
+
+ ctx.write_method_resolution(
+ self.call_expr,
+ method_callee.def_id,
+ method_callee.args,
+ );
+ }
+ None => {
+ assert!(
+ ctx.lang_items.FnOnce.is_none(),
+ "Expected to find a suitable `Fn`/`FnMut`/`FnOnce` implementation for `{:?}`",
+ self.closure_ty
+ )
+ }
+ }
+ }
+}