Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
-rw-r--r--crates/hir-ty/src/infer/closure.rs932
1 files changed, 928 insertions, 4 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs
index a6449d019f..23189f383e 100644
--- a/crates/hir-ty/src/infer/closure.rs
+++ b/crates/hir-ty/src/infer/closure.rs
@@ -1,12 +1,33 @@
//! Inference of closure parameter types based on the closure's expected type.
-use chalk_ir::{cast::Cast, AliasEq, AliasTy, FnSubst, WhereClause};
-use hir_def::{expr::ExprId, HasModule};
+use std::{cmp, collections::HashMap, convert::Infallible, mem};
+
+use chalk_ir::{
+ cast::Cast,
+ fold::{FallibleTypeFolder, TypeFoldable},
+ AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause,
+};
+use hir_def::{
+ data::adt::VariantData,
+ hir::{Array, BinaryOp, BindingId, CaptureBy, Expr, ExprId, Pat, PatId, Statement, UnaryOp},
+ lang_item::LangItem,
+ resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
+ DefWithBodyId, FieldId, HasModule, VariantId,
+};
+use hir_expand::name;
+use rustc_hash::FxHashMap;
use smallvec::SmallVec;
+use stdx::never;
use crate::{
- to_chalk_trait_id, utils, ChalkTraitId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty,
- TyExt, TyKind,
+ db::HirDatabase,
+ from_placeholder_idx, make_binders,
+ mir::{BorrowKind, MirSpan, ProjectionElem},
+ static_lifetime, to_chalk_trait_id,
+ traits::FnTrait,
+ utils::{self, generics, Generics},
+ Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnPointer, FnSig,
+ Interner, Substitution, Ty, TyExt,
};
use super::{Expectation, InferenceContext};
@@ -86,3 +107,906 @@ impl InferenceContext<'_> {
None
}
}
+
+// The below functions handle capture and closure kind (Fn, FnMut, ..)
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub(crate) struct HirPlace {
+ pub(crate) local: BindingId,
+ pub(crate) projections: Vec<ProjectionElem<Infallible, Ty>>,
+}
+
+impl HirPlace {
+ fn ty(&self, ctx: &mut InferenceContext<'_>) -> Ty {
+ let mut ty = ctx.table.resolve_completely(ctx.result[self.local].clone());
+ for p in &self.projections {
+ ty = p.projected_ty(
+ ty,
+ ctx.db,
+ |_, _, _| {
+ unreachable!("Closure field only happens in MIR");
+ },
+ ctx.owner.module(ctx.db.upcast()).krate(),
+ );
+ }
+ ty.clone()
+ }
+
+ fn capture_kind_of_truncated_place(
+ &self,
+ mut current_capture: CaptureKind,
+ len: usize,
+ ) -> CaptureKind {
+ match current_capture {
+ CaptureKind::ByRef(BorrowKind::Mut { .. }) => {
+ if self.projections[len..].iter().any(|x| *x == ProjectionElem::Deref) {
+ current_capture = CaptureKind::ByRef(BorrowKind::Unique);
+ }
+ }
+ _ => (),
+ }
+ current_capture
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+pub enum CaptureKind {
+ ByRef(BorrowKind),
+ ByValue,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct CapturedItem {
+ pub(crate) place: HirPlace,
+ pub(crate) kind: CaptureKind,
+ pub(crate) span: MirSpan,
+ pub(crate) ty: Binders<Ty>,
+}
+
+impl CapturedItem {
+ pub fn local(&self) -> BindingId {
+ self.place.local
+ }
+
+ pub fn ty(&self, subst: &Substitution) -> Ty {
+ self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst())
+ }
+
+ pub fn kind(&self) -> CaptureKind {
+ self.kind
+ }
+
+ pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
+ let body = db.body(owner);
+ let mut result = body[self.place.local].name.display(db.upcast()).to_string();
+ let mut field_need_paren = false;
+ for proj in &self.place.projections {
+ match proj {
+ ProjectionElem::Deref => {
+ result = format!("*{result}");
+ field_need_paren = true;
+ }
+ ProjectionElem::Field(f) => {
+ if field_need_paren {
+ result = format!("({result})");
+ }
+ let variant_data = f.parent.variant_data(db.upcast());
+ let field = match &*variant_data {
+ VariantData::Record(fields) => fields[f.local_id]
+ .name
+ .as_str()
+ .unwrap_or("[missing field]")
+ .to_string(),
+ VariantData::Tuple(fields) => fields
+ .iter()
+ .position(|x| x.0 == f.local_id)
+ .unwrap_or_default()
+ .to_string(),
+ VariantData::Unit => "[missing field]".to_string(),
+ };
+ result = format!("{result}.{field}");
+ field_need_paren = false;
+ }
+ &ProjectionElem::TupleOrClosureField(field) => {
+ if field_need_paren {
+ result = format!("({result})");
+ }
+ result = format!("{result}.{field}");
+ field_need_paren = false;
+ }
+ ProjectionElem::Index(_)
+ | ProjectionElem::ConstantIndex { .. }
+ | ProjectionElem::Subslice { .. }
+ | ProjectionElem::OpaqueCast(_) => {
+ never!("Not happen in closure capture");
+ continue;
+ }
+ }
+ }
+ result
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct CapturedItemWithoutTy {
+ pub(crate) place: HirPlace,
+ pub(crate) kind: CaptureKind,
+ pub(crate) span: MirSpan,
+}
+
+impl CapturedItemWithoutTy {
+ fn with_ty(self, ctx: &mut InferenceContext<'_>) -> CapturedItem {
+ let ty = self.place.ty(ctx).clone();
+ let ty = match &self.kind {
+ CaptureKind::ByValue => ty,
+ CaptureKind::ByRef(bk) => {
+ let m = match bk {
+ BorrowKind::Mut { .. } => Mutability::Mut,
+ _ => Mutability::Not,
+ };
+ TyKind::Ref(m, static_lifetime(), ty).intern(Interner)
+ }
+ };
+ return CapturedItem {
+ place: self.place,
+ kind: self.kind,
+ span: self.span,
+ ty: replace_placeholder_with_binder(ctx.db, ctx.owner, ty),
+ };
+
+ fn replace_placeholder_with_binder(
+ db: &dyn HirDatabase,
+ owner: DefWithBodyId,
+ ty: Ty,
+ ) -> Binders<Ty> {
+ struct Filler<'a> {
+ db: &'a dyn HirDatabase,
+ generics: Generics,
+ }
+ impl FallibleTypeFolder<Interner> for Filler<'_> {
+ type Error = ();
+
+ fn as_dyn(&mut self) -> &mut dyn FallibleTypeFolder<Interner, Error = Self::Error> {
+ self
+ }
+
+ fn interner(&self) -> Interner {
+ Interner
+ }
+
+ fn try_fold_free_placeholder_const(
+ &mut self,
+ ty: chalk_ir::Ty<Interner>,
+ idx: chalk_ir::PlaceholderIndex,
+ outer_binder: DebruijnIndex,
+ ) -> Result<chalk_ir::Const<Interner>, Self::Error> {
+ let x = from_placeholder_idx(self.db, idx);
+ let Some(idx) = self.generics.param_idx(x) else {
+ return Err(());
+ };
+ Ok(BoundVar::new(outer_binder, idx).to_const(Interner, ty))
+ }
+
+ fn try_fold_free_placeholder_ty(
+ &mut self,
+ idx: chalk_ir::PlaceholderIndex,
+ outer_binder: DebruijnIndex,
+ ) -> std::result::Result<Ty, Self::Error> {
+ let x = from_placeholder_idx(self.db, idx);
+ let Some(idx) = self.generics.param_idx(x) else {
+ return Err(());
+ };
+ Ok(BoundVar::new(outer_binder, idx).to_ty(Interner))
+ }
+ }
+ let Some(generic_def) = owner.as_generic_def_id() else {
+ return Binders::empty(Interner, ty);
+ };
+ let filler = &mut Filler { db, generics: generics(db.upcast(), generic_def) };
+ let result = ty.clone().try_fold_with(filler, DebruijnIndex::INNERMOST).unwrap_or(ty);
+ make_binders(db, &filler.generics, result)
+ }
+ }
+}
+
+impl InferenceContext<'_> {
+ fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option<HirPlace> {
+ let r = self.place_of_expr_without_adjust(tgt_expr)?;
+ let default = vec![];
+ let adjustments = self.result.expr_adjustments.get(&tgt_expr).unwrap_or(&default);
+ apply_adjusts_to_place(r, adjustments)
+ }
+
+ fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option<HirPlace> {
+ match &self.body[tgt_expr] {
+ Expr::Path(p) => {
+ let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr);
+ if let Some(r) = resolver.resolve_path_in_value_ns(self.db.upcast(), p) {
+ if let ResolveValueResult::ValueNs(v) = r {
+ if let ValueNs::LocalBinding(b) = v {
+ return Some(HirPlace { local: b, projections: vec![] });
+ }
+ }
+ }
+ }
+ Expr::Field { expr, name } => {
+ let mut place = self.place_of_expr(*expr)?;
+ if let TyKind::Tuple(..) = self.expr_ty(*expr).kind(Interner) {
+ let index = name.as_tuple_index()?;
+ place.projections.push(ProjectionElem::TupleOrClosureField(index))
+ } else {
+ let field = self.result.field_resolution(tgt_expr)?;
+ place.projections.push(ProjectionElem::Field(field));
+ }
+ return Some(place);
+ }
+ Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
+ if matches!(
+ self.expr_ty_after_adjustments(*expr).kind(Interner),
+ TyKind::Ref(..) | TyKind::Raw(..)
+ ) {
+ let mut place = self.place_of_expr(*expr)?;
+ place.projections.push(ProjectionElem::Deref);
+ return Some(place);
+ }
+ }
+ _ => (),
+ }
+ None
+ }
+
+ fn push_capture(&mut self, capture: CapturedItemWithoutTy) {
+ self.current_captures.push(capture);
+ }
+
+ fn ref_expr(&mut self, expr: ExprId) {
+ if let Some(place) = self.place_of_expr(expr) {
+ self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared), expr.into());
+ }
+ self.walk_expr(expr);
+ }
+
+ fn add_capture(&mut self, place: HirPlace, kind: CaptureKind, span: MirSpan) {
+ if self.is_upvar(&place) {
+ self.push_capture(CapturedItemWithoutTy { place, kind, span });
+ }
+ }
+
+ fn mutate_expr(&mut self, expr: ExprId) {
+ if let Some(place) = self.place_of_expr(expr) {
+ self.add_capture(
+ place,
+ CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false }),
+ expr.into(),
+ );
+ }
+ self.walk_expr(expr);
+ }
+
+ fn consume_expr(&mut self, expr: ExprId) {
+ if let Some(place) = self.place_of_expr(expr) {
+ self.consume_place(place, expr.into());
+ }
+ self.walk_expr(expr);
+ }
+
+ fn consume_place(&mut self, place: HirPlace, span: MirSpan) {
+ if self.is_upvar(&place) {
+ let ty = place.ty(self).clone();
+ let kind = if self.is_ty_copy(ty) {
+ CaptureKind::ByRef(BorrowKind::Shared)
+ } else {
+ CaptureKind::ByValue
+ };
+ self.push_capture(CapturedItemWithoutTy { place, kind, span });
+ }
+ }
+
+ fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment]) {
+ if let Some((last, rest)) = adjustment.split_last() {
+ match last.kind {
+ Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => {
+ self.walk_expr_with_adjust(tgt_expr, rest)
+ }
+ Adjust::Deref(Some(m)) => match m.0 {
+ Some(m) => {
+ self.ref_capture_with_adjusts(m, tgt_expr, rest);
+ }
+ None => unreachable!(),
+ },
+ Adjust::Borrow(b) => {
+ self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest);
+ }
+ }
+ } else {
+ self.walk_expr_without_adjust(tgt_expr);
+ }
+ }
+
+ fn ref_capture_with_adjusts(&mut self, m: Mutability, tgt_expr: ExprId, rest: &[Adjustment]) {
+ let capture_kind = match m {
+ Mutability::Mut => {
+ CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false })
+ }
+ Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared),
+ };
+ if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) {
+ if let Some(place) = apply_adjusts_to_place(place, rest) {
+ self.add_capture(place, capture_kind, tgt_expr.into());
+ }
+ }
+ self.walk_expr_with_adjust(tgt_expr, rest);
+ }
+
+ fn walk_expr(&mut self, tgt_expr: ExprId) {
+ if let Some(x) = self.result.expr_adjustments.get_mut(&tgt_expr) {
+ // FIXME: this take is completely unneeded, and just is here to make borrow checker
+ // happy. Remove it if you can.
+ let x_taken = mem::take(x);
+ self.walk_expr_with_adjust(tgt_expr, &x_taken);
+ *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken;
+ } else {
+ self.walk_expr_without_adjust(tgt_expr);
+ }
+ }
+
+ fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) {
+ match &self.body[tgt_expr] {
+ Expr::If { condition, then_branch, else_branch } => {
+ self.consume_expr(*condition);
+ self.consume_expr(*then_branch);
+ if let &Some(expr) = else_branch {
+ self.consume_expr(expr);
+ }
+ }
+ Expr::Async { statements, tail, .. }
+ | Expr::Unsafe { statements, tail, .. }
+ | Expr::Block { statements, tail, .. } => {
+ for s in statements.iter() {
+ match s {
+ Statement::Let { pat, type_ref: _, initializer, else_branch } => {
+ if let Some(else_branch) = else_branch {
+ self.consume_expr(*else_branch);
+ if let Some(initializer) = initializer {
+ self.consume_expr(*initializer);
+ }
+ return;
+ }
+ if let Some(initializer) = initializer {
+ self.walk_expr(*initializer);
+ if let Some(place) = self.place_of_expr(*initializer) {
+ self.consume_with_pat(place, *pat);
+ }
+ }
+ }
+ Statement::Expr { expr, has_semi: _ } => {
+ self.consume_expr(*expr);
+ }
+ }
+ }
+ if let Some(tail) = tail {
+ self.consume_expr(*tail);
+ }
+ }
+ Expr::While { condition, body, label: _ } => {
+ self.consume_expr(*condition);
+ self.consume_expr(*body);
+ }
+ Expr::Call { callee, args, is_assignee_expr: _ } => {
+ self.consume_expr(*callee);
+ self.consume_exprs(args.iter().copied());
+ }
+ Expr::MethodCall { receiver, args, .. } => {
+ self.consume_expr(*receiver);
+ self.consume_exprs(args.iter().copied());
+ }
+ Expr::Match { expr, arms } => {
+ for arm in arms.iter() {
+ self.consume_expr(arm.expr);
+ if let Some(guard) = arm.guard {
+ self.consume_expr(guard);
+ }
+ }
+ self.walk_expr(*expr);
+ if let Some(discr_place) = self.place_of_expr(*expr) {
+ if self.is_upvar(&discr_place) {
+ let mut capture_mode = None;
+ for arm in arms.iter() {
+ self.walk_pat(&mut capture_mode, arm.pat);
+ }
+ if let Some(c) = capture_mode {
+ self.push_capture(CapturedItemWithoutTy {
+ place: discr_place,
+ kind: c,
+ span: (*expr).into(),
+ })
+ }
+ }
+ }
+ }
+ Expr::Break { expr, label: _ }
+ | Expr::Return { expr }
+ | Expr::Yield { expr }
+ | Expr::Yeet { expr } => {
+ if let &Some(expr) = expr {
+ self.consume_expr(expr);
+ }
+ }
+ Expr::RecordLit { fields, spread, .. } => {
+ if let &Some(expr) = spread {
+ self.consume_expr(expr);
+ }
+ self.consume_exprs(fields.iter().map(|x| x.expr));
+ }
+ Expr::Field { expr, name: _ } => self.select_from_expr(*expr),
+ Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
+ if matches!(
+ self.expr_ty_after_adjustments(*expr).kind(Interner),
+ TyKind::Ref(..) | TyKind::Raw(..)
+ ) {
+ self.select_from_expr(*expr);
+ } else if let Some((f, _)) = self.result.method_resolution(tgt_expr) {
+ let mutability = 'b: {
+ if let Some(deref_trait) =
+ self.resolve_lang_item(LangItem::DerefMut).and_then(|x| x.as_trait())
+ {
+ if let Some(deref_fn) =
+ self.db.trait_data(deref_trait).method_by_name(&name![deref_mut])
+ {
+ break 'b deref_fn == f;
+ }
+ }
+ false
+ };
+ if mutability {
+ self.mutate_expr(*expr);
+ } else {
+ self.ref_expr(*expr);
+ }
+ } else {
+ self.select_from_expr(*expr);
+ }
+ }
+ Expr::UnaryOp { expr, op: _ }
+ | Expr::Array(Array::Repeat { initializer: expr, repeat: _ })
+ | Expr::Await { expr }
+ | Expr::Loop { body: expr, label: _ }
+ | Expr::Let { pat: _, expr }
+ | Expr::Box { expr }
+ | Expr::Cast { expr, type_ref: _ } => {
+ self.consume_expr(*expr);
+ }
+ Expr::Ref { expr, rawness: _, mutability } => match mutability {
+ hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr),
+ hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr),
+ },
+ Expr::BinaryOp { lhs, rhs, op } => {
+ let Some(op) = op else {
+ return;
+ };
+ if matches!(op, BinaryOp::Assignment { .. }) {
+ self.mutate_expr(*lhs);
+ self.consume_expr(*rhs);
+ return;
+ }
+ self.consume_expr(*lhs);
+ self.consume_expr(*rhs);
+ }
+ Expr::Range { lhs, rhs, range_type: _ } => {
+ if let &Some(expr) = lhs {
+ self.consume_expr(expr);
+ }
+ if let &Some(expr) = rhs {
+ self.consume_expr(expr);
+ }
+ }
+ Expr::Index { base, index } => {
+ self.select_from_expr(*base);
+ self.consume_expr(*index);
+ }
+ Expr::Closure { .. } => {
+ let ty = self.expr_ty(tgt_expr);
+ let TyKind::Closure(id, _) = ty.kind(Interner) else {
+ never!("closure type is always closure");
+ return;
+ };
+ let (captures, _) =
+ self.result.closure_info.get(id).expect(
+ "We sort closures, so we should always have data for inner closures",
+ );
+ let mut cc = mem::take(&mut self.current_captures);
+ cc.extend(captures.iter().filter(|x| self.is_upvar(&x.place)).map(|x| {
+ CapturedItemWithoutTy { place: x.place.clone(), kind: x.kind, span: x.span }
+ }));
+ self.current_captures = cc;
+ }
+ Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ })
+ | Expr::Tuple { exprs, is_assignee_expr: _ } => {
+ self.consume_exprs(exprs.iter().copied())
+ }
+ Expr::Missing
+ | Expr::Continue { .. }
+ | Expr::Path(_)
+ | Expr::Literal(_)
+ | Expr::Const(_)
+ | Expr::Underscore => (),
+ }
+ }
+
+ fn walk_pat(&mut self, result: &mut Option<CaptureKind>, pat: PatId) {
+ let mut update_result = |ck: CaptureKind| match result {
+ Some(r) => {
+ *r = cmp::max(*r, ck);
+ }
+ None => *result = Some(ck),
+ };
+
+ self.walk_pat_inner(
+ pat,
+ &mut update_result,
+ BorrowKind::Mut { allow_two_phase_borrow: false },
+ );
+ }
+
+ fn walk_pat_inner(
+ &mut self,
+ p: PatId,
+ update_result: &mut impl FnMut(CaptureKind),
+ mut for_mut: BorrowKind,
+ ) {
+ match &self.body[p] {
+ Pat::Ref { .. }
+ | Pat::Box { .. }
+ | Pat::Missing
+ | Pat::Wild
+ | Pat::Tuple { .. }
+ | Pat::Or(_) => (),
+ Pat::TupleStruct { .. } | Pat::Record { .. } => {
+ if let Some(variant) = self.result.variant_resolution_for_pat(p) {
+ let adt = variant.adt_id();
+ let is_multivariant = match adt {
+ hir_def::AdtId::EnumId(e) => self.db.enum_data(e).variants.len() != 1,
+ _ => false,
+ };
+ if is_multivariant {
+ update_result(CaptureKind::ByRef(BorrowKind::Shared));
+ }
+ }
+ }
+ Pat::Slice { .. }
+ | Pat::ConstBlock(_)
+ | Pat::Path(_)
+ | Pat::Lit(_)
+ | Pat::Range { .. } => {
+ update_result(CaptureKind::ByRef(BorrowKind::Shared));
+ }
+ Pat::Bind { id, .. } => match self.result.binding_modes[*id] {
+ crate::BindingMode::Move => {
+ if self.is_ty_copy(self.result.type_of_binding[*id].clone()) {
+ update_result(CaptureKind::ByRef(BorrowKind::Shared));
+ } else {
+ update_result(CaptureKind::ByValue);
+ }
+ }
+ crate::BindingMode::Ref(r) => match r {
+ Mutability::Mut => update_result(CaptureKind::ByRef(for_mut)),
+ Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)),
+ },
+ },
+ }
+ if self.result.pat_adjustments.get(&p).map_or(false, |x| !x.is_empty()) {
+ for_mut = BorrowKind::Unique;
+ }
+ self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut));
+ }
+
+ fn expr_ty(&self, expr: ExprId) -> Ty {
+ self.result[expr].clone()
+ }
+
+ fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty {
+ let mut ty = None;
+ if let Some(x) = self.result.expr_adjustments.get(&e) {
+ if let Some(x) = x.last() {
+ ty = Some(x.target.clone());
+ }
+ }
+ ty.unwrap_or_else(|| self.expr_ty(e))
+ }
+
+ fn is_upvar(&self, place: &HirPlace) -> bool {
+ let b = &self.body[place.local];
+ if let Some(c) = self.current_closure {
+ let (_, root) = self.db.lookup_intern_closure(c.into());
+ return b.is_upvar(root);
+ }
+ false
+ }
+
+ fn is_ty_copy(&mut self, ty: Ty) -> bool {
+ if let TyKind::Closure(id, _) = ty.kind(Interner) {
+ // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We
+ // should probably let chalk know which closures are copy, but I don't know how doing it
+ // without creating query cycles.
+ return self.result.closure_info.get(id).map(|x| x.1 == FnTrait::Fn).unwrap_or(true);
+ }
+ self.table.resolve_completely(ty).is_copy(self.db, self.owner)
+ }
+
+ fn select_from_expr(&mut self, expr: ExprId) {
+ self.walk_expr(expr);
+ }
+
+ fn adjust_for_move_closure(&mut self) {
+ for capture in &mut self.current_captures {
+ if let Some(first_deref) =
+ capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref)
+ {
+ capture.place.projections.truncate(first_deref);
+ }
+ capture.kind = CaptureKind::ByValue;
+ }
+ }
+
+ fn minimize_captures(&mut self) {
+ self.current_captures.sort_by_key(|x| x.place.projections.len());
+ let mut hash_map = HashMap::<HirPlace, usize>::new();
+ let result = mem::take(&mut self.current_captures);
+ for item in result {
+ let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] };
+ let mut it = item.place.projections.iter();
+ let prev_index = loop {
+ if let Some(k) = hash_map.get(&lookup_place) {
+ break Some(*k);
+ }
+ match it.next() {
+ Some(x) => lookup_place.projections.push(x.clone()),
+ None => break None,
+ }
+ };
+ match prev_index {
+ Some(p) => {
+ let len = self.current_captures[p].place.projections.len();
+ let kind_after_truncate =
+ item.place.capture_kind_of_truncated_place(item.kind, len);
+ self.current_captures[p].kind =
+ cmp::max(kind_after_truncate, self.current_captures[p].kind);
+ }
+ None => {
+ hash_map.insert(item.place.clone(), self.current_captures.len());
+ self.current_captures.push(item);
+ }
+ }
+ }
+ }
+
+ fn consume_with_pat(&mut self, mut place: HirPlace, pat: PatId) {
+ let cnt = self.result.pat_adjustments.get(&pat).map(|x| x.len()).unwrap_or_default();
+ place.projections = place
+ .projections
+ .iter()
+ .cloned()
+ .chain((0..cnt).map(|_| ProjectionElem::Deref))
+ .collect::<Vec<_>>()
+ .into();
+ match &self.body[pat] {
+ Pat::Missing | Pat::Wild => (),
+ Pat::Tuple { args, ellipsis } => {
+ let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
+ let field_count = match self.result[pat].kind(Interner) {
+ TyKind::Tuple(_, s) => s.len(Interner),
+ _ => return,
+ };
+ let fields = 0..field_count;
+ let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev()));
+ for (arg, i) in it {
+ let mut p = place.clone();
+ p.projections.push(ProjectionElem::TupleOrClosureField(i));
+ self.consume_with_pat(p, *arg);
+ }
+ }
+ Pat::Or(pats) => {
+ for pat in pats.iter() {
+ self.consume_with_pat(place.clone(), *pat);
+ }
+ }
+ Pat::Record { args, .. } => {
+ let Some(variant) = self.result.variant_resolution_for_pat(pat) else {
+ return;
+ };
+ match variant {
+ VariantId::EnumVariantId(_) | VariantId::UnionId(_) => {
+ self.consume_place(place, pat.into())
+ }
+ VariantId::StructId(s) => {
+ let vd = &*self.db.struct_data(s).variant_data;
+ for field_pat in args.iter() {
+ let arg = field_pat.pat;
+ let Some(local_id) = vd.field(&field_pat.name) else {
+ continue;
+ };
+ let mut p = place.clone();
+ p.projections.push(ProjectionElem::Field(FieldId {
+ parent: variant.into(),
+ local_id,
+ }));
+ self.consume_with_pat(p, arg);
+ }
+ }
+ }
+ }
+ Pat::Range { .. }
+ | Pat::Slice { .. }
+ | Pat::ConstBlock(_)
+ | Pat::Path(_)
+ | Pat::Lit(_) => self.consume_place(place, pat.into()),
+ Pat::Bind { id, subpat: _ } => {
+ let mode = self.result.binding_modes[*id];
+ let capture_kind = match mode {
+ BindingMode::Move => {
+ self.consume_place(place, pat.into());
+ return;
+ }
+ BindingMode::Ref(Mutability::Not) => BorrowKind::Shared,
+ BindingMode::Ref(Mutability::Mut) => {
+ BorrowKind::Mut { allow_two_phase_borrow: false }
+ }
+ };
+ self.add_capture(place, CaptureKind::ByRef(capture_kind), pat.into());
+ }
+ Pat::TupleStruct { path: _, args, ellipsis } => {
+ let Some(variant) = self.result.variant_resolution_for_pat(pat) else {
+ return;
+ };
+ match variant {
+ VariantId::EnumVariantId(_) | VariantId::UnionId(_) => {
+ self.consume_place(place, pat.into())
+ }
+ VariantId::StructId(s) => {
+ let vd = &*self.db.struct_data(s).variant_data;
+ let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
+ let fields = vd.fields().iter();
+ let it =
+ al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev()));
+ for (arg, (i, _)) in it {
+ let mut p = place.clone();
+ p.projections.push(ProjectionElem::Field(FieldId {
+ parent: variant.into(),
+ local_id: i,
+ }));
+ self.consume_with_pat(p, *arg);
+ }
+ }
+ }
+ }
+ Pat::Ref { pat, mutability: _ } => {
+ place.projections.push(ProjectionElem::Deref);
+ self.consume_with_pat(place, *pat)
+ }
+ Pat::Box { .. } => (), // not supported
+ }
+ }
+
+ fn consume_exprs(&mut self, exprs: impl Iterator<Item = ExprId>) {
+ for expr in exprs {
+ self.consume_expr(expr);
+ }
+ }
+
+ fn closure_kind(&self) -> FnTrait {
+ let mut r = FnTrait::Fn;
+ for x in &self.current_captures {
+ r = cmp::min(
+ r,
+ match &x.kind {
+ CaptureKind::ByRef(BorrowKind::Unique | BorrowKind::Mut { .. }) => {
+ FnTrait::FnMut
+ }
+ CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn,
+ CaptureKind::ByValue => FnTrait::FnOnce,
+ },
+ )
+ }
+ r
+ }
+
+ fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait {
+ let (_, root) = self.db.lookup_intern_closure(closure.into());
+ self.current_closure = Some(closure);
+ let Expr::Closure { body, capture_by, .. } = &self.body[root] else {
+ unreachable!("Closure expression id is always closure");
+ };
+ self.consume_expr(*body);
+ for item in &self.current_captures {
+ if matches!(item.kind, CaptureKind::ByRef(BorrowKind::Mut { .. }))
+ && !item.place.projections.contains(&ProjectionElem::Deref)
+ {
+ // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in
+ // MIR. I didn't do that due duplicate diagnostics.
+ self.result.mutated_bindings_in_closure.insert(item.place.local);
+ }
+ }
+ // closure_kind should be done before adjust_for_move_closure
+ let closure_kind = self.closure_kind();
+ match capture_by {
+ CaptureBy::Value => self.adjust_for_move_closure(),
+ CaptureBy::Ref => (),
+ }
+ self.minimize_captures();
+ let result = mem::take(&mut self.current_captures);
+ let captures = result.into_iter().map(|x| x.with_ty(self)).collect::<Vec<_>>();
+ self.result.closure_info.insert(closure, (captures, closure_kind));
+ closure_kind
+ }
+
+ pub(crate) fn infer_closures(&mut self) {
+ let deferred_closures = self.sort_closures();
+ for (closure, exprs) in deferred_closures.into_iter().rev() {
+ self.current_captures = vec![];
+ let kind = self.analyze_closure(closure);
+
+ for (derefed_callee, callee_ty, params, expr) in exprs {
+ if let &Expr::Call { callee, .. } = &self.body[expr] {
+ let mut adjustments =
+ self.result.expr_adjustments.remove(&callee).unwrap_or_default();
+ self.write_fn_trait_method_resolution(
+ kind,
+ &derefed_callee,
+ &mut adjustments,
+ &callee_ty,
+ &params,
+ expr,
+ );
+ self.result.expr_adjustments.insert(callee, adjustments);
+ }
+ }
+ }
+ }
+
+ /// We want to analyze some closures before others, to have a correct analysis:
+ /// * We should analyze nested closures before the parent, since the parent should capture some of
+ /// the things that its children captures.
+ /// * If a closure calls another closure, we need to analyze the callee, to find out how we should
+ /// capture it (e.g. by move for FnOnce)
+ ///
+ /// These dependencies are collected in the main inference. We do a topological sort in this function. It
+ /// will consume the `deferred_closures` field and return its content in a sorted vector.
+ fn sort_closures(&mut self) -> Vec<(ClosureId, Vec<(Ty, Ty, Vec<Ty>, ExprId)>)> {
+ let mut deferred_closures = mem::take(&mut self.deferred_closures);
+ let mut dependents_count: FxHashMap<ClosureId, usize> =
+ deferred_closures.keys().map(|x| (*x, 0)).collect();
+ for (_, deps) in &self.closure_dependencies {
+ for dep in deps {
+ *dependents_count.entry(*dep).or_default() += 1;
+ }
+ }
+ let mut queue: Vec<_> =
+ deferred_closures.keys().copied().filter(|x| dependents_count[x] == 0).collect();
+ let mut result = vec![];
+ while let Some(x) = queue.pop() {
+ if let Some(d) = deferred_closures.remove(&x) {
+ result.push((x, d));
+ }
+ for dep in self.closure_dependencies.get(&x).into_iter().flat_map(|x| x.iter()) {
+ let cnt = dependents_count.get_mut(dep).unwrap();
+ *cnt -= 1;
+ if *cnt == 0 {
+ queue.push(*dep);
+ }
+ }
+ }
+ result
+ }
+}
+
+fn apply_adjusts_to_place(mut r: HirPlace, adjustments: &[Adjustment]) -> Option<HirPlace> {
+ for adj in adjustments {
+ match &adj.kind {
+ Adjust::Deref(None) => {
+ r.projections.push(ProjectionElem::Deref);
+ }
+ _ => return None,
+ }
+ }
+ Some(r)
+}