Unnamed repository; edit this file 'description' to name the repository.
Auto merge of #13475 - lowr:fix/lookup-impl-method-trait-ref, r=flodiebold
fix: Test all generic args for trait when finding matching impl
Addresses https://github.com/rust-lang/rust-analyzer/pull/13463#issuecomment-1287816680
When finding matching impl for a trait method, we've been testing the unifiability of self type. However, there can be multiple impl of a trait for the same type with different generic arguments for the trait. This patch takes it into account and tests the unifiability of all type arguments for the trait (the first being the self type) thus enables rust-analyzer to find the correct impl even in such cases.
| -rw-r--r-- | crates/hir-ty/src/infer/unify.rs | 12 | ||||
| -rw-r--r-- | crates/hir-ty/src/method_resolution.rs | 84 | ||||
| -rw-r--r-- | crates/hir/src/source_analyzer.rs | 65 | ||||
| -rw-r--r-- | crates/ide/src/goto_definition.rs | 82 |
4 files changed, 170 insertions, 73 deletions
diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index b00e3216b2..12f45f00f9 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -340,8 +340,8 @@ impl<'a> InferenceTable<'a> { self.resolve_with_fallback(t, &|_, _, d, _| d) } - /// Unify two types and register new trait goals that arise from that. - pub(crate) fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool { + /// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that. + pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool { let result = match self.try_unify(ty1, ty2) { Ok(r) => r, Err(_) => return false, @@ -350,9 +350,13 @@ impl<'a> InferenceTable<'a> { true } - /// Unify two types and return new trait goals arising from it, so the + /// Unify two relatable values (e.g. `Ty`) and return new trait goals arising from it, so the /// caller needs to deal with them. - pub(crate) fn try_unify<T: Zip<Interner>>(&mut self, t1: &T, t2: &T) -> InferResult<()> { + pub(crate) fn try_unify<T: ?Sized + Zip<Interner>>( + &mut self, + t1: &T, + t2: &T, + ) -> InferResult<()> { match self.var_unification_table.relate( Interner, &self.db, diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs index 5998680dcd..b1178ba0d2 100644 --- a/crates/hir-ty/src/method_resolution.rs +++ b/crates/hir-ty/src/method_resolution.rs @@ -22,10 +22,10 @@ use crate::{ from_foreign_def_id, infer::{unify::InferenceTable, Adjust, Adjustment, AutoBorrow, OverloadedDeref, PointerCast}, primitive::{FloatTy, IntTy, UintTy}, - static_lifetime, + static_lifetime, to_chalk_trait_id, utils::all_super_traits, AdtId, Canonical, CanonicalVarKinds, DebruijnIndex, ForeignDefId, InEnvironment, Interner, - Scalar, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyExt, TyKind, + Scalar, Substitution, TraitEnvironment, TraitRef, TraitRefExt, Ty, TyBuilder, TyExt, TyKind, }; /// This is used as a key for indexing impls. @@ -624,52 +624,76 @@ pub(crate) fn iterate_method_candidates<T>( slot } +/// Looks up the impl method that actually runs for the trait method `func`. +/// +/// Returns `func` if it's not a method defined in a trait or the lookup failed. pub fn lookup_impl_method( - self_ty: &Ty, db: &dyn HirDatabase, env: Arc<TraitEnvironment>, - trait_: TraitId, + func: FunctionId, + fn_subst: Substitution, +) -> FunctionId { + let trait_id = match func.lookup(db.upcast()).container { + ItemContainerId::TraitId(id) => id, + _ => return func, + }; + let trait_params = db.generic_params(trait_id.into()).type_or_consts.len(); + let fn_params = fn_subst.len(Interner) - trait_params; + let trait_ref = TraitRef { + trait_id: to_chalk_trait_id(trait_id), + substitution: Substitution::from_iter(Interner, fn_subst.iter(Interner).skip(fn_params)), + }; + + let name = &db.function_data(func).name; + lookup_impl_method_for_trait_ref(trait_ref, db, env, name).unwrap_or(func) +} + +fn lookup_impl_method_for_trait_ref( + trait_ref: TraitRef, + db: &dyn HirDatabase, + env: Arc<TraitEnvironment>, name: &Name, ) -> Option<FunctionId> { - let self_ty_fp = TyFingerprint::for_trait_impl(self_ty)?; - let trait_impls = db.trait_impls_in_deps(env.krate); - let impls = trait_impls.for_trait_and_self_ty(trait_, self_ty_fp); - let mut table = InferenceTable::new(db, env.clone()); - find_matching_impl(impls, &mut table, &self_ty).and_then(|data| { - data.items.iter().find_map(|it| match it { - AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f), - _ => None, - }) + let self_ty = trait_ref.self_type_parameter(Interner); + let self_ty_fp = TyFingerprint::for_trait_impl(&self_ty)?; + let impls = db.trait_impls_in_deps(env.krate); + let impls = impls.for_trait_and_self_ty(trait_ref.hir_trait_id(), self_ty_fp); + + let table = InferenceTable::new(db, env); + + let impl_data = find_matching_impl(impls, table, trait_ref)?; + impl_data.items.iter().find_map(|it| match it { + AssocItemId::FunctionId(f) => (db.function_data(*f).name == *name).then(|| *f), + _ => None, }) } fn find_matching_impl( mut impls: impl Iterator<Item = ImplId>, - table: &mut InferenceTable<'_>, - self_ty: &Ty, + mut table: InferenceTable<'_>, + actual_trait_ref: TraitRef, ) -> Option<Arc<ImplData>> { let db = table.db; loop { let impl_ = impls.next()?; let r = table.run_in_snapshot(|table| { let impl_data = db.impl_data(impl_); - let substs = + let impl_substs = TyBuilder::subst_for_def(db, impl_, None).fill_with_inference_vars(table).build(); - let impl_ty = db.impl_self_ty(impl_).substitute(Interner, &substs); - - table - .unify(self_ty, &impl_ty) - .then(|| { - let wh_goals = - crate::chalk_db::convert_where_clauses(db, impl_.into(), &substs) - .into_iter() - .map(|b| b.cast(Interner)); + let trait_ref = db + .impl_trait(impl_) + .expect("non-trait method in find_matching_impl") + .substitute(Interner, &impl_substs); - let goal = crate::Goal::all(Interner, wh_goals); + if !table.unify(&trait_ref, &actual_trait_ref) { + return None; + } - table.try_obligation(goal).map(|_| impl_data) - }) - .flatten() + let wcs = crate::chalk_db::convert_where_clauses(db, impl_.into(), &impl_substs) + .into_iter() + .map(|b| b.cast(Interner)); + let goal = crate::Goal::all(Interner, wcs); + table.try_obligation(goal).map(|_| impl_data) }); if r.is_some() { break r; @@ -1214,7 +1238,7 @@ fn is_valid_fn_candidate( let expected_receiver = sig.map(|s| s.params()[0].clone()).substitute(Interner, &fn_subst); - check_that!(table.unify(&receiver_ty, &expected_receiver)); + check_that!(table.unify(receiver_ty, &expected_receiver)); } if let ItemContainerId::ImplId(impl_id) = container { diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index 07bae2b38c..f86c571005 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -270,7 +270,7 @@ impl SourceAnalyzer { let expr_id = self.expr_id(db, &call.clone().into())?; let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?; - Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, substs)) } pub(crate) fn resolve_await_to_poll( @@ -311,7 +311,7 @@ impl SourceAnalyzer { // HACK: subst for `poll()` coincides with that for `Future` because `poll()` itself // doesn't have any generic parameters, so we skip building another subst for `poll()`. let substs = hir_ty::TyBuilder::subst_for_def(db, future_trait, None).push(ty).build(); - Some(self.resolve_impl_method_or_trait_def(db, poll_fn, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, poll_fn, substs)) } pub(crate) fn resolve_prefix_expr( @@ -331,7 +331,7 @@ impl SourceAnalyzer { // don't have any generic parameters, so we skip building another subst for the methods. let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build(); - Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs)) } pub(crate) fn resolve_index_expr( @@ -351,7 +351,7 @@ impl SourceAnalyzer { .push(base_ty.clone()) .push(index_ty.clone()) .build(); - Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs)) } pub(crate) fn resolve_bin_expr( @@ -372,7 +372,7 @@ impl SourceAnalyzer { .push(rhs.clone()) .build(); - Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs)) } pub(crate) fn resolve_try_expr( @@ -392,7 +392,7 @@ impl SourceAnalyzer { // doesn't have any generic parameters, so we skip building another subst for `branch()`. let substs = hir_ty::TyBuilder::subst_for_def(db, op_trait, None).push(ty.clone()).build(); - Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs)) + Some(self.resolve_impl_method_or_trait_def(db, op_fn, substs)) } pub(crate) fn resolve_field( @@ -487,9 +487,9 @@ impl SourceAnalyzer { let mut prefer_value_ns = false; let resolved = (|| { + let infer = self.infer.as_deref()?; if let Some(path_expr) = parent().and_then(ast::PathExpr::cast) { let expr_id = self.expr_id(db, &path_expr.into())?; - let infer = self.infer.as_ref()?; if let Some(assoc) = infer.assoc_resolutions_for_expr(expr_id) { let assoc = match assoc { AssocItemId::FunctionId(f_in_trait) => { @@ -497,9 +497,12 @@ impl SourceAnalyzer { None => assoc, Some(func_ty) => { if let TyKind::FnDef(_fn_def, subs) = func_ty.kind(Interner) { - self.resolve_impl_method(db, f_in_trait, subs) - .map(AssocItemId::FunctionId) - .unwrap_or(assoc) + self.resolve_impl_method_or_trait_def( + db, + f_in_trait, + subs.clone(), + ) + .into() } else { assoc } @@ -520,18 +523,18 @@ impl SourceAnalyzer { prefer_value_ns = true; } else if let Some(path_pat) = parent().and_then(ast::PathPat::cast) { let pat_id = self.pat_id(&path_pat.into())?; - if let Some(assoc) = self.infer.as_ref()?.assoc_resolutions_for_pat(pat_id) { + if let Some(assoc) = infer.assoc_resolutions_for_pat(pat_id) { return Some(PathResolution::Def(AssocItem::from(assoc).into())); } if let Some(VariantId::EnumVariantId(variant)) = - self.infer.as_ref()?.variant_resolution_for_pat(pat_id) + infer.variant_resolution_for_pat(pat_id) { return Some(PathResolution::Def(ModuleDef::Variant(variant.into()))); } } else if let Some(rec_lit) = parent().and_then(ast::RecordExpr::cast) { let expr_id = self.expr_id(db, &rec_lit.into())?; if let Some(VariantId::EnumVariantId(variant)) = - self.infer.as_ref()?.variant_resolution_for_expr(expr_id) + infer.variant_resolution_for_expr(expr_id) { return Some(PathResolution::Def(ModuleDef::Variant(variant.into()))); } @@ -541,8 +544,7 @@ impl SourceAnalyzer { || parent().and_then(ast::TupleStructPat::cast).map(ast::Pat::from); if let Some(pat) = record_pat.or_else(tuple_struct_pat) { let pat_id = self.pat_id(&pat)?; - let variant_res_for_pat = - self.infer.as_ref()?.variant_resolution_for_pat(pat_id); + let variant_res_for_pat = infer.variant_resolution_for_pat(pat_id); if let Some(VariantId::EnumVariantId(variant)) = variant_res_for_pat { return Some(PathResolution::Def(ModuleDef::Variant(variant.into()))); } @@ -780,37 +782,22 @@ impl SourceAnalyzer { false } - fn resolve_impl_method( + fn resolve_impl_method_or_trait_def( &self, db: &dyn HirDatabase, func: FunctionId, - substs: &Substitution, - ) -> Option<FunctionId> { - let impled_trait = match func.lookup(db.upcast()).container { - ItemContainerId::TraitId(trait_id) => trait_id, - _ => return None, - }; - if substs.is_empty(Interner) { - return None; - } - let self_ty = substs.at(Interner, 0).ty(Interner)?; + substs: Substitution, + ) -> FunctionId { let krate = self.resolver.krate(); - let trait_env = self.resolver.body_owner()?.as_generic_def_id().map_or_else( + let owner = match self.resolver.body_owner() { + Some(it) => it, + None => return func, + }; + let env = owner.as_generic_def_id().map_or_else( || Arc::new(hir_ty::TraitEnvironment::empty(krate)), |d| db.trait_environment(d), ); - - let fun_data = db.function_data(func); - method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name) - } - - fn resolve_impl_method_or_trait_def( - &self, - db: &dyn HirDatabase, - func: FunctionId, - substs: &Substitution, - ) -> FunctionId { - self.resolve_impl_method(db, func, substs).unwrap_or(func) + method_resolution::lookup_impl_method(db, env, func, substs) } fn lang_trait_fn( diff --git a/crates/ide/src/goto_definition.rs b/crates/ide/src/goto_definition.rs index d0be1b3f40..f97c67b144 100644 --- a/crates/ide/src/goto_definition.rs +++ b/crates/ide/src/goto_definition.rs @@ -1834,4 +1834,86 @@ fn f() { "#, ); } + + #[test] + fn goto_bin_op_multiple_impl() { + check( + r#" +//- minicore: add +struct S; +impl core::ops::Add for S { + fn add( + //^^^ + ) {} +} +impl core::ops::Add<usize> for S { + fn add( + ) {} +} + +fn f() { + S +$0 S +} +"#, + ); + + check( + r#" +//- minicore: add +struct S; +impl core::ops::Add for S { + fn add( + ) {} +} +impl core::ops::Add<usize> for S { + fn add( + //^^^ + ) {} +} + +fn f() { + S +$0 0usize +} +"#, + ); + } + + #[test] + fn path_call_multiple_trait_impl() { + check( + r#" +trait Trait<T> { + fn f(_: T); +} +impl Trait<i32> for usize { + fn f(_: i32) {} + //^ +} +impl Trait<i64> for usize { + fn f(_: i64) {} +} +fn main() { + usize::f$0(0i32); +} +"#, + ); + + check( + r#" +trait Trait<T> { + fn f(_: T); +} +impl Trait<i32> for usize { + fn f(_: i32) {} +} +impl Trait<i64> for usize { + fn f(_: i64) {} + //^ +} +fn main() { + usize::f$0(0i64); +} +"#, + ) + } } |