Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ra-salsa/tests/parallel/setup.rs')
| -rw-r--r-- | crates/ra-salsa/tests/parallel/setup.rs | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/crates/ra-salsa/tests/parallel/setup.rs b/crates/ra-salsa/tests/parallel/setup.rs new file mode 100644 index 0000000000..fd1f51326e --- /dev/null +++ b/crates/ra-salsa/tests/parallel/setup.rs @@ -0,0 +1,197 @@ +use crate::signal::Signal; +use ra_salsa::Database; +use ra_salsa::ParallelDatabase; +use ra_salsa::Snapshot; +use std::sync::Arc; +use std::{ + cell::Cell, + panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, +}; + +#[ra_salsa::query_group(Par)] +pub(crate) trait ParDatabase: Knobs { + #[ra_salsa::input] + fn input(&self, key: char) -> usize; + + fn sum(&self, key: &'static str) -> usize; + + /// Invokes `sum` + fn sum2(&self, key: &'static str) -> usize; + + /// Invokes `sum` but doesn't really care about the result. + fn sum2_drop_sum(&self, key: &'static str) -> usize; + + /// Invokes `sum2` + fn sum3(&self, key: &'static str) -> usize; + + /// Invokes `sum2_drop_sum` + fn sum3_drop_sum(&self, key: &'static str) -> usize; +} + +/// Various "knobs" and utilities used by tests to force +/// a certain behavior. +pub(crate) trait Knobs { + fn knobs(&self) -> &KnobsStruct; + + fn signal(&self, stage: usize); + + fn wait_for(&self, stage: usize); +} + +pub(crate) trait WithValue<T> { + fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R; +} + +impl<T> WithValue<T> for Cell<T> { + fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R { + let old_value = self.replace(value); + + let result = catch_unwind(AssertUnwindSafe(closure)); + + self.set(old_value); + + match result { + Ok(r) => r, + Err(payload) => resume_unwind(payload), + } + } +} + +#[derive(Default, Clone, Copy, PartialEq, Eq)] +pub(crate) enum CancellationFlag { + #[default] + Down, + Panic, +} + +/// Various "knobs" that can be used to customize how the queries +/// behave on one specific thread. Note that this state is +/// intentionally thread-local (apart from `signal`). +#[derive(Clone, Default)] +pub(crate) struct KnobsStruct { + /// A kind of flexible barrier used to coordinate execution across + /// threads to ensure we reach various weird states. + pub(crate) signal: Arc<Signal>, + + /// When this database is about to block, send a signal. + pub(crate) signal_on_will_block: Cell<usize>, + + /// Invocations of `sum` will signal this stage on entry. + pub(crate) sum_signal_on_entry: Cell<usize>, + + /// Invocations of `sum` will wait for this stage on entry. + pub(crate) sum_wait_for_on_entry: Cell<usize>, + + /// If true, invocations of `sum` will panic before they exit. + pub(crate) sum_should_panic: Cell<bool>, + + /// If true, invocations of `sum` will wait for cancellation before + /// they exit. + pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>, + + /// Invocations of `sum` will wait for this stage prior to exiting. + pub(crate) sum_wait_for_on_exit: Cell<usize>, + + /// Invocations of `sum` will signal this stage prior to exiting. + pub(crate) sum_signal_on_exit: Cell<usize>, + + /// Invocations of `sum3_drop_sum` will panic unconditionally + pub(crate) sum3_drop_sum_should_panic: Cell<bool>, +} + +fn sum(db: &dyn ParDatabase, key: &'static str) -> usize { + let mut sum = 0; + + db.signal(db.knobs().sum_signal_on_entry.get()); + + db.wait_for(db.knobs().sum_wait_for_on_entry.get()); + + if db.knobs().sum_should_panic.get() { + panic!("query set to panic before exit") + } + + for ch in key.chars() { + sum += db.input(ch); + } + + match db.knobs().sum_wait_for_cancellation.get() { + CancellationFlag::Down => (), + CancellationFlag::Panic => { + tracing::debug!("waiting for cancellation"); + loop { + db.unwind_if_cancelled(); + std::thread::yield_now(); + } + } + } + + db.wait_for(db.knobs().sum_wait_for_on_exit.get()); + + db.signal(db.knobs().sum_signal_on_exit.get()); + + sum +} + +fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize { + db.sum(key) +} + +fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { + let _ = db.sum(key); + 22 +} + +fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize { + db.sum2(key) +} + +fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { + if db.knobs().sum3_drop_sum_should_panic.get() { + panic!("sum3_drop_sum executed") + } + db.sum2_drop_sum(key) +} + +#[ra_salsa::database( + Par, + crate::parallel_cycle_all_recover::ParallelCycleAllRecover, + crate::parallel_cycle_none_recover::ParallelCycleNoneRecover, + crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers, + crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers +)] +#[derive(Default)] +pub(crate) struct ParDatabaseImpl { + storage: ra_salsa::Storage<Self>, + knobs: KnobsStruct, +} + +impl Database for ParDatabaseImpl { + fn salsa_event(&self, event: ra_salsa::Event) { + if let ra_salsa::EventKind::WillBlockOn { .. } = event.kind { + self.signal(self.knobs().signal_on_will_block.get()); + } + } +} + +impl ParallelDatabase for ParDatabaseImpl { + fn snapshot(&self) -> Snapshot<Self> { + Snapshot::new(ParDatabaseImpl { + storage: self.storage.snapshot(), + knobs: self.knobs.clone(), + }) + } +} + +impl Knobs for ParDatabaseImpl { + fn knobs(&self) -> &KnobsStruct { + &self.knobs + } + + fn signal(&self, stage: usize) { + self.knobs.signal.signal(stage); + } + + fn wait_for(&self, stage: usize) { + self.knobs.signal.wait_for(stage); + } +} |