Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #21113 from ChayimFriedman2/byte-range
fix: Infer range patterns correctly
Lukas Wirth 5 months ago
parent 813c7d4 · parent ba27d77 · commit b2a6245
-rw-r--r--crates/hir-ty/src/infer/pat.rs54
-rw-r--r--crates/hir-ty/src/tests/patterns.rs30
2 files changed, 29 insertions, 55 deletions
diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs
index d8b02dea15..ece2bdc4fd 100644
--- a/crates/hir-ty/src/infer/pat.rs
+++ b/crates/hir-ty/src/infer/pat.rs
@@ -11,7 +11,6 @@ use hir_expand::name::Name;
use rustc_ast_ir::Mutability;
use rustc_type_ir::inherent::{GenericArg as _, GenericArgs as _, IntoKind, SliceLike, Ty as _};
use stdx::TupleExt;
-use syntax::ast::RangeOp;
use crate::{
DeclContext, DeclOrigin, InferenceDiagnostic,
@@ -350,51 +349,16 @@ impl<'db> InferenceContext<'_, 'db> {
self.infer_slice_pat(expected, prefix, *slice, suffix, default_bm, decl)
}
Pat::Wild => expected,
- Pat::Range { start, end, range_type } => {
- // FIXME: Expectation
- let lhs_expectation = Expectation::none();
- let lhs_ty =
- start.map(|start| self.infer_expr(start, &lhs_expectation, ExprIsRead::Yes));
- let rhs_expectation = lhs_ty.map_or_else(Expectation::none, Expectation::HasType);
- let rhs_ty = end.map(|end| self.infer_expr(end, &rhs_expectation, ExprIsRead::Yes));
- let single_arg_adt = |adt, ty: Ty<'db>| {
- Ty::new_adt(
- self.interner(),
- adt,
- GenericArgs::new_from_iter(self.interner(), [ty.into()]),
- )
- };
- match (range_type, lhs_ty, rhs_ty) {
- (RangeOp::Exclusive, None, None) => match self.resolve_range_full() {
- Some(adt) => Ty::new_adt(self.interner(), adt, self.types.empty_args),
- None => self.err_ty(),
- },
- (RangeOp::Exclusive, None, Some(ty)) => match self.resolve_range_to() {
- Some(adt) => single_arg_adt(adt, ty),
- None => self.err_ty(),
- },
- (RangeOp::Inclusive, None, Some(ty)) => {
- match self.resolve_range_to_inclusive() {
- Some(adt) => single_arg_adt(adt, ty),
- None => self.err_ty(),
- }
- }
- (RangeOp::Exclusive, Some(_), Some(ty)) => match self.resolve_range() {
- Some(adt) => single_arg_adt(adt, ty),
- None => self.err_ty(),
- },
- (RangeOp::Inclusive, Some(_), Some(ty)) => {
- match self.resolve_range_inclusive() {
- Some(adt) => single_arg_adt(adt, ty),
- None => self.err_ty(),
- }
- }
- (RangeOp::Exclusive, Some(ty), None) => match self.resolve_range_from() {
- Some(adt) => single_arg_adt(adt, ty),
- None => self.err_ty(),
- },
- (RangeOp::Inclusive, _, None) => self.err_ty(),
+ Pat::Range { start, end, range_type: _ } => {
+ if let Some(start) = *start {
+ let start_ty = self.infer_expr(start, &Expectation::None, ExprIsRead::Yes);
+ _ = self.demand_eqtype(start.into(), expected, start_ty);
+ }
+ if let Some(end) = *end {
+ let end_ty = self.infer_expr(end, &Expectation::None, ExprIsRead::Yes);
+ _ = self.demand_eqtype(end.into(), expected, end_ty);
}
+ expected
}
&Pat::Lit(expr) => {
// Don't emit type mismatches again, the expression lowering already did that.
diff --git a/crates/hir-ty/src/tests/patterns.rs b/crates/hir-ty/src/tests/patterns.rs
index 5e150e2bcc..c312b16759 100644
--- a/crates/hir-ty/src/tests/patterns.rs
+++ b/crates/hir-ty/src/tests/patterns.rs
@@ -196,28 +196,38 @@ fn test(x..y: &core::ops::Range<u32>) {
}
"#,
expect![[r#"
- 8..9 'x': u32
+ 8..9 'x': Range<u32>
8..12 'x..y': Range<u32>
- 11..12 'y': u32
+ 11..12 'y': Range<u32>
38..96 '{ ...2 {} }': ()
44..66 'if let...u32 {}': ()
47..63 'let 1....= 2u32': bool
- 51..52 '1': i32
- 51..56 '1..76': Range<i32>
- 54..56 '76': i32
+ 51..52 '1': u32
+ 51..56 '1..76': u32
+ 54..56 '76': u32
59..63 '2u32': u32
64..66 '{}': ()
71..94 'if let...u32 {}': ()
74..91 'let 1....= 2u32': bool
- 78..79 '1': i32
- 78..84 '1..=76': RangeInclusive<i32>
- 82..84 '76': i32
+ 78..79 '1': u32
+ 78..84 '1..=76': u32
+ 82..84 '76': u32
87..91 '2u32': u32
92..94 '{}': ()
- 51..56: expected u32, got Range<i32>
- 78..84: expected u32, got RangeInclusive<i32>
"#]],
);
+ check_no_mismatches(
+ r#"
+//- minicore: range
+fn main() {
+ let byte: u8 = 0u8;
+ let b = match byte {
+ b'0'..=b'9' => true,
+ _ => false,
+ };
+}
+ "#,
+ );
}
#[test]