Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/salsa/tests/parallel/stress.rs')
-rw-r--r--crates/salsa/tests/parallel/stress.rs168
1 files changed, 168 insertions, 0 deletions
diff --git a/crates/salsa/tests/parallel/stress.rs b/crates/salsa/tests/parallel/stress.rs
new file mode 100644
index 0000000000..2fa317b2b9
--- /dev/null
+++ b/crates/salsa/tests/parallel/stress.rs
@@ -0,0 +1,168 @@
+use rand::seq::SliceRandom;
+use rand::Rng;
+
+use salsa::ParallelDatabase;
+use salsa::Snapshot;
+use salsa::{Cancelled, Database};
+
+// Number of operations a reader performs
+const N_MUTATOR_OPS: usize = 100;
+const N_READER_OPS: usize = 100;
+
+#[salsa::query_group(Stress)]
+trait StressDatabase: salsa::Database {
+ #[salsa::input]
+ fn a(&self, key: usize) -> usize;
+
+ fn b(&self, key: usize) -> usize;
+
+ fn c(&self, key: usize) -> usize;
+}
+
+fn b(db: &dyn StressDatabase, key: usize) -> usize {
+ db.unwind_if_cancelled();
+ db.a(key)
+}
+
+fn c(db: &dyn StressDatabase, key: usize) -> usize {
+ db.b(key)
+}
+
+#[salsa::database(Stress)]
+#[derive(Default)]
+struct StressDatabaseImpl {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for StressDatabaseImpl {}
+
+impl salsa::ParallelDatabase for StressDatabaseImpl {
+ fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
+ Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+enum Query {
+ A,
+ B,
+ C,
+}
+
+enum MutatorOp {
+ WriteOp(WriteOp),
+ LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
+}
+
+#[derive(Debug)]
+enum WriteOp {
+ SetA(usize, usize),
+}
+
+#[derive(Debug)]
+enum ReadOp {
+ Get(Query, usize),
+}
+
+impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
+ *[Query::A, Query::B, Query::C].choose(rng).unwrap()
+ }
+}
+
+impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
+ if rng.gen_bool(0.5) {
+ MutatorOp::WriteOp(rng.gen())
+ } else {
+ MutatorOp::LaunchReader {
+ ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
+ check_cancellation: rng.gen(),
+ }
+ }
+ }
+}
+
+impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
+ let key = rng.gen::<usize>() % 10;
+ let value = rng.gen::<usize>() % 10;
+ WriteOp::SetA(key, value)
+ }
+}
+
+impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
+ let query = rng.gen::<Query>();
+ let key = rng.gen::<usize>() % 10;
+ ReadOp::Get(query, key)
+ }
+}
+
+fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
+ for op in ops {
+ if check_cancellation {
+ db.unwind_if_cancelled();
+ }
+ op.execute(db);
+ }
+}
+
+impl WriteOp {
+ fn execute(self, db: &mut StressDatabaseImpl) {
+ match self {
+ WriteOp::SetA(key, value) => {
+ db.set_a(key, value);
+ }
+ }
+ }
+}
+
+impl ReadOp {
+ fn execute(self, db: &StressDatabaseImpl) {
+ match self {
+ ReadOp::Get(query, key) => match query {
+ Query::A => {
+ db.a(key);
+ }
+ Query::B => {
+ let _ = db.b(key);
+ }
+ Query::C => {
+ let _ = db.c(key);
+ }
+ },
+ }
+ }
+}
+
+#[test]
+fn stress_test() {
+ let mut db = StressDatabaseImpl::default();
+ for i in 0..10 {
+ db.set_a(i, i);
+ }
+
+ let mut rng = rand::thread_rng();
+
+ // generate the ops that the mutator thread will perform
+ let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
+
+ // execute the "main thread", which sometimes snapshots off other threads
+ let mut all_threads = vec![];
+ for op in write_ops {
+ match op {
+ MutatorOp::WriteOp(w) => w.execute(&mut db),
+ MutatorOp::LaunchReader { ops, check_cancellation } => {
+ all_threads.push(std::thread::spawn({
+ let db = db.snapshot();
+ move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
+ }))
+ }
+ }
+ }
+
+ for thread in all_threads {
+ thread.join().unwrap().ok();
+ }
+}