Unnamed repository; edit this file 'description' to name the repository.
Merge pull request #22214 from ChayimFriedman2/type-annotations-const-block
fix: Add missing exprs to visiting
Shoyu Vanilla (Flint) 3 weeks ago
parent aa64e48 · parent 5fe0582 · commit 941cf42
-rw-r--r--crates/hir-def/src/expr_store.rs407
-rw-r--r--crates/hir-ty/src/diagnostics/unsafe_check.rs1
-rw-r--r--crates/hir/src/source_analyzer.rs4
-rw-r--r--crates/ide-diagnostics/src/handlers/type_must_be_known.rs13
4 files changed, 191 insertions, 234 deletions
diff --git a/crates/hir-def/src/expr_store.rs b/crates/hir-def/src/expr_store.rs
index f72419a3ae..2c2b477115 100644
--- a/crates/hir-def/src/expr_store.rs
+++ b/crates/hir-def/src/expr_store.rs
@@ -9,7 +9,10 @@ pub mod scope;
#[cfg(test)]
mod tests;
-use std::ops::{Deref, Index};
+use std::{
+ borrow::Borrow,
+ ops::{Deref, Index},
+};
use cfg::{CfgExpr, CfgOptions};
use either::Either;
@@ -27,8 +30,8 @@ use crate::{
db::DefDatabase,
expr_store::path::Path,
hir::{
- Array, AsmOperand, Binding, BindingId, Expr, ExprId, ExprOrPatId, Label, LabelId, Pat,
- PatId, RecordFieldPat, RecordSpread, Statement,
+ Array, AsmOperand, Binding, BindingId, Expr, ExprId, ExprOrPatId, InlineAsm, Label,
+ LabelId, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
},
nameres::{DefMap, block_def_map},
signatures::VariantFields,
@@ -554,31 +557,35 @@ impl ExpressionStore {
}
pub fn walk_pats_shallow(&self, pat_id: PatId, mut f: impl FnMut(PatId)) {
+ // Do not use `..` patterns or field accesses here, only destructuring, to ensure we cover all cases
+ // (we've had multiple bugs with this in the past).
let pat = &self[pat_id];
match pat {
- Pat::Range { .. }
- | Pat::Lit(..)
- | Pat::Path(..)
- | Pat::ConstBlock(..)
+ Pat::Range { start: _, end: _, range_type: _ }
+ | Pat::Lit(_)
+ | Pat::Path(_)
+ | Pat::ConstBlock(_)
| Pat::Wild
| Pat::Missing
| Pat::Rest
| Pat::Expr(_) => {}
- &Pat::Bind { subpat, .. } => {
+ &Pat::Bind { subpat, id: _ } => {
if let Some(subpat) = subpat {
f(subpat);
}
}
- Pat::Or(args) | Pat::Tuple { args, .. } | Pat::TupleStruct { args, .. } => {
+ Pat::Or(args)
+ | Pat::Tuple { args, ellipsis: _ }
+ | Pat::TupleStruct { args, ellipsis: _, path: _ } => {
args.iter().copied().for_each(f);
}
- Pat::Ref { pat, .. } => f(*pat),
+ Pat::Ref { pat, mutability: _ } => f(*pat),
Pat::Slice { prefix, slice, suffix } => {
let total_iter = prefix.iter().chain(slice.iter()).chain(suffix.iter());
total_iter.copied().for_each(f);
}
- Pat::Record { args, .. } => {
- args.iter().for_each(|RecordFieldPat { pat, .. }| f(*pat));
+ Pat::Record { args, ellipsis: _, path: _ } => {
+ args.iter().for_each(|RecordFieldPat { pat, name: _ }| f(*pat));
}
Pat::Box { inner } => f(*inner),
}
@@ -606,273 +613,197 @@ impl ExpressionStore {
self.expr_only.as_ref()?.binding_owners.get(&id).copied()
}
- /// Walks the immediate children expressions and calls `f` for each child expression.
- ///
- /// Note that this does not walk const blocks.
- pub fn walk_child_exprs(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) {
- let expr = &self[expr_id];
- match expr {
- Expr::Continue { .. }
- | Expr::Const(_)
+ fn walk_child_exprs_impl(&self, expr_id: ExprId, mut visitor: impl ExprVisitor) {
+ // Do not use `..` patterns or field accesses here, only destructuring, to ensure we cover all cases
+ // (we've had multiple bugs with this in the past).
+ match &self[expr_id] {
+ Expr::Continue { label: _ }
| Expr::Missing
| Expr::Path(_)
| Expr::OffsetOf(_)
| Expr::Literal(_)
| Expr::Underscore => {}
- Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op {
- AsmOperand::In { expr, .. }
- | AsmOperand::Out { expr: Some(expr), .. }
- | AsmOperand::InOut { expr, .. }
- | AsmOperand::Const(expr)
- | AsmOperand::Label(expr) => f(*expr),
- AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
- f(*in_expr);
- if let Some(out_expr) = out_expr {
- f(*out_expr);
+ Expr::InlineAsm(InlineAsm { operands, options: _, kind: _ }) => {
+ operands.iter().for_each(|(_, op)| match op {
+ AsmOperand::In { expr, reg: _ }
+ | AsmOperand::Out { expr: Some(expr), late: _, reg: _ }
+ | AsmOperand::InOut { expr, late: _, reg: _ }
+ | AsmOperand::Const(expr)
+ | AsmOperand::Label(expr) => visitor.on_expr(*expr),
+ AsmOperand::SplitInOut { in_expr, out_expr, late: _, reg: _ } => {
+ visitor.on_expr(*in_expr);
+ visitor.on_expr_opt(*out_expr);
}
- }
- AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
- }),
+ AsmOperand::Out { expr: None, late: _, reg: _ } | AsmOperand::Sym(_) => (),
+ })
+ }
Expr::If { condition, then_branch, else_branch } => {
- f(*condition);
- f(*then_branch);
- if let &Some(else_branch) = else_branch {
- f(else_branch);
- }
+ visitor.on_expr(*condition);
+ visitor.on_expr(*then_branch);
+ visitor.on_expr_opt(*else_branch);
}
Expr::Let { expr, pat } => {
- self.walk_exprs_in_pat(*pat, &mut f);
- f(*expr);
+ visitor.on_pat(*pat);
+ visitor.on_expr(*expr);
}
- Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => {
- for stmt in statements.iter() {
+ Expr::Block { statements, tail, id: _, label: _ }
+ | Expr::Unsafe { statements, tail, id: _ } => {
+ for stmt in statements {
match stmt {
- Statement::Let { initializer, else_branch, pat, .. } => {
- if let &Some(expr) = initializer {
- f(expr);
- }
- if let &Some(expr) = else_branch {
- f(expr);
- }
- self.walk_exprs_in_pat(*pat, &mut f);
+ Statement::Let { initializer, else_branch, pat, type_ref: _ } => {
+ visitor.on_expr_opt(*initializer);
+ visitor.on_expr_opt(*else_branch);
+ visitor.on_pat(*pat);
+ }
+ Statement::Expr { expr: expression, has_semi: _ } => {
+ visitor.on_expr(*expression)
}
- Statement::Expr { expr: expression, .. } => f(*expression),
Statement::Item(_) => (),
}
}
- if let &Some(expr) = tail {
- f(expr);
- }
+ visitor.on_expr_opt(*tail);
}
- Expr::Loop { body, .. } => f(*body),
- Expr::Call { callee, args, .. } => {
- f(*callee);
- args.iter().copied().for_each(f);
+ Expr::Loop { body, label: _ } => visitor.on_expr(*body),
+ Expr::Call { callee, args } => {
+ visitor.on_expr(*callee);
+ visitor.on_exprs(args);
}
- Expr::MethodCall { receiver, args, .. } => {
- f(*receiver);
- args.iter().copied().for_each(f);
+ Expr::MethodCall { receiver, args, generic_args: _, method_name: _ } => {
+ visitor.on_expr(*receiver);
+ visitor.on_exprs(args);
}
Expr::Match { expr, arms } => {
- f(*expr);
- arms.iter().for_each(|arm| {
- f(arm.expr);
- if let Some(guard) = arm.guard {
- f(guard);
- }
- self.walk_exprs_in_pat(arm.pat, &mut f);
+ visitor.on_expr(*expr);
+ arms.iter().for_each(|MatchArm { pat, guard, expr }| {
+ visitor.on_expr(*expr);
+ visitor.on_expr_opt(*guard);
+ visitor.on_pat(*pat);
});
}
- Expr::Break { expr, .. }
+ Expr::Break { expr, label: _ }
| Expr::Return { expr }
| Expr::Yield { expr }
- | Expr::Yeet { expr } => {
- if let &Some(expr) = expr {
- f(expr);
+ | Expr::Yeet { expr } => visitor.on_expr_opt(*expr),
+ Expr::Become { expr } => visitor.on_expr(*expr),
+ Expr::RecordLit { fields, spread, path: _ } => {
+ for RecordLitField { name: _, expr } in fields.iter() {
+ visitor.on_expr(*expr);
}
- }
- Expr::Become { expr } => f(*expr),
- Expr::RecordLit { fields, spread, .. } => {
- for field in fields.iter() {
- f(field.expr);
- }
- if let RecordSpread::Expr(expr) = spread {
- f(*expr);
+ match spread {
+ RecordSpread::Expr(expr) => visitor.on_expr(*expr),
+ RecordSpread::None | RecordSpread::FieldDefaults => {}
}
}
- Expr::Closure { body, .. } => {
- f(*body);
+ Expr::Closure {
+ body,
+ args,
+ arg_types: _,
+ capture_by: _,
+ closure_kind: _,
+ ret_type: _,
+ } => {
+ visitor.on_expr(*body);
+ visitor.on_pats(args);
}
- Expr::BinaryOp { lhs, rhs, .. } => {
- f(*lhs);
- f(*rhs);
+ Expr::BinaryOp { lhs, rhs, op: _ } => {
+ visitor.on_expr(*lhs);
+ visitor.on_expr(*rhs);
}
- Expr::Range { lhs, rhs, .. } => {
- if let &Some(lhs) = rhs {
- f(lhs);
- }
- if let &Some(rhs) = lhs {
- f(rhs);
- }
+ Expr::Range { lhs, rhs, range_type: _ } => {
+ visitor.on_expr_opt(*lhs);
+ visitor.on_expr_opt(*rhs);
}
- Expr::Index { base, index, .. } => {
- f(*base);
- f(*index);
+ Expr::Index { base, index } => {
+ visitor.on_expr(*base);
+ visitor.on_expr(*index);
}
- Expr::Field { expr, .. }
+ Expr::Field { expr, name: _ }
| Expr::Await { expr }
- | Expr::Cast { expr, .. }
- | Expr::Ref { expr, .. }
- | Expr::UnaryOp { expr, .. }
- | Expr::Box { expr } => {
- f(*expr);
+ | Expr::Cast { expr, type_ref: _ }
+ | Expr::Ref { expr, mutability: _, rawness: _ }
+ | Expr::UnaryOp { expr, op: _ }
+ | Expr::Box { expr }
+ | Expr::Const(expr) => {
+ visitor.on_expr(*expr);
}
- Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f),
+ Expr::Tuple { exprs } => visitor.on_exprs(exprs),
Expr::Array(a) => match a {
- Array::ElementList { elements, .. } => elements.iter().copied().for_each(f),
+ Array::ElementList { elements } => visitor.on_exprs(elements),
Array::Repeat { initializer, repeat } => {
- f(*initializer);
- f(*repeat)
+ visitor.on_expr(*initializer);
+ visitor.on_expr(*repeat)
}
},
&Expr::Assignment { target, value } => {
- self.walk_exprs_in_pat(target, &mut f);
- f(value);
+ visitor.on_pat(target);
+ visitor.on_expr(value);
}
}
}
- /// Walks the immediate children expressions and calls `f` for each child expression but does
- /// not walk expressions within patterns.
- ///
- /// Note that this does not walk const blocks.
- pub fn walk_child_exprs_without_pats(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) {
- let expr = &self[expr_id];
- match expr {
- Expr::Continue { .. }
- | Expr::Const(_)
- | Expr::Missing
- | Expr::Path(_)
- | Expr::OffsetOf(_)
- | Expr::Literal(_)
- | Expr::Underscore => {}
- Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op {
- AsmOperand::In { expr, .. }
- | AsmOperand::Out { expr: Some(expr), .. }
- | AsmOperand::InOut { expr, .. }
- | AsmOperand::Const(expr)
- | AsmOperand::Label(expr) => f(*expr),
- AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
- f(*in_expr);
- if let Some(out_expr) = out_expr {
- f(*out_expr);
- }
- }
- AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
- }),
- Expr::If { condition, then_branch, else_branch } => {
- f(*condition);
- f(*then_branch);
- if let &Some(else_branch) = else_branch {
- f(else_branch);
- }
- }
- Expr::Let { expr, .. } => {
- f(*expr);
- }
- Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => {
- for stmt in statements.iter() {
- match stmt {
- Statement::Let { initializer, else_branch, .. } => {
- if let &Some(expr) = initializer {
- f(expr);
- }
- if let &Some(expr) = else_branch {
- f(expr);
- }
- }
- Statement::Expr { expr: expression, .. } => f(*expression),
- Statement::Item(_) => (),
- }
- }
- if let &Some(expr) = tail {
- f(expr);
- }
- }
- Expr::Loop { body, .. } => f(*body),
- Expr::Call { callee, args, .. } => {
- f(*callee);
- args.iter().copied().for_each(f);
- }
- Expr::MethodCall { receiver, args, .. } => {
- f(*receiver);
- args.iter().copied().for_each(f);
- }
- Expr::Match { expr, arms } => {
- f(*expr);
- arms.iter().map(|arm| arm.expr).for_each(f);
- }
- Expr::Break { expr, .. }
- | Expr::Return { expr }
- | Expr::Yield { expr }
- | Expr::Yeet { expr } => {
- if let &Some(expr) = expr {
- f(expr);
- }
- }
- Expr::Become { expr } => f(*expr),
- Expr::RecordLit { fields, spread, .. } => {
- for field in fields.iter() {
- f(field.expr);
- }
- if let RecordSpread::Expr(expr) = spread {
- f(*expr);
- }
- }
- Expr::Closure { body, .. } => {
- f(*body);
- }
- Expr::BinaryOp { lhs, rhs, .. } => {
- f(*lhs);
- f(*rhs);
- }
- Expr::Range { lhs, rhs, .. } => {
- if let &Some(lhs) = rhs {
- f(lhs);
- }
- if let &Some(rhs) = lhs {
- f(rhs);
- }
+ /// Walks the immediate children expressions and calls `f` for each child expression.
+ pub fn walk_child_exprs(&self, expr_id: ExprId, callback: impl FnMut(ExprId)) {
+ return self.walk_child_exprs_impl(expr_id, Visitor { callback, store: self });
+
+ struct Visitor<'a, F> {
+ callback: F,
+ store: &'a ExpressionStore,
+ }
+
+ impl<F: FnMut(ExprId)> ExprVisitor for Visitor<'_, F> {
+ fn on_expr(&mut self, expr: ExprId) {
+ (self.callback)(expr);
}
- Expr::Index { base, index, .. } => {
- f(*base);
- f(*index);
+
+ fn on_pat(&mut self, pat: PatId) {
+ self.store.walk_exprs_in_pat(pat, &mut self.callback);
}
- Expr::Field { expr, .. }
- | Expr::Await { expr }
- | Expr::Cast { expr, .. }
- | Expr::Ref { expr, .. }
- | Expr::UnaryOp { expr, .. }
- | Expr::Box { expr } => {
- f(*expr);
+ }
+ }
+
+ /// Walks the immediate children expressions and calls `f` for each child expression but does
+ /// not walk expressions within patterns.
+ pub fn walk_child_exprs_without_pats(&self, expr_id: ExprId, callback: impl FnMut(ExprId)) {
+ return self.walk_child_exprs_impl(expr_id, Visitor { callback });
+
+ struct Visitor<F> {
+ callback: F,
+ }
+
+ impl<F: FnMut(ExprId)> ExprVisitor for Visitor<F> {
+ fn on_expr(&mut self, expr: ExprId) {
+ (self.callback)(expr);
}
- Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f),
- Expr::Array(a) => match a {
- Array::ElementList { elements, .. } => elements.iter().copied().for_each(f),
- Array::Repeat { initializer, repeat } => {
- f(*initializer);
- f(*repeat)
- }
- },
- &Expr::Assignment { target: _, value } => f(value),
+
+ fn on_pat(&mut self, _pat: PatId) {}
}
}
- pub fn walk_exprs_in_pat(&self, pat_id: PatId, f: &mut impl FnMut(ExprId)) {
- self.walk_pats(pat_id, &mut |pat| {
- if let Pat::Expr(expr) | Pat::ConstBlock(expr) = self[pat] {
+ pub fn walk_exprs_in_pat(&self, pat_id: PatId, mut f: impl FnMut(ExprId)) {
+ self.walk_pats(pat_id, &mut |pat| match self[pat] {
+ Pat::Expr(expr) | Pat::ConstBlock(expr) | Pat::Lit(expr) => {
f(expr);
}
+ Pat::Range { start, end, range_type: _ } => {
+ if let Some(start) = start {
+ f(start);
+ }
+ if let Some(end) = end {
+ f(end);
+ }
+ }
+ Pat::Missing
+ | Pat::Rest
+ | Pat::Wild
+ | Pat::Tuple { .. }
+ | Pat::Or(_)
+ | Pat::Record { .. }
+ | Pat::Slice { .. }
+ | Pat::Path(_)
+ | Pat::Bind { .. }
+ | Pat::TupleStruct { .. }
+ | Pat::Ref { .. }
+ | Pat::Box { .. } => {}
});
}
@@ -940,6 +871,22 @@ impl ExpressionStore {
}
}
+trait ExprVisitor {
+ fn on_expr(&mut self, expr: ExprId);
+ fn on_pat(&mut self, pat: PatId);
+ fn on_expr_opt(&mut self, expr: Option<ExprId>) {
+ if let Some(expr) = expr {
+ self.on_expr(expr);
+ }
+ }
+ fn on_exprs(&mut self, exprs: impl IntoIterator<Item: Borrow<ExprId>>) {
+ exprs.into_iter().for_each(|expr| self.on_expr(*expr.borrow()));
+ }
+ fn on_pats(&mut self, exprs: impl IntoIterator<Item: Borrow<PatId>>) {
+ exprs.into_iter().for_each(|expr| self.on_pat(*expr.borrow()));
+ }
+}
+
impl Index<ExprId> for ExpressionStore {
type Output = Expr;
diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs
index 4893d72a5c..ad9909a204 100644
--- a/crates/hir-ty/src/diagnostics/unsafe_check.rs
+++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs
@@ -424,7 +424,6 @@ impl<'db> UnsafeVisitor<'db> {
Expr::Closure { args, .. } => {
self.walk_pats_top(args.iter().copied(), current);
}
- Expr::Const(e) => self.walk_expr(*e),
_ => {}
}
diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs
index 1aec56a0e0..783faa9ac8 100644
--- a/crates/hir/src/source_analyzer.rs
+++ b/crates/hir/src/source_analyzer.rs
@@ -1455,9 +1455,7 @@ impl<'db> SourceAnalyzer<'db> {
};
match expanded_expr {
ExprOrPatId::ExprId(expanded_expr) => walk_expr(expanded_expr),
- ExprOrPatId::PatId(expanded_pat) => {
- body.walk_exprs_in_pat(expanded_pat, &mut walk_expr)
- }
+ ExprOrPatId::PatId(expanded_pat) => body.walk_exprs_in_pat(expanded_pat, walk_expr),
}
return is_unsafe;
}
diff --git a/crates/ide-diagnostics/src/handlers/type_must_be_known.rs b/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
index 5363f4a5ce..08bcc738cb 100644
--- a/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
+++ b/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
@@ -103,4 +103,17 @@ fn foo() {
"#,
);
}
+
+ #[test]
+ fn const_block_does_not_cause_error() {
+ check_diagnostics(
+ r#"
+fn bar<T>(_inner: fn() -> *const T) {}
+
+fn foo() {
+ bar(const { || 0 as *const i32 })
+}
+ "#,
+ );
+ }
}