Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/expr.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/expr.rs | 266 |
1 files changed, 78 insertions, 188 deletions
diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index e494fa813f..d80ea71674 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -11,7 +11,7 @@ use hir_def::{ InlineAsmKind, LabelId, Literal, Pat, PatId, RecordSpread, Statement, UnaryOp, }, resolver::ValueNs, - signatures::{FunctionSignature, VariantFields}, + signatures::VariantFields, }; use hir_def::{FunctionId, hir::ClosureKind}; use hir_expand::name::Name; @@ -24,9 +24,7 @@ use syntax::ast::RangeOp; use tracing::debug; use crate::{ - Adjust, Adjustment, CallableDefId, DeclContext, DeclOrigin, Rawness, - autoderef::InferenceContextAutoderef, - consteval, + Adjust, Adjustment, CallableDefId, DeclContext, DeclOrigin, Rawness, consteval, generics::generics, infer::{ AllowTwoPhase, BreakableKind, coerce::CoerceMany, find_continuable, @@ -35,8 +33,7 @@ use crate::{ lower::{GenericPredicates, lower_mutability}, method_resolution::{self, CandidateId, MethodCallee, MethodError}, next_solver::{ - ClauseKind, ErrorGuaranteed, FnSig, GenericArg, GenericArgs, TraitRef, Ty, TyKind, - TypeError, + ClauseKind, FnSig, GenericArg, GenericArgs, TraitRef, Ty, TyKind, TypeError, infer::{ BoundRegionConversionTime, InferOk, traits::{Obligation, ObligationCause}, @@ -44,7 +41,6 @@ use crate::{ obligation_ctxt::ObligationCtxt, util::clauses_as_obligations, }, - traits::FnTrait, }; use super::{ @@ -1180,68 +1176,6 @@ impl<'db> InferenceContext<'_, 'db> { } oprnd_t } - - pub(crate) fn write_fn_trait_method_resolution( - &mut self, - fn_x: FnTrait, - derefed_callee: Ty<'db>, - adjustments: &mut Vec<Adjustment>, - callee_ty: Ty<'db>, - params: &[Ty<'db>], - tgt_expr: ExprId, - ) { - match fn_x { - FnTrait::FnOnce | FnTrait::AsyncFnOnce => (), - FnTrait::FnMut | FnTrait::AsyncFnMut => { - if let TyKind::Ref(lt, inner, Mutability::Mut) = derefed_callee.kind() { - if adjustments - .last() - .map(|it| matches!(it.kind, Adjust::Borrow(_))) - .unwrap_or(true) - { - // prefer reborrow to move - adjustments - .push(Adjustment { kind: Adjust::Deref(None), target: inner.store() }); - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Mut, - inner, - lt, - )) - } - } else { - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Mut, - derefed_callee, - self.table.next_region_var(), - )); - } - } - FnTrait::Fn | FnTrait::AsyncFn => { - if !matches!(derefed_callee.kind(), TyKind::Ref(_, _, Mutability::Not)) { - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Not, - derefed_callee, - self.table.next_region_var(), - )); - } - } - } - let Some(trait_) = fn_x.get_id(self.lang_items) else { - return; - }; - let trait_data = trait_.trait_items(self.db); - if let Some(func) = trait_data.method_by_name(&fn_x.method_name()) { - let subst = GenericArgs::new_from_slice(&[ - callee_ty.into(), - Ty::new_tup(self.interner(), params).into(), - ]); - self.write_method_resolution(tgt_expr, func, subst); - } - } - fn infer_expr_array(&mut self, array: &Array, expected: &Expectation<'db>) -> Ty<'db> { let elem_ty = match expected .to_option(&mut self.table) @@ -1658,76 +1592,6 @@ impl<'db> InferenceContext<'_, 'db> { MethodCallee { def_id, args, sig } } - fn infer_call( - &mut self, - tgt_expr: ExprId, - callee: ExprId, - args: &[ExprId], - expected: &Expectation<'db>, - ) -> Ty<'db> { - let callee_ty = self.infer_expr(callee, &Expectation::none(), ExprIsRead::Yes); - let callee_ty = self.table.try_structurally_resolve_type(callee_ty); - let interner = self.interner(); - let mut derefs = InferenceContextAutoderef::new_from_inference_context(self, callee_ty); - let (res, derefed_callee) = loop { - let Some((callee_deref_ty, _)) = derefs.next() else { - break (None, callee_ty); - }; - if let Some(res) = derefs.ctx().table.callable_sig(callee_deref_ty, args.len()) { - break (Some(res), callee_deref_ty); - } - }; - // if the function is unresolved, we use is_varargs=true to - // suppress the arg count diagnostic here - let is_varargs = derefed_callee.callable_sig(interner).is_some_and(|sig| sig.c_variadic()) - || res.is_none(); - let (param_tys, ret_ty) = match res { - Some((func, params, ret_ty)) => { - let infer_ok = derefs.adjust_steps_as_infer_ok(); - let mut adjustments = self.table.register_infer_ok(infer_ok); - if let Some(fn_x) = func { - self.write_fn_trait_method_resolution( - fn_x, - derefed_callee, - &mut adjustments, - callee_ty, - ¶ms, - tgt_expr, - ); - } - if let TyKind::Closure(c, _) = self.table.resolve_completely(callee_ty).kind() { - self.add_current_closure_dependency(c.into()); - self.deferred_closures.entry(c.into()).or_default().push(( - derefed_callee, - callee_ty, - params.clone(), - tgt_expr, - )); - } - self.write_expr_adj(callee, adjustments.into_boxed_slice()); - (params, ret_ty) - } - None => { - self.push_diagnostic(InferenceDiagnostic::ExpectedFunction { - call_expr: tgt_expr, - found: callee_ty.store(), - }); - (Vec::new(), Ty::new_error(interner, ErrorGuaranteed)) - } - }; - let indices_to_skip = self.check_legacy_const_generics(derefed_callee, args); - self.check_call( - tgt_expr, - args, - callee_ty, - ¶m_tys, - ret_ty, - &indices_to_skip, - is_varargs, - expected, - ) - } - fn check_call( &mut self, tgt_expr: ExprId, @@ -1749,6 +1613,7 @@ impl<'db> InferenceContext<'_, 'db> { args, indices_to_skip, is_varargs, + TupleArgumentsFlag::DontTupleArguments, ); ret_ty } @@ -1879,13 +1744,22 @@ impl<'db> InferenceContext<'_, 'db> { }; let ret_ty = sig.output(); - self.check_call_arguments(tgt_expr, param_tys, ret_ty, expected, args, &[], sig.c_variadic); + self.check_call_arguments( + tgt_expr, + param_tys, + ret_ty, + expected, + args, + &[], + sig.c_variadic, + TupleArgumentsFlag::DontTupleArguments, + ); ret_ty } /// Generic function that factors out common logic from function calls, /// method calls and overloaded operators. - pub(in super::super) fn check_call_arguments( + pub(super) fn check_call_arguments( &mut self, call_expr: ExprId, // Types (as defined in the *signature* of the target function) @@ -1898,6 +1772,8 @@ impl<'db> InferenceContext<'_, 'db> { skip_indices: &[u32], // Whether the function is variadic, for example when imported from C c_variadic: bool, + // Whether the arguments have been bundled in a tuple (ex: closures) + tuple_arguments: TupleArgumentsFlag, ) { let formal_input_tys: Vec<_> = formal_input_tys .iter() @@ -1949,12 +1825,46 @@ impl<'db> InferenceContext<'_, 'db> { }) .unwrap_or_default(); + // If the arguments should be wrapped in a tuple (ex: closures), unwrap them here + let (formal_input_tys, expected_input_tys) = + if tuple_arguments == TupleArgumentsFlag::TupleArguments { + let tuple_type = self.table.structurally_resolve_type(formal_input_tys[0]); + match tuple_type.kind() { + // We expected a tuple and got a tuple + TyKind::Tuple(arg_types) => { + // Argument length differs + if arg_types.len() != provided_args.len() { + // FIXME: Emit an error. + } + let expected_input_tys = match expected_input_tys { + Some(expected_input_tys) => match expected_input_tys.first() { + Some(ty) => match ty.kind() { + TyKind::Tuple(tys) => Some(tys.iter().collect()), + _ => None, + }, + None => None, + }, + None => None, + }; + (arg_types.iter().collect(), expected_input_tys) + } + _ => { + // Otherwise, there's a mismatch, so clear out what we're expecting, and set + // our input types to err_args so we don't blow up the error messages + // FIXME: Emit an error. + (vec![self.types.types.error; provided_args.len()], None) + } + } + } else { + (formal_input_tys.to_vec(), expected_input_tys) + }; + // If there are no external expectations at the call site, just use the types from the function defn - let expected_input_tys = if let Some(expected_input_tys) = &expected_input_tys { + let expected_input_tys = if let Some(expected_input_tys) = expected_input_tys { assert_eq!(expected_input_tys.len(), formal_input_tys.len()); expected_input_tys } else { - &formal_input_tys + formal_input_tys.clone() }; let minimum_input_count = expected_input_tys.len(); @@ -2127,51 +2037,6 @@ impl<'db> InferenceContext<'_, 'db> { } } - /// Returns the argument indices to skip. - fn check_legacy_const_generics(&mut self, callee: Ty<'db>, args: &[ExprId]) -> Box<[u32]> { - let (func, _subst) = match callee.kind() { - TyKind::FnDef(callable, subst) => { - let func = match callable.0 { - CallableDefId::FunctionId(f) => f, - _ => return Default::default(), - }; - (func, subst) - } - _ => return Default::default(), - }; - - let data = FunctionSignature::of(self.db, func); - let Some(legacy_const_generics_indices) = data.legacy_const_generics_indices(self.db, func) - else { - return Default::default(); - }; - let mut legacy_const_generics_indices = Box::<[u32]>::from(legacy_const_generics_indices); - - // only use legacy const generics if the param count matches with them - if data.params.len() + legacy_const_generics_indices.len() != args.len() { - if args.len() <= data.params.len() { - return Default::default(); - } else { - // there are more parameters than there should be without legacy - // const params; use them - legacy_const_generics_indices.sort_unstable(); - return legacy_const_generics_indices; - } - } - - // check legacy const parameters - for arg_idx in legacy_const_generics_indices.iter().copied() { - if arg_idx >= args.len() as u32 { - continue; - } - let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly - self.infer_expr(args[arg_idx as usize], &expected, ExprIsRead::Yes); - // FIXME: evaluate and unify with the const - } - legacy_const_generics_indices.sort_unstable(); - legacy_const_generics_indices - } - pub(super) fn with_breakable_ctx<T>( &mut self, kind: BreakableKind, @@ -2187,3 +2052,28 @@ impl<'db> InferenceContext<'_, 'db> { (if ctx.may_break { ctx.coerce.map(|ctx| ctx.complete(self)) } else { None }, res) } } + +/// Controls whether the arguments are tupled. This is used for the call +/// operator. +/// +/// Tupling means that all call-side arguments are packed into a tuple and +/// passed as a single parameter. For example, if tupling is enabled, this +/// function: +/// ``` +/// fn f(x: (isize, isize)) {} +/// ``` +/// Can be called as: +/// ```ignore UNSOLVED (can this be done in user code?) +/// # fn f(x: (isize, isize)) {} +/// f(1, 2); +/// ``` +/// Instead of: +/// ``` +/// # fn f(x: (isize, isize)) {} +/// f((1, 2)); +/// ``` +#[derive(Copy, Clone, Eq, PartialEq)] +pub(super) enum TupleArgumentsFlag { + DontTupleArguments, + TupleArguments, +} |