Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/salsa/src/derived/slot.rs')
-rw-r--r--crates/salsa/src/derived/slot.rs185
1 files changed, 56 insertions, 129 deletions
diff --git a/crates/salsa/src/derived/slot.rs b/crates/salsa/src/derived/slot.rs
index cfafa40ce3..de7a397607 100644
--- a/crates/salsa/src/derived/slot.rs
+++ b/crates/salsa/src/derived/slot.rs
@@ -1,12 +1,8 @@
use crate::debug::TableEntry;
-use crate::derived::MemoizationPolicy;
use crate::durability::Durability;
-use crate::lru::LruIndex;
-use crate::lru::LruNode;
use crate::plumbing::{DatabaseOps, QueryFunction};
use crate::revision::Revision;
use crate::runtime::local_state::ActiveQueryGuard;
-use crate::runtime::local_state::QueryInputs;
use crate::runtime::local_state::QueryRevisions;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
@@ -15,21 +11,18 @@ use crate::runtime::WaitResult;
use crate::Cycle;
use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb};
use parking_lot::{RawRwLock, RwLock};
-use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::{debug, info};
-pub(super) struct Slot<Q, MP>
+pub(super) struct Slot<Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
{
key_index: u32,
+ // FIXME: Yeet this
group_index: u16,
state: RwLock<QueryState<Q>>,
- policy: PhantomData<MP>,
- lru_index: LruIndex,
}
/// Defines the "current state" of query's memoized results.
@@ -55,7 +48,7 @@ where
struct Memo<V> {
/// The result of the query, if we decide to memoize it.
- value: Option<V>,
+ value: V,
/// Last revision when this memo was verified; this begins
/// as the current revision.
@@ -78,12 +71,6 @@ enum ProbeState<V, G> {
/// verified in this revision.
Stale(G),
- /// There is an entry, and it has been verified
- /// in this revision, but it has no cached
- /// value. The `Revision` is the revision where the
- /// value last changed (if we were to recompute it).
- NoValue(G, Revision),
-
/// There is an entry which has been verified,
/// and it has the following value-- or, we blocked
/// on another thread, and that resulted in a cycle.
@@ -104,18 +91,16 @@ enum MaybeChangedSinceProbeState<G> {
Stale(G),
}
-impl<Q, MP> Slot<Q, MP>
+impl<Q> Slot<Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
+ Q::Value: Eq,
{
pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self {
Self {
key_index: database_key_index.key_index,
group_index: database_key_index.group_index,
state: RwLock::new(QueryState::NotComputed),
- lru_index: LruIndex::default(),
- policy: PhantomData,
}
}
@@ -147,9 +132,7 @@ where
loop {
match self.probe(db, self.state.read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
- ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => {
- break
- }
+ ProbeState::Stale(..) | ProbeState::NotComputed(..) => break,
ProbeState::Retry => continue,
}
}
@@ -177,9 +160,7 @@ where
let mut old_memo = loop {
match self.probe(db, self.state.upgradable_read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
- ProbeState::Stale(state)
- | ProbeState::NotComputed(state)
- | ProbeState::NoValue(state, _) => {
+ ProbeState::Stale(state) | ProbeState::NotComputed(state) => {
type RwLockUpgradableReadGuard<'a, T> =
lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
@@ -227,7 +208,7 @@ where
runtime: &Runtime,
revision_now: Revision,
active_query: ActiveQueryGuard<'_>,
- panic_guard: PanicGuard<'_, Q, MP>,
+ panic_guard: PanicGuard<'_, Q>,
old_memo: Option<Memo<Q::Value>>,
key: &Q::Key,
) -> StampedValue<Q::Value> {
@@ -286,22 +267,18 @@ where
// "backdate" its `changed_at` revision to be the same as the
// old value.
if let Some(old_memo) = &old_memo {
- if let Some(old_value) = &old_memo.value {
- // Careful: if the value became less durable than it
- // used to be, that is a "breaking change" that our
- // consumers must be aware of. Becoming *more* durable
- // is not. See the test `constant_to_non_constant`.
- if revisions.durability >= old_memo.revisions.durability
- && MP::memoized_value_eq(old_value, &value)
- {
- debug!(
- "read_upgrade({:?}): value is equal, back-dating to {:?}",
- self, old_memo.revisions.changed_at,
- );
-
- assert!(old_memo.revisions.changed_at <= revisions.changed_at);
- revisions.changed_at = old_memo.revisions.changed_at;
- }
+ // Careful: if the value became less durable than it
+ // used to be, that is a "breaking change" that our
+ // consumers must be aware of. Becoming *more* durable
+ // is not. See the test `constant_to_non_constant`.
+ if revisions.durability >= old_memo.revisions.durability && old_memo.value == value {
+ debug!(
+ "read_upgrade({:?}): value is equal, back-dating to {:?}",
+ self, old_memo.revisions.changed_at,
+ );
+
+ assert!(old_memo.revisions.changed_at <= revisions.changed_at);
+ revisions.changed_at = old_memo.revisions.changed_at;
}
}
@@ -311,8 +288,7 @@ where
changed_at: revisions.changed_at,
};
- let memo_value =
- if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None };
+ let memo_value = new_value.value.clone();
debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,);
@@ -372,20 +348,16 @@ where
return ProbeState::Stale(state);
}
- if let Some(value) = &memo.value {
- let value = StampedValue {
- durability: memo.revisions.durability,
- changed_at: memo.revisions.changed_at,
- value: value.clone(),
- };
+ let value = &memo.value;
+ let value = StampedValue {
+ durability: memo.revisions.durability,
+ changed_at: memo.revisions.changed_at,
+ value: value.clone(),
+ };
- info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at);
+ info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at);
- ProbeState::UpToDate(value)
- } else {
- let changed_at = memo.revisions.changed_at;
- ProbeState::NoValue(state, changed_at)
- }
+ ProbeState::UpToDate(value)
}
}
}
@@ -408,21 +380,9 @@ where
match &*self.state.read() {
QueryState::NotComputed => None,
QueryState::InProgress { .. } => Some(TableEntry::new(key.clone(), None)),
- QueryState::Memoized(memo) => Some(TableEntry::new(key.clone(), memo.value.clone())),
- }
- }
-
- pub(super) fn evict(&self) {
- let mut state = self.state.write();
- if let QueryState::Memoized(memo) = &mut *state {
- // Evicting a value with an untracked input could
- // lead to inconsistencies. Note that we can't check
- // `has_untracked_input` when we add the value to the cache,
- // because inputs can become untracked in the next revision.
- if memo.has_untracked_input() {
- return;
+ QueryState::Memoized(memo) => {
+ Some(TableEntry::new(key.clone(), Some(memo.value.clone())))
}
- memo.value = None;
}
}
@@ -430,7 +390,8 @@ where
tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision);
match &mut *self.state.write() {
QueryState::Memoized(memo) => {
- memo.revisions.inputs = QueryInputs::Untracked;
+ memo.revisions.untracked = true;
+ memo.revisions.inputs = None;
memo.revisions.changed_at = new_revision;
Some(memo.revisions.durability)
}
@@ -489,8 +450,7 @@ where
// If we know when value last changed, we can return right away.
// Note that we don't need the actual value to be available.
- ProbeState::NoValue(_, changed_at)
- | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => {
+ ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => {
MaybeChangedSinceProbeState::ChangedAt(changed_at)
}
@@ -545,7 +505,7 @@ where
let maybe_changed = old_memo.revisions.changed_at > revision;
panic_guard.proceed(Some(old_memo));
maybe_changed
- } else if old_memo.value.is_some() {
+ } else {
// We found that this memoized value may have changed
// but we have an old value. We can re-run the code and
// actually *check* if it has changed.
@@ -559,12 +519,6 @@ where
key,
);
changed_at > revision
- } else {
- // We found that inputs to this memoized value may have chanced
- // but we don't have an old value to compare against or re-use.
- // No choice but to drop the memo and say that its value may have changed.
- panic_guard.proceed(None);
- true
}
}
@@ -583,10 +537,6 @@ where
mutex_guard,
)
}
-
- fn should_memoize_value(&self, key: &Q::Key) -> bool {
- MP::should_memoize_value(key)
- }
}
impl<Q> QueryState<Q>
@@ -598,21 +548,21 @@ where
}
}
-struct PanicGuard<'me, Q, MP>
+struct PanicGuard<'me, Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
+ Q::Value: Eq,
{
- slot: &'me Slot<Q, MP>,
+ slot: &'me Slot<Q>,
runtime: &'me Runtime,
}
-impl<'me, Q, MP> PanicGuard<'me, Q, MP>
+impl<'me, Q> PanicGuard<'me, Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
+ Q::Value: Eq,
{
- fn new(slot: &'me Slot<Q, MP>, runtime: &'me Runtime) -> Self {
+ fn new(slot: &'me Slot<Q>, runtime: &'me Runtime) -> Self {
Self { slot, runtime }
}
@@ -666,10 +616,10 @@ Please report this bug to https://github.com/salsa-rs/salsa/issues."
}
}
-impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP>
+impl<'me, Q> Drop for PanicGuard<'me, Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
+ Q::Value: Eq,
{
fn drop(&mut self) {
if std::thread::panicking() {
@@ -702,15 +652,11 @@ where
revision_now: Revision,
active_query: &ActiveQueryGuard<'_>,
) -> Option<StampedValue<V>> {
- // If we don't have a memoized value, nothing to validate.
- if self.value.is_none() {
- return None;
- }
if self.verify_revisions(db, revision_now, active_query) {
- self.value.clone().map(|value| StampedValue {
+ Some(StampedValue {
durability: self.revisions.durability,
changed_at: self.revisions.changed_at,
- value,
+ value: self.value.clone(),
})
} else {
None
@@ -746,11 +692,8 @@ where
match &self.revisions.inputs {
// We can't validate values that had untracked inputs; just have to
// re-execute.
- QueryInputs::Untracked => {
- return false;
- }
-
- QueryInputs::NoInputs => {}
+ None if self.revisions.untracked => return false,
+ None => {}
// Check whether any of our inputs changed since the
// **last point where we were verified** (not since we
@@ -761,7 +704,7 @@ where
// R1. But our *verification* date will be R2, and we
// are only interested in finding out whether the
// input changed *again*.
- QueryInputs::Tracked { inputs } => {
+ Some(inputs) => {
let changed_input =
inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at));
if let Some(input) = changed_input {
@@ -791,58 +734,42 @@ where
self.verified_at = revision_now;
true
}
-
- fn has_untracked_input(&self) -> bool {
- matches!(self.revisions.inputs, QueryInputs::Untracked)
- }
}
-impl<Q, MP> std::fmt::Debug for Slot<Q, MP>
+impl<Q> std::fmt::Debug for Slot<Q>
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}", Q::default())
}
}
-impl<Q, MP> LruNode for Slot<Q, MP>
-where
- Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
-{
- fn lru_index(&self) -> &LruIndex {
- &self.lru_index
- }
-}
-
-/// Check that `Slot<Q, MP>: Send + Sync` as long as
+/// Check that `Slot<Q, >: Send + Sync` as long as
/// `DB::DatabaseData: Send + Sync`, which in turn implies that
/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
#[allow(dead_code)]
-fn check_send_sync<Q, MP>()
+fn check_send_sync<Q>()
where
Q: QueryFunction,
- MP: MemoizationPolicy<Q>,
+
Q::Key: Send + Sync,
Q::Value: Send + Sync,
{
fn is_send_sync<T: Send + Sync>() {}
- is_send_sync::<Slot<Q, MP>>();
+ is_send_sync::<Slot<Q>>();
}
-/// Check that `Slot<Q, MP>: 'static` as long as
+/// Check that `Slot<Q, >: 'static` as long as
/// `DB::DatabaseData: 'static`, which in turn implies that
/// `Q::Key: 'static`, `Q::Value: 'static`.
#[allow(dead_code)]
-fn check_static<Q, MP>()
+fn check_static<Q>()
where
Q: QueryFunction + 'static,
- MP: MemoizationPolicy<Q> + 'static,
Q::Key: 'static,
Q::Value: 'static,
{
fn is_static<T: 'static>() {}
- is_static::<Slot<Q, MP>>();
+ is_static::<Slot<Q>>();
}