Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/salsa/src/derived/slot.rs')
| -rw-r--r-- | crates/salsa/src/derived/slot.rs | 185 |
1 files changed, 56 insertions, 129 deletions
diff --git a/crates/salsa/src/derived/slot.rs b/crates/salsa/src/derived/slot.rs index cfafa40ce3..de7a397607 100644 --- a/crates/salsa/src/derived/slot.rs +++ b/crates/salsa/src/derived/slot.rs @@ -1,12 +1,8 @@ use crate::debug::TableEntry; -use crate::derived::MemoizationPolicy; use crate::durability::Durability; -use crate::lru::LruIndex; -use crate::lru::LruNode; use crate::plumbing::{DatabaseOps, QueryFunction}; use crate::revision::Revision; use crate::runtime::local_state::ActiveQueryGuard; -use crate::runtime::local_state::QueryInputs; use crate::runtime::local_state::QueryRevisions; use crate::runtime::Runtime; use crate::runtime::RuntimeId; @@ -15,21 +11,18 @@ use crate::runtime::WaitResult; use crate::Cycle; use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; use parking_lot::{RawRwLock, RwLock}; -use std::marker::PhantomData; use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; use tracing::{debug, info}; -pub(super) struct Slot<Q, MP> +pub(super) struct Slot<Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, { key_index: u32, + // FIXME: Yeet this group_index: u16, state: RwLock<QueryState<Q>>, - policy: PhantomData<MP>, - lru_index: LruIndex, } /// Defines the "current state" of query's memoized results. @@ -55,7 +48,7 @@ where struct Memo<V> { /// The result of the query, if we decide to memoize it. - value: Option<V>, + value: V, /// Last revision when this memo was verified; this begins /// as the current revision. @@ -78,12 +71,6 @@ enum ProbeState<V, G> { /// verified in this revision. Stale(G), - /// There is an entry, and it has been verified - /// in this revision, but it has no cached - /// value. The `Revision` is the revision where the - /// value last changed (if we were to recompute it). - NoValue(G, Revision), - /// There is an entry which has been verified, /// and it has the following value-- or, we blocked /// on another thread, and that resulted in a cycle. @@ -104,18 +91,16 @@ enum MaybeChangedSinceProbeState<G> { Stale(G), } -impl<Q, MP> Slot<Q, MP> +impl<Q> Slot<Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, + Q::Value: Eq, { pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self { Self { key_index: database_key_index.key_index, group_index: database_key_index.group_index, state: RwLock::new(QueryState::NotComputed), - lru_index: LruIndex::default(), - policy: PhantomData, } } @@ -147,9 +132,7 @@ where loop { match self.probe(db, self.state.read(), runtime, revision_now) { ProbeState::UpToDate(v) => return v, - ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { - break - } + ProbeState::Stale(..) | ProbeState::NotComputed(..) => break, ProbeState::Retry => continue, } } @@ -177,9 +160,7 @@ where let mut old_memo = loop { match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { ProbeState::UpToDate(v) => return v, - ProbeState::Stale(state) - | ProbeState::NotComputed(state) - | ProbeState::NoValue(state, _) => { + ProbeState::Stale(state) | ProbeState::NotComputed(state) => { type RwLockUpgradableReadGuard<'a, T> = lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; @@ -227,7 +208,7 @@ where runtime: &Runtime, revision_now: Revision, active_query: ActiveQueryGuard<'_>, - panic_guard: PanicGuard<'_, Q, MP>, + panic_guard: PanicGuard<'_, Q>, old_memo: Option<Memo<Q::Value>>, key: &Q::Key, ) -> StampedValue<Q::Value> { @@ -286,22 +267,18 @@ where // "backdate" its `changed_at` revision to be the same as the // old value. if let Some(old_memo) = &old_memo { - if let Some(old_value) = &old_memo.value { - // Careful: if the value became less durable than it - // used to be, that is a "breaking change" that our - // consumers must be aware of. Becoming *more* durable - // is not. See the test `constant_to_non_constant`. - if revisions.durability >= old_memo.revisions.durability - && MP::memoized_value_eq(old_value, &value) - { - debug!( - "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.revisions.changed_at, - ); - - assert!(old_memo.revisions.changed_at <= revisions.changed_at); - revisions.changed_at = old_memo.revisions.changed_at; - } + // Careful: if the value became less durable than it + // used to be, that is a "breaking change" that our + // consumers must be aware of. Becoming *more* durable + // is not. See the test `constant_to_non_constant`. + if revisions.durability >= old_memo.revisions.durability && old_memo.value == value { + debug!( + "read_upgrade({:?}): value is equal, back-dating to {:?}", + self, old_memo.revisions.changed_at, + ); + + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; } } @@ -311,8 +288,7 @@ where changed_at: revisions.changed_at, }; - let memo_value = - if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None }; + let memo_value = new_value.value.clone(); debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); @@ -372,20 +348,16 @@ where return ProbeState::Stale(state); } - if let Some(value) = &memo.value { - let value = StampedValue { - durability: memo.revisions.durability, - changed_at: memo.revisions.changed_at, - value: value.clone(), - }; + let value = &memo.value; + let value = StampedValue { + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, + value: value.clone(), + }; - info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); - ProbeState::UpToDate(value) - } else { - let changed_at = memo.revisions.changed_at; - ProbeState::NoValue(state, changed_at) - } + ProbeState::UpToDate(value) } } } @@ -408,21 +380,9 @@ where match &*self.state.read() { QueryState::NotComputed => None, QueryState::InProgress { .. } => Some(TableEntry::new(key.clone(), None)), - QueryState::Memoized(memo) => Some(TableEntry::new(key.clone(), memo.value.clone())), - } - } - - pub(super) fn evict(&self) { - let mut state = self.state.write(); - if let QueryState::Memoized(memo) = &mut *state { - // Evicting a value with an untracked input could - // lead to inconsistencies. Note that we can't check - // `has_untracked_input` when we add the value to the cache, - // because inputs can become untracked in the next revision. - if memo.has_untracked_input() { - return; + QueryState::Memoized(memo) => { + Some(TableEntry::new(key.clone(), Some(memo.value.clone()))) } - memo.value = None; } } @@ -430,7 +390,8 @@ where tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision); match &mut *self.state.write() { QueryState::Memoized(memo) => { - memo.revisions.inputs = QueryInputs::Untracked; + memo.revisions.untracked = true; + memo.revisions.inputs = None; memo.revisions.changed_at = new_revision; Some(memo.revisions.durability) } @@ -489,8 +450,7 @@ where // If we know when value last changed, we can return right away. // Note that we don't need the actual value to be available. - ProbeState::NoValue(_, changed_at) - | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { + ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { MaybeChangedSinceProbeState::ChangedAt(changed_at) } @@ -545,7 +505,7 @@ where let maybe_changed = old_memo.revisions.changed_at > revision; panic_guard.proceed(Some(old_memo)); maybe_changed - } else if old_memo.value.is_some() { + } else { // We found that this memoized value may have changed // but we have an old value. We can re-run the code and // actually *check* if it has changed. @@ -559,12 +519,6 @@ where key, ); changed_at > revision - } else { - // We found that inputs to this memoized value may have chanced - // but we don't have an old value to compare against or re-use. - // No choice but to drop the memo and say that its value may have changed. - panic_guard.proceed(None); - true } } @@ -583,10 +537,6 @@ where mutex_guard, ) } - - fn should_memoize_value(&self, key: &Q::Key) -> bool { - MP::should_memoize_value(key) - } } impl<Q> QueryState<Q> @@ -598,21 +548,21 @@ where } } -struct PanicGuard<'me, Q, MP> +struct PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, + Q::Value: Eq, { - slot: &'me Slot<Q, MP>, + slot: &'me Slot<Q>, runtime: &'me Runtime, } -impl<'me, Q, MP> PanicGuard<'me, Q, MP> +impl<'me, Q> PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, + Q::Value: Eq, { - fn new(slot: &'me Slot<Q, MP>, runtime: &'me Runtime) -> Self { + fn new(slot: &'me Slot<Q>, runtime: &'me Runtime) -> Self { Self { slot, runtime } } @@ -666,10 +616,10 @@ Please report this bug to https://github.com/salsa-rs/salsa/issues." } } -impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP> +impl<'me, Q> Drop for PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, + Q::Value: Eq, { fn drop(&mut self) { if std::thread::panicking() { @@ -702,15 +652,11 @@ where revision_now: Revision, active_query: &ActiveQueryGuard<'_>, ) -> Option<StampedValue<V>> { - // If we don't have a memoized value, nothing to validate. - if self.value.is_none() { - return None; - } if self.verify_revisions(db, revision_now, active_query) { - self.value.clone().map(|value| StampedValue { + Some(StampedValue { durability: self.revisions.durability, changed_at: self.revisions.changed_at, - value, + value: self.value.clone(), }) } else { None @@ -746,11 +692,8 @@ where match &self.revisions.inputs { // We can't validate values that had untracked inputs; just have to // re-execute. - QueryInputs::Untracked => { - return false; - } - - QueryInputs::NoInputs => {} + None if self.revisions.untracked => return false, + None => {} // Check whether any of our inputs changed since the // **last point where we were verified** (not since we @@ -761,7 +704,7 @@ where // R1. But our *verification* date will be R2, and we // are only interested in finding out whether the // input changed *again*. - QueryInputs::Tracked { inputs } => { + Some(inputs) => { let changed_input = inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); if let Some(input) = changed_input { @@ -791,58 +734,42 @@ where self.verified_at = revision_now; true } - - fn has_untracked_input(&self) -> bool { - matches!(self.revisions.inputs, QueryInputs::Untracked) - } } -impl<Q, MP> std::fmt::Debug for Slot<Q, MP> +impl<Q> std::fmt::Debug for Slot<Q> where Q: QueryFunction, - MP: MemoizationPolicy<Q>, { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(fmt, "{:?}", Q::default()) } } -impl<Q, MP> LruNode for Slot<Q, MP> -where - Q: QueryFunction, - MP: MemoizationPolicy<Q>, -{ - fn lru_index(&self) -> &LruIndex { - &self.lru_index - } -} - -/// Check that `Slot<Q, MP>: Send + Sync` as long as +/// Check that `Slot<Q, >: Send + Sync` as long as /// `DB::DatabaseData: Send + Sync`, which in turn implies that /// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. #[allow(dead_code)] -fn check_send_sync<Q, MP>() +fn check_send_sync<Q>() where Q: QueryFunction, - MP: MemoizationPolicy<Q>, + Q::Key: Send + Sync, Q::Value: Send + Sync, { fn is_send_sync<T: Send + Sync>() {} - is_send_sync::<Slot<Q, MP>>(); + is_send_sync::<Slot<Q>>(); } -/// Check that `Slot<Q, MP>: 'static` as long as +/// Check that `Slot<Q, >: 'static` as long as /// `DB::DatabaseData: 'static`, which in turn implies that /// `Q::Key: 'static`, `Q::Value: 'static`. #[allow(dead_code)] -fn check_static<Q, MP>() +fn check_static<Q>() where Q: QueryFunction + 'static, - MP: MemoizationPolicy<Q> + 'static, Q::Key: 'static, Q::Value: 'static, { fn is_static<T: 'static>() {} - is_static::<Slot<Q, MP>>(); + is_static::<Slot<Q>>(); } |