Unnamed repository; edit this file 'description' to name the repository.
Have an `upvars_mentioned()` query that only computes what upvars a closure captures
It's required for coercion, where we now use it, as well as for other things.
| -rw-r--r-- | crates/hir-ty/src/infer/coerce.rs | 39 | ||||
| -rw-r--r-- | crates/hir-ty/src/lib.rs | 1 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/simple.rs | 23 | ||||
| -rw-r--r-- | crates/hir-ty/src/upvars.rs | 319 |
4 files changed, 359 insertions, 23 deletions
diff --git a/crates/hir-ty/src/infer/coerce.rs b/crates/hir-ty/src/infer/coerce.rs index ec7dad0fa0..e79868f4ae 100644 --- a/crates/hir-ty/src/infer/coerce.rs +++ b/crates/hir-ty/src/infer/coerce.rs @@ -56,7 +56,7 @@ use tracing::{debug, instrument}; use crate::{ Adjust, Adjustment, AutoBorrow, ParamEnvAndCrate, PointerCast, TargetFeatures, autoderef::Autoderef, - db::{HirDatabase, InternedClosureId}, + db::{HirDatabase, InternedClosure, InternedClosureId}, infer::{ AllowTwoPhase, AutoBorrowMutability, InferenceContext, TypeMismatch, expr::ExprIsRead, }, @@ -74,6 +74,7 @@ use crate::{ }, obligation_ctxt::ObligationCtxt, }, + upvars::upvars_mentioned, utils::TargetFeatureIsSafeInTarget, }; @@ -896,7 +897,7 @@ where fn coerce_closure_to_fn( &mut self, a: Ty<'db>, - _closure_def_id_a: InternedClosureId, + closure_def_id_a: InternedClosureId, args_a: GenericArgs<'db>, b: Ty<'db>, ) -> CoerceResult<'db> { @@ -904,19 +905,7 @@ where debug_assert!(self.infcx().shallow_resolve(b) == b); match b.kind() { - // FIXME: We need to have an `upvars_mentioned()` query: - // At this point we haven't done capture analysis, which means - // that the ClosureArgs just contains an inference variable instead - // of tuple of captured types. - // - // All we care here is if any variable is being captured and not the exact paths, - // so we check `upvars_mentioned` for root variables being captured. - TyKind::FnPtr(_, hdr) => - // if self - // .db - // .upvars_mentioned(closure_def_id_a.expect_local()) - // .is_none_or(|u| u.is_empty()) => - { + TyKind::FnPtr(_, hdr) if !is_capturing_closure(self.db(), closure_def_id_a) => { // We coerce the closure, which has fn type // `extern "rust-call" fn((arg0,arg1,...)) -> _` // to @@ -1089,14 +1078,12 @@ impl<'db> InferenceContext<'_, 'db> { // Special-case that coercion alone cannot handle: // Function items or non-capturing closures of differing IDs or GenericArgs. let (a_sig, b_sig) = { - let is_capturing_closure = |_ty: Ty<'db>| { - // FIXME: - // if let TyKind::Closure(closure_def_id, _args) = ty.kind() { - // self.db.upvars_mentioned(closure_def_id.expect_local()).is_some() - // } else { - // false - // } - false + let is_capturing_closure = |ty: Ty<'db>| { + if let TyKind::Closure(closure_def_id, _args) = ty.kind() { + is_capturing_closure(self.db, closure_def_id.0) + } else { + false + } }; if is_capturing_closure(prev_ty) || is_capturing_closure(new_ty) { (None, None) @@ -1728,3 +1715,9 @@ fn coerce<'db>( .collect(); Ok((adjustments, ty)) } + +fn is_capturing_closure(db: &dyn HirDatabase, closure: InternedClosureId) -> bool { + let InternedClosure(owner, expr) = closure.loc(db); + upvars_mentioned(db, owner) + .is_some_and(|upvars| upvars.get(&expr).is_some_and(|upvars| !upvars.is_empty())) +} diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index 373862229b..41c381220c 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -50,6 +50,7 @@ pub mod method_resolution; pub mod mir; pub mod primitive; pub mod traits; +pub mod upvars; #[cfg(test)] mod test_db; diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 2e107b2c59..db557b7507 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -1,5 +1,7 @@ use expect_test::expect; +use crate::tests::check_infer_with_mismatches; + use super::{check, check_infer, check_no_mismatches, check_types}; #[test] @@ -3956,3 +3958,24 @@ fn bar() { "#, ); } + +#[test] +fn cannot_coerce_capturing_closure_to_fn_ptr() { + check_infer_with_mismatches( + r#" +fn foo() { + let a = 1; + let _: fn() -> i32 = || a; +} + "#, + expect![[r#" + 9..58 '{ ...| a; }': () + 19..20 'a': i32 + 23..24 '1': i32 + 34..35 '_': fn() -> i32 + 51..55 '|| a': impl Fn() -> i32 + 54..55 'a': i32 + 51..55: expected fn() -> i32, got impl Fn() -> i32 + "#]], + ); +} diff --git a/crates/hir-ty/src/upvars.rs b/crates/hir-ty/src/upvars.rs new file mode 100644 index 0000000000..ee864ab068 --- /dev/null +++ b/crates/hir-ty/src/upvars.rs @@ -0,0 +1,319 @@ +//! A simple query to collect tall locals (upvars) a closure use. + +use hir_def::{ + DefWithBodyId, + expr_store::{Body, path::Path}, + hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat}, + resolver::{HasResolver, Resolver, ValueNs}, +}; +use hir_expand::mod_path::PathKind; +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::db::HirDatabase; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +// Kept sorted. +pub struct Upvars(Box<[BindingId]>); + +impl Upvars { + fn new(upvars: &FxHashSet<BindingId>) -> Upvars { + let mut upvars = upvars.iter().copied().collect::<Box<[_]>>(); + upvars.sort_unstable(); + Upvars(upvars) + } + + #[inline] + pub fn contains(&self, local: BindingId) -> bool { + self.0.binary_search(&local).is_ok() + } + + #[inline] + pub fn iter(&self) -> impl ExactSizeIterator<Item = BindingId> { + self.0.iter().copied() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +/// Returns a map from `Expr::Closure` to its upvars. +#[salsa::tracked(returns(as_deref))] +pub fn upvars_mentioned( + db: &dyn HirDatabase, + owner: DefWithBodyId, +) -> Option<Box<FxHashMap<ExprId, Upvars>>> { + let body = db.body(owner); + let mut resolver = owner.resolver(db); + let mut result = FxHashMap::default(); + handle_expr_outside_closure(db, &mut resolver, owner, &body, body.body_expr, &mut result); + return if result.is_empty() { + None + } else { + result.shrink_to_fit(); + Some(Box::new(result)) + }; + + fn handle_expr_outside_closure<'db>( + db: &'db dyn HirDatabase, + resolver: &mut Resolver<'db>, + owner: DefWithBodyId, + body: &Body, + expr: ExprId, + closures_map: &mut FxHashMap<ExprId, Upvars>, + ) { + match &body[expr] { + &Expr::Closure { body: body_expr, .. } => { + let mut upvars = FxHashSet::default(); + handle_expr_inside_closure( + db, + resolver, + owner, + body, + expr, + body_expr, + &mut upvars, + closures_map, + ); + if !upvars.is_empty() { + closures_map.insert(expr, Upvars::new(&upvars)); + } + } + _ => body.walk_child_exprs(expr, |expr| { + handle_expr_outside_closure(db, resolver, owner, body, expr, closures_map) + }), + } + } + + fn handle_expr_inside_closure<'db>( + db: &'db dyn HirDatabase, + resolver: &mut Resolver<'db>, + owner: DefWithBodyId, + body: &Body, + current_closure: ExprId, + expr: ExprId, + upvars: &mut FxHashSet<BindingId>, + closures_map: &mut FxHashMap<ExprId, Upvars>, + ) { + match &body[expr] { + Expr::Path(path) => { + resolve_maybe_upvar( + db, + resolver, + owner, + body, + current_closure, + expr, + expr.into(), + upvars, + path, + ); + } + &Expr::Assignment { target, .. } => { + body.walk_pats(target, &mut |pat| { + let Pat::Path(path) = &body[pat] else { return }; + resolve_maybe_upvar( + db, + resolver, + owner, + body, + current_closure, + expr, + pat.into(), + upvars, + path, + ); + }); + } + &Expr::Closure { body: body_expr, .. } => { + let mut closure_upvars = FxHashSet::default(); + handle_expr_inside_closure( + db, + resolver, + owner, + body, + expr, + body_expr, + &mut closure_upvars, + closures_map, + ); + if !closure_upvars.is_empty() { + closures_map.insert(expr, Upvars::new(&closure_upvars)); + // All nested closure's upvars are also upvars of the parent closure. + upvars.extend( + closure_upvars + .iter() + .copied() + .filter(|local| body.binding_owner(*local) != Some(current_closure)), + ); + } + return; + } + _ => {} + } + body.walk_child_exprs(expr, |expr| { + handle_expr_inside_closure( + db, + resolver, + owner, + body, + current_closure, + expr, + upvars, + closures_map, + ) + }); + } +} + +fn resolve_maybe_upvar<'db>( + db: &'db dyn HirDatabase, + resolver: &mut Resolver<'db>, + owner: DefWithBodyId, + body: &Body, + current_closure: ExprId, + expr: ExprId, + id: ExprOrPatId, + upvars: &mut FxHashSet<BindingId>, + path: &Path, +) { + if let Path::BarePath(mod_path) = path + && matches!(mod_path.kind, PathKind::Plain) + && mod_path.segments().len() == 1 + { + // Could be a variable. + let guard = resolver.update_to_inner_scope(db, owner, expr); + let resolution = + resolver.resolve_path_in_value_ns_fully(db, path, body.expr_or_pat_path_hygiene(id)); + if let Some(ValueNs::LocalBinding(local)) = resolution + && body.binding_owner(local) != Some(current_closure) + { + upvars.insert(local); + } + resolver.reset_to_guard(guard); + } +} + +#[cfg(test)] +mod tests { + use expect_test::{Expect, expect}; + use hir_def::{ModuleDefId, db::DefDatabase, nameres::crate_def_map}; + use itertools::Itertools; + use span::Edition; + use test_fixture::WithFixture; + + use crate::{test_db::TestDB, upvars::upvars_mentioned}; + + #[track_caller] + fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expectation: Expect) { + let db = TestDB::with_files(ra_fixture); + crate::attach_db(&db, || { + let def_map = crate_def_map(&db, db.test_crate()); + let func = def_map + .modules() + .flat_map(|(_, module)| module.scope.declarations()) + .filter_map(|decl| match decl { + ModuleDefId::FunctionId(func) => Some(func), + _ => None, + }) + .exactly_one() + .unwrap_or_else(|_| panic!("expected one function")); + let (body, source_map) = db.body_with_source_map(func.into()); + let Some(upvars) = upvars_mentioned(&db, func.into()) else { + expectation.assert_eq(""); + return; + }; + let mut closures = Vec::new(); + for (&closure, upvars) in upvars { + let closure_range = source_map.expr_syntax(closure).unwrap().value.text_range(); + let upvars = upvars + .iter() + .map(|local| body[local].name.display(&db, Edition::CURRENT)) + .join(", "); + closures.push((closure_range, upvars)); + } + closures.sort_unstable_by_key(|(range, _)| (range.start(), range.end())); + let closures = closures + .into_iter() + .map(|(range, upvars)| format!("{range:?}: {upvars}")) + .join("\n"); + expectation.assert_eq(&closures); + }); + } + + #[test] + fn simple() { + check( + r#" +struct foo; +fn foo(param: i32) { + let local = "boo"; + || { param; foo }; + || local; + || { param; local; param; local; }; + || 0xDEAFBEAF; +} + "#, + expect![[r#" + 60..77: param + 83..91: local + 97..131: param, local"#]], + ); + } + + #[test] + fn nested() { + check( + r#" +fn foo() { + let (a, b); + || { + || a; + || b; + }; +} + "#, + expect![[r#" + 31..69: a, b + 44..48: a + 58..62: b"#]], + ); + } + + #[test] + fn closure_var() { + check( + r#" +fn foo() { + let upvar = 1; + |closure_param: i32| { + let closure_local = closure_param; + closure_local + upvar + }; +} + "#, + expect!["34..135: upvar"], + ); + } + + #[test] + fn closure_var_nested() { + check( + r#" +fn foo() { + let a = 1; + |b: i32| { + || { + let c = 123; + a + b + c + } + }; +} + "#, + expect![[r#" + 30..116: a + 49..110: a, b"#]], + ); + } +} |