Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/hir-ty/src/mir/lower/pattern_matching.rs')
| -rw-r--r-- | crates/hir-ty/src/mir/lower/pattern_matching.rs | 385 |
1 files changed, 202 insertions, 183 deletions
diff --git a/crates/hir-ty/src/mir/lower/pattern_matching.rs b/crates/hir-ty/src/mir/lower/pattern_matching.rs index ee2a0306d5..5cd1be6842 100644 --- a/crates/hir-ty/src/mir/lower/pattern_matching.rs +++ b/crates/hir-ty/src/mir/lower/pattern_matching.rs @@ -2,7 +2,7 @@ use hir_def::{hir::LiteralOrConst, resolver::HasResolver, AssocItemId}; -use crate::utils::pattern_matching_dereference_count; +use crate::BindingMode; use super::*; @@ -18,6 +18,26 @@ pub(super) enum AdtPatternShape<'a> { Unit, } +/// We need to do pattern matching in two phases: One to check if the pattern matches, and one to fill the bindings +/// of patterns. This is necessary to prevent double moves and similar problems. For example: +/// ```ignore +/// struct X; +/// match (X, 3) { +/// (b, 2) | (b, 3) => {}, +/// _ => {} +/// } +/// ``` +/// If we do everything in one pass, we will move `X` to the first `b`, then we see that the second field of tuple +/// doesn't match and we should move the `X` to the second `b` (which here is the same thing, but doesn't need to be) and +/// it might even doesn't match the second pattern and we may want to not move `X` at all. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MatchingMode { + /// Check that if this pattern matches + Check, + /// Assume that this pattern matches, fill bindings + Bind, +} + impl MirLowerCtx<'_> { /// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if /// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which @@ -30,19 +50,49 @@ impl MirLowerCtx<'_> { /// so it should be an empty block. pub(super) fn pattern_match( &mut self, + current: BasicBlockId, + current_else: Option<BasicBlockId>, + cond_place: Place, + pattern: PatId, + ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { + let (current, current_else) = self.pattern_match_inner( + current, + current_else, + cond_place.clone(), + pattern, + MatchingMode::Check, + )?; + let (current, current_else) = self.pattern_match_inner( + current, + current_else, + cond_place, + pattern, + MatchingMode::Bind, + )?; + Ok((current, current_else)) + } + + fn pattern_match_inner( + &mut self, mut current: BasicBlockId, mut current_else: Option<BasicBlockId>, mut cond_place: Place, - mut cond_ty: Ty, pattern: PatId, - mut binding_mode: BindingAnnotation, + mode: MatchingMode, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { + let cnt = self.infer.pat_adjustments.get(&pattern).map(|x| x.len()).unwrap_or_default(); + cond_place.projection = cond_place + .projection + .iter() + .cloned() + .chain((0..cnt).map(|_| ProjectionElem::Deref)) + .collect::<Vec<_>>() + .into(); Ok(match &self.body.pats[pattern] { Pat::Missing => return Err(MirLowerError::IncompletePattern), Pat::Wild => (current, current_else), Pat::Tuple { args, ellipsis } => { - pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); - let subst = match cond_ty.kind(Interner) { + let subst = match self.infer[pattern].kind(Interner) { TyKind::Tuple(_, s) => s, _ => { return Err(MirLowerError::TypeError( @@ -55,25 +105,31 @@ impl MirLowerCtx<'_> { current_else, args, *ellipsis, - subst.iter(Interner).enumerate().map(|(i, x)| { - (PlaceElem::TupleOrClosureField(i), x.assert_ty_ref(Interner).clone()) - }), - &cond_place, - binding_mode, + (0..subst.len(Interner)).map(|i| PlaceElem::TupleOrClosureField(i)), + &(&mut cond_place), + mode, )? } Pat::Or(pats) => { let then_target = self.new_basic_block(); let mut finished = false; for pat in &**pats { - let (next, next_else) = self.pattern_match( + let (mut next, next_else) = self.pattern_match_inner( current, None, - cond_place.clone(), - cond_ty.clone(), + (&mut cond_place).clone(), *pat, - binding_mode, + MatchingMode::Check, )?; + if mode == MatchingMode::Bind { + (next, _) = self.pattern_match_inner( + next, + None, + (&mut cond_place).clone(), + *pat, + MatchingMode::Bind, + )?; + } self.set_goto(next, then_target, pattern.into()); match next_else { Some(t) => { @@ -86,8 +142,12 @@ impl MirLowerCtx<'_> { } } if !finished { - let ce = *current_else.get_or_insert_with(|| self.new_basic_block()); - self.set_goto(current, ce, pattern.into()); + if mode == MatchingMode::Bind { + self.set_terminator(current, TerminatorKind::Unreachable, pattern.into()); + } else { + let ce = *current_else.get_or_insert_with(|| self.new_basic_block()); + self.set_goto(current, ce, pattern.into()); + } } (then_target, current_else) } @@ -96,19 +156,19 @@ impl MirLowerCtx<'_> { not_supported!("unresolved variant for record"); }; self.pattern_matching_variant( - cond_ty, - binding_mode, cond_place, variant, current, pattern.into(), current_else, AdtPatternShape::Record { args: &*args }, + mode, )? } Pat::Range { start, end } => { let mut add_check = |l: &LiteralOrConst, binop| -> Result<()> { - let lv = self.lower_literal_or_const_to_operand(cond_ty.clone(), l)?; + let lv = + self.lower_literal_or_const_to_operand(self.infer[pattern].clone(), l)?; let else_target = *current_else.get_or_insert_with(|| self.new_basic_block()); let next = self.new_basic_block(); let discr: Place = @@ -116,7 +176,11 @@ impl MirLowerCtx<'_> { self.push_assignment( current, discr.clone(), - Rvalue::CheckedBinaryOp(binop, lv, Operand::Copy(cond_place.clone())), + Rvalue::CheckedBinaryOp( + binop, + lv, + Operand::Copy((&mut cond_place).clone()), + ), pattern.into(), ); let discr = Operand::Copy(discr); @@ -131,24 +195,25 @@ impl MirLowerCtx<'_> { current = next; Ok(()) }; - if let Some(start) = start { - add_check(start, BinOp::Le)?; - } - if let Some(end) = end { - add_check(end, BinOp::Ge)?; + if mode == MatchingMode::Check { + if let Some(start) = start { + add_check(start, BinOp::Le)?; + } + if let Some(end) = end { + add_check(end, BinOp::Ge)?; + } } (current, current_else) } Pat::Slice { prefix, slice, suffix } => { - pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); - if let TyKind::Slice(_) = cond_ty.kind(Interner) { + if let TyKind::Slice(_) = self.infer[pattern].kind(Interner) { let pattern_len = prefix.len() + suffix.len(); let place_len: Place = self.temp(TyBuilder::usize(), current, pattern.into())?.into(); self.push_assignment( current, place_len.clone(), - Rvalue::Len(cond_place.clone()), + Rvalue::Len((&mut cond_place).clone()), pattern.into(), ); let else_target = *current_else.get_or_insert_with(|| self.new_basic_block()); @@ -193,63 +258,49 @@ impl MirLowerCtx<'_> { current = next; } for (i, &pat) in prefix.iter().enumerate() { - let next_place = cond_place.project(ProjectionElem::ConstantIndex { + let next_place = (&mut cond_place).project(ProjectionElem::ConstantIndex { offset: i as u64, from_end: false, }); - let cond_ty = self.infer[pat].clone(); - (current, current_else) = self.pattern_match( - current, - current_else, - next_place, - cond_ty, - pat, - binding_mode, - )?; + (current, current_else) = + self.pattern_match_inner(current, current_else, next_place, pat, mode)?; } if let Some(slice) = slice { - if let Pat::Bind { id, subpat: _ } = self.body[*slice] { - let next_place = cond_place.project(ProjectionElem::Subslice { - from: prefix.len() as u64, - to: suffix.len() as u64, - }); - (current, current_else) = self.pattern_match_binding( - id, - &mut binding_mode, - next_place, - (*slice).into(), - current, - current_else, - )?; + if mode == MatchingMode::Bind { + if let Pat::Bind { id, subpat: _ } = self.body[*slice] { + let next_place = (&mut cond_place).project(ProjectionElem::Subslice { + from: prefix.len() as u64, + to: suffix.len() as u64, + }); + (current, current_else) = self.pattern_match_binding( + id, + next_place, + (*slice).into(), + current, + current_else, + )?; + } } } for (i, &pat) in suffix.iter().enumerate() { - let next_place = cond_place.project(ProjectionElem::ConstantIndex { + let next_place = (&mut cond_place).project(ProjectionElem::ConstantIndex { offset: i as u64, from_end: true, }); - let cond_ty = self.infer[pat].clone(); - (current, current_else) = self.pattern_match( - current, - current_else, - next_place, - cond_ty, - pat, - binding_mode, - )?; + (current, current_else) = + self.pattern_match_inner(current, current_else, next_place, pat, mode)?; } (current, current_else) } Pat::Path(p) => match self.infer.variant_resolution_for_pat(pattern) { Some(variant) => self.pattern_matching_variant( - cond_ty, - binding_mode, cond_place, variant, current, pattern.into(), current_else, AdtPatternShape::Unit, + mode, )?, None => { let unresolved_name = || MirLowerError::unresolved_path(self.db, p); @@ -270,9 +321,17 @@ impl MirLowerCtx<'_> { } not_supported!("path in pattern position that is not const or variant") }; - let tmp: Place = self.temp(cond_ty.clone(), current, pattern.into())?.into(); + let tmp: Place = + self.temp(self.infer[pattern].clone(), current, pattern.into())?.into(); let span = pattern.into(); - self.lower_const(c.into(), current, tmp.clone(), subst, span, cond_ty.clone())?; + self.lower_const( + c.into(), + current, + tmp.clone(), + subst, + span, + self.infer[pattern].clone(), + )?; let tmp2: Place = self.temp(TyBuilder::bool(), current, pattern.into())?.into(); self.push_assignment( current, @@ -299,61 +358,58 @@ impl MirLowerCtx<'_> { }, Pat::Lit(l) => match &self.body.exprs[*l] { Expr::Literal(l) => { - let c = self.lower_literal_to_operand(cond_ty, l)?; - self.pattern_match_const(current_else, current, c, cond_place, pattern)? + let c = self.lower_literal_to_operand(self.infer[pattern].clone(), l)?; + if mode == MatchingMode::Check { + self.pattern_match_const(current_else, current, c, cond_place, pattern)? + } else { + (current, current_else) + } } _ => not_supported!("expression path literal"), }, Pat::Bind { id, subpat } => { if let Some(subpat) = subpat { - (current, current_else) = self.pattern_match( + (current, current_else) = self.pattern_match_inner( current, current_else, - cond_place.clone(), - cond_ty, + (&mut cond_place).clone(), *subpat, - binding_mode, + mode, )? } - self.pattern_match_binding( - *id, - &mut binding_mode, - cond_place, - pattern.into(), - current, - current_else, - )? + if mode == MatchingMode::Bind { + self.pattern_match_binding( + *id, + cond_place, + pattern.into(), + current, + current_else, + )? + } else { + (current, current_else) + } } Pat::TupleStruct { path: _, args, ellipsis } => { let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else { not_supported!("unresolved variant"); }; self.pattern_matching_variant( - cond_ty, - binding_mode, cond_place, variant, current, pattern.into(), current_else, AdtPatternShape::Tuple { args, ellipsis: *ellipsis }, + mode, )? } - Pat::Ref { pat, mutability: _ } => { - if let Some((ty, _, _)) = cond_ty.as_reference() { - cond_ty = ty.clone(); - self.pattern_match( - current, - current_else, - cond_place.project(ProjectionElem::Deref), - cond_ty, - *pat, - binding_mode, - )? - } else { - return Err(MirLowerError::TypeError("& pattern for non reference")); - } - } + Pat::Ref { pat, mutability: _ } => self.pattern_match_inner( + current, + current_else, + cond_place.project(ProjectionElem::Deref), + *pat, + mode, + )?, Pat::Box { .. } => not_supported!("box pattern"), Pat::ConstBlock(_) => not_supported!("const block pattern"), }) @@ -362,27 +418,21 @@ impl MirLowerCtx<'_> { fn pattern_match_binding( &mut self, id: BindingId, - binding_mode: &mut BindingAnnotation, cond_place: Place, span: MirSpan, current: BasicBlockId, current_else: Option<BasicBlockId>, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { let target_place = self.binding_local(id)?; - let mode = self.body.bindings[id].mode; - if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { - *binding_mode = mode; - } + let mode = self.infer.binding_modes[id]; self.push_storage_live(id, current)?; self.push_assignment( current, target_place.into(), - match *binding_mode { - BindingAnnotation::Unannotated | BindingAnnotation::Mutable => { - Operand::Copy(cond_place).into() - } - BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place), - BindingAnnotation::RefMut => { + match mode { + BindingMode::Move => Operand::Copy(cond_place).into(), + BindingMode::Ref(Mutability::Not) => Rvalue::Ref(BorrowKind::Shared, cond_place), + BindingMode::Ref(Mutability::Mut) => { Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, cond_place) } }, @@ -420,52 +470,48 @@ impl MirLowerCtx<'_> { Ok((then_target, Some(else_target))) } - pub(super) fn pattern_matching_variant( + fn pattern_matching_variant( &mut self, - mut cond_ty: Ty, - mut binding_mode: BindingAnnotation, - mut cond_place: Place, + cond_place: Place, variant: VariantId, - current: BasicBlockId, + mut current: BasicBlockId, span: MirSpan, - current_else: Option<BasicBlockId>, + mut current_else: Option<BasicBlockId>, shape: AdtPatternShape<'_>, + mode: MatchingMode, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { - pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); - let subst = match cond_ty.kind(Interner) { - TyKind::Adt(_, s) => s, - _ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")), - }; Ok(match variant { VariantId::EnumVariantId(v) => { - let e = self.const_eval_discriminant(v)? as u128; - let tmp = self.discr_temp_place(current); - self.push_assignment( - current, - tmp.clone(), - Rvalue::Discriminant(cond_place.clone()), - span, - ); - let next = self.new_basic_block(); - let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); - self.set_terminator( - current, - TerminatorKind::SwitchInt { - discr: Operand::Copy(tmp), - targets: SwitchTargets::static_if(e, next, else_target), - }, - span, - ); + if mode == MatchingMode::Check { + let e = self.const_eval_discriminant(v)? as u128; + let tmp = self.discr_temp_place(current); + self.push_assignment( + current, + tmp.clone(), + Rvalue::Discriminant(cond_place.clone()), + span, + ); + let next = self.new_basic_block(); + let else_target = current_else.get_or_insert_with(|| self.new_basic_block()); + self.set_terminator( + current, + TerminatorKind::SwitchInt { + discr: Operand::Copy(tmp), + targets: SwitchTargets::static_if(e, next, *else_target), + }, + span, + ); + current = next; + } let enum_data = self.db.enum_data(v.parent); self.pattern_matching_variant_fields( shape, &enum_data.variants[v.local_id].variant_data, variant, - subst, - next, - Some(else_target), + current, + current_else, &cond_place, - binding_mode, + mode, )? } VariantId::StructId(s) => { @@ -474,11 +520,10 @@ impl MirLowerCtx<'_> { shape, &struct_data.variant_data, variant, - subst, current, current_else, &cond_place, - binding_mode, + mode, )? } VariantId::UnionId(_) => { @@ -492,13 +537,11 @@ impl MirLowerCtx<'_> { shape: AdtPatternShape<'_>, variant_data: &VariantData, v: VariantId, - subst: &Substitution, current: BasicBlockId, current_else: Option<BasicBlockId>, cond_place: &Place, - binding_mode: BindingAnnotation, + mode: MatchingMode, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { - let fields_type = self.db.field_types(v); Ok(match shape { AdtPatternShape::Record { args } => { let it = args @@ -509,25 +552,16 @@ impl MirLowerCtx<'_> { Ok(( PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }), x.pat, - fields_type[field_id].clone().substitute(Interner, subst), )) }) .collect::<Result<Vec<_>>>()?; - self.pattern_match_adt( - current, - current_else, - it.into_iter(), - cond_place, - binding_mode, - )? + self.pattern_match_adt(current, current_else, it.into_iter(), cond_place, mode)? } AdtPatternShape::Tuple { args, ellipsis } => { - let fields = variant_data.fields().iter().map(|(x, _)| { - ( - PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), - fields_type[x].clone().substitute(Interner, subst), - ) - }); + let fields = variant_data + .fields() + .iter() + .map(|(x, _)| PlaceElem::Field(FieldId { parent: v.into(), local_id: x })); self.pattern_match_tuple_like( current, current_else, @@ -535,7 +569,7 @@ impl MirLowerCtx<'_> { ellipsis, fields, cond_place, - binding_mode, + mode, )? } AdtPatternShape::Unit => (current, current_else), @@ -546,14 +580,14 @@ impl MirLowerCtx<'_> { &mut self, mut current: BasicBlockId, mut current_else: Option<BasicBlockId>, - args: impl Iterator<Item = (PlaceElem, PatId, Ty)>, + args: impl Iterator<Item = (PlaceElem, PatId)>, cond_place: &Place, - binding_mode: BindingAnnotation, + mode: MatchingMode, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { - for (proj, arg, ty) in args { + for (proj, arg) in args { let cond_place = cond_place.project(proj); (current, current_else) = - self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?; + self.pattern_match_inner(current, current_else, cond_place, arg, mode)?; } Ok((current, current_else)) } @@ -564,31 +598,16 @@ impl MirLowerCtx<'_> { current_else: Option<BasicBlockId>, args: &[PatId], ellipsis: Option<usize>, - fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone, + fields: impl DoubleEndedIterator<Item = PlaceElem> + Clone, cond_place: &Place, - binding_mode: BindingAnnotation, + mode: MatchingMode, ) -> Result<(BasicBlockId, Option<BasicBlockId>)> { let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); let it = al .iter() .zip(fields.clone()) .chain(ar.iter().rev().zip(fields.rev())) - .map(|(x, y)| (y.0, *x, y.1)); - self.pattern_match_adt(current, current_else, it, cond_place, binding_mode) + .map(|(x, y)| (y, *x)); + self.pattern_match_adt(current, current_else, it, cond_place, mode) } } - -fn pattern_matching_dereference( - cond_ty: &mut Ty, - binding_mode: &mut BindingAnnotation, - cond_place: &mut Place, -) { - let cnt = pattern_matching_dereference_count(cond_ty, binding_mode); - cond_place.projection = cond_place - .projection - .iter() - .cloned() - .chain((0..cnt).map(|_| ProjectionElem::Deref)) - .collect::<Vec<_>>() - .into(); -} |