Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/infer/closure.rs')
| -rw-r--r-- | crates/hir-ty/src/infer/closure.rs | 385 |
1 files changed, 249 insertions, 136 deletions
diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 910e19dc58..68e1a6ee82 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -18,7 +18,7 @@ use hir_def::{ use hir_expand::name::Name; use intern::sym; use rustc_hash::FxHashMap; -use smallvec::SmallVec; +use smallvec::{smallvec, SmallVec}; use stdx::never; use crate::{ @@ -236,7 +236,13 @@ pub enum CaptureKind { pub struct CapturedItem { pub(crate) place: HirPlace, pub(crate) kind: CaptureKind, - pub(crate) span: MirSpan, + /// The inner vec is the stacks; the outer vec is for each capture reference. + /// + /// Even though we always report only the last span (i.e. the most inclusive span), + /// we need to keep them all, since when a closure occurs inside a closure, we + /// copy all captures of the inner closure to the outer closure, and then we may + /// truncate them, and we want the correct span to be reported. + span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, pub(crate) ty: Binders<Ty>, } @@ -253,6 +259,10 @@ impl CapturedItem { self.kind } + pub fn spans(&self) -> SmallVec<[MirSpan; 3]> { + self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect() + } + pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { let body = db.body(owner); let krate = owner.krate(db.upcast()); @@ -314,7 +324,8 @@ impl CapturedItem { pub(crate) struct CapturedItemWithoutTy { pub(crate) place: HirPlace, pub(crate) kind: CaptureKind, - pub(crate) span: MirSpan, + /// The inner vec is the stacks; the outer vec is for each capture reference. + pub(crate) span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, } impl CapturedItemWithoutTy { @@ -333,7 +344,7 @@ impl CapturedItemWithoutTy { return CapturedItem { place: self.place, kind: self.kind, - span: self.span, + span_stacks: self.span_stacks, ty: replace_placeholder_with_binder(ctx, ty), }; @@ -393,22 +404,26 @@ impl InferenceContext<'_> { let r = self.place_of_expr_without_adjust(tgt_expr)?; let default = vec![]; let adjustments = self.result.expr_adjustments.get(&tgt_expr).unwrap_or(&default); - apply_adjusts_to_place(r, adjustments) + apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments) } + /// Changes `current_capture_span_stack` to contain the stack of spans for this expr. fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option<HirPlace> { + self.current_capture_span_stack.clear(); match &self.body[tgt_expr] { Expr::Path(p) => { let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr); if let Some(ResolveValueResult::ValueNs(ValueNs::LocalBinding(b), _)) = resolver.resolve_path_in_value_ns(self.db.upcast(), p) { + self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); return Some(HirPlace { local: b, projections: vec![] }); } } Expr::Field { expr, name: _ } => { let mut place = self.place_of_expr(*expr)?; let field = self.result.field_resolution(tgt_expr)?; + self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); place.projections.push(ProjectionElem::Field(field)); return Some(place); } @@ -418,6 +433,7 @@ impl InferenceContext<'_> { TyKind::Ref(..) | TyKind::Raw(..) ) { let mut place = self.place_of_expr(*expr)?; + self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); place.projections.push(ProjectionElem::Deref); return Some(place); } @@ -427,29 +443,65 @@ impl InferenceContext<'_> { None } - fn push_capture(&mut self, capture: CapturedItemWithoutTy) { - self.current_captures.push(capture); + fn push_capture(&mut self, place: HirPlace, kind: CaptureKind) { + self.current_captures.push(CapturedItemWithoutTy { + place, + kind, + span_stacks: smallvec![self.current_capture_span_stack.iter().copied().collect()], + }); } - fn ref_expr(&mut self, expr: ExprId) { - if let Some(place) = self.place_of_expr(expr) { - self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared), expr.into()); + fn is_ref_span(&self, span: MirSpan) -> bool { + match span { + MirSpan::ExprId(expr) => matches!(self.body[expr], Expr::Ref { .. }), + MirSpan::BindingId(_) => true, + MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false, + } + } + + fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) { + // The first span is the identifier, and it must always remain. + truncate_to += 1; + for span_stack in &mut capture.span_stacks { + let mut remained = truncate_to; + let mut actual_truncate_to = 0; + for &span in &*span_stack { + actual_truncate_to += 1; + if !self.is_ref_span(span) { + remained -= 1; + if remained == 0 { + break; + } + } + } + if actual_truncate_to < span_stack.len() + && self.is_ref_span(span_stack[actual_truncate_to]) + { + // Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect. + actual_truncate_to += 1; + } + span_stack.truncate(actual_truncate_to); + } + } + + fn ref_expr(&mut self, expr: ExprId, place: Option<HirPlace>) { + if let Some(place) = place { + self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared)); } self.walk_expr(expr); } - fn add_capture(&mut self, place: HirPlace, kind: CaptureKind, span: MirSpan) { + fn add_capture(&mut self, place: HirPlace, kind: CaptureKind) { if self.is_upvar(&place) { - self.push_capture(CapturedItemWithoutTy { place, kind, span }); + self.push_capture(place, kind); } } - fn mutate_expr(&mut self, expr: ExprId) { - if let Some(place) = self.place_of_expr(expr) { + fn mutate_expr(&mut self, expr: ExprId, place: Option<HirPlace>) { + if let Some(place) = place { self.add_capture( place, CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), - expr.into(), ); } self.walk_expr(expr); @@ -457,12 +509,12 @@ impl InferenceContext<'_> { fn consume_expr(&mut self, expr: ExprId) { if let Some(place) = self.place_of_expr(expr) { - self.consume_place(place, expr.into()); + self.consume_place(place); } self.walk_expr(expr); } - fn consume_place(&mut self, place: HirPlace, span: MirSpan) { + fn consume_place(&mut self, place: HirPlace) { if self.is_upvar(&place) { let ty = place.ty(self); let kind = if self.is_ty_copy(ty) { @@ -470,7 +522,7 @@ impl InferenceContext<'_> { } else { CaptureKind::ByValue }; - self.push_capture(CapturedItemWithoutTy { place, kind, span }); + self.push_capture(place, kind); } } @@ -501,8 +553,10 @@ impl InferenceContext<'_> { Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared), }; if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) { - if let Some(place) = apply_adjusts_to_place(place, rest) { - self.add_capture(place, capture_kind, tgt_expr.into()); + if let Some(place) = + apply_adjusts_to_place(&mut self.current_capture_span_stack, place, rest) + { + self.add_capture(place, capture_kind); } } self.walk_expr_with_adjust(tgt_expr, rest); @@ -584,11 +638,7 @@ impl InferenceContext<'_> { self.walk_pat(&mut capture_mode, arm.pat); } if let Some(c) = capture_mode { - self.push_capture(CapturedItemWithoutTy { - place: discr_place, - kind: c, - span: (*expr).into(), - }) + self.push_capture(discr_place, c); } } } @@ -632,10 +682,11 @@ impl InferenceContext<'_> { } false }; + let place = self.place_of_expr(*expr); if mutability { - self.mutate_expr(*expr); + self.mutate_expr(*expr, place); } else { - self.ref_expr(*expr); + self.ref_expr(*expr, place); } } else { self.select_from_expr(*expr); @@ -650,16 +701,22 @@ impl InferenceContext<'_> { | Expr::Cast { expr, type_ref: _ } => { self.consume_expr(*expr); } - Expr::Ref { expr, rawness: _, mutability } => match mutability { - hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr), - hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr), - }, + Expr::Ref { expr, rawness: _, mutability } => { + // We need to do this before we push the span so the order will be correct. + let place = self.place_of_expr(*expr); + self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); + match mutability { + hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr, place), + hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr, place), + } + } Expr::BinaryOp { lhs, rhs, op } => { let Some(op) = op else { return; }; if matches!(op, BinaryOp::Assignment { .. }) { - self.mutate_expr(*lhs); + let place = self.place_of_expr(*lhs); + self.mutate_expr(*lhs, place); self.consume_expr(*rhs); return; } @@ -690,7 +747,11 @@ impl InferenceContext<'_> { ); let mut cc = mem::take(&mut self.current_captures); cc.extend(captures.iter().filter(|it| self.is_upvar(&it.place)).map(|it| { - CapturedItemWithoutTy { place: it.place.clone(), kind: it.kind, span: it.span } + CapturedItemWithoutTy { + place: it.place.clone(), + kind: it.kind, + span_stacks: it.span_stacks.clone(), + } })); self.current_captures = cc; } @@ -812,10 +873,13 @@ impl InferenceContext<'_> { } fn restrict_precision_for_unsafe(&mut self) { - for capture in &mut self.current_captures { + // FIXME: Borrow checker problems without this. + let mut current_captures = std::mem::take(&mut self.current_captures); + for capture in &mut current_captures { let mut ty = self.table.resolve_completely(self.result[capture.place.local].clone()); if ty.as_raw_ptr().is_some() || ty.is_union() { capture.kind = CaptureKind::ByRef(BorrowKind::Shared); + self.truncate_capture_spans(capture, 0); capture.place.projections.truncate(0); continue; } @@ -830,29 +894,35 @@ impl InferenceContext<'_> { ); if ty.as_raw_ptr().is_some() || ty.is_union() { capture.kind = CaptureKind::ByRef(BorrowKind::Shared); + self.truncate_capture_spans(capture, i + 1); capture.place.projections.truncate(i + 1); break; } } } + self.current_captures = current_captures; } fn adjust_for_move_closure(&mut self) { - for capture in &mut self.current_captures { + // FIXME: Borrow checker won't allow without this. + let mut current_captures = std::mem::take(&mut self.current_captures); + for capture in &mut current_captures { if let Some(first_deref) = capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref) { + self.truncate_capture_spans(capture, first_deref); capture.place.projections.truncate(first_deref); } capture.kind = CaptureKind::ByValue; } + self.current_captures = current_captures; } fn minimize_captures(&mut self) { - self.current_captures.sort_by_key(|it| it.place.projections.len()); + self.current_captures.sort_unstable_by_key(|it| it.place.projections.len()); let mut hash_map = FxHashMap::<HirPlace, usize>::default(); let result = mem::take(&mut self.current_captures); - for item in result { + for mut item in result { let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] }; let mut it = item.place.projections.iter(); let prev_index = loop { @@ -860,12 +930,17 @@ impl InferenceContext<'_> { break Some(*k); } match it.next() { - Some(it) => lookup_place.projections.push(it.clone()), + Some(it) => { + lookup_place.projections.push(it.clone()); + } None => break None, } }; match prev_index { Some(p) => { + let prev_projections_len = self.current_captures[p].place.projections.len(); + self.truncate_capture_spans(&mut item, prev_projections_len); + self.current_captures[p].span_stacks.extend(item.span_stacks); let len = self.current_captures[p].place.projections.len(); let kind_after_truncate = item.place.capture_kind_of_truncated_place(item.kind, len); @@ -880,113 +955,128 @@ impl InferenceContext<'_> { } } - fn consume_with_pat(&mut self, mut place: HirPlace, pat: PatId) { - let cnt = self.result.pat_adjustments.get(&pat).map(|it| it.len()).unwrap_or_default(); - place.projections = place - .projections - .iter() - .cloned() - .chain((0..cnt).map(|_| ProjectionElem::Deref)) - .collect::<Vec<_>>(); - match &self.body[pat] { - Pat::Missing | Pat::Wild => (), - Pat::Tuple { args, ellipsis } => { - let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let field_count = match self.result[pat].kind(Interner) { - TyKind::Tuple(_, s) => s.len(Interner), - _ => return, - }; - let fields = 0..field_count; - let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); - for (arg, i) in it { - let mut p = place.clone(); - p.projections.push(ProjectionElem::Field(Either::Right(TupleFieldId { - tuple: TupleId(!0), // dummy this, as its unused anyways - index: i as u32, - }))); - self.consume_with_pat(p, *arg); - } - } - Pat::Or(pats) => { - for pat in pats.iter() { - self.consume_with_pat(place.clone(), *pat); + fn consume_with_pat(&mut self, mut place: HirPlace, tgt_pat: PatId) { + let adjustments_count = + self.result.pat_adjustments.get(&tgt_pat).map(|it| it.len()).unwrap_or_default(); + place.projections.extend((0..adjustments_count).map(|_| ProjectionElem::Deref)); + self.current_capture_span_stack + .extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat))); + 'reset_span_stack: { + match &self.body[tgt_pat] { + Pat::Missing | Pat::Wild => (), + Pat::Tuple { args, ellipsis } => { + let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); + let field_count = match self.result[tgt_pat].kind(Interner) { + TyKind::Tuple(_, s) => s.len(Interner), + _ => break 'reset_span_stack, + }; + let fields = 0..field_count; + let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); + for (&arg, i) in it { + let mut p = place.clone(); + self.current_capture_span_stack.push(MirSpan::PatId(arg)); + p.projections.push(ProjectionElem::Field(Either::Right(TupleFieldId { + tuple: TupleId(!0), // dummy this, as its unused anyways + index: i as u32, + }))); + self.consume_with_pat(p, arg); + self.current_capture_span_stack.pop(); + } } - } - Pat::Record { args, .. } => { - let Some(variant) = self.result.variant_resolution_for_pat(pat) else { - return; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place, pat.into()) + Pat::Or(pats) => { + for pat in pats.iter() { + self.consume_with_pat(place.clone(), *pat); } - VariantId::StructId(s) => { - let vd = &*self.db.struct_data(s).variant_data; - for field_pat in args.iter() { - let arg = field_pat.pat; - let Some(local_id) = vd.field(&field_pat.name) else { - continue; - }; - let mut p = place.clone(); - p.projections.push(ProjectionElem::Field(Either::Left(FieldId { - parent: variant, - local_id, - }))); - self.consume_with_pat(p, arg); + } + Pat::Record { args, .. } => { + let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { + break 'reset_span_stack; + }; + match variant { + VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { + self.consume_place(place) + } + VariantId::StructId(s) => { + let vd = &*self.db.struct_data(s).variant_data; + for field_pat in args.iter() { + let arg = field_pat.pat; + let Some(local_id) = vd.field(&field_pat.name) else { + continue; + }; + let mut p = place.clone(); + self.current_capture_span_stack.push(MirSpan::PatId(arg)); + p.projections.push(ProjectionElem::Field(Either::Left(FieldId { + parent: variant, + local_id, + }))); + self.consume_with_pat(p, arg); + self.current_capture_span_stack.pop(); + } } } } - } - Pat::Range { .. } - | Pat::Slice { .. } - | Pat::ConstBlock(_) - | Pat::Path(_) - | Pat::Lit(_) => self.consume_place(place, pat.into()), - Pat::Bind { id: _, subpat: _ } => { - let mode = self.result.binding_modes[pat]; - let capture_kind = match mode { - BindingMode::Move => { - self.consume_place(place, pat.into()); - return; - } - BindingMode::Ref(Mutability::Not) => BorrowKind::Shared, - BindingMode::Ref(Mutability::Mut) => { - BorrowKind::Mut { kind: MutBorrowKind::Default } - } - }; - self.add_capture(place, CaptureKind::ByRef(capture_kind), pat.into()); - } - Pat::TupleStruct { path: _, args, ellipsis } => { - let Some(variant) = self.result.variant_resolution_for_pat(pat) else { - return; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place, pat.into()) - } - VariantId::StructId(s) => { - let vd = &*self.db.struct_data(s).variant_data; - let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let fields = vd.fields().iter(); - let it = - al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); - for (arg, (i, _)) in it { - let mut p = place.clone(); - p.projections.push(ProjectionElem::Field(Either::Left(FieldId { - parent: variant, - local_id: i, - }))); - self.consume_with_pat(p, *arg); + Pat::Range { .. } + | Pat::Slice { .. } + | Pat::ConstBlock(_) + | Pat::Path(_) + | Pat::Lit(_) => self.consume_place(place), + &Pat::Bind { id, subpat: _ } => { + let mode = self.result.binding_modes[tgt_pat]; + let capture_kind = match mode { + BindingMode::Move => { + self.consume_place(place); + break 'reset_span_stack; + } + BindingMode::Ref(Mutability::Not) => BorrowKind::Shared, + BindingMode::Ref(Mutability::Mut) => { + BorrowKind::Mut { kind: MutBorrowKind::Default } + } + }; + self.current_capture_span_stack.push(MirSpan::BindingId(id)); + self.add_capture(place, CaptureKind::ByRef(capture_kind)); + self.current_capture_span_stack.pop(); + } + Pat::TupleStruct { path: _, args, ellipsis } => { + let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { + break 'reset_span_stack; + }; + match variant { + VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { + self.consume_place(place) + } + VariantId::StructId(s) => { + let vd = &*self.db.struct_data(s).variant_data; + let (al, ar) = + args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); + let fields = vd.fields().iter(); + let it = al + .iter() + .zip(fields.clone()) + .chain(ar.iter().rev().zip(fields.rev())); + for (&arg, (i, _)) in it { + let mut p = place.clone(); + self.current_capture_span_stack.push(MirSpan::PatId(arg)); + p.projections.push(ProjectionElem::Field(Either::Left(FieldId { + parent: variant, + local_id: i, + }))); + self.consume_with_pat(p, arg); + self.current_capture_span_stack.pop(); + } } } } + Pat::Ref { pat, mutability: _ } => { + self.current_capture_span_stack.push(MirSpan::PatId(tgt_pat)); + place.projections.push(ProjectionElem::Deref); + self.consume_with_pat(place, *pat); + self.current_capture_span_stack.pop(); + } + Pat::Box { .. } => (), // not supported } - Pat::Ref { pat, mutability: _ } => { - place.projections.push(ProjectionElem::Deref); - self.consume_with_pat(place, *pat) - } - Pat::Box { .. } => (), // not supported } + self.current_capture_span_stack + .truncate(self.current_capture_span_stack.len() - adjustments_count); } fn consume_exprs(&mut self, exprs: impl Iterator<Item = ExprId>) { @@ -1044,12 +1134,28 @@ impl InferenceContext<'_> { CaptureBy::Ref => (), } self.minimize_captures(); + self.strip_captures_ref_span(); let result = mem::take(&mut self.current_captures); let captures = result.into_iter().map(|it| it.with_ty(self)).collect::<Vec<_>>(); self.result.closure_info.insert(closure, (captures, closure_kind)); closure_kind } + fn strip_captures_ref_span(&mut self) { + // FIXME: Borrow checker won't allow without this. + let mut captures = std::mem::take(&mut self.current_captures); + for capture in &mut captures { + if matches!(capture.kind, CaptureKind::ByValue) { + for span_stack in &mut capture.span_stacks { + if self.is_ref_span(span_stack[span_stack.len() - 1]) { + span_stack.truncate(span_stack.len() - 1); + } + } + } + } + self.current_captures = captures; + } + pub(crate) fn infer_closures(&mut self) { let deferred_closures = self.sort_closures(); for (closure, exprs) in deferred_closures.into_iter().rev() { @@ -1110,10 +1216,17 @@ impl InferenceContext<'_> { } } -fn apply_adjusts_to_place(mut r: HirPlace, adjustments: &[Adjustment]) -> Option<HirPlace> { +/// Call this only when the last span in the stack isn't a split. +fn apply_adjusts_to_place( + current_capture_span_stack: &mut Vec<MirSpan>, + mut r: HirPlace, + adjustments: &[Adjustment], +) -> Option<HirPlace> { + let span = *current_capture_span_stack.last().expect("empty capture span stack"); for adj in adjustments { match &adj.kind { Adjust::Deref(None) => { + current_capture_span_stack.push(span); r.projections.push(ProjectionElem::Deref); } _ => return None, |