Unnamed repository; edit this file 'description' to name the repository.
27 files changed, 514 insertions, 297 deletions
diff --git a/crates/hir-def/src/body.rs b/crates/hir-def/src/body.rs index 8fd9255b8b..545d2bebf5 100644 --- a/crates/hir-def/src/body.rs +++ b/crates/hir-def/src/body.rs @@ -24,7 +24,7 @@ use syntax::{ast, AstPtr, SyntaxNode, SyntaxNodePtr}; use crate::{ attr::Attrs, db::DefDatabase, - expr::{dummy_expr_id, Expr, ExprId, Label, LabelId, Pat, PatId}, + expr::{dummy_expr_id, Binding, BindingId, Expr, ExprId, Label, LabelId, Pat, PatId}, item_scope::BuiltinShadowMode, macro_id_to_def_id, nameres::DefMap, @@ -270,6 +270,7 @@ pub struct Mark { pub struct Body { pub exprs: Arena<Expr>, pub pats: Arena<Pat>, + pub bindings: Arena<Binding>, pub or_pats: FxHashMap<PatId, Arc<[PatId]>>, pub labels: Arena<Label>, /// The patterns for the function's parameters. While the parameter types are @@ -435,13 +436,24 @@ impl Body { } fn shrink_to_fit(&mut self) { - let Self { _c: _, body_expr: _, block_scopes, or_pats, exprs, labels, params, pats } = self; + let Self { + _c: _, + body_expr: _, + block_scopes, + or_pats, + exprs, + labels, + params, + pats, + bindings, + } = self; block_scopes.shrink_to_fit(); or_pats.shrink_to_fit(); exprs.shrink_to_fit(); labels.shrink_to_fit(); params.shrink_to_fit(); pats.shrink_to_fit(); + bindings.shrink_to_fit(); } } @@ -451,6 +463,7 @@ impl Default for Body { body_expr: dummy_expr_id(), exprs: Default::default(), pats: Default::default(), + bindings: Default::default(), or_pats: Default::default(), labels: Default::default(), params: Default::default(), @@ -484,6 +497,14 @@ impl Index<LabelId> for Body { } } +impl Index<BindingId> for Body { + type Output = Binding; + + fn index(&self, b: BindingId) -> &Binding { + &self.bindings[b] + } +} + // FIXME: Change `node_` prefix to something more reasonable. // Perhaps `expr_syntax` and `expr_id`? impl BodySourceMap { diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 3164a5f4c2..b7458bfb8a 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -15,6 +15,7 @@ use la_arena::Arena; use once_cell::unsync::OnceCell; use profile::Count; use rustc_hash::FxHashMap; +use smallvec::SmallVec; use syntax::{ ast::{ self, ArrayExprKind, AstChildren, HasArgList, HasLoopBody, HasName, LiteralKind, @@ -30,9 +31,9 @@ use crate::{ builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint}, db::DefDatabase, expr::{ - dummy_expr_id, Array, BindingAnnotation, ClosureKind, Expr, ExprId, FloatTypeWrapper, - Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField, - Statement, + dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId, + FloatTypeWrapper, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, + RecordFieldPat, RecordLitField, Statement, }, item_scope::BuiltinShadowMode, path::{GenericArgs, Path}, @@ -87,6 +88,7 @@ pub(super) fn lower( body: Body { exprs: Arena::default(), pats: Arena::default(), + bindings: Arena::default(), labels: Arena::default(), params: Vec::new(), body_expr: dummy_expr_id(), @@ -116,6 +118,22 @@ struct ExprCollector<'a> { is_lowering_generator: bool, } +#[derive(Debug, Default)] +struct BindingList { + map: FxHashMap<Name, BindingId>, +} + +impl BindingList { + fn find( + &mut self, + ec: &mut ExprCollector<'_>, + name: Name, + mode: BindingAnnotation, + ) -> BindingId { + *self.map.entry(name).or_insert_with_key(|n| ec.alloc_binding(n.clone(), mode)) + } +} + impl ExprCollector<'_> { fn collect( mut self, @@ -127,17 +145,16 @@ impl ExprCollector<'_> { param_list.self_param().filter(|_| attr_enabled.next().unwrap_or(false)) { let ptr = AstPtr::new(&self_param); - let param_pat = self.alloc_pat( - Pat::Bind { - name: name![self], - mode: BindingAnnotation::new( - self_param.mut_token().is_some() && self_param.amp_token().is_none(), - false, - ), - subpat: None, - }, - Either::Right(ptr), + let binding_id = self.alloc_binding( + name![self], + BindingAnnotation::new( + self_param.mut_token().is_some() && self_param.amp_token().is_none(), + false, + ), ); + let param_pat = + self.alloc_pat(Pat::Bind { id: binding_id, subpat: None }, Either::Right(ptr)); + self.add_definition_to_binding(binding_id, param_pat); self.body.params.push(param_pat); } @@ -179,6 +196,9 @@ impl ExprCollector<'_> { id } + fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId { + self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() }) + } fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId { let src = self.expander.to_source(ptr); let id = self.make_pat(pat, src.clone()); @@ -804,7 +824,7 @@ impl ExprCollector<'_> { } fn collect_pat(&mut self, pat: ast::Pat) -> PatId { - let pat_id = self.collect_pat_(pat); + let pat_id = self.collect_pat_(pat, &mut BindingList::default()); for (_, pats) in self.name_to_pat_grouping.drain() { let pats = Arc::<[_]>::from(pats); self.body.or_pats.extend(pats.iter().map(|&pat| (pat, pats.clone()))); @@ -820,7 +840,7 @@ impl ExprCollector<'_> { } } - fn collect_pat_(&mut self, pat: ast::Pat) -> PatId { + fn collect_pat_(&mut self, pat: ast::Pat, binding_list: &mut BindingList) -> PatId { let pattern = match &pat { ast::Pat::IdentPat(bp) => { let name = bp.name().map(|nr| nr.as_name()).unwrap_or_else(Name::missing); @@ -828,8 +848,10 @@ impl ExprCollector<'_> { let key = self.is_lowering_inside_or_pat.then(|| name.clone()); let annotation = BindingAnnotation::new(bp.mut_token().is_some(), bp.ref_token().is_some()); - let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat)); - let pattern = if annotation == BindingAnnotation::Unannotated && subpat.is_none() { + let subpat = bp.pat().map(|subpat| self.collect_pat_(subpat, binding_list)); + let (binding, pattern) = if annotation == BindingAnnotation::Unannotated + && subpat.is_none() + { // This could also be a single-segment path pattern. To // decide that, we need to try resolving the name. let (resolved, _) = self.expander.def_map.resolve_path( @@ -839,12 +861,12 @@ impl ExprCollector<'_> { BuiltinShadowMode::Other, ); match resolved.take_values() { - Some(ModuleDefId::ConstId(_)) => Pat::Path(name.into()), + Some(ModuleDefId::ConstId(_)) => (None, Pat::Path(name.into())), Some(ModuleDefId::EnumVariantId(_)) => { // this is only really valid for unit variants, but // shadowing other enum variants with a pattern is // an error anyway - Pat::Path(name.into()) + (None, Pat::Path(name.into())) } Some(ModuleDefId::AdtId(AdtId::StructId(s))) if self.db.struct_data(s).variant_data.kind() != StructKind::Record => @@ -852,17 +874,24 @@ impl ExprCollector<'_> { // Funnily enough, record structs *can* be shadowed // by pattern bindings (but unit or tuple structs // can't). - Pat::Path(name.into()) + (None, Pat::Path(name.into())) } // shadowing statics is an error as well, so we just ignore that case here - _ => Pat::Bind { name, mode: annotation, subpat }, + _ => { + let id = binding_list.find(self, name, annotation); + (Some(id), Pat::Bind { id, subpat }) + } } } else { - Pat::Bind { name, mode: annotation, subpat } + let id = binding_list.find(self, name, annotation); + (Some(id), Pat::Bind { id, subpat }) }; let ptr = AstPtr::new(&pat); let pat = self.alloc_pat(pattern, Either::Left(ptr)); + if let Some(binding_id) = binding { + self.add_definition_to_binding(binding_id, pat); + } if let Some(key) = key { self.name_to_pat_grouping.entry(key).or_default().push(pat); } @@ -871,11 +900,11 @@ impl ExprCollector<'_> { ast::Pat::TupleStructPat(p) => { let path = p.path().and_then(|path| self.expander.parse_path(self.db, path)).map(Box::new); - let (args, ellipsis) = self.collect_tuple_pat(p.fields()); + let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list); Pat::TupleStruct { path, args, ellipsis } } ast::Pat::RefPat(p) => { - let pat = self.collect_pat_opt(p.pat()); + let pat = self.collect_pat_opt_(p.pat(), binding_list); let mutability = Mutability::from_mutable(p.mut_token().is_some()); Pat::Ref { pat, mutability } } @@ -886,12 +915,12 @@ impl ExprCollector<'_> { } ast::Pat::OrPat(p) => { self.is_lowering_inside_or_pat = true; - let pats = p.pats().map(|p| self.collect_pat_(p)).collect(); + let pats = p.pats().map(|p| self.collect_pat_(p, binding_list)).collect(); Pat::Or(pats) } - ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat()), + ast::Pat::ParenPat(p) => return self.collect_pat_opt_(p.pat(), binding_list), ast::Pat::TuplePat(p) => { - let (args, ellipsis) = self.collect_tuple_pat(p.fields()); + let (args, ellipsis) = self.collect_tuple_pat(p.fields(), binding_list); Pat::Tuple { args, ellipsis } } ast::Pat::WildcardPat(_) => Pat::Wild, @@ -904,7 +933,7 @@ impl ExprCollector<'_> { .fields() .filter_map(|f| { let ast_pat = f.pat()?; - let pat = self.collect_pat_(ast_pat); + let pat = self.collect_pat_(ast_pat, binding_list); let name = f.field_name()?.as_name(); Some(RecordFieldPat { name, pat }) }) @@ -923,9 +952,15 @@ impl ExprCollector<'_> { // FIXME properly handle `RestPat` Pat::Slice { - prefix: prefix.into_iter().map(|p| self.collect_pat_(p)).collect(), - slice: slice.map(|p| self.collect_pat_(p)), - suffix: suffix.into_iter().map(|p| self.collect_pat_(p)).collect(), + prefix: prefix + .into_iter() + .map(|p| self.collect_pat_(p, binding_list)) + .collect(), + slice: slice.map(|p| self.collect_pat_(p, binding_list)), + suffix: suffix + .into_iter() + .map(|p| self.collect_pat_(p, binding_list)) + .collect(), } } ast::Pat::LiteralPat(lit) => { @@ -948,7 +983,7 @@ impl ExprCollector<'_> { Pat::Missing } ast::Pat::BoxPat(boxpat) => { - let inner = self.collect_pat_opt_(boxpat.pat()); + let inner = self.collect_pat_opt_(boxpat.pat(), binding_list); Pat::Box { inner } } ast::Pat::ConstBlockPat(const_block_pat) => { @@ -965,7 +1000,7 @@ impl ExprCollector<'_> { let src = self.expander.to_source(Either::Left(AstPtr::new(&pat))); let pat = self.collect_macro_call(call, macro_ptr, true, |this, expanded_pat| { - this.collect_pat_opt_(expanded_pat) + this.collect_pat_opt_(expanded_pat, binding_list) }); self.source_map.pat_map.insert(src, pat); return pat; @@ -979,21 +1014,25 @@ impl ExprCollector<'_> { self.alloc_pat(pattern, Either::Left(ptr)) } - fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>) -> PatId { + fn collect_pat_opt_(&mut self, pat: Option<ast::Pat>, binding_list: &mut BindingList) -> PatId { match pat { - Some(pat) => self.collect_pat_(pat), + Some(pat) => self.collect_pat_(pat, binding_list), None => self.missing_pat(), } } - fn collect_tuple_pat(&mut self, args: AstChildren<ast::Pat>) -> (Box<[PatId]>, Option<usize>) { + fn collect_tuple_pat( + &mut self, + args: AstChildren<ast::Pat>, + binding_list: &mut BindingList, + ) -> (Box<[PatId]>, Option<usize>) { // Find the location of the `..`, if there is one. Note that we do not // consider the possibility of there being multiple `..` here. let ellipsis = args.clone().position(|p| matches!(p, ast::Pat::RestPat(_))); // We want to skip the `..` pattern here, since we account for it above. let args = args .filter(|p| !matches!(p, ast::Pat::RestPat(_))) - .map(|p| self.collect_pat_(p)) + .map(|p| self.collect_pat_(p, binding_list)) .collect(); (args, ellipsis) @@ -1022,6 +1061,10 @@ impl ExprCollector<'_> { None => Some(()), } } + + fn add_definition_to_binding(&mut self, binding_id: BindingId, pat_id: PatId) { + self.body.bindings[binding_id].definitions.push(pat_id); + } } impl From<ast::LiteralKind> for Literal { diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs index 622756ee8a..f8b159797e 100644 --- a/crates/hir-def/src/body/pretty.rs +++ b/crates/hir-def/src/body/pretty.rs @@ -5,7 +5,7 @@ use std::fmt::{self, Write}; use syntax::ast::HasName; use crate::{ - expr::{Array, BindingAnnotation, ClosureKind, Literal, Movability, Statement}, + expr::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement}, pretty::{print_generic_args, print_path, print_type_ref}, type_ref::TypeRef, }; @@ -524,14 +524,8 @@ impl<'a> Printer<'a> { } Pat::Path(path) => self.print_path(path), Pat::Lit(expr) => self.print_expr(*expr), - Pat::Bind { mode, name, subpat } => { - let mode = match mode { - BindingAnnotation::Unannotated => "", - BindingAnnotation::Mutable => "mut ", - BindingAnnotation::Ref => "ref ", - BindingAnnotation::RefMut => "ref mut ", - }; - w!(self, "{}{}", mode, name); + Pat::Bind { id, subpat } => { + self.print_binding(*id); if let Some(pat) = subpat { self.whitespace(); self.print_pat(*pat); @@ -635,4 +629,15 @@ impl<'a> Printer<'a> { fn print_path(&mut self, path: &Path) { print_path(path, self).unwrap(); } + + fn print_binding(&mut self, id: BindingId) { + let Binding { name, mode, .. } = &self.body.bindings[id]; + let mode = match mode { + BindingAnnotation::Unannotated => "", + BindingAnnotation::Mutable => "mut ", + BindingAnnotation::Ref => "ref ", + BindingAnnotation::RefMut => "ref mut ", + }; + w!(self, "{}{}", mode, name); + } } diff --git a/crates/hir-def/src/body/scope.rs b/crates/hir-def/src/body/scope.rs index cab657b807..12fc1f116d 100644 --- a/crates/hir-def/src/body/scope.rs +++ b/crates/hir-def/src/body/scope.rs @@ -8,7 +8,7 @@ use rustc_hash::FxHashMap; use crate::{ body::Body, db::DefDatabase, - expr::{Expr, ExprId, LabelId, Pat, PatId, Statement}, + expr::{Binding, BindingId, Expr, ExprId, LabelId, Pat, PatId, Statement}, BlockId, DefWithBodyId, }; @@ -23,7 +23,7 @@ pub struct ExprScopes { #[derive(Debug, PartialEq, Eq)] pub struct ScopeEntry { name: Name, - pat: PatId, + binding: BindingId, } impl ScopeEntry { @@ -31,8 +31,8 @@ impl ScopeEntry { &self.name } - pub fn pat(&self) -> PatId { - self.pat + pub fn binding(&self) -> BindingId { + self.binding } } @@ -126,18 +126,23 @@ impl ExprScopes { }) } - fn add_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { + fn add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId) { + let Binding { name, .. } = &body.bindings[binding]; + let entry = ScopeEntry { name: name.clone(), binding }; + self.scopes[scope].entries.push(entry); + } + + fn add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) { let pattern = &body[pat]; - if let Pat::Bind { name, .. } = pattern { - let entry = ScopeEntry { name: name.clone(), pat }; - self.scopes[scope].entries.push(entry); + if let Pat::Bind { id, .. } = pattern { + self.add_bindings(body, scope, *id); } - pattern.walk_child_pats(|pat| self.add_bindings(body, scope, pat)); + pattern.walk_child_pats(|pat| self.add_pat_bindings(body, scope, pat)); } fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) { - params.iter().for_each(|pat| self.add_bindings(body, scope, *pat)); + params.iter().for_each(|pat| self.add_pat_bindings(body, scope, *pat)); } fn set_scope(&mut self, node: ExprId, scope: ScopeId) { @@ -170,7 +175,7 @@ fn compute_block_scopes( } *scope = scopes.new_scope(*scope); - scopes.add_bindings(body, *scope, *pat); + scopes.add_pat_bindings(body, *scope, *pat); } Statement::Expr { expr, .. } => { compute_expr_scopes(*expr, body, scopes, scope); @@ -208,7 +213,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope Expr::For { iterable, pat, body: body_expr, label } => { compute_expr_scopes(*iterable, body, scopes, scope); let mut scope = scopes.new_labeled_scope(*scope, make_label(label)); - scopes.add_bindings(body, scope, *pat); + scopes.add_pat_bindings(body, scope, *pat); compute_expr_scopes(*body_expr, body, scopes, &mut scope); } Expr::While { condition, body: body_expr, label } => { @@ -229,7 +234,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope compute_expr_scopes(*expr, body, scopes, scope); for arm in arms.iter() { let mut scope = scopes.new_scope(*scope); - scopes.add_bindings(body, scope, arm.pat); + scopes.add_pat_bindings(body, scope, arm.pat); if let Some(guard) = arm.guard { scope = scopes.new_scope(scope); compute_expr_scopes(guard, body, scopes, &mut scope); @@ -248,7 +253,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope &Expr::Let { pat, expr } => { compute_expr_scopes(expr, body, scopes, scope); *scope = scopes.new_scope(*scope); - scopes.add_bindings(body, *scope, pat); + scopes.add_pat_bindings(body, *scope, pat); } e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)), }; @@ -450,7 +455,7 @@ fn foo() { let function = find_function(&db, file_id); let scopes = db.expr_scopes(function.into()); - let (_body, source_map) = db.body_with_source_map(function.into()); + let (body, source_map) = db.body_with_source_map(function.into()); let expr_scope = { let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap(); @@ -460,7 +465,9 @@ fn foo() { }; let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap(); - let pat_src = source_map.pat_syntax(resolved.pat()).unwrap(); + let pat_src = source_map + .pat_syntax(*body.bindings[resolved.binding()].definitions.first().unwrap()) + .unwrap(); let local_name = pat_src.value.either( |it| it.syntax_node_ptr().to_node(file.syntax()), diff --git a/crates/hir-def/src/expr.rs b/crates/hir-def/src/expr.rs index 78a2f86123..bbea608c55 100644 --- a/crates/hir-def/src/expr.rs +++ b/crates/hir-def/src/expr.rs @@ -17,6 +17,7 @@ use std::fmt; use hir_expand::name::Name; use intern::Interned; use la_arena::{Idx, RawIdx}; +use smallvec::SmallVec; use crate::{ builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint}, @@ -29,6 +30,8 @@ pub use syntax::ast::{ArithOp, BinaryOp, CmpOp, LogicOp, Ordering, RangeOp, Unar pub type ExprId = Idx<Expr>; +pub type BindingId = Idx<Binding>; + /// FIXME: this is a hacky function which should be removed pub(crate) fn dummy_expr_id() -> ExprId { ExprId::from_raw(RawIdx::from(u32::MAX)) @@ -434,6 +437,13 @@ impl BindingAnnotation { } #[derive(Debug, Clone, Eq, PartialEq)] +pub struct Binding { + pub name: Name, + pub mode: BindingAnnotation, + pub definitions: SmallVec<[PatId; 1]>, +} + +#[derive(Debug, Clone, Eq, PartialEq)] pub struct RecordFieldPat { pub name: Name, pub pat: PatId, @@ -451,7 +461,7 @@ pub enum Pat { Slice { prefix: Box<[PatId]>, slice: Option<PatId>, suffix: Box<[PatId]> }, Path(Box<Path>), Lit(ExprId), - Bind { mode: BindingAnnotation, name: Name, subpat: Option<PatId> }, + Bind { id: BindingId, subpat: Option<PatId> }, TupleStruct { path: Option<Box<Path>>, args: Box<[PatId]>, ellipsis: Option<usize> }, Ref { pat: PatId, mutability: Mutability }, Box { inner: PatId }, diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs index eea837ddd2..61e64fc103 100644 --- a/crates/hir-def/src/resolver.rs +++ b/crates/hir-def/src/resolver.rs @@ -12,7 +12,7 @@ use crate::{ body::scope::{ExprScopes, ScopeId}, builtin_type::BuiltinType, db::DefDatabase, - expr::{ExprId, LabelId, PatId}, + expr::{BindingId, ExprId, LabelId}, generics::{GenericParams, TypeOrConstParamData}, item_scope::{BuiltinShadowMode, BUILTIN_SCOPE}, nameres::DefMap, @@ -105,7 +105,7 @@ pub enum ResolveValueResult { #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum ValueNs { ImplSelf(ImplId), - LocalBinding(PatId), + LocalBinding(BindingId), FunctionId(FunctionId), ConstId(ConstId), StaticId(StaticId), @@ -267,7 +267,7 @@ impl Resolver { if let Some(e) = entry { return Some(ResolveValueResult::ValueNs(ValueNs::LocalBinding( - e.pat(), + e.binding(), ))); } } @@ -617,7 +617,7 @@ pub enum ScopeDef { ImplSelfType(ImplId), AdtSelfType(AdtId), GenericParam(GenericParamId), - Local(PatId), + Local(BindingId), Label(LabelId), } @@ -669,7 +669,7 @@ impl Scope { acc.add(&name, ScopeDef::Label(label)) } scope.expr_scopes.entries(scope.scope_id).iter().for_each(|e| { - acc.add_local(e.name(), e.pat()); + acc.add_local(e.name(), e.binding()); }); } } @@ -859,7 +859,7 @@ impl ScopeNames { self.add(name, ScopeDef::Unknown) } } - fn add_local(&mut self, name: &Name, pat: PatId) { + fn add_local(&mut self, name: &Name, binding: BindingId) { let set = self.map.entry(name.clone()).or_default(); // XXX: hack, account for local (and only local) shadowing. // @@ -870,7 +870,7 @@ impl ScopeNames { cov_mark::hit!(shadowing_shows_single_completion); return; } - set.push(ScopeDef::Local(pat)) + set.push(ScopeDef::Local(binding)) } } diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index 19145b2d98..b7a466c389 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -546,6 +546,49 @@ fn let_else() { } #[test] +fn function_param_patterns() { + check_number( + r#" + const fn f((a, b): &(u8, u8)) -> u8 { + *a + *b + } + const GOAL: u8 = f(&(2, 3)); + "#, + 5, + ); + check_number( + r#" + const fn f(c @ (a, b): &(u8, u8)) -> u8 { + *a + *b + (*c).1 + } + const GOAL: u8 = f(&(2, 3)); + "#, + 8, + ); + check_number( + r#" + const fn f(ref a: u8) -> u8 { + *a + } + const GOAL: u8 = f(2); + "#, + 2, + ); + check_number( + r#" + struct Foo(u8); + impl Foo { + const fn f(&self, (a, b): &(u8, u8)) -> u8 { + self.0 + *a + *b + } + } + const GOAL: u8 = Foo(4).f(&(2, 3)); + "#, + 9, + ); +} + +#[test] fn options() { check_number( r#" diff --git a/crates/hir-ty/src/diagnostics/decl_check.rs b/crates/hir-ty/src/diagnostics/decl_check.rs index f4d1013ceb..d36b93e3bd 100644 --- a/crates/hir-ty/src/diagnostics/decl_check.rs +++ b/crates/hir-ty/src/diagnostics/decl_check.rs @@ -235,8 +235,8 @@ impl<'a> DeclValidator<'a> { let pats_replacements = body .pats .iter() - .filter_map(|(id, pat)| match pat { - Pat::Bind { name, .. } => Some((id, name)), + .filter_map(|(pat_id, pat)| match pat { + Pat::Bind { id, .. } => Some((pat_id, &body.bindings[*id].name)), _ => None, }) .filter_map(|(id, bind_name)| { diff --git a/crates/hir-ty/src/diagnostics/match_check.rs b/crates/hir-ty/src/diagnostics/match_check.rs index 8b0f051b46..859a37804a 100644 --- a/crates/hir-ty/src/diagnostics/match_check.rs +++ b/crates/hir-ty/src/diagnostics/match_check.rs @@ -146,8 +146,9 @@ impl<'a> PatCtxt<'a> { PatKind::Leaf { subpatterns } } - hir_def::expr::Pat::Bind { ref name, subpat, .. } => { + hir_def::expr::Pat::Bind { id, subpat, .. } => { let bm = self.infer.pat_binding_modes[&pat]; + let name = &self.body.bindings[id].name; match (bm, ty.kind(Interner)) { (BindingMode::Ref(_), TyKind::Ref(.., rty)) => ty = rty, (BindingMode::Ref(_), _) => { diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 869b39ab37..bac733d988 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -22,7 +22,7 @@ use hir_def::{ body::Body, builtin_type::{BuiltinInt, BuiltinType, BuiltinUint}, data::{ConstData, StaticData}, - expr::{BindingAnnotation, ExprId, ExprOrPatId, PatId}, + expr::{BindingAnnotation, BindingId, ExprId, ExprOrPatId, PatId}, lang_item::{LangItem, LangItemTarget}, layout::Integer, path::Path, @@ -352,6 +352,7 @@ pub struct InferenceResult { /// **Note**: When a pattern type is resolved it may still contain /// unresolved or missing subpatterns or subpatterns of mismatched types. pub type_of_pat: ArenaMap<PatId, Ty>, + pub type_of_binding: ArenaMap<BindingId, Ty>, pub type_of_rpit: ArenaMap<RpitId, Ty>, type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>, /// Interned common types to return references to. @@ -414,6 +415,14 @@ impl Index<PatId> for InferenceResult { } } +impl Index<BindingId> for InferenceResult { + type Output = Ty; + + fn index(&self, b: BindingId) -> &Ty { + self.type_of_binding.get(b).unwrap_or(&self.standard_types.unknown) + } +} + /// The inference context contains all information needed during type inference. #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'a> { @@ -534,7 +543,10 @@ impl<'a> InferenceContext<'a> { for ty in result.type_of_pat.values_mut() { *ty = table.resolve_completely(ty.clone()); } - for ty in result.type_of_rpit.iter_mut().map(|x| x.1) { + for ty in result.type_of_binding.values_mut() { + *ty = table.resolve_completely(ty.clone()); + } + for ty in result.type_of_rpit.values_mut() { *ty = table.resolve_completely(ty.clone()); } for mismatch in result.type_mismatches.values_mut() { @@ -704,6 +716,10 @@ impl<'a> InferenceContext<'a> { self.result.type_of_pat.insert(pat, ty); } + fn write_binding_ty(&mut self, id: BindingId, ty: Ty) { + self.result.type_of_binding.insert(id, ty); + } + fn push_diagnostic(&mut self, diagnostic: InferenceDiagnostic) { self.result.diagnostics.push(diagnostic); } diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs index a7bd009e34..566ed298c5 100644 --- a/crates/hir-ty/src/infer/pat.rs +++ b/crates/hir-ty/src/infer/pat.rs @@ -5,7 +5,7 @@ use std::iter::repeat_with; use chalk_ir::Mutability; use hir_def::{ body::Body, - expr::{BindingAnnotation, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId, RecordFieldPat}, + expr::{Binding, BindingAnnotation, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId, RecordFieldPat, BindingId}, path::Path, }; use hir_expand::name::Name; @@ -248,8 +248,8 @@ impl<'a> InferenceContext<'a> { // FIXME update resolver for the surrounding expression self.infer_path(path, pat.into()).unwrap_or_else(|| self.err_ty()) } - Pat::Bind { mode, name: _, subpat } => { - return self.infer_bind_pat(pat, *mode, default_bm, *subpat, &expected); + Pat::Bind { id, subpat } => { + return self.infer_bind_pat(pat, *id, default_bm, *subpat, &expected); } Pat::Slice { prefix, slice, suffix } => { self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm) @@ -320,11 +320,12 @@ impl<'a> InferenceContext<'a> { fn infer_bind_pat( &mut self, pat: PatId, - mode: BindingAnnotation, + binding: BindingId, default_bm: BindingMode, subpat: Option<PatId>, expected: &Ty, ) -> Ty { + let Binding { mode, .. } = self.body.bindings[binding]; let mode = if mode == BindingAnnotation::Unannotated { default_bm } else { @@ -344,7 +345,8 @@ impl<'a> InferenceContext<'a> { } BindingMode::Move => inner_ty.clone(), }; - self.write_pat_ty(pat, bound_ty); + self.write_pat_ty(pat, bound_ty.clone()); + self.write_binding_ty(binding, bound_ty); return inner_ty; } @@ -420,11 +422,14 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool { Pat::Lit(expr) => { !matches!(body[*expr], Expr::Literal(Literal::String(..) | Literal::ByteString(..))) } - Pat::Bind { - mode: BindingAnnotation::Mutable | BindingAnnotation::Unannotated, - subpat: Some(subpat), - .. - } => is_non_ref_pat(body, *subpat), + Pat::Bind { id, subpat: Some(subpat), .. } + if matches!( + body.bindings[*id].mode, + BindingAnnotation::Mutable | BindingAnnotation::Unannotated + ) => + { + is_non_ref_pat(body, *subpat) + } Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false, } } @@ -432,7 +437,7 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool { pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool { let mut res = false; walk_pats(body, pat_id, &mut |pat| { - res |= matches!(pat, Pat::Bind { mode: BindingAnnotation::Ref, .. }) + res |= matches!(pat, Pat::Bind { id, .. } if body.bindings[*id].mode == BindingAnnotation::Ref); }); res } diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs index 891e1fab2e..2267fedaa8 100644 --- a/crates/hir-ty/src/infer/path.rs +++ b/crates/hir-ty/src/infer/path.rs @@ -50,7 +50,7 @@ impl<'a> InferenceContext<'a> { }; let typable: ValueTyDefId = match value { - ValueNs::LocalBinding(pat) => match self.result.type_of_pat.get(pat) { + ValueNs::LocalBinding(pat) => match self.result.type_of_binding.get(pat) { Some(ty) => return Some(ty.clone()), None => { never!("uninferred pattern?"); diff --git a/crates/hir-ty/src/layout/tests.rs b/crates/hir-ty/src/layout/tests.rs index 546044fc13..a8971fde3c 100644 --- a/crates/hir-ty/src/layout/tests.rs +++ b/crates/hir-ty/src/layout/tests.rs @@ -65,17 +65,9 @@ fn eval_expr(ra_fixture: &str, minicore: &str) -> Result<Layout, LayoutError> { }) .unwrap(); let hir_body = db.body(adt_id.into()); - let pat = hir_body - .pats - .iter() - .find(|x| match x.1 { - hir_def::expr::Pat::Bind { name, .. } => name.to_smol_str() == "goal", - _ => false, - }) - .unwrap() - .0; + let b = hir_body.bindings.iter().find(|x| x.1.name.to_smol_str() == "goal").unwrap().0; let infer = db.infer(adt_id.into()); - let goal_ty = infer.type_of_pat[pat].clone(); + let goal_ty = infer.type_of_binding[b].clone(); layout_of_ty(&db, &goal_ty, module_id.krate()) } diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 936b56a021..8f28d62db0 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -6,7 +6,8 @@ use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind}; use hir_def::{ body::Body, expr::{ - Array, BindingAnnotation, ExprId, LabelId, Literal, MatchArm, Pat, PatId, RecordLitField, + Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId, + RecordLitField, }, layout::LayoutError, resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, @@ -30,7 +31,7 @@ struct LoopBlocks { struct MirLowerCtx<'a> { result: MirBody, owner: DefWithBodyId, - binding_locals: ArenaMap<PatId, LocalId>, + binding_locals: ArenaMap<BindingId, LocalId>, current_loop_blocks: Option<LoopBlocks>, discr_temp: Option<Place>, db: &'a dyn HirDatabase, @@ -43,7 +44,9 @@ pub enum MirLowerError { ConstEvalError(Box<ConstEvalError>), LayoutError(LayoutError), IncompleteExpr, - UnresolvedName, + UnresolvedName(String), + UnresolvedMethod, + UnresolvedField, MissingFunctionDefinition, TypeError(&'static str), NotSupported(String), @@ -222,22 +225,23 @@ impl MirLowerCtx<'_> { match &self.body.exprs[expr_id] { Expr::Missing => Err(MirLowerError::IncompleteExpr), Expr::Path(p) => { + let unresolved_name = || MirLowerError::UnresolvedName("".to_string()); let resolver = resolver_for_expr(self.db.upcast(), self.owner, expr_id); let pr = resolver .resolve_path_in_value_ns(self.db.upcast(), p.mod_path()) - .ok_or(MirLowerError::UnresolvedName)?; + .ok_or_else(unresolved_name)?; let pr = match pr { ResolveValueResult::ValueNs(v) => v, ResolveValueResult::Partial(..) => { return match self .infer .assoc_resolutions_for_expr(expr_id) - .ok_or(MirLowerError::UnresolvedName)? + .ok_or_else(unresolved_name)? .0 //.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))? { hir_def::AssocItemId::ConstId(c) => self.lower_const(c, current, place), - _ => return Err(MirLowerError::UnresolvedName), + _ => return Err(unresolved_name()), }; } }; @@ -394,7 +398,7 @@ impl MirLowerCtx<'_> { } Expr::MethodCall { receiver, args, .. } => { let (func_id, generic_args) = - self.infer.method_resolution(expr_id).ok_or(MirLowerError::UnresolvedName)?; + self.infer.method_resolution(expr_id).ok_or(MirLowerError::UnresolvedMethod)?; let ty = chalk_ir::TyKind::FnDef( CallableDefId::FunctionId(func_id).to_chalk(self.db), generic_args, @@ -476,7 +480,7 @@ impl MirLowerCtx<'_> { let variant_id = self .infer .variant_resolution_for_expr(expr_id) - .ok_or(MirLowerError::UnresolvedName)?; + .ok_or_else(|| MirLowerError::UnresolvedName("".to_string()))?; let subst = match self.expr_ty(expr_id).kind(Interner) { TyKind::Adt(_, s) => s.clone(), _ => not_supported!("Non ADT record literal"), @@ -487,7 +491,7 @@ impl MirLowerCtx<'_> { let mut operands = vec![None; variant_data.fields().len()]; for RecordLitField { name, expr } in fields.iter() { let field_id = - variant_data.field(name).ok_or(MirLowerError::UnresolvedName)?; + variant_data.field(name).ok_or(MirLowerError::UnresolvedField)?; let op; (op, current) = self.lower_expr_to_some_operand(*expr, current)?; operands[u32::from(field_id.into_raw()) as usize] = Some(op); @@ -509,7 +513,7 @@ impl MirLowerCtx<'_> { not_supported!("Union record literal with more than one field"); }; let local_id = - variant_data.field(name).ok_or(MirLowerError::UnresolvedName)?; + variant_data.field(name).ok_or(MirLowerError::UnresolvedField)?; let mut place = place; place .projection @@ -529,7 +533,7 @@ impl MirLowerCtx<'_> { let field = self .infer .field_resolution(expr_id) - .ok_or(MirLowerError::UnresolvedName)?; + .ok_or(MirLowerError::UnresolvedField)?; current_place.projection.push(ProjectionElem::Field(field)); } self.push_assignment(current, place, Operand::Copy(current_place).into()); @@ -962,8 +966,9 @@ impl MirLowerCtx<'_> { } (then_target, Some(else_target)) } - Pat::Bind { mode, name: _, subpat } => { - let target_place = self.binding_locals[pattern]; + Pat::Bind { id, subpat } => { + let target_place = self.binding_locals[*id]; + let mode = self.body.bindings[*id].mode; if let Some(subpat) = subpat { (current, current_else) = self.pattern_match( current, @@ -975,7 +980,7 @@ impl MirLowerCtx<'_> { )? } if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { - binding_mode = *mode; + binding_mode = mode; } self.push_assignment( current, @@ -1189,17 +1194,40 @@ pub fn lower_to_mir( let mut locals = Arena::new(); // 0 is return local locals.alloc(Local { mutability: Mutability::Mut, ty: infer[root_expr].clone() }); - let mut create_local_of_path = |p: PatId| { - // FIXME: mutablity is broken - locals.alloc(Local { mutability: Mutability::Not, ty: infer[p].clone() }) + let mut binding_locals: ArenaMap<BindingId, LocalId> = ArenaMap::new(); + let param_locals: ArenaMap<PatId, LocalId> = if let DefWithBodyId::FunctionId(fid) = owner { + let substs = TyBuilder::placeholder_subst(db, fid); + let callable_sig = db.callable_item_signature(fid.into()).substitute(Interner, &substs); + // 1 to param_len is for params + body.params + .iter() + .zip(callable_sig.params().iter()) + .map(|(&x, ty)| { + let local_id = locals.alloc(Local { mutability: Mutability::Not, ty: ty.clone() }); + if let Pat::Bind { id, subpat: None } = body[x] { + if matches!( + body.bindings[id].mode, + BindingAnnotation::Unannotated | BindingAnnotation::Mutable + ) { + binding_locals.insert(id, local_id); + } + } + (x, local_id) + }) + .collect() + } else { + if !body.params.is_empty() { + return Err(MirLowerError::TypeError("Unexpected parameter for non function body")); + } + ArenaMap::new() }; - // 1 to param_len is for params - let mut binding_locals: ArenaMap<PatId, LocalId> = - body.params.iter().map(|&x| (x, create_local_of_path(x))).collect(); // and then rest of bindings - for (pat_id, _) in body.pats.iter() { - if !binding_locals.contains_idx(pat_id) { - binding_locals.insert(pat_id, create_local_of_path(pat_id)); + for (id, _) in body.bindings.iter() { + if !binding_locals.contains_idx(id) { + binding_locals.insert( + id, + locals.alloc(Local { mutability: Mutability::Not, ty: infer[id].clone() }), + ); } } let mir = MirBody { basic_blocks, locals, start_block, owner, arg_count: body.params.len() }; @@ -1213,7 +1241,27 @@ pub fn lower_to_mir( current_loop_blocks: None, discr_temp: None, }; - let b = ctx.lower_expr_to_place(root_expr, return_slot().into(), start_block)?; + let mut current = start_block; + for ¶m in &body.params { + if let Pat::Bind { id, .. } = body[param] { + if param_locals[param] == ctx.binding_locals[id] { + continue; + } + } + let r = ctx.pattern_match( + current, + None, + param_locals[param].into(), + ctx.result.locals[param_locals[param]].ty.clone(), + param, + BindingAnnotation::Unannotated, + )?; + if let Some(b) = r.1 { + ctx.set_terminator(b, Terminator::Unreachable); + } + current = r.0; + } + let b = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)?; ctx.result.basic_blocks[b].terminator = Some(Terminator::Return); Ok(ctx.result) } diff --git a/crates/hir/src/from_id.rs b/crates/hir/src/from_id.rs index 4327691956..aaaa7abf38 100644 --- a/crates/hir/src/from_id.rs +++ b/crates/hir/src/from_id.rs @@ -4,7 +4,7 @@ //! are splitting the hir. use hir_def::{ - expr::{LabelId, PatId}, + expr::{BindingId, LabelId}, AdtId, AssocItemId, DefWithBodyId, EnumVariantId, FieldId, GenericDefId, GenericParamId, ModuleDefId, VariantId, }; @@ -251,9 +251,9 @@ impl From<AssocItem> for GenericDefId { } } -impl From<(DefWithBodyId, PatId)> for Local { - fn from((parent, pat_id): (DefWithBodyId, PatId)) -> Self { - Local { parent, pat_id } +impl From<(DefWithBodyId, BindingId)> for Local { + fn from((parent, binding_id): (DefWithBodyId, BindingId)) -> Self { + Local { parent, binding_id } } } diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index df6484db53..b83d83b5ed 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -41,7 +41,7 @@ use either::Either; use hir_def::{ adt::VariantData, body::{BodyDiagnostic, SyntheticSyntax}, - expr::{BindingAnnotation, ExprOrPatId, LabelId, Pat, PatId}, + expr::{BindingAnnotation, BindingId, ExprOrPatId, LabelId, Pat}, generics::{LifetimeParamData, TypeOrConstParamData, TypeParamProvenance}, item_tree::ItemTreeNode, lang_item::{LangItem, LangItemTarget}, @@ -77,7 +77,7 @@ use rustc_hash::FxHashSet; use stdx::{impl_from, never}; use syntax::{ ast::{self, HasAttrs as _, HasDocComments, HasName}, - AstNode, AstPtr, SmolStr, SyntaxNodePtr, TextRange, T, + AstNode, AstPtr, SmolStr, SyntaxNode, SyntaxNodePtr, TextRange, T, }; use crate::db::{DefDatabase, HirDatabase}; @@ -1782,8 +1782,8 @@ impl Param { let parent = DefWithBodyId::FunctionId(self.func.into()); let body = db.body(parent); let pat_id = body.params[self.idx]; - if let Pat::Bind { .. } = &body[pat_id] { - Some(Local { parent, pat_id: body.params[self.idx] }) + if let Pat::Bind { id, .. } = &body[pat_id] { + Some(Local { parent, binding_id: *id }) } else { None } @@ -2460,13 +2460,42 @@ impl GenericDef { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct Local { pub(crate) parent: DefWithBodyId, - pub(crate) pat_id: PatId, + pub(crate) binding_id: BindingId, +} + +pub struct LocalSource { + pub local: Local, + pub source: InFile<Either<ast::IdentPat, ast::SelfParam>>, +} + +impl LocalSource { + pub fn as_ident_pat(&self) -> Option<&ast::IdentPat> { + match &self.source.value { + Either::Left(x) => Some(x), + Either::Right(_) => None, + } + } + + pub fn into_ident_pat(self) -> Option<ast::IdentPat> { + match self.source.value { + Either::Left(x) => Some(x), + Either::Right(_) => None, + } + } + + pub fn original_file(&self, db: &dyn HirDatabase) -> FileId { + self.source.file_id.original_file(db.upcast()) + } + + pub fn syntax(&self) -> &SyntaxNode { + self.source.value.syntax() + } } impl Local { pub fn is_param(self, db: &dyn HirDatabase) -> bool { - let src = self.source(db); - match src.value { + let src = self.primary_source(db); + match src.source.value { Either::Left(pat) => pat .syntax() .ancestors() @@ -2486,13 +2515,7 @@ impl Local { pub fn name(self, db: &dyn HirDatabase) -> Name { let body = db.body(self.parent); - match &body[self.pat_id] { - Pat::Bind { name, .. } => name.clone(), - _ => { - stdx::never!("hir::Local is missing a name!"); - Name::missing() - } - } + body[self.binding_id].name.clone() } pub fn is_self(self, db: &dyn HirDatabase) -> bool { @@ -2501,15 +2524,12 @@ impl Local { pub fn is_mut(self, db: &dyn HirDatabase) -> bool { let body = db.body(self.parent); - matches!(&body[self.pat_id], Pat::Bind { mode: BindingAnnotation::Mutable, .. }) + body[self.binding_id].mode == BindingAnnotation::Mutable } pub fn is_ref(self, db: &dyn HirDatabase) -> bool { let body = db.body(self.parent); - matches!( - &body[self.pat_id], - Pat::Bind { mode: BindingAnnotation::Ref | BindingAnnotation::RefMut, .. } - ) + matches!(body[self.binding_id].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) } pub fn parent(self, _db: &dyn HirDatabase) -> DefWithBody { @@ -2523,34 +2543,33 @@ impl Local { pub fn ty(self, db: &dyn HirDatabase) -> Type { let def = self.parent; let infer = db.infer(def); - let ty = infer[self.pat_id].clone(); + let ty = infer[self.binding_id].clone(); Type::new(db, def, ty) } - pub fn associated_locals(self, db: &dyn HirDatabase) -> Box<[Local]> { - let body = db.body(self.parent); - body.ident_patterns_for(&self.pat_id) + /// All definitions for this local. Example: `let (a$0, _) | (_, a$0) = x;` + pub fn sources(self, db: &dyn HirDatabase) -> Vec<LocalSource> { + let (body, source_map) = db.body_with_source_map(self.parent); + body[self.binding_id] + .definitions .iter() - .map(|&pat_id| Local { parent: self.parent, pat_id }) + .map(|&definition| { + let src = source_map.pat_syntax(definition).unwrap(); // Hmm... + let root = src.file_syntax(db.upcast()); + src.map(|ast| match ast { + // Suspicious unwrap + Either::Left(it) => Either::Left(it.cast().unwrap().to_node(&root)), + Either::Right(it) => Either::Right(it.to_node(&root)), + }) + }) + .map(|source| LocalSource { local: self, source }) .collect() } - /// If this local is part of a multi-local, retrieve the representative local. - /// That is the local that references are being resolved to. - pub fn representative(self, db: &dyn HirDatabase) -> Local { - let body = db.body(self.parent); - Local { pat_id: body.pattern_representative(self.pat_id), ..self } - } - - pub fn source(self, db: &dyn HirDatabase) -> InFile<Either<ast::IdentPat, ast::SelfParam>> { - let (_body, source_map) = db.body_with_source_map(self.parent); - let src = source_map.pat_syntax(self.pat_id).unwrap(); // Hmm... - let root = src.file_syntax(db.upcast()); - src.map(|ast| match ast { - // Suspicious unwrap - Either::Left(it) => Either::Left(it.cast().unwrap().to_node(&root)), - Either::Right(it) => Either::Right(it.to_node(&root)), - }) + /// The leftmost definition for this local. Example: `let (a$0, _) | (_, a) = x;` + pub fn primary_source(self, db: &dyn HirDatabase) -> LocalSource { + let all_sources = self.sources(db); + all_sources.into_iter().next().unwrap() } } diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index 697f4b43bc..cbd350cea4 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -1654,8 +1654,8 @@ impl<'a> SemanticsScope<'a> { resolver::ScopeDef::ImplSelfType(it) => ScopeDef::ImplSelfType(it.into()), resolver::ScopeDef::AdtSelfType(it) => ScopeDef::AdtSelfType(it.into()), resolver::ScopeDef::GenericParam(id) => ScopeDef::GenericParam(id.into()), - resolver::ScopeDef::Local(pat_id) => match self.resolver.body_owner() { - Some(parent) => ScopeDef::Local(Local { parent, pat_id }), + resolver::ScopeDef::Local(binding_id) => match self.resolver.body_owner() { + Some(parent) => ScopeDef::Local(Local { parent, binding_id }), None => continue, }, resolver::ScopeDef::Label(label_id) => match self.resolver.body_owner() { diff --git a/crates/hir/src/semantics/source_to_def.rs b/crates/hir/src/semantics/source_to_def.rs index ddfec20e3f..f6f8c9a250 100644 --- a/crates/hir/src/semantics/source_to_def.rs +++ b/crates/hir/src/semantics/source_to_def.rs @@ -89,7 +89,7 @@ use base_db::FileId; use hir_def::{ child_by_source::ChildBySource, dyn_map::DynMap, - expr::{LabelId, PatId}, + expr::{BindingId, LabelId}, keys::{self, Key}, AdtId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FieldId, FunctionId, GenericDefId, GenericParamId, ImplId, LifetimeParamId, MacroId, ModuleId, StaticId, StructId, @@ -98,7 +98,7 @@ use hir_def::{ use hir_expand::{attrs::AttrId, name::AsName, HirFileId, MacroCallId}; use rustc_hash::FxHashMap; use smallvec::SmallVec; -use stdx::impl_from; +use stdx::{impl_from, never}; use syntax::{ ast::{self, HasName}, AstNode, SyntaxNode, @@ -216,14 +216,14 @@ impl SourceToDefCtx<'_, '_> { pub(super) fn bind_pat_to_def( &mut self, src: InFile<ast::IdentPat>, - ) -> Option<(DefWithBodyId, PatId)> { + ) -> Option<(DefWithBodyId, BindingId)> { let container = self.find_pat_or_label_container(src.syntax())?; let (body, source_map) = self.db.body_with_source_map(container); let src = src.map(ast::Pat::from); let pat_id = source_map.node_pat(src.as_ref())?; // the pattern could resolve to a constant, verify that that is not the case - if let crate::Pat::Bind { .. } = body[pat_id] { - Some((container, pat_id)) + if let crate::Pat::Bind { id, .. } = body[pat_id] { + Some((container, id)) } else { None } @@ -231,11 +231,16 @@ impl SourceToDefCtx<'_, '_> { pub(super) fn self_param_to_def( &mut self, src: InFile<ast::SelfParam>, - ) -> Option<(DefWithBodyId, PatId)> { + ) -> Option<(DefWithBodyId, BindingId)> { let container = self.find_pat_or_label_container(src.syntax())?; - let (_body, source_map) = self.db.body_with_source_map(container); + let (body, source_map) = self.db.body_with_source_map(container); let pat_id = source_map.node_self_param(src.as_ref())?; - Some((container, pat_id)) + if let crate::Pat::Bind { id, .. } = body[pat_id] { + Some((container, id)) + } else { + never!(); + None + } } pub(super) fn label_to_def( &mut self, diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index 118a7f8ea8..133fa810d6 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -422,8 +422,8 @@ impl SourceAnalyzer { // Shorthand syntax, resolve to the local let path = ModPath::from_segments(PathKind::Plain, once(local_name.clone())); match self.resolver.resolve_path_in_value_ns_fully(db.upcast(), &path) { - Some(ValueNs::LocalBinding(pat_id)) => { - Some(Local { pat_id, parent: self.resolver.body_owner()? }) + Some(ValueNs::LocalBinding(binding_id)) => { + Some(Local { binding_id, parent: self.resolver.body_owner()? }) } _ => None, } @@ -1018,8 +1018,8 @@ fn resolve_hir_path_( let values = || { resolver.resolve_path_in_value_ns_fully(db.upcast(), path.mod_path()).and_then(|val| { let res = match val { - ValueNs::LocalBinding(pat_id) => { - let var = Local { parent: body_owner?, pat_id }; + ValueNs::LocalBinding(binding_id) => { + let var = Local { parent: body_owner?, binding_id }; PathResolution::Local(var) } ValueNs::FunctionId(it) => PathResolution::Def(Function::from(it).into()), diff --git a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs index 65c2479e9f..745a870ab6 100644 --- a/crates/ide-assists/src/handlers/convert_match_to_let_else.rs +++ b/crates/ide-assists/src/handlers/convert_match_to_let_else.rs @@ -101,7 +101,7 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; match NameRefClass::classify(&ctx.sema, &name_ref)? { NameRefClass::Definition(Definition::Local(local)) => { - let source = local.source(ctx.db()).value.left()?; + let source = local.primary_source(ctx.db()).into_ident_pat()?; Some(source.name()?) } _ => None, diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index e04a1dabb2..0b90c9ba34 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -3,7 +3,8 @@ use std::iter; use ast::make; use either::Either; use hir::{ - HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam, + HasSource, HirDisplay, InFile, Local, LocalSource, ModuleDef, PathResolution, Semantics, + TypeInfo, TypeParam, }; use ide_db::{ defs::{Definition, NameRefClass}, @@ -710,7 +711,7 @@ impl FunctionBody { ) => local_ref, _ => return, }; - let InFile { file_id, value } = local_ref.source(sema.db); + let InFile { file_id, value } = local_ref.primary_source(sema.db).source; // locals defined inside macros are not relevant to us if !file_id.is_macro() { match value { @@ -972,11 +973,11 @@ impl FunctionBody { locals: impl Iterator<Item = Local>, ) -> Vec<Param> { locals - .map(|local| (local, local.source(ctx.db()))) + .map(|local| (local, local.primary_source(ctx.db()))) .filter(|(_, src)| is_defined_outside_of_body(ctx, self, src)) - .filter_map(|(local, src)| match src.value { - Either::Left(src) => Some((local, src)), - Either::Right(_) => { + .filter_map(|(local, src)| match src.into_ident_pat() { + Some(src) => Some((local, src)), + None => { stdx::never!(false, "Local::is_self returned false, but source is SelfParam"); None } @@ -1238,17 +1239,9 @@ fn local_outlives_body( fn is_defined_outside_of_body( ctx: &AssistContext<'_>, body: &FunctionBody, - src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>, + src: &LocalSource, ) -> bool { - src.file_id.original_file(ctx.db()) == ctx.file_id() - && !body.contains_node(either_syntax(&src.value)) -} - -fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode { - match value { - Either::Left(pat) => pat.syntax(), - Either::Right(it) => it.syntax(), - } + src.original_file(ctx.db()) == ctx.file_id() && !body.contains_node(src.syntax()) } /// find where to put extracted function definition diff --git a/crates/ide-assists/src/handlers/inline_local_variable.rs b/crates/ide-assists/src/handlers/inline_local_variable.rs index ce44100e34..e69d1a2967 100644 --- a/crates/ide-assists/src/handlers/inline_local_variable.rs +++ b/crates/ide-assists/src/handlers/inline_local_variable.rs @@ -1,4 +1,3 @@ -use either::Either; use hir::{PathResolution, Semantics}; use ide_db::{ base_db::FileId, @@ -205,12 +204,14 @@ fn inline_usage( return None; } - // FIXME: Handle multiple local definitions - let bind_pat = match local.source(sema.db).value { - Either::Left(ident) => ident, - _ => return None, + let sources = local.sources(sema.db); + let [source] = sources.as_slice() else { + // Not applicable with locals with multiple definitions (i.e. or patterns) + return None; }; + let bind_pat = source.as_ident_pat()?; + let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?; let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all(); diff --git a/crates/ide-db/src/rename.rs b/crates/ide-db/src/rename.rs index 4179f1bd4f..298c714139 100644 --- a/crates/ide-db/src/rename.rs +++ b/crates/ide-db/src/rename.rs @@ -121,14 +121,8 @@ impl Definition { Definition::Trait(it) => name_range(it, sema), Definition::TraitAlias(it) => name_range(it, sema), Definition::TypeAlias(it) => name_range(it, sema), - Definition::Local(local) => { - let src = local.source(sema.db); - let name = match &src.value { - Either::Left(bind_pat) => bind_pat.name()?, - Either::Right(_) => return None, - }; - src.with_value(name.syntax()).original_file_range_opt(sema.db) - } + // A local might be `self` or have multiple definitons like `let (a | a) = 2`, so it should be handled as a special case + Definition::Local(_) => return None, Definition::GenericParam(generic_param) => match generic_param { hir::GenericParam::LifetimeParam(lifetime_param) => { let src = lifetime_param.source(sema.db)?; @@ -302,13 +296,7 @@ fn rename_reference( source_change.insert_source_edit(file_id, edit); Ok(()) }; - match def { - Definition::Local(l) => l - .associated_locals(sema.db) - .iter() - .try_for_each(|&local| insert_def_edit(Definition::Local(local))), - def => insert_def_edit(def), - }?; + insert_def_edit(def)?; Ok(source_change) } @@ -471,59 +459,64 @@ fn source_edit_from_def( def: Definition, new_name: &str, ) -> Result<(FileId, TextEdit)> { - let FileRange { file_id, range } = def - .range_for_rename(sema) - .ok_or_else(|| format_err!("No identifier available to rename"))?; - let mut edit = TextEdit::builder(); if let Definition::Local(local) = def { - if let Either::Left(pat) = local.source(sema.db).value { - // special cases required for renaming fields/locals in Record patterns - if let Some(pat_field) = pat.syntax().parent().and_then(ast::RecordPatField::cast) { + let mut file_id = None; + for source in local.sources(sema.db) { + let source = source.source; + file_id = source.file_id.file_id(); + if let Either::Left(pat) = source.value { let name_range = pat.name().unwrap().syntax().text_range(); - if let Some(name_ref) = pat_field.name_ref() { - if new_name == name_ref.text() && pat.at_token().is_none() { - // Foo { field: ref mut local } -> Foo { ref mut field } - // ^^^^^^ delete this - // ^^^^^ replace this with `field` - cov_mark::hit!(test_rename_local_put_init_shorthand_pat); - edit.delete( - name_ref - .syntax() - .text_range() - .cover_offset(pat.syntax().text_range().start()), - ); - edit.replace(name_range, name_ref.text().to_string()); + // special cases required for renaming fields/locals in Record patterns + if let Some(pat_field) = pat.syntax().parent().and_then(ast::RecordPatField::cast) { + if let Some(name_ref) = pat_field.name_ref() { + if new_name == name_ref.text() && pat.at_token().is_none() { + // Foo { field: ref mut local } -> Foo { ref mut field } + // ^^^^^^ delete this + // ^^^^^ replace this with `field` + cov_mark::hit!(test_rename_local_put_init_shorthand_pat); + edit.delete( + name_ref + .syntax() + .text_range() + .cover_offset(pat.syntax().text_range().start()), + ); + edit.replace(name_range, name_ref.text().to_string()); + } else { + // Foo { field: ref mut local @ local 2} -> Foo { field: ref mut new_name @ local2 } + // Foo { field: ref mut local } -> Foo { field: ref mut new_name } + // ^^^^^ replace this with `new_name` + edit.replace(name_range, new_name.to_string()); + } } else { - // Foo { field: ref mut local @ local 2} -> Foo { field: ref mut new_name @ local2 } - // Foo { field: ref mut local } -> Foo { field: ref mut new_name } - // ^^^^^ replace this with `new_name` + // Foo { ref mut field } -> Foo { field: ref mut new_name } + // ^ insert `field: ` + // ^^^^^ replace this with `new_name` + edit.insert( + pat.syntax().text_range().start(), + format!("{}: ", pat_field.field_name().unwrap()), + ); edit.replace(name_range, new_name.to_string()); } } else { - // Foo { ref mut field } -> Foo { field: ref mut new_name } - // ^ insert `field: ` - // ^^^^^ replace this with `new_name` - edit.insert( - pat.syntax().text_range().start(), - format!("{}: ", pat_field.field_name().unwrap()), - ); edit.replace(name_range, new_name.to_string()); } } } + let Some(file_id) = file_id else { bail!("No file available to rename") }; + return Ok((file_id, edit.finish())); } - if edit.is_empty() { - let (range, new_name) = match def { - Definition::GenericParam(hir::GenericParam::LifetimeParam(_)) - | Definition::Label(_) => ( - TextRange::new(range.start() + syntax::TextSize::from(1), range.end()), - new_name.strip_prefix('\'').unwrap_or(new_name).to_owned(), - ), - _ => (range, new_name.to_owned()), - }; - edit.replace(range, new_name); - } + let FileRange { file_id, range } = def + .range_for_rename(sema) + .ok_or_else(|| format_err!("No identifier available to rename"))?; + let (range, new_name) = match def { + Definition::GenericParam(hir::GenericParam::LifetimeParam(_)) | Definition::Label(_) => ( + TextRange::new(range.start() + syntax::TextSize::from(1), range.end()), + new_name.strip_prefix('\'').unwrap_or(new_name).to_owned(), + ), + _ => (range, new_name.to_owned()), + }; + edit.replace(range, new_name); Ok((file_id, edit.finish())) } diff --git a/crates/ide-db/src/search.rs b/crates/ide-db/src/search.rs index bcdaac4cf8..6298ea1927 100644 --- a/crates/ide-db/src/search.rs +++ b/crates/ide-db/src/search.rs @@ -320,7 +320,7 @@ impl Definition { scope: None, include_self_kw_refs: None, local_repr: match self { - Definition::Local(local) => Some(local.representative(sema.db)), + Definition::Local(local) => Some(local), _ => None, }, search_self_mod: false, @@ -646,7 +646,7 @@ impl<'a> FindUsages<'a> { match NameRefClass::classify(self.sema, name_ref) { Some(NameRefClass::Definition(def @ Definition::Local(local))) if matches!( - self.local_repr, Some(repr) if repr == local.representative(self.sema.db) + self.local_repr, Some(repr) if repr == local ) => { let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); @@ -707,7 +707,7 @@ impl<'a> FindUsages<'a> { Definition::Field(_) if field == self.def => { ReferenceCategory::new(&field, name_ref) } - Definition::Local(_) if matches!(self.local_repr, Some(repr) if repr == local.representative(self.sema.db)) => { + Definition::Local(_) if matches!(self.local_repr, Some(repr) if repr == local) => { ReferenceCategory::new(&Definition::Local(local), name_ref) } _ => return false, @@ -755,7 +755,7 @@ impl<'a> FindUsages<'a> { Some(NameClass::Definition(def @ Definition::Local(local))) if def != self.def => { if matches!( self.local_repr, - Some(repr) if local.representative(self.sema.db) == repr + Some(repr) if local == repr ) { let FileRange { file_id, range } = self.sema.original_range(name.syntax()); let reference = FileReference { diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs index c889eb930f..d88ffd25c4 100644 --- a/crates/ide/src/highlight_related.rs +++ b/crates/ide/src/highlight_related.rs @@ -14,7 +14,7 @@ use syntax::{ SyntaxNode, SyntaxToken, TextRange, T, }; -use crate::{references, NavigationTarget, TryToNav}; +use crate::{navigation_target::ToNav, references, NavigationTarget, TryToNav}; #[derive(PartialEq, Eq, Hash)] pub struct HighlightedRange { @@ -98,32 +98,39 @@ fn highlight_references( category: access, }); let mut res = FxHashSet::default(); - - let mut def_to_hl_range = |def| { - let hl_range = match def { - Definition::Module(module) => { - Some(NavigationTarget::from_module_to_decl(sema.db, module)) - } - def => def.try_to_nav(sema.db), - } - .filter(|decl| decl.file_id == file_id) - .and_then(|decl| decl.focus_range) - .map(|range| { - let category = - references::decl_mutability(&def, node, range).then_some(ReferenceCategory::Write); - HighlightedRange { range, category } - }); - if let Some(hl_range) = hl_range { - res.insert(hl_range); - } - }; for &def in &defs { match def { - Definition::Local(local) => local - .associated_locals(sema.db) - .iter() - .for_each(|&local| def_to_hl_range(Definition::Local(local))), - def => def_to_hl_range(def), + Definition::Local(local) => { + let category = local.is_mut(sema.db).then_some(ReferenceCategory::Write); + local + .sources(sema.db) + .into_iter() + .map(|x| x.to_nav(sema.db)) + .filter(|decl| decl.file_id == file_id) + .filter_map(|decl| decl.focus_range) + .map(|range| HighlightedRange { range, category }) + .for_each(|x| { + res.insert(x); + }); + } + def => { + let hl_range = match def { + Definition::Module(module) => { + Some(NavigationTarget::from_module_to_decl(sema.db, module)) + } + def => def.try_to_nav(sema.db), + } + .filter(|decl| decl.file_id == file_id) + .and_then(|decl| decl.focus_range) + .map(|range| { + let category = references::decl_mutability(&def, node, range) + .then_some(ReferenceCategory::Write); + HighlightedRange { range, category } + }); + if let Some(hl_range) = hl_range { + res.insert(hl_range); + } + } } } diff --git a/crates/ide/src/hover/render.rs b/crates/ide/src/hover/render.rs index 6a29ddf59e..da725ce502 100644 --- a/crates/ide/src/hover/render.rs +++ b/crates/ide/src/hover/render.rs @@ -635,8 +635,8 @@ fn local(db: &RootDatabase, it: hir::Local) -> Option<Markup> { let ty = it.ty(db); let ty = ty.display_truncated(db, None); let is_mut = if it.is_mut(db) { "mut " } else { "" }; - let desc = match it.source(db).value { - Either::Left(ident) => { + let desc = match it.primary_source(db).into_ident_pat() { + Some(ident) => { let name = it.name(db); let let_kw = if ident .syntax() @@ -649,7 +649,7 @@ fn local(db: &RootDatabase, it: hir::Local) -> Option<Markup> { }; format!("{let_kw}{is_mut}{name}: {ty}") } - Either::Right(_) => format!("{is_mut}self: {ty}"), + None => format!("{is_mut}self: {ty}"), }; markup(None, desc, None) } diff --git a/crates/ide/src/navigation_target.rs b/crates/ide/src/navigation_target.rs index 11d10d2b85..6aae82f981 100644 --- a/crates/ide/src/navigation_target.rs +++ b/crates/ide/src/navigation_target.rs @@ -5,7 +5,7 @@ use std::fmt; use either::Either; use hir::{ symbols::FileSymbol, AssocItem, Documentation, FieldSource, HasAttrs, HasSource, HirDisplay, - InFile, ModuleSource, Semantics, + InFile, LocalSource, ModuleSource, Semantics, }; use ide_db::{ base_db::{FileId, FileRange}, @@ -387,9 +387,11 @@ impl TryToNav for hir::GenericParam { } } -impl ToNav for hir::Local { +impl ToNav for LocalSource { fn to_nav(&self, db: &RootDatabase) -> NavigationTarget { - let InFile { file_id, value } = self.source(db); + let InFile { file_id, value } = &self.source; + let file_id = *file_id; + let local = self.local; let (node, name) = match &value { Either::Left(bind_pat) => (bind_pat.syntax(), bind_pat.name()), Either::Right(it) => (it.syntax(), it.name()), @@ -398,10 +400,10 @@ impl ToNav for hir::Local { let FileRange { file_id, range: full_range } = InFile::new(file_id, node).original_file_range(db); - let name = self.name(db).to_smol_str(); - let kind = if self.is_self(db) { + let name = local.name(db).to_smol_str(); + let kind = if local.is_self(db) { SymbolKind::SelfParam - } else if self.is_param(db) { + } else if local.is_param(db) { SymbolKind::ValueParam } else { SymbolKind::Local @@ -419,6 +421,12 @@ impl ToNav for hir::Local { } } +impl ToNav for hir::Local { + fn to_nav(&self, db: &RootDatabase) -> NavigationTarget { + self.primary_source(db).to_nav(db) + } +} + impl ToNav for hir::Label { fn to_nav(&self, db: &RootDatabase) -> NavigationTarget { let InFile { file_id, value } = self.source(db); |