Unnamed repository; edit this file 'description' to name the repository.
Add missing exprs to visiting
There was a note in the docs that const blocks aren't visited, but most people don't read the docs and expect them to be visited. If you don't want them to, you can handle them specially.
Chayim Refael Friedman 3 weeks ago
parent 8534be9 · commit 5fe0582
-rw-r--r--crates/hir-def/src/expr_store.rs407
-rw-r--r--crates/hir-ty/src/diagnostics/unsafe_check.rs1
-rw-r--r--crates/hir/src/source_analyzer.rs4
-rw-r--r--crates/ide-diagnostics/src/handlers/type_must_be_known.rs13
4 files changed, 191 insertions, 234 deletions
diff --git a/crates/hir-def/src/expr_store.rs b/crates/hir-def/src/expr_store.rs
index f72419a3ae..2c2b477115 100644
--- a/crates/hir-def/src/expr_store.rs
+++ b/crates/hir-def/src/expr_store.rs
@@ -9,7 +9,10 @@ pub mod scope;
#[cfg(test)]
mod tests;
-use std::ops::{Deref, Index};
+use std::{
+ borrow::Borrow,
+ ops::{Deref, Index},
+};
use cfg::{CfgExpr, CfgOptions};
use either::Either;
@@ -27,8 +30,8 @@ use crate::{
db::DefDatabase,
expr_store::path::Path,
hir::{
- Array, AsmOperand, Binding, BindingId, Expr, ExprId, ExprOrPatId, Label, LabelId, Pat,
- PatId, RecordFieldPat, RecordSpread, Statement,
+ Array, AsmOperand, Binding, BindingId, Expr, ExprId, ExprOrPatId, InlineAsm, Label,
+ LabelId, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
},
nameres::{DefMap, block_def_map},
signatures::VariantFields,
@@ -554,31 +557,35 @@ impl ExpressionStore {
}
pub fn walk_pats_shallow(&self, pat_id: PatId, mut f: impl FnMut(PatId)) {
+ // Do not use `..` patterns or field accesses here, only destructuring, to ensure we cover all cases
+ // (we've had multiple bugs with this in the past).
let pat = &self[pat_id];
match pat {
- Pat::Range { .. }
- | Pat::Lit(..)
- | Pat::Path(..)
- | Pat::ConstBlock(..)
+ Pat::Range { start: _, end: _, range_type: _ }
+ | Pat::Lit(_)
+ | Pat::Path(_)
+ | Pat::ConstBlock(_)
| Pat::Wild
| Pat::Missing
| Pat::Rest
| Pat::Expr(_) => {}
- &Pat::Bind { subpat, .. } => {
+ &Pat::Bind { subpat, id: _ } => {
if let Some(subpat) = subpat {
f(subpat);
}
}
- Pat::Or(args) | Pat::Tuple { args, .. } | Pat::TupleStruct { args, .. } => {
+ Pat::Or(args)
+ | Pat::Tuple { args, ellipsis: _ }
+ | Pat::TupleStruct { args, ellipsis: _, path: _ } => {
args.iter().copied().for_each(f);
}
- Pat::Ref { pat, .. } => f(*pat),
+ Pat::Ref { pat, mutability: _ } => f(*pat),
Pat::Slice { prefix, slice, suffix } => {
let total_iter = prefix.iter().chain(slice.iter()).chain(suffix.iter());
total_iter.copied().for_each(f);
}
- Pat::Record { args, .. } => {
- args.iter().for_each(|RecordFieldPat { pat, .. }| f(*pat));
+ Pat::Record { args, ellipsis: _, path: _ } => {
+ args.iter().for_each(|RecordFieldPat { pat, name: _ }| f(*pat));
}
Pat::Box { inner } => f(*inner),
}
@@ -606,273 +613,197 @@ impl ExpressionStore {
self.expr_only.as_ref()?.binding_owners.get(&id).copied()
}
- /// Walks the immediate children expressions and calls `f` for each child expression.
- ///
- /// Note that this does not walk const blocks.
- pub fn walk_child_exprs(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) {
- let expr = &self[expr_id];
- match expr {
- Expr::Continue { .. }
- | Expr::Const(_)
+ fn walk_child_exprs_impl(&self, expr_id: ExprId, mut visitor: impl ExprVisitor) {
+ // Do not use `..` patterns or field accesses here, only destructuring, to ensure we cover all cases
+ // (we've had multiple bugs with this in the past).
+ match &self[expr_id] {
+ Expr::Continue { label: _ }
| Expr::Missing
| Expr::Path(_)
| Expr::OffsetOf(_)
| Expr::Literal(_)
| Expr::Underscore => {}
- Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op {
- AsmOperand::In { expr, .. }
- | AsmOperand::Out { expr: Some(expr), .. }
- | AsmOperand::InOut { expr, .. }
- | AsmOperand::Const(expr)
- | AsmOperand::Label(expr) => f(*expr),
- AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
- f(*in_expr);
- if let Some(out_expr) = out_expr {
- f(*out_expr);
+ Expr::InlineAsm(InlineAsm { operands, options: _, kind: _ }) => {
+ operands.iter().for_each(|(_, op)| match op {
+ AsmOperand::In { expr, reg: _ }
+ | AsmOperand::Out { expr: Some(expr), late: _, reg: _ }
+ | AsmOperand::InOut { expr, late: _, reg: _ }
+ | AsmOperand::Const(expr)
+ | AsmOperand::Label(expr) => visitor.on_expr(*expr),
+ AsmOperand::SplitInOut { in_expr, out_expr, late: _, reg: _ } => {
+ visitor.on_expr(*in_expr);
+ visitor.on_expr_opt(*out_expr);
}
- }
- AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
- }),
+ AsmOperand::Out { expr: None, late: _, reg: _ } | AsmOperand::Sym(_) => (),
+ })
+ }
Expr::If { condition, then_branch, else_branch } => {
- f(*condition);
- f(*then_branch);
- if let &Some(else_branch) = else_branch {
- f(else_branch);
- }
+ visitor.on_expr(*condition);
+ visitor.on_expr(*then_branch);
+ visitor.on_expr_opt(*else_branch);
}
Expr::Let { expr, pat } => {
- self.walk_exprs_in_pat(*pat, &mut f);
- f(*expr);
+ visitor.on_pat(*pat);
+ visitor.on_expr(*expr);
}
- Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => {
- for stmt in statements.iter() {
+ Expr::Block { statements, tail, id: _, label: _ }
+ | Expr::Unsafe { statements, tail, id: _ } => {
+ for stmt in statements {
match stmt {
- Statement::Let { initializer, else_branch, pat, .. } => {
- if let &Some(expr) = initializer {
- f(expr);
- }
- if let &Some(expr) = else_branch {
- f(expr);
- }
- self.walk_exprs_in_pat(*pat, &mut f);
+ Statement::Let { initializer, else_branch, pat, type_ref: _ } => {
+ visitor.on_expr_opt(*initializer);
+ visitor.on_expr_opt(*else_branch);
+ visitor.on_pat(*pat);
+ }
+ Statement::Expr { expr: expression, has_semi: _ } => {
+ visitor.on_expr(*expression)
}
- Statement::Expr { expr: expression, .. } => f(*expression),
Statement::Item(_) => (),
}
}
- if let &Some(expr) = tail {
- f(expr);
- }
+ visitor.on_expr_opt(*tail);
}
- Expr::Loop { body, .. } => f(*body),
- Expr::Call { callee, args, .. } => {
- f(*callee);
- args.iter().copied().for_each(f);
+ Expr::Loop { body, label: _ } => visitor.on_expr(*body),
+ Expr::Call { callee, args } => {
+ visitor.on_expr(*callee);
+ visitor.on_exprs(args);
}
- Expr::MethodCall { receiver, args, .. } => {
- f(*receiver);
- args.iter().copied().for_each(f);
+ Expr::MethodCall { receiver, args, generic_args: _, method_name: _ } => {
+ visitor.on_expr(*receiver);
+ visitor.on_exprs(args);
}
Expr::Match { expr, arms } => {
- f(*expr);
- arms.iter().for_each(|arm| {
- f(arm.expr);
- if let Some(guard) = arm.guard {
- f(guard);
- }
- self.walk_exprs_in_pat(arm.pat, &mut f);
+ visitor.on_expr(*expr);
+ arms.iter().for_each(|MatchArm { pat, guard, expr }| {
+ visitor.on_expr(*expr);
+ visitor.on_expr_opt(*guard);
+ visitor.on_pat(*pat);
});
}
- Expr::Break { expr, .. }
+ Expr::Break { expr, label: _ }
| Expr::Return { expr }
| Expr::Yield { expr }
- | Expr::Yeet { expr } => {
- if let &Some(expr) = expr {
- f(expr);
+ | Expr::Yeet { expr } => visitor.on_expr_opt(*expr),
+ Expr::Become { expr } => visitor.on_expr(*expr),
+ Expr::RecordLit { fields, spread, path: _ } => {
+ for RecordLitField { name: _, expr } in fields.iter() {
+ visitor.on_expr(*expr);
}
- }
- Expr::Become { expr } => f(*expr),
- Expr::RecordLit { fields, spread, .. } => {
- for field in fields.iter() {
- f(field.expr);
- }
- if let RecordSpread::Expr(expr) = spread {
- f(*expr);
+ match spread {
+ RecordSpread::Expr(expr) => visitor.on_expr(*expr),
+ RecordSpread::None | RecordSpread::FieldDefaults => {}
}
}
- Expr::Closure { body, .. } => {
- f(*body);
+ Expr::Closure {
+ body,
+ args,
+ arg_types: _,
+ capture_by: _,
+ closure_kind: _,
+ ret_type: _,
+ } => {
+ visitor.on_expr(*body);
+ visitor.on_pats(args);
}
- Expr::BinaryOp { lhs, rhs, .. } => {
- f(*lhs);
- f(*rhs);
+ Expr::BinaryOp { lhs, rhs, op: _ } => {
+ visitor.on_expr(*lhs);
+ visitor.on_expr(*rhs);
}
- Expr::Range { lhs, rhs, .. } => {
- if let &Some(lhs) = rhs {
- f(lhs);
- }
- if let &Some(rhs) = lhs {
- f(rhs);
- }
+ Expr::Range { lhs, rhs, range_type: _ } => {
+ visitor.on_expr_opt(*lhs);
+ visitor.on_expr_opt(*rhs);
}
- Expr::Index { base, index, .. } => {
- f(*base);
- f(*index);
+ Expr::Index { base, index } => {
+ visitor.on_expr(*base);
+ visitor.on_expr(*index);
}
- Expr::Field { expr, .. }
+ Expr::Field { expr, name: _ }
| Expr::Await { expr }
- | Expr::Cast { expr, .. }
- | Expr::Ref { expr, .. }
- | Expr::UnaryOp { expr, .. }
- | Expr::Box { expr } => {
- f(*expr);
+ | Expr::Cast { expr, type_ref: _ }
+ | Expr::Ref { expr, mutability: _, rawness: _ }
+ | Expr::UnaryOp { expr, op: _ }
+ | Expr::Box { expr }
+ | Expr::Const(expr) => {
+ visitor.on_expr(*expr);
}
- Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f),
+ Expr::Tuple { exprs } => visitor.on_exprs(exprs),
Expr::Array(a) => match a {
- Array::ElementList { elements, .. } => elements.iter().copied().for_each(f),
+ Array::ElementList { elements } => visitor.on_exprs(elements),
Array::Repeat { initializer, repeat } => {
- f(*initializer);
- f(*repeat)
+ visitor.on_expr(*initializer);
+ visitor.on_expr(*repeat)
}
},
&Expr::Assignment { target, value } => {
- self.walk_exprs_in_pat(target, &mut f);
- f(value);
+ visitor.on_pat(target);
+ visitor.on_expr(value);
}
}
}
- /// Walks the immediate children expressions and calls `f` for each child expression but does
- /// not walk expressions within patterns.
- ///
- /// Note that this does not walk const blocks.
- pub fn walk_child_exprs_without_pats(&self, expr_id: ExprId, mut f: impl FnMut(ExprId)) {
- let expr = &self[expr_id];
- match expr {
- Expr::Continue { .. }
- | Expr::Const(_)
- | Expr::Missing
- | Expr::Path(_)
- | Expr::OffsetOf(_)
- | Expr::Literal(_)
- | Expr::Underscore => {}
- Expr::InlineAsm(it) => it.operands.iter().for_each(|(_, op)| match op {
- AsmOperand::In { expr, .. }
- | AsmOperand::Out { expr: Some(expr), .. }
- | AsmOperand::InOut { expr, .. }
- | AsmOperand::Const(expr)
- | AsmOperand::Label(expr) => f(*expr),
- AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
- f(*in_expr);
- if let Some(out_expr) = out_expr {
- f(*out_expr);
- }
- }
- AsmOperand::Out { expr: None, .. } | AsmOperand::Sym(_) => (),
- }),
- Expr::If { condition, then_branch, else_branch } => {
- f(*condition);
- f(*then_branch);
- if let &Some(else_branch) = else_branch {
- f(else_branch);
- }
- }
- Expr::Let { expr, .. } => {
- f(*expr);
- }
- Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => {
- for stmt in statements.iter() {
- match stmt {
- Statement::Let { initializer, else_branch, .. } => {
- if let &Some(expr) = initializer {
- f(expr);
- }
- if let &Some(expr) = else_branch {
- f(expr);
- }
- }
- Statement::Expr { expr: expression, .. } => f(*expression),
- Statement::Item(_) => (),
- }
- }
- if let &Some(expr) = tail {
- f(expr);
- }
- }
- Expr::Loop { body, .. } => f(*body),
- Expr::Call { callee, args, .. } => {
- f(*callee);
- args.iter().copied().for_each(f);
- }
- Expr::MethodCall { receiver, args, .. } => {
- f(*receiver);
- args.iter().copied().for_each(f);
- }
- Expr::Match { expr, arms } => {
- f(*expr);
- arms.iter().map(|arm| arm.expr).for_each(f);
- }
- Expr::Break { expr, .. }
- | Expr::Return { expr }
- | Expr::Yield { expr }
- | Expr::Yeet { expr } => {
- if let &Some(expr) = expr {
- f(expr);
- }
- }
- Expr::Become { expr } => f(*expr),
- Expr::RecordLit { fields, spread, .. } => {
- for field in fields.iter() {
- f(field.expr);
- }
- if let RecordSpread::Expr(expr) = spread {
- f(*expr);
- }
- }
- Expr::Closure { body, .. } => {
- f(*body);
- }
- Expr::BinaryOp { lhs, rhs, .. } => {
- f(*lhs);
- f(*rhs);
- }
- Expr::Range { lhs, rhs, .. } => {
- if let &Some(lhs) = rhs {
- f(lhs);
- }
- if let &Some(rhs) = lhs {
- f(rhs);
- }
+ /// Walks the immediate children expressions and calls `f` for each child expression.
+ pub fn walk_child_exprs(&self, expr_id: ExprId, callback: impl FnMut(ExprId)) {
+ return self.walk_child_exprs_impl(expr_id, Visitor { callback, store: self });
+
+ struct Visitor<'a, F> {
+ callback: F,
+ store: &'a ExpressionStore,
+ }
+
+ impl<F: FnMut(ExprId)> ExprVisitor for Visitor<'_, F> {
+ fn on_expr(&mut self, expr: ExprId) {
+ (self.callback)(expr);
}
- Expr::Index { base, index, .. } => {
- f(*base);
- f(*index);
+
+ fn on_pat(&mut self, pat: PatId) {
+ self.store.walk_exprs_in_pat(pat, &mut self.callback);
}
- Expr::Field { expr, .. }
- | Expr::Await { expr }
- | Expr::Cast { expr, .. }
- | Expr::Ref { expr, .. }
- | Expr::UnaryOp { expr, .. }
- | Expr::Box { expr } => {
- f(*expr);
+ }
+ }
+
+ /// Walks the immediate children expressions and calls `f` for each child expression but does
+ /// not walk expressions within patterns.
+ pub fn walk_child_exprs_without_pats(&self, expr_id: ExprId, callback: impl FnMut(ExprId)) {
+ return self.walk_child_exprs_impl(expr_id, Visitor { callback });
+
+ struct Visitor<F> {
+ callback: F,
+ }
+
+ impl<F: FnMut(ExprId)> ExprVisitor for Visitor<F> {
+ fn on_expr(&mut self, expr: ExprId) {
+ (self.callback)(expr);
}
- Expr::Tuple { exprs, .. } => exprs.iter().copied().for_each(f),
- Expr::Array(a) => match a {
- Array::ElementList { elements, .. } => elements.iter().copied().for_each(f),
- Array::Repeat { initializer, repeat } => {
- f(*initializer);
- f(*repeat)
- }
- },
- &Expr::Assignment { target: _, value } => f(value),
+
+ fn on_pat(&mut self, _pat: PatId) {}
}
}
- pub fn walk_exprs_in_pat(&self, pat_id: PatId, f: &mut impl FnMut(ExprId)) {
- self.walk_pats(pat_id, &mut |pat| {
- if let Pat::Expr(expr) | Pat::ConstBlock(expr) = self[pat] {
+ pub fn walk_exprs_in_pat(&self, pat_id: PatId, mut f: impl FnMut(ExprId)) {
+ self.walk_pats(pat_id, &mut |pat| match self[pat] {
+ Pat::Expr(expr) | Pat::ConstBlock(expr) | Pat::Lit(expr) => {
f(expr);
}
+ Pat::Range { start, end, range_type: _ } => {
+ if let Some(start) = start {
+ f(start);
+ }
+ if let Some(end) = end {
+ f(end);
+ }
+ }
+ Pat::Missing
+ | Pat::Rest
+ | Pat::Wild
+ | Pat::Tuple { .. }
+ | Pat::Or(_)
+ | Pat::Record { .. }
+ | Pat::Slice { .. }
+ | Pat::Path(_)
+ | Pat::Bind { .. }
+ | Pat::TupleStruct { .. }
+ | Pat::Ref { .. }
+ | Pat::Box { .. } => {}
});
}
@@ -940,6 +871,22 @@ impl ExpressionStore {
}
}
+trait ExprVisitor {
+ fn on_expr(&mut self, expr: ExprId);
+ fn on_pat(&mut self, pat: PatId);
+ fn on_expr_opt(&mut self, expr: Option<ExprId>) {
+ if let Some(expr) = expr {
+ self.on_expr(expr);
+ }
+ }
+ fn on_exprs(&mut self, exprs: impl IntoIterator<Item: Borrow<ExprId>>) {
+ exprs.into_iter().for_each(|expr| self.on_expr(*expr.borrow()));
+ }
+ fn on_pats(&mut self, exprs: impl IntoIterator<Item: Borrow<PatId>>) {
+ exprs.into_iter().for_each(|expr| self.on_pat(*expr.borrow()));
+ }
+}
+
impl Index<ExprId> for ExpressionStore {
type Output = Expr;
diff --git a/crates/hir-ty/src/diagnostics/unsafe_check.rs b/crates/hir-ty/src/diagnostics/unsafe_check.rs
index 4893d72a5c..ad9909a204 100644
--- a/crates/hir-ty/src/diagnostics/unsafe_check.rs
+++ b/crates/hir-ty/src/diagnostics/unsafe_check.rs
@@ -424,7 +424,6 @@ impl<'db> UnsafeVisitor<'db> {
Expr::Closure { args, .. } => {
self.walk_pats_top(args.iter().copied(), current);
}
- Expr::Const(e) => self.walk_expr(*e),
_ => {}
}
diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs
index 1aec56a0e0..783faa9ac8 100644
--- a/crates/hir/src/source_analyzer.rs
+++ b/crates/hir/src/source_analyzer.rs
@@ -1455,9 +1455,7 @@ impl<'db> SourceAnalyzer<'db> {
};
match expanded_expr {
ExprOrPatId::ExprId(expanded_expr) => walk_expr(expanded_expr),
- ExprOrPatId::PatId(expanded_pat) => {
- body.walk_exprs_in_pat(expanded_pat, &mut walk_expr)
- }
+ ExprOrPatId::PatId(expanded_pat) => body.walk_exprs_in_pat(expanded_pat, walk_expr),
}
return is_unsafe;
}
diff --git a/crates/ide-diagnostics/src/handlers/type_must_be_known.rs b/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
index 5363f4a5ce..08bcc738cb 100644
--- a/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
+++ b/crates/ide-diagnostics/src/handlers/type_must_be_known.rs
@@ -103,4 +103,17 @@ fn foo() {
"#,
);
}
+
+ #[test]
+ fn const_block_does_not_cause_error() {
+ check_diagnostics(
+ r#"
+fn bar<T>(_inner: fn() -> *const T) {}
+
+fn foo() {
+ bar(const { || 0 as *const i32 })
+}
+ "#,
+ );
+ }
}