Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
-rw-r--r--crates/hir-ty/src/infer/closure.rs56
1 files changed, 46 insertions, 10 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index 800897c6fc..bd57ca8916 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -38,7 +38,7 @@ use crate::{
infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever},
make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
- to_chalk_trait_id,
+ to_assoc_type_id, to_chalk_trait_id,
traits::FnTrait,
utils::{self, elaborate_clause_supertraits},
};
@@ -245,7 +245,7 @@ impl InferenceContext<'_> {
}
fn deduce_closure_kind_from_predicate_clauses(
- &self,
+ &mut self,
expected_ty: &Ty,
clauses: impl DoubleEndedIterator<Item = WhereClause>,
closure_kind: ClosureKind,
@@ -378,7 +378,7 @@ impl InferenceContext<'_> {
}
fn deduce_sig_from_projection(
- &self,
+ &mut self,
closure_kind: ClosureKind,
projection_ty: &ProjectionTy,
projected_ty: &Ty,
@@ -392,13 +392,16 @@ impl InferenceContext<'_> {
// For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
// for closures and async closures, respectively.
- match closure_kind {
- ClosureKind::Closure | ClosureKind::Async
- if self.fn_trait_kind_from_trait_id(trait_).is_some() =>
- {
- self.extract_sig_from_projection(projection_ty, projected_ty)
- }
- _ => None,
+ let fn_trait_kind = self.fn_trait_kind_from_trait_id(trait_)?;
+ if !matches!(closure_kind, ClosureKind::Closure | ClosureKind::Async) {
+ return None;
+ }
+ if fn_trait_kind.is_async() {
+ // If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is,
+ // but we do know it must implement `Future<Output = X>`.
+ self.extract_async_fn_sig_from_projection(projection_ty, projected_ty)
+ } else {
+ self.extract_sig_from_projection(projection_ty, projected_ty)
}
}
@@ -424,6 +427,39 @@ impl InferenceContext<'_> {
)))
}
+ fn extract_async_fn_sig_from_projection(
+ &mut self,
+ projection_ty: &ProjectionTy,
+ projected_ty: &Ty,
+ ) -> Option<FnSubst<Interner>> {
+ let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner);
+
+ let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else {
+ return None;
+ };
+
+ let ret_param_future_output = projected_ty;
+ let ret_param_future = self.table.new_type_var();
+ let future_output =
+ LangItem::FutureOutput.resolve_type_alias(self.db, self.resolver.krate())?;
+ let future_projection = crate::AliasTy::Projection(crate::ProjectionTy {
+ associated_ty_id: to_assoc_type_id(future_output),
+ substitution: Substitution::from1(Interner, ret_param_future.clone()),
+ });
+ self.table.register_obligation(
+ crate::AliasEq { alias: future_projection, ty: ret_param_future_output.clone() }
+ .cast(Interner),
+ );
+
+ Some(FnSubst(Substitution::from_iter(
+ Interner,
+ input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new(
+ Interner,
+ chalk_ir::GenericArgData::Ty(ret_param_future),
+ ))),
+ )))
+ }
+
fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
}