Unnamed repository; edit this file 'description' to name the repository.
-rw-r--r--Cargo.lock146
-rw-r--r--Cargo.toml10
-rw-r--r--crates/base-db/Cargo.toml2
-rw-r--r--crates/salsa/Cargo.toml35
-rw-r--r--crates/salsa/FAQ.md34
-rw-r--r--crates/salsa/LICENSE-APACHE201
-rw-r--r--crates/salsa/LICENSE-MIT23
-rw-r--r--crates/salsa/README.md42
-rw-r--r--crates/salsa/salsa-macros/Cargo.toml23
-rw-r--r--crates/salsa/salsa-macros/LICENSE-APACHE1
-rw-r--r--crates/salsa/salsa-macros/LICENSE-MIT1
-rw-r--r--crates/salsa/salsa-macros/README.md1
-rw-r--r--crates/salsa/salsa-macros/src/database_storage.rs250
-rw-r--r--crates/salsa/salsa-macros/src/lib.rs146
-rw-r--r--crates/salsa/salsa-macros/src/parenthesized.rs13
-rw-r--r--crates/salsa/salsa-macros/src/query_group.rs734
-rw-r--r--crates/salsa/src/debug.rs66
-rw-r--r--crates/salsa/src/derived.rs233
-rw-r--r--crates/salsa/src/derived/slot.rs833
-rw-r--r--crates/salsa/src/doctest.rs115
-rw-r--r--crates/salsa/src/durability.rs50
-rw-r--r--crates/salsa/src/hash.rs4
-rw-r--r--crates/salsa/src/input.rs240
-rw-r--r--crates/salsa/src/intern_id.rs131
-rw-r--r--crates/salsa/src/interned.rs409
-rw-r--r--crates/salsa/src/lib.rs742
-rw-r--r--crates/salsa/src/lru.rs325
-rw-r--r--crates/salsa/src/plumbing.rs238
-rw-r--r--crates/salsa/src/revision.rs67
-rw-r--r--crates/salsa/src/runtime.rs667
-rw-r--r--crates/salsa/src/runtime/dependency_graph.rs251
-rw-r--r--crates/salsa/src/runtime/local_state.rs214
-rw-r--r--crates/salsa/src/storage.rs54
-rw-r--r--crates/salsa/tests/cycles.rs493
-rw-r--r--crates/salsa/tests/dyn_trait.rs28
-rw-r--r--crates/salsa/tests/incremental/constants.rs145
-rw-r--r--crates/salsa/tests/incremental/counter.rs14
-rw-r--r--crates/salsa/tests/incremental/implementation.rs59
-rw-r--r--crates/salsa/tests/incremental/log.rs16
-rw-r--r--crates/salsa/tests/incremental/main.rs9
-rw-r--r--crates/salsa/tests/incremental/memoized_dep_inputs.rs60
-rw-r--r--crates/salsa/tests/incremental/memoized_inputs.rs76
-rw-r--r--crates/salsa/tests/incremental/memoized_volatile.rs77
-rw-r--r--crates/salsa/tests/interned.rs90
-rw-r--r--crates/salsa/tests/lru.rs102
-rw-r--r--crates/salsa/tests/macros.rs11
-rw-r--r--crates/salsa/tests/no_send_sync.rs31
-rw-r--r--crates/salsa/tests/on_demand_inputs.rs147
-rw-r--r--crates/salsa/tests/panic_safely.rs93
-rw-r--r--crates/salsa/tests/parallel/cancellation.rs132
-rw-r--r--crates/salsa/tests/parallel/frozen.rs57
-rw-r--r--crates/salsa/tests/parallel/independent.rs29
-rw-r--r--crates/salsa/tests/parallel/main.rs13
-rw-r--r--crates/salsa/tests/parallel/parallel_cycle_all_recover.rs110
-rw-r--r--crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs110
-rw-r--r--crates/salsa/tests/parallel/parallel_cycle_none_recover.rs69
-rw-r--r--crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs95
-rw-r--r--crates/salsa/tests/parallel/race.rs37
-rw-r--r--crates/salsa/tests/parallel/setup.rs197
-rw-r--r--crates/salsa/tests/parallel/signal.rs40
-rw-r--r--crates/salsa/tests/parallel/stress.rs168
-rw-r--r--crates/salsa/tests/parallel/true_parallel.rs125
-rw-r--r--crates/salsa/tests/storage_varieties/implementation.rs19
-rw-r--r--crates/salsa/tests/storage_varieties/main.rs5
-rw-r--r--crates/salsa/tests/storage_varieties/queries.rs22
-rw-r--r--crates/salsa/tests/storage_varieties/tests.rs49
-rw-r--r--crates/salsa/tests/transparent.rs39
-rw-r--r--crates/salsa/tests/variadic.rs51
-rw-r--r--crates/span/Cargo.toml2
69 files changed, 9083 insertions, 38 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 068baaecc4..dc2bf3a769 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -72,8 +72,8 @@ dependencies = [
"cfg",
"la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"profile",
- "rust-analyzer-salsa",
"rustc-hash",
+ "salsa",
"semver",
"span",
"stdx",
@@ -358,6 +358,15 @@ dependencies = [
]
[[package]]
+name = "env_logger"
+version = "0.10.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
+dependencies = [
+ "log",
+]
+
+[[package]]
name = "equivalent"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -442,6 +451,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
[[package]]
+name = "getrandom"
+version = "0.2.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi",
+]
+
+[[package]]
name = "gimli"
version = "0.27.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -919,6 +939,12 @@ dependencies = [
]
[[package]]
+name = "linked-hash-map"
+version = "0.5.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
+
+[[package]]
name = "load-cargo"
version = "0.0.0"
dependencies = [
@@ -1262,6 +1288,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
[[package]]
+name = "ppv-lite86"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
+
+[[package]]
name = "proc-macro-api"
version = "0.0.0"
dependencies = [
@@ -1505,6 +1537,36 @@ dependencies = [
]
[[package]]
+name = "rand"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
+dependencies = [
+ "getrandom",
+]
+
+[[package]]
name = "rayon"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1612,35 +1674,6 @@ dependencies = [
]
[[package]]
-name = "rust-analyzer-salsa"
-version = "0.17.0-pre.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "719825638c59fd26a55412a24561c7c5bcf54364c88b9a7a04ba08a6eafaba8d"
-dependencies = [
- "indexmap",
- "lock_api",
- "oorandom",
- "parking_lot",
- "rust-analyzer-salsa-macros",
- "rustc-hash",
- "smallvec",
- "tracing",
- "triomphe",
-]
-
-[[package]]
-name = "rust-analyzer-salsa-macros"
-version = "0.17.0-pre.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4d96498e9684848c6676c399032ebc37c52da95ecbefa83d71ccc53b9f8a4a8e"
-dependencies = [
- "heck",
- "proc-macro2",
- "quote",
- "syn",
-]
-
-[[package]]
name = "rustc-demangle"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1669,6 +1702,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
+name = "salsa"
+version = "0.0.0"
+dependencies = [
+ "dissimilar",
+ "expect-test",
+ "indexmap",
+ "linked-hash-map",
+ "lock_api",
+ "oorandom",
+ "parking_lot",
+ "rand",
+ "rustc-hash",
+ "salsa-macros",
+ "smallvec",
+ "test-log",
+ "tracing",
+ "triomphe",
+]
+
+[[package]]
+name = "salsa-macros"
+version = "0.0.0"
+dependencies = [
+ "heck",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1792,7 +1855,7 @@ name = "span"
version = "0.0.0"
dependencies = [
"la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "rust-analyzer-salsa",
+ "salsa",
"stdx",
"syntax",
"vfs",
@@ -1890,6 +1953,27 @@ dependencies = [
]
[[package]]
+name = "test-log"
+version = "0.2.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6159ab4116165c99fc88cce31f99fa2c9dbe08d3691cb38da02fc3b45f357d2b"
+dependencies = [
+ "env_logger",
+ "test-log-macros",
+]
+
+[[package]]
+name = "test-log-macros"
+version = "0.2.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ba277e77219e9eea169e8508942db1bf5d8a41ff2db9b20aab5a5aadc9fa25d"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "test-utils"
version = "0.0.0"
dependencies = [
diff --git a/Cargo.toml b/Cargo.toml
index 2cc3b0a0bc..f40156b99e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -70,6 +70,7 @@ proc-macro-srv = { path = "./crates/proc-macro-srv", version = "0.0.0" }
proc-macro-srv-cli = { path = "./crates/proc-macro-srv-cli", version = "0.0.0" }
profile = { path = "./crates/profile", version = "0.0.0" }
project-model = { path = "./crates/project-model", version = "0.0.0" }
+salsa = { path = "./crates/salsa", version = "0.0.0" }
span = { path = "./crates/span", version = "0.0.0" }
stdx = { path = "./crates/stdx", version = "0.0.0" }
syntax = { path = "./crates/syntax", version = "0.0.0" }
@@ -106,22 +107,21 @@ dissimilar = "1.0.7"
either = "1.9.0"
expect-test = "1.4.0"
hashbrown = { version = "0.14", features = [
- "inline-more",
+ "inline-more",
], default-features = false }
indexmap = "2.1.0"
itertools = "0.12.0"
libc = "0.2.150"
nohash-hasher = "0.2.0"
rayon = "1.8.0"
-rust-analyzer-salsa = "0.17.0-pre.6"
rustc-hash = "1.1.0"
semver = "1.0.14"
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1.0.108"
smallvec = { version = "1.10.0", features = [
- "const_new",
- "union",
- "const_generics",
+ "const_new",
+ "union",
+ "const_generics",
] }
smol_str = "0.2.1"
text-size = "1.1.1"
diff --git a/crates/base-db/Cargo.toml b/crates/base-db/Cargo.toml
index 485ba78846..801ba2d1f6 100644
--- a/crates/base-db/Cargo.toml
+++ b/crates/base-db/Cargo.toml
@@ -13,7 +13,7 @@ doctest = false
[dependencies]
la-arena.workspace = true
-rust-analyzer-salsa.workspace = true
+salsa.workspace = true
rustc-hash.workspace = true
triomphe.workspace = true
semver.workspace = true
diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml
new file mode 100644
index 0000000000..4ccbc3de84
--- /dev/null
+++ b/crates/salsa/Cargo.toml
@@ -0,0 +1,35 @@
+[package]
+name = "salsa"
+version = "0.0.0"
+authors = ["Salsa developers"]
+edition = "2021"
+license = "Apache-2.0 OR MIT"
+repository = "https://github.com/salsa-rs/salsa"
+description = "A generic framework for on-demand, incrementalized computation (experimental)"
+
+rust-version.workspace = true
+
+[lib]
+name = "salsa"
+
+[dependencies]
+indexmap = "2.1.0"
+lock_api = "0.4"
+tracing = "0.1"
+parking_lot = "0.12.1"
+rustc-hash = "1.0"
+smallvec = "1.0.0"
+oorandom = "11"
+triomphe = "0.1.11"
+
+salsa-macros = { version = "0.0.0", path = "salsa-macros" }
+
+[dev-dependencies]
+linked-hash-map = "0.5.6"
+rand = "0.8.5"
+test-log = "0.2.14"
+expect-test = "1.4.0"
+dissimilar = "1.0.7"
+
+[lints]
+workspace = true
diff --git a/crates/salsa/FAQ.md b/crates/salsa/FAQ.md
new file mode 100644
index 0000000000..9c9f6f92da
--- /dev/null
+++ b/crates/salsa/FAQ.md
@@ -0,0 +1,34 @@
+# Frequently asked questions
+
+## Why is it called salsa?
+
+I like salsa! Don't you?! Well, ok, there's a bit more to it. The
+underlying algorithm for figuring out which bits of code need to be
+re-executed after any given change is based on the algorithm used in
+rustc. Michael Woerister and I first described the rustc algorithm in
+terms of two colors, red and green, and hence we called it the
+"red-green algorithm". This made me think of the New Mexico State
+Question --- ["Red or green?"][nm] --- which refers to chile
+(salsa). Although this version no longer uses colors (we borrowed
+revision counters from Glimmer, instead), I still like the name.
+
+[nm]: https://www.sos.state.nm.us/about-new-mexico/state-question/
+
+## What is the relationship between salsa and an Entity-Component System (ECS)?
+
+You may have noticed that Salsa "feels" a lot like an ECS in some
+ways. That's true -- Salsa's queries are a bit like *components* (and
+the keys to the queries are a bit like *entities*). But there is one
+big difference: **ECS is -- at its heart -- a mutable system**. You
+can get or set a component of some entity whenever you like. In
+contrast, salsa's queries **define "derived values" via pure
+computations**.
+
+Partly as a consequence, ECS doesn't handle incremental updates for
+you. When you update some component of some entity, you have to ensure
+that other entities' components are updated appropriately.
+
+Finally, ECS offers interesting metadata and "aspect-like" facilities,
+such as iterating over all entities that share certain components.
+Salsa has no analogue to that.
+
diff --git a/crates/salsa/LICENSE-APACHE b/crates/salsa/LICENSE-APACHE
new file mode 100644
index 0000000000..16fe87b06e
--- /dev/null
+++ b/crates/salsa/LICENSE-APACHE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+Copyright [yyyy] [name of copyright owner]
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/crates/salsa/LICENSE-MIT b/crates/salsa/LICENSE-MIT
new file mode 100644
index 0000000000..31aa79387f
--- /dev/null
+++ b/crates/salsa/LICENSE-MIT
@@ -0,0 +1,23 @@
+Permission is hereby granted, free of charge, to any
+person obtaining a copy of this software and associated
+documentation files (the "Software"), to deal in the
+Software without restriction, including without
+limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software
+is furnished to do so, subject to the following
+conditions:
+
+The above copyright notice and this permission notice
+shall be included in all copies or substantial portions
+of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
diff --git a/crates/salsa/README.md b/crates/salsa/README.md
new file mode 100644
index 0000000000..4a8d9f8c73
--- /dev/null
+++ b/crates/salsa/README.md
@@ -0,0 +1,42 @@
+# salsa
+
+*A generic framework for on-demand, incrementalized computation.*
+
+## Obligatory warning
+
+This is a fork of https://github.com/salsa-rs/salsa/ adjusted to rust-analyzer's needs.
+
+## Credits
+
+This system is heavily inspired by [adapton](http://adapton.org/), [glimmer](https://github.com/glimmerjs/glimmer-vm), and rustc's query
+system. So credit goes to Eduard-Mihai Burtescu, Matthew Hammer,
+Yehuda Katz, and Michael Woerister.
+
+## Key idea
+
+The key idea of `salsa` is that you define your program as a set of
+**queries**. Every query is used like function `K -> V` that maps from
+some key of type `K` to a value of type `V`. Queries come in two basic
+varieties:
+
+- **Inputs**: the base inputs to your system. You can change these
+ whenever you like.
+- **Functions**: pure functions (no side effects) that transform your
+ inputs into other values. The results of queries is memoized to
+ avoid recomputing them a lot. When you make changes to the inputs,
+ we'll figure out (fairly intelligently) when we can re-use these
+ memoized values and when we have to recompute them.
+
+## Want to learn more?
+
+To learn more about Salsa, try one of the following:
+
+- read the [heavily commented `hello_world` example](https://github.com/salsa-rs/salsa/blob/master/examples/hello_world/main.rs);
+- check out the [Salsa book](https://salsa-rs.github.io/salsa);
+- watch one of our [videos](https://salsa-rs.github.io/salsa/videos.html).
+
+## Getting in touch
+
+The bulk of the discussion happens in the [issues](https://github.com/salsa-rs/salsa/issues)
+and [pull requests](https://github.com/salsa-rs/salsa/pulls),
+but we have a [zulip chat](https://salsa.zulipchat.com/) as well.
diff --git a/crates/salsa/salsa-macros/Cargo.toml b/crates/salsa/salsa-macros/Cargo.toml
new file mode 100644
index 0000000000..791d2f6e9f
--- /dev/null
+++ b/crates/salsa/salsa-macros/Cargo.toml
@@ -0,0 +1,23 @@
+[package]
+name = "salsa-macros"
+version = "0.0.0"
+authors = ["Salsa developers"]
+edition = "2021"
+license = "Apache-2.0 OR MIT"
+repository = "https://github.com/salsa-rs/salsa"
+description = "Procedural macros for the salsa crate"
+
+rust-version.workspace = true
+
+[lib]
+proc-macro = true
+name = "salsa_macros"
+
+[dependencies]
+heck = "0.4"
+proc-macro2 = "1.0"
+quote = "1.0"
+syn = { version = "2.0", features = ["full", "extra-traits"] }
+
+[lints]
+workspace = true
diff --git a/crates/salsa/salsa-macros/LICENSE-APACHE b/crates/salsa/salsa-macros/LICENSE-APACHE
new file mode 100644
index 0000000000..0bf2cad648
--- /dev/null
+++ b/crates/salsa/salsa-macros/LICENSE-APACHE
@@ -0,0 +1 @@
+../LICENSE-APACHE
diff --git a/crates/salsa/salsa-macros/LICENSE-MIT b/crates/salsa/salsa-macros/LICENSE-MIT
new file mode 100644
index 0000000000..d99cce5f72
--- /dev/null
+++ b/crates/salsa/salsa-macros/LICENSE-MIT
@@ -0,0 +1 @@
+../LICENSE-MIT
diff --git a/crates/salsa/salsa-macros/README.md b/crates/salsa/salsa-macros/README.md
new file mode 100644
index 0000000000..94389aee61
--- /dev/null
+++ b/crates/salsa/salsa-macros/README.md
@@ -0,0 +1 @@
+../README.md
diff --git a/crates/salsa/salsa-macros/src/database_storage.rs b/crates/salsa/salsa-macros/src/database_storage.rs
new file mode 100644
index 0000000000..0ec75bb043
--- /dev/null
+++ b/crates/salsa/salsa-macros/src/database_storage.rs
@@ -0,0 +1,250 @@
+//!
+use heck::ToSnakeCase;
+use proc_macro::TokenStream;
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::{Ident, ItemStruct, Path, Token};
+
+type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>;
+
+pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
+ let args = syn::parse_macro_input!(args as QueryGroupList);
+ let input = syn::parse_macro_input!(input as ItemStruct);
+
+ let query_groups = &args.query_groups;
+ let database_name = &input.ident;
+ let visibility = &input.vis;
+ let db_storage_field = quote! { storage };
+
+ let mut output = proc_macro2::TokenStream::new();
+ output.extend(quote! { #input });
+
+ let query_group_names_snake: Vec<_> = query_groups
+ .iter()
+ .map(|query_group| {
+ let group_name = query_group.name();
+ Ident::new(&group_name.to_string().to_snake_case(), group_name.span())
+ })
+ .collect();
+
+ let query_group_storage_names: Vec<_> = query_groups
+ .iter()
+ .map(|QueryGroup { group_path }| {
+ quote! {
+ <#group_path as salsa::plumbing::QueryGroup>::GroupStorage
+ }
+ })
+ .collect();
+
+ // For each query group `foo::MyGroup` create a link to its
+ // `foo::MyGroupGroupStorage`
+ let mut storage_fields = proc_macro2::TokenStream::new();
+ let mut storage_initializers = proc_macro2::TokenStream::new();
+ let mut has_group_impls = proc_macro2::TokenStream::new();
+ for (((query_group, group_name_snake), group_storage), group_index) in query_groups
+ .iter()
+ .zip(&query_group_names_snake)
+ .zip(&query_group_storage_names)
+ .zip(0_u16..)
+ {
+ let group_path = &query_group.group_path;
+
+ // rewrite the last identifier (`MyGroup`, above) to
+ // (e.g.) `MyGroupGroupStorage`.
+ storage_fields.extend(quote! {
+ #group_name_snake: #group_storage,
+ });
+
+ // rewrite the last identifier (`MyGroup`, above) to
+ // (e.g.) `MyGroupGroupStorage`.
+ storage_initializers.extend(quote! {
+ #group_name_snake: #group_storage::new(#group_index),
+ });
+
+ // ANCHOR:HasQueryGroup
+ has_group_impls.extend(quote! {
+ impl salsa::plumbing::HasQueryGroup<#group_path> for #database_name {
+ fn group_storage(&self) -> &#group_storage {
+ &self.#db_storage_field.query_store().#group_name_snake
+ }
+
+ fn group_storage_mut(&mut self) -> (&#group_storage, &mut salsa::Runtime) {
+ let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut();
+ (&query_store_mut.#group_name_snake, runtime)
+ }
+ }
+ });
+ // ANCHOR_END:HasQueryGroup
+ }
+
+ // create group storage wrapper struct
+ output.extend(quote! {
+ #[doc(hidden)]
+ #visibility struct __SalsaDatabaseStorage {
+ #storage_fields
+ }
+
+ impl Default for __SalsaDatabaseStorage {
+ fn default() -> Self {
+ Self {
+ #storage_initializers
+ }
+ }
+ }
+ });
+
+ // Create a tuple (D1, D2, ...) where Di is the data for a given query group.
+ let mut database_data = vec![];
+ for QueryGroup { group_path } in query_groups {
+ database_data.push(quote! {
+ <#group_path as salsa::plumbing::QueryGroup>::GroupData
+ });
+ }
+
+ // ANCHOR:DatabaseStorageTypes
+ output.extend(quote! {
+ impl salsa::plumbing::DatabaseStorageTypes for #database_name {
+ type DatabaseStorage = __SalsaDatabaseStorage;
+ }
+ });
+ // ANCHOR_END:DatabaseStorageTypes
+
+ // ANCHOR:DatabaseOps
+ let mut fmt_ops = proc_macro2::TokenStream::new();
+ let mut maybe_changed_ops = proc_macro2::TokenStream::new();
+ let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
+ let mut for_each_ops = proc_macro2::TokenStream::new();
+ for ((QueryGroup { group_path }, group_storage), group_index) in
+ query_groups.iter().zip(&query_group_storage_names).zip(0_u16..)
+ {
+ fmt_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.fmt_index(self, input, fmt)
+ }
+ });
+ maybe_changed_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.maybe_changed_after(self, input, revision)
+ }
+ });
+ cycle_recovery_strategy_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.cycle_recovery_strategy(self, input)
+ }
+ });
+ for_each_ops.extend(quote! {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.for_each_query(runtime, &mut op);
+ });
+ }
+ output.extend(quote! {
+ impl salsa::plumbing::DatabaseOps for #database_name {
+ fn ops_database(&self) -> &dyn salsa::Database {
+ self
+ }
+
+ fn ops_salsa_runtime(&self) -> &salsa::Runtime {
+ self.#db_storage_field.salsa_runtime()
+ }
+
+ fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime {
+ self.#db_storage_field.salsa_runtime_mut()
+ }
+
+ fn fmt_index(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ match input.group_index() {
+ #fmt_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn maybe_changed_after(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ revision: salsa::Revision
+ ) -> bool {
+ match input.group_index() {
+ #maybe_changed_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn cycle_recovery_strategy(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ ) -> salsa::plumbing::CycleRecoveryStrategy {
+ match input.group_index() {
+ #cycle_recovery_strategy_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn for_each_query(
+ &self,
+ mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps),
+ ) {
+ let runtime = salsa::Database::salsa_runtime(self);
+ #for_each_ops
+ }
+ }
+ });
+ // ANCHOR_END:DatabaseOps
+
+ output.extend(has_group_impls);
+
+ output.into()
+}
+
+#[derive(Clone, Debug)]
+struct QueryGroupList {
+ query_groups: PunctuatedQueryGroups,
+}
+
+impl Parse for QueryGroupList {
+ fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+ let query_groups: PunctuatedQueryGroups =
+ input.parse_terminated(QueryGroup::parse, Token![,])?;
+ Ok(QueryGroupList { query_groups })
+ }
+}
+
+#[derive(Clone, Debug)]
+struct QueryGroup {
+ group_path: Path,
+}
+
+impl QueryGroup {
+ /// The name of the query group trait.
+ fn name(&self) -> Ident {
+ self.group_path.segments.last().unwrap().ident.clone()
+ }
+}
+
+impl Parse for QueryGroup {
+ /// ```ignore
+ /// impl HelloWorldDatabase;
+ /// ```
+ fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+ let group_path: Path = input.parse()?;
+ Ok(QueryGroup { group_path })
+ }
+}
+
+struct Nothing;
+
+impl Parse for Nothing {
+ fn parse(_input: ParseStream<'_>) -> syn::Result<Self> {
+ Ok(Nothing)
+ }
+}
diff --git a/crates/salsa/salsa-macros/src/lib.rs b/crates/salsa/salsa-macros/src/lib.rs
new file mode 100644
index 0000000000..8af48b1e3f
--- /dev/null
+++ b/crates/salsa/salsa-macros/src/lib.rs
@@ -0,0 +1,146 @@
+//! This crate provides salsa's macros and attributes.
+
+#![recursion_limit = "256"]
+
+#[macro_use]
+extern crate quote;
+
+use proc_macro::TokenStream;
+
+mod database_storage;
+mod parenthesized;
+mod query_group;
+
+/// The decorator that defines a salsa "query group" trait. This is a
+/// trait that defines everything that a block of queries need to
+/// execute, as well as defining the queries themselves that are
+/// exported for others to use.
+///
+/// This macro declares the "prototype" for a group of queries. It will
+/// expand into a trait and a set of structs, one per query.
+///
+/// For each query, you give the name of the accessor method to invoke
+/// the query (e.g., `my_query`, below), as well as its parameter
+/// types and the output type. You also give the name for a query type
+/// (e.g., `MyQuery`, below) that represents the query, and optionally
+/// other details, such as its storage.
+///
+/// # Examples
+///
+/// The simplest example is something like this:
+///
+/// ```ignore
+/// #[salsa::query_group]
+/// trait TypeckDatabase {
+/// #[salsa::input] // see below for other legal attributes
+/// fn my_query(&self, input: u32) -> u64;
+///
+/// /// Queries can have any number of inputs (including zero); if there
+/// /// is not exactly one input, then the key type will be
+/// /// a tuple of the input types, so in this case `(u32, f32)`.
+/// fn other_query(&self, input1: u32, input2: f32) -> u64;
+/// }
+/// ```
+///
+/// Here is a list of legal `salsa::XXX` attributes:
+///
+/// - Storage attributes: control how the query data is stored and set. These
+/// are described in detail in the section below.
+/// - `#[salsa::input]`
+/// - `#[salsa::memoized]`
+/// - `#[salsa::dependencies]`
+/// - Query execution:
+/// - `#[salsa::invoke(path::to::my_fn)]` -- for a non-input, this
+/// indicates the function to call when a query must be
+/// recomputed. The default is to call a function in the same
+/// module with the same name as the query.
+/// - `#[query_type(MyQueryTypeName)]` specifies the name of the
+/// dummy struct created for the query. Default is the name of the
+/// query, in camel case, plus the word "Query" (e.g.,
+/// `MyQueryQuery` and `OtherQueryQuery` in the examples above).
+///
+/// # Storage attributes
+///
+/// Here are the possible storage values for each query. The default
+/// is `storage memoized`.
+///
+/// ## Input queries
+///
+/// Specifying `storage input` will give you an **input
+/// query**. Unlike derived queries, whose value is given by a
+/// function, input queries are explicitly set by doing
+/// `db.query(QueryType).set(key, value)` (where `QueryType` is the
+/// `type` specified for the query). Accessing a value that has not
+/// yet been set will panic. Each time you invoke `set`, we assume the
+/// value has changed, and so we will potentially re-execute derived
+/// queries that read (transitively) from this input.
+///
+/// ## Derived queries
+///
+/// Derived queries are specified by a function.
+///
+/// - `#[salsa::memoized]` (the default) -- The result is memoized
+/// between calls. If the inputs have changed, we will recompute
+/// the value, but then compare against the old memoized value,
+/// which can significantly reduce the amount of recomputation
+/// required in new revisions. This does require that the value
+/// implements `Eq`.
+/// - `#[salsa::dependencies]` -- does not cache the value, so it will
+/// be recomputed every time it is needed. We do track the inputs, however,
+/// so if they have not changed, then things that rely on this query
+/// may be known not to have changed.
+///
+/// ## Attribute combinations
+///
+/// Some attributes are mutually exclusive. For example, it is an error to add
+/// multiple storage specifiers:
+///
+/// ```compile_fail
+/// # use salsa_macros as salsa;
+/// #[salsa::query_group]
+/// trait CodegenDatabase {
+/// #[salsa::input]
+/// #[salsa::memoized]
+/// fn my_query(&self, input: u32) -> u64;
+/// }
+/// ```
+///
+/// It is also an error to annotate a function to `invoke` on an `input` query:
+///
+/// ```compile_fail
+/// # use salsa_macros as salsa;
+/// #[salsa::query_group]
+/// trait CodegenDatabase {
+/// #[salsa::input]
+/// #[salsa::invoke(typeck::my_query)]
+/// fn my_query(&self, input: u32) -> u64;
+/// }
+/// ```
+#[proc_macro_attribute]
+pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
+ query_group::query_group(args, input)
+}
+
+/// This attribute is placed on your database struct. It takes a list of the
+/// query groups that your database supports. The format looks like so:
+///
+/// ```rust,ignore
+/// #[salsa::database(MyQueryGroup1, MyQueryGroup2)]
+/// struct MyDatabase {
+/// runtime: salsa::Runtime<MyDatabase>, // <-- your database will need this field, too
+/// }
+/// ```
+///
+/// Here, the struct `MyDatabase` would support the two query groups
+/// `MyQueryGroup1` and `MyQueryGroup2`. In addition to the `database`
+/// attribute, the struct needs to have a `runtime` field (of type
+/// [`salsa::Runtime`]) and to implement the `salsa::Database` trait.
+///
+/// See [the `hello_world` example][hw] for more details.
+///
+/// [`salsa::Runtime`]: struct.Runtime.html
+/// [hw]: https://github.com/salsa-rs/salsa/tree/master/examples/hello_world
+#[proc_macro_attribute]
+pub fn database(args: TokenStream, input: TokenStream) -> TokenStream {
+ database_storage::database(args, input)
+}
diff --git a/crates/salsa/salsa-macros/src/parenthesized.rs b/crates/salsa/salsa-macros/src/parenthesized.rs
new file mode 100644
index 0000000000..9df41e03c1
--- /dev/null
+++ b/crates/salsa/salsa-macros/src/parenthesized.rs
@@ -0,0 +1,13 @@
+//!
+pub(crate) struct Parenthesized<T>(pub(crate) T);
+
+impl<T> syn::parse::Parse for Parenthesized<T>
+where
+ T: syn::parse::Parse,
+{
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let content;
+ syn::parenthesized!(content in input);
+ content.parse::<T>().map(Parenthesized)
+ }
+}
diff --git a/crates/salsa/salsa-macros/src/query_group.rs b/crates/salsa/salsa-macros/src/query_group.rs
new file mode 100644
index 0000000000..7d0eac9f5d
--- /dev/null
+++ b/crates/salsa/salsa-macros/src/query_group.rs
@@ -0,0 +1,734 @@
+//!
+use std::{convert::TryFrom, iter::FromIterator};
+
+use crate::parenthesized::Parenthesized;
+use heck::ToUpperCamelCase;
+use proc_macro::TokenStream;
+use proc_macro2::Span;
+use quote::ToTokens;
+use syn::{
+ parse_macro_input, parse_quote, spanned::Spanned, Attribute, Error, FnArg, Ident, ItemTrait,
+ ReturnType, TraitItem, Type,
+};
+
+/// Implementation for `[salsa::query_group]` decorator.
+pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
+ let group_struct = parse_macro_input!(args as Ident);
+ let input: ItemTrait = parse_macro_input!(input as ItemTrait);
+ // println!("args: {:#?}", args);
+ // println!("input: {:#?}", input);
+
+ let input_span = input.span();
+ let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs);
+ if !salsa_attrs.is_empty() {
+ return Error::new(input_span, format!("unsupported attributes: {:?}", salsa_attrs))
+ .to_compile_error()
+ .into();
+ }
+
+ let trait_vis = input.vis;
+ let trait_name = input.ident;
+ let _generics = input.generics.clone();
+ let dyn_db = quote! { dyn #trait_name };
+
+ // Decompose the trait into the corresponding queries.
+ let mut queries = vec![];
+ for item in input.items {
+ if let TraitItem::Fn(method) = item {
+ let query_name = method.sig.ident.to_string();
+
+ let mut storage = QueryStorage::Memoized;
+ let mut cycle = None;
+ let mut invoke = None;
+
+ let mut query_type =
+ format_ident!("{}Query", query_name.to_string().to_upper_camel_case());
+ let mut num_storages = 0;
+
+ // Extract attributes.
+ let (attrs, salsa_attrs) = filter_attrs(method.attrs);
+ for SalsaAttr { name, tts, span } in salsa_attrs {
+ match name.as_str() {
+ "memoized" => {
+ storage = QueryStorage::Memoized;
+ num_storages += 1;
+ }
+ "dependencies" => {
+ storage = QueryStorage::Dependencies;
+ num_storages += 1;
+ }
+ "input" => {
+ storage = QueryStorage::Input;
+ num_storages += 1;
+ }
+ "interned" => {
+ storage = QueryStorage::Interned;
+ num_storages += 1;
+ }
+ "cycle" => {
+ cycle = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
+ }
+ "invoke" => {
+ invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
+ }
+ "query_type" => {
+ query_type = parse_macro_input!(tts as Parenthesized<Ident>).0;
+ }
+ "transparent" => {
+ storage = QueryStorage::Transparent;
+ num_storages += 1;
+ }
+ _ => {
+ return Error::new(span, format!("unknown salsa attribute `{}`", name))
+ .to_compile_error()
+ .into();
+ }
+ }
+ }
+
+ let sig_span = method.sig.span();
+ // Check attribute combinations.
+ if num_storages > 1 {
+ return Error::new(sig_span, "multiple storage attributes specified")
+ .to_compile_error()
+ .into();
+ }
+ match &invoke {
+ Some(invoke) if storage == QueryStorage::Input => {
+ return Error::new(
+ invoke.span(),
+ "#[salsa::invoke] cannot be set on #[salsa::input] queries",
+ )
+ .to_compile_error()
+ .into();
+ }
+ _ => {}
+ }
+
+ // Extract keys.
+ let mut iter = method.sig.inputs.iter();
+ let self_receiver = match iter.next() {
+ Some(FnArg::Receiver(sr)) if sr.mutability.is_none() => sr,
+ _ => {
+ return Error::new(
+ sig_span,
+ format!("first argument of query `{}` must be `&self`", query_name),
+ )
+ .to_compile_error()
+ .into();
+ }
+ };
+ let mut keys: Vec<(Ident, Type)> = vec![];
+ for (idx, arg) in iter.enumerate() {
+ match arg {
+ FnArg::Typed(syn::PatType { pat, ty, .. }) => keys.push((
+ match pat.as_ref() {
+ syn::Pat::Ident(ident_pat) => ident_pat.ident.clone(),
+ _ => format_ident!("key{}", idx),
+ },
+ Type::clone(ty),
+ )),
+ arg => {
+ return Error::new(
+ arg.span(),
+ format!("unsupported argument `{:?}` of `{}`", arg, query_name,),
+ )
+ .to_compile_error()
+ .into();
+ }
+ }
+ }
+
+ // Extract value.
+ let value = match method.sig.output {
+ ReturnType::Type(_, ref ty) => ty.as_ref().clone(),
+ ref ret => {
+ return Error::new(
+ ret.span(),
+ format!("unsupported return type `{:?}` of `{}`", ret, query_name),
+ )
+ .to_compile_error()
+ .into();
+ }
+ };
+
+ // For `#[salsa::interned]` keys, we create a "lookup key" automatically.
+ //
+ // For a query like:
+ //
+ // fn foo(&self, x: Key1, y: Key2) -> u32
+ //
+ // we would create
+ //
+ // fn lookup_foo(&self, x: u32) -> (Key1, Key2)
+ let lookup_query = if let QueryStorage::Interned = storage {
+ let lookup_query_type =
+ format_ident!("{}LookupQuery", query_name.to_string().to_upper_camel_case());
+ let lookup_fn_name = format_ident!("lookup_{}", query_name);
+ let keys = keys.iter().map(|(_, ty)| ty);
+ let lookup_value: Type = parse_quote!((#(#keys),*));
+ let lookup_keys = vec![(parse_quote! { key }, value.clone())];
+ Some(Query {
+ query_type: lookup_query_type,
+ query_name: format!("{}", lookup_fn_name),
+ fn_name: lookup_fn_name,
+ receiver: self_receiver.clone(),
+ attrs: vec![], // FIXME -- some automatically generated docs on this method?
+ storage: QueryStorage::InternedLookup { intern_query_type: query_type.clone() },
+ keys: lookup_keys,
+ value: lookup_value,
+ invoke: None,
+ cycle: cycle.clone(),
+ })
+ } else {
+ None
+ };
+
+ queries.push(Query {
+ query_type,
+ query_name,
+ fn_name: method.sig.ident,
+ receiver: self_receiver.clone(),
+ attrs,
+ storage,
+ keys,
+ value,
+ invoke,
+ cycle,
+ });
+
+ queries.extend(lookup_query);
+ }
+ }
+
+ let group_storage = format_ident!("{}GroupStorage__", trait_name, span = Span::call_site());
+
+ let mut query_fn_declarations = proc_macro2::TokenStream::new();
+ let mut query_fn_definitions = proc_macro2::TokenStream::new();
+ let mut storage_fields = proc_macro2::TokenStream::new();
+ let mut queries_with_storage = vec![];
+ for query in &queries {
+ #[allow(clippy::map_identity)]
+ // clippy is incorrect here, this is not the identity function due to match ergonomics
+ let (key_names, keys): (Vec<_>, Vec<_>) = query.keys.iter().map(|(a, b)| (a, b)).unzip();
+ let value = &query.value;
+ let fn_name = &query.fn_name;
+ let qt = &query.query_type;
+ let attrs = &query.attrs;
+ let self_receiver = &query.receiver;
+
+ query_fn_declarations.extend(quote! {
+ #(#attrs)*
+ fn #fn_name(#self_receiver, #(#key_names: #keys),*) -> #value;
+ });
+
+ // Special case: transparent queries don't create actual storage,
+ // just inline the definition
+ if let QueryStorage::Transparent = query.storage {
+ let invoke = query.invoke_tt();
+ query_fn_definitions.extend(quote! {
+ fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
+ #invoke(self, #(#key_names),*)
+ }
+ });
+ continue;
+ }
+
+ queries_with_storage.push(fn_name);
+
+ query_fn_definitions.extend(quote! {
+ fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
+ // Create a shim to force the code to be monomorphized in the
+ // query crate. Our experiments revealed that this makes a big
+ // difference in total compilation time in rust-analyzer, though
+ // it's not totally obvious why that should be.
+ fn __shim(db: &(dyn #trait_name + '_), #(#key_names: #keys),*) -> #value {
+ salsa::plumbing::get_query_table::<#qt>(db).get((#(#key_names),*))
+ }
+ __shim(self, #(#key_names),*)
+
+ }
+ });
+
+ // For input queries, we need `set_foo` etc
+ if let QueryStorage::Input = query.storage {
+ let set_fn_name = format_ident!("set_{}", fn_name);
+ let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name);
+
+ let set_fn_docs = format!(
+ "
+ Set the value of the `{fn_name}` input.
+
+ See `{fn_name}` for details.
+
+ *Note:* Setting values will trigger cancellation
+ of any ongoing queries; this method blocks until
+ those queries have been cancelled.
+ ",
+ fn_name = fn_name
+ );
+
+ let set_constant_fn_docs = format!(
+ "
+ Set the value of the `{fn_name}` input with a
+ specific durability instead of the default of
+ `Durability::LOW`. You can use `Durability::MAX`
+ to promise that its value will never change again.
+
+ See `{fn_name}` for details.
+
+ *Note:* Setting values will trigger cancellation
+ of any ongoing queries; this method blocks until
+ those queries have been cancelled.
+ ",
+ fn_name = fn_name
+ );
+
+ query_fn_declarations.extend(quote! {
+ # [doc = #set_fn_docs]
+ fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value);
+
+
+ # [doc = #set_constant_fn_docs]
+ fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability);
+ });
+
+ query_fn_definitions.extend(quote! {
+ fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value) {
+ fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value) {
+ salsa::plumbing::get_query_table_mut::<#qt>(db).set((#(#key_names),*), value__)
+ }
+ __shim(self, #(#key_names,)* value__)
+ }
+
+ fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) {
+ fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) {
+ salsa::plumbing::get_query_table_mut::<#qt>(db).set_with_durability((#(#key_names),*), value__, durability__)
+ }
+ __shim(self, #(#key_names,)* value__ ,durability__)
+ }
+ });
+ }
+
+ // A field for the storage struct
+ storage_fields.extend(quote! {
+ #fn_name: std::sync::Arc<<#qt as salsa::Query>::Storage>,
+ });
+ }
+
+ // Emit the trait itself.
+ let mut output = {
+ let bounds = &input.supertraits;
+ quote! {
+ #(#trait_attrs)*
+ #trait_vis trait #trait_name :
+ salsa::Database +
+ salsa::plumbing::HasQueryGroup<#group_struct> +
+ #bounds
+ {
+ #query_fn_declarations
+ }
+ }
+ };
+
+ // Emit the query group struct and impl of `QueryGroup`.
+ output.extend(quote! {
+ /// Representative struct for the query group.
+ #trait_vis struct #group_struct { }
+
+ impl salsa::plumbing::QueryGroup for #group_struct
+ {
+ type DynDb = #dyn_db;
+ type GroupStorage = #group_storage;
+ }
+ });
+
+ // Emit an impl of the trait
+ output.extend({
+ let bounds = input.supertraits;
+ quote! {
+ impl<DB> #trait_name for DB
+ where
+ DB: #bounds,
+ DB: salsa::Database,
+ DB: salsa::plumbing::HasQueryGroup<#group_struct>,
+ {
+ #query_fn_definitions
+ }
+ }
+ });
+
+ let non_transparent_queries =
+ || queries.iter().filter(|q| !matches!(q.storage, QueryStorage::Transparent));
+
+ // Emit the query types.
+ for (query, query_index) in non_transparent_queries().zip(0_u16..) {
+ let fn_name = &query.fn_name;
+ let qt = &query.query_type;
+
+ let storage = match &query.storage {
+ QueryStorage::Memoized => quote!(salsa::plumbing::MemoizedStorage<Self>),
+ QueryStorage::Dependencies => {
+ quote!(salsa::plumbing::DependencyStorage<Self>)
+ }
+ QueryStorage::Input => quote!(salsa::plumbing::InputStorage<Self>),
+ QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<Self>),
+ QueryStorage::InternedLookup { intern_query_type } => {
+ quote!(salsa::plumbing::LookupInternedStorage<Self, #intern_query_type>)
+ }
+ QueryStorage::Transparent => panic!("should have been filtered"),
+ };
+ let keys = query.keys.iter().map(|(_, ty)| ty);
+ let value = &query.value;
+ let query_name = &query.query_name;
+
+ // Emit the query struct and implement the Query trait on it.
+ output.extend(quote! {
+ #[derive(Default, Debug)]
+ #trait_vis struct #qt;
+ });
+
+ output.extend(quote! {
+ impl #qt {
+ /// Get access to extra methods pertaining to this query.
+ /// You can also use it to invoke this query.
+ #trait_vis fn in_db(self, db: &#dyn_db) -> salsa::QueryTable<'_, Self>
+ {
+ salsa::plumbing::get_query_table::<#qt>(db)
+ }
+ }
+ });
+
+ output.extend(quote! {
+ impl #qt {
+ /// Like `in_db`, but gives access to methods for setting the
+ /// value of an input. Not applicable to derived queries.
+ ///
+ /// # Threads, cancellation, and blocking
+ ///
+ /// Mutating the value of a query cannot be done while there are
+ /// still other queries executing. If you are using your database
+ /// within a single thread, this is not a problem: you only have
+ /// `&self` access to the database, but this method requires `&mut
+ /// self`.
+ ///
+ /// However, if you have used `snapshot` to create other threads,
+ /// then attempts to `set` will **block the current thread** until
+ /// those snapshots are dropped (usually when those threads
+ /// complete). This also implies that if you create a snapshot but
+ /// do not send it to another thread, then invoking `set` will
+ /// deadlock.
+ ///
+ /// Before blocking, the thread that is attempting to `set` will
+ /// also set a cancellation flag. This will cause any query
+ /// invocations in other threads to unwind with a `Cancelled`
+ /// sentinel value and eventually let the `set` succeed once all
+ /// threads have unwound past the salsa invocation.
+ ///
+ /// If your query implementations are performing expensive
+ /// operations without invoking another query, you can also use
+ /// the `Runtime::unwind_if_cancelled` method to check for an
+ /// ongoing cancellation and bring those operations to a close,
+ /// thus allowing the `set` to succeed. Otherwise, long-running
+ /// computations may lead to "starvation", meaning that the
+ /// thread attempting to `set` has to wait a long, long time. =)
+ #trait_vis fn in_db_mut(self, db: &mut #dyn_db) -> salsa::QueryTableMut<'_, Self>
+ {
+ salsa::plumbing::get_query_table_mut::<#qt>(db)
+ }
+ }
+
+ impl<'d> salsa::QueryDb<'d> for #qt
+ {
+ type DynDb = #dyn_db + 'd;
+ type Group = #group_struct;
+ type GroupStorage = #group_storage;
+ }
+
+ // ANCHOR:Query_impl
+ impl salsa::Query for #qt
+ {
+ type Key = (#(#keys),*);
+ type Value = #value;
+ type Storage = #storage;
+
+ const QUERY_INDEX: u16 = #query_index;
+
+ const QUERY_NAME: &'static str = #query_name;
+
+ fn query_storage<'a>(
+ group_storage: &'a <Self as salsa::QueryDb<'_>>::GroupStorage,
+ ) -> &'a std::sync::Arc<Self::Storage> {
+ &group_storage.#fn_name
+ }
+
+ fn query_storage_mut<'a>(
+ group_storage: &'a <Self as salsa::QueryDb<'_>>::GroupStorage,
+ ) -> &'a std::sync::Arc<Self::Storage> {
+ &group_storage.#fn_name
+ }
+ }
+ // ANCHOR_END:Query_impl
+ });
+
+ // Implement the QueryFunction trait for queries which need it.
+ if query.storage.needs_query_function() {
+ let span = query.fn_name.span();
+
+ let key_names: Vec<_> = query.keys.iter().map(|(pat, _)| pat).collect();
+ let key_pattern = if query.keys.len() == 1 {
+ quote! { #(#key_names),* }
+ } else {
+ quote! { (#(#key_names),*) }
+ };
+ let invoke = query.invoke_tt();
+
+ let recover = if let Some(cycle_recovery_fn) = &query.cycle {
+ quote! {
+ const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy =
+ salsa::plumbing::CycleRecoveryStrategy::Fallback;
+ fn cycle_fallback(db: &<Self as salsa::QueryDb<'_>>::DynDb, cycle: &salsa::Cycle, #key_pattern: &<Self as salsa::Query>::Key)
+ -> <Self as salsa::Query>::Value {
+ #cycle_recovery_fn(
+ db,
+ cycle,
+ #(#key_names),*
+ )
+ }
+ }
+ } else {
+ quote! {
+ const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy =
+ salsa::plumbing::CycleRecoveryStrategy::Panic;
+ }
+ };
+
+ output.extend(quote_spanned! {span=>
+ // ANCHOR:QueryFunction_impl
+ impl salsa::plumbing::QueryFunction for #qt
+ {
+ fn execute(db: &<Self as salsa::QueryDb<'_>>::DynDb, #key_pattern: <Self as salsa::Query>::Key)
+ -> <Self as salsa::Query>::Value {
+ #invoke(db, #(#key_names),*)
+ }
+
+ #recover
+ }
+ // ANCHOR_END:QueryFunction_impl
+ });
+ }
+ }
+
+ let mut fmt_ops = proc_macro2::TokenStream::new();
+ for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
+ fmt_ops.extend(quote! {
+ #query_index => {
+ salsa::plumbing::QueryStorageOps::fmt_index(
+ &*self.#fn_name, db, input, fmt,
+ )
+ }
+ });
+ }
+
+ let mut maybe_changed_ops = proc_macro2::TokenStream::new();
+ for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
+ maybe_changed_ops.extend(quote! {
+ #query_index => {
+ salsa::plumbing::QueryStorageOps::maybe_changed_after(
+ &*self.#fn_name, db, input, revision
+ )
+ }
+ });
+ }
+
+ let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
+ for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
+ cycle_recovery_strategy_ops.extend(quote! {
+ #query_index => {
+ salsa::plumbing::QueryStorageOps::cycle_recovery_strategy(
+ &*self.#fn_name
+ )
+ }
+ });
+ }
+
+ let mut for_each_ops = proc_macro2::TokenStream::new();
+ for Query { fn_name, .. } in non_transparent_queries() {
+ for_each_ops.extend(quote! {
+ op(&*self.#fn_name);
+ });
+ }
+
+ // Emit query group storage struct
+ output.extend(quote! {
+ #trait_vis struct #group_storage {
+ #storage_fields
+ }
+
+ // ANCHOR:group_storage_new
+ impl #group_storage {
+ #trait_vis fn new(group_index: u16) -> Self {
+ #group_storage {
+ #(
+ #queries_with_storage:
+ std::sync::Arc::new(salsa::plumbing::QueryStorageOps::new(group_index)),
+ )*
+ }
+ }
+ }
+ // ANCHOR_END:group_storage_new
+
+ // ANCHOR:group_storage_methods
+ impl #group_storage {
+ #trait_vis fn fmt_index(
+ &self,
+ db: &(#dyn_db + '_),
+ input: salsa::DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ match input.query_index() {
+ #fmt_ops
+ i => panic!("salsa: impossible query index {}", i),
+ }
+ }
+
+ #trait_vis fn maybe_changed_after(
+ &self,
+ db: &(#dyn_db + '_),
+ input: salsa::DatabaseKeyIndex,
+ revision: salsa::Revision,
+ ) -> bool {
+ match input.query_index() {
+ #maybe_changed_ops
+ i => panic!("salsa: impossible query index {}", i),
+ }
+ }
+
+ #trait_vis fn cycle_recovery_strategy(
+ &self,
+ db: &(#dyn_db + '_),
+ input: salsa::DatabaseKeyIndex,
+ ) -> salsa::plumbing::CycleRecoveryStrategy {
+ match input.query_index() {
+ #cycle_recovery_strategy_ops
+ i => panic!("salsa: impossible query index {}", i),
+ }
+ }
+
+ #trait_vis fn for_each_query(
+ &self,
+ _runtime: &salsa::Runtime,
+ mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps),
+ ) {
+ #for_each_ops
+ }
+ }
+ // ANCHOR_END:group_storage_methods
+ });
+ output.into()
+}
+
+struct SalsaAttr {
+ name: String,
+ tts: TokenStream,
+ span: Span,
+}
+
+impl std::fmt::Debug for SalsaAttr {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(fmt, "{:?}", self.name)
+ }
+}
+
+impl TryFrom<syn::Attribute> for SalsaAttr {
+ type Error = syn::Attribute;
+
+ fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
+ if is_not_salsa_attr_path(attr.path()) {
+ return Err(attr);
+ }
+
+ let span = attr.span();
+ let name = attr.path().segments[1].ident.to_string();
+ let tts = match attr.meta {
+ syn::Meta::Path(path) => path.into_token_stream(),
+ syn::Meta::List(ref list) => {
+ let tts = list
+ .into_token_stream()
+ .into_iter()
+ .skip(attr.path().to_token_stream().into_iter().count());
+ proc_macro2::TokenStream::from_iter(tts)
+ }
+ syn::Meta::NameValue(nv) => nv.into_token_stream(),
+ }
+ .into();
+
+ Ok(SalsaAttr { name, tts, span })
+ }
+}
+
+fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
+ path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
+}
+
+fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
+ let mut other = vec![];
+ let mut salsa = vec![];
+ // Leave non-salsa attributes untouched. These are
+ // attributes that don't start with `salsa::` or don't have
+ // exactly two segments in their path.
+ // Keep the salsa attributes around.
+ for attr in attrs {
+ match SalsaAttr::try_from(attr) {
+ Ok(it) => salsa.push(it),
+ Err(it) => other.push(it),
+ }
+ }
+ (other, salsa)
+}
+
+#[derive(Debug)]
+struct Query {
+ fn_name: Ident,
+ receiver: syn::Receiver,
+ query_name: String,
+ attrs: Vec<syn::Attribute>,
+ query_type: Ident,
+ storage: QueryStorage,
+ keys: Vec<(Ident, syn::Type)>,
+ value: syn::Type,
+ invoke: Option<syn::Path>,
+ cycle: Option<syn::Path>,
+}
+
+impl Query {
+ fn invoke_tt(&self) -> proc_macro2::TokenStream {
+ match &self.invoke {
+ Some(i) => i.into_token_stream(),
+ None => self.fn_name.clone().into_token_stream(),
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+enum QueryStorage {
+ Memoized,
+ Dependencies,
+ Input,
+ Interned,
+ InternedLookup { intern_query_type: Ident },
+ Transparent,
+}
+
+impl QueryStorage {
+ /// Do we need a `QueryFunction` impl for this type of query?
+ fn needs_query_function(&self) -> bool {
+ match self {
+ QueryStorage::Input
+ | QueryStorage::Interned
+ | QueryStorage::InternedLookup { .. }
+ | QueryStorage::Transparent => false,
+ QueryStorage::Memoized | QueryStorage::Dependencies => true,
+ }
+ }
+}
diff --git a/crates/salsa/src/debug.rs b/crates/salsa/src/debug.rs
new file mode 100644
index 0000000000..0925ddb3d8
--- /dev/null
+++ b/crates/salsa/src/debug.rs
@@ -0,0 +1,66 @@
+//! Debugging APIs: these are meant for use when unit-testing or
+//! debugging your application but aren't ordinarily needed.
+
+use crate::durability::Durability;
+use crate::plumbing::QueryStorageOps;
+use crate::Query;
+use crate::QueryTable;
+use std::iter::FromIterator;
+
+/// Additional methods on queries that can be used to "peek into"
+/// their current state. These methods are meant for debugging and
+/// observing the effects of garbage collection etc.
+pub trait DebugQueryTable {
+ /// Key of this query.
+ type Key;
+
+ /// Value of this query.
+ type Value;
+
+ /// Returns a lower bound on the durability for the given key.
+ /// This is typically the minimum durability of all values that
+ /// the query accessed, but we may return a lower durability in
+ /// some cases.
+ fn durability(&self, key: Self::Key) -> Durability;
+
+ /// Get the (current) set of the entries in the query table.
+ fn entries<C>(&self) -> C
+ where
+ C: FromIterator<TableEntry<Self::Key, Self::Value>>;
+}
+
+/// An entry from a query table, for debugging and inspecting the table state.
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+#[non_exhaustive]
+pub struct TableEntry<K, V> {
+ /// key of the query
+ pub key: K,
+ /// value of the query, if it is stored
+ pub value: Option<V>,
+}
+
+impl<K, V> TableEntry<K, V> {
+ pub(crate) fn new(key: K, value: Option<V>) -> TableEntry<K, V> {
+ TableEntry { key, value }
+ }
+}
+
+impl<Q> DebugQueryTable for QueryTable<'_, Q>
+where
+ Q: Query,
+ Q::Storage: QueryStorageOps<Q>,
+{
+ type Key = Q::Key;
+ type Value = Q::Value;
+
+ fn durability(&self, key: Q::Key) -> Durability {
+ self.storage.durability(self.db, &key)
+ }
+
+ fn entries<C>(&self) -> C
+ where
+ C: FromIterator<TableEntry<Self::Key, Self::Value>>,
+ {
+ self.storage.entries(self.db)
+ }
+}
diff --git a/crates/salsa/src/derived.rs b/crates/salsa/src/derived.rs
new file mode 100644
index 0000000000..c381e66e08
--- /dev/null
+++ b/crates/salsa/src/derived.rs
@@ -0,0 +1,233 @@
+//!
+use crate::debug::TableEntry;
+use crate::durability::Durability;
+use crate::hash::FxIndexMap;
+use crate::lru::Lru;
+use crate::plumbing::DerivedQueryStorageOps;
+use crate::plumbing::LruQueryStorageOps;
+use crate::plumbing::QueryFunction;
+use crate::plumbing::QueryStorageMassOps;
+use crate::plumbing::QueryStorageOps;
+use crate::runtime::StampedValue;
+use crate::Runtime;
+use crate::{Database, DatabaseKeyIndex, QueryDb, Revision};
+use parking_lot::RwLock;
+use std::borrow::Borrow;
+use std::convert::TryFrom;
+use std::hash::Hash;
+use std::marker::PhantomData;
+use triomphe::Arc;
+
+mod slot;
+use slot::Slot;
+
+/// Memoized queries store the result plus a list of the other queries
+/// that they invoked. This means we can avoid recomputing them when
+/// none of those inputs have changed.
+pub type MemoizedStorage<Q> = DerivedStorage<Q, AlwaysMemoizeValue>;
+
+/// "Dependency" queries just track their dependencies and not the
+/// actual value (which they produce on demand). This lessens the
+/// storage requirements.
+pub type DependencyStorage<Q> = DerivedStorage<Q, NeverMemoizeValue>;
+
+/// Handles storage where the value is 'derived' by executing a
+/// function (in contrast to "inputs").
+pub struct DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ group_index: u16,
+ lru_list: Lru<Slot<Q, MP>>,
+ slot_map: RwLock<FxIndexMap<Q::Key, Arc<Slot<Q, MP>>>>,
+ policy: PhantomData<MP>,
+}
+
+impl<Q, MP> std::panic::RefUnwindSafe for DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+ Q::Key: std::panic::RefUnwindSafe,
+ Q::Value: std::panic::RefUnwindSafe,
+{
+}
+
+pub trait MemoizationPolicy<Q>: Send + Sync
+where
+ Q: QueryFunction,
+{
+ fn should_memoize_value(key: &Q::Key) -> bool;
+
+ fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool;
+}
+
+pub enum AlwaysMemoizeValue {}
+impl<Q> MemoizationPolicy<Q> for AlwaysMemoizeValue
+where
+ Q: QueryFunction,
+ Q::Value: Eq,
+{
+ fn should_memoize_value(_key: &Q::Key) -> bool {
+ true
+ }
+
+ fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool {
+ old_value == new_value
+ }
+}
+
+pub enum NeverMemoizeValue {}
+impl<Q> MemoizationPolicy<Q> for NeverMemoizeValue
+where
+ Q: QueryFunction,
+{
+ fn should_memoize_value(_key: &Q::Key) -> bool {
+ false
+ }
+
+ fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool {
+ panic!("cannot reach since we never memoize")
+ }
+}
+
+impl<Q, MP> DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn slot(&self, key: &Q::Key) -> Arc<Slot<Q, MP>> {
+ if let Some(v) = self.slot_map.read().get(key) {
+ return v.clone();
+ }
+
+ let mut write = self.slot_map.write();
+ let entry = write.entry(key.clone());
+ let key_index = u32::try_from(entry.index()).unwrap();
+ let database_key_index = DatabaseKeyIndex {
+ group_index: self.group_index,
+ query_index: Q::QUERY_INDEX,
+ key_index,
+ };
+ entry.or_insert_with(|| Arc::new(Slot::new(key.clone(), database_key_index))).clone()
+ }
+}
+
+impl<Q, MP> QueryStorageOps<Q> for DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY;
+
+ fn new(group_index: u16) -> Self {
+ DerivedStorage {
+ group_index,
+ slot_map: RwLock::new(FxIndexMap::default()),
+ lru_list: Default::default(),
+ policy: PhantomData,
+ }
+ }
+
+ fn fmt_index(
+ &self,
+ _db: &<Q as QueryDb<'_>>::DynDb,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ assert_eq!(index.group_index, self.group_index);
+ assert_eq!(index.query_index, Q::QUERY_INDEX);
+ let slot_map = self.slot_map.read();
+ let key = slot_map.get_index(index.key_index as usize).unwrap().0;
+ write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
+ }
+
+ fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ input: DatabaseKeyIndex,
+ revision: Revision,
+ ) -> bool {
+ assert_eq!(input.group_index, self.group_index);
+ assert_eq!(input.query_index, Q::QUERY_INDEX);
+ debug_assert!(revision < db.salsa_runtime().current_revision());
+ let slot = self.slot_map.read().get_index(input.key_index as usize).unwrap().1.clone();
+ slot.maybe_changed_after(db, revision)
+ }
+
+ fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
+ db.unwind_if_cancelled();
+
+ let slot = self.slot(key);
+ let StampedValue { value, durability, changed_at } = slot.read(db);
+
+ if let Some(evicted) = self.lru_list.record_use(&slot) {
+ evicted.evict();
+ }
+
+ db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
+ slot.database_key_index(),
+ durability,
+ changed_at,
+ );
+
+ value
+ }
+
+ fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
+ self.slot(key).durability(db)
+ }
+
+ fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
+ where
+ C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
+ {
+ let slot_map = self.slot_map.read();
+ slot_map.values().filter_map(|slot| slot.as_table_entry()).collect()
+ }
+}
+
+impl<Q, MP> QueryStorageMassOps for DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn purge(&self) {
+ self.lru_list.purge();
+ *self.slot_map.write() = Default::default();
+ }
+}
+
+impl<Q, MP> LruQueryStorageOps for DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn set_lru_capacity(&self, new_capacity: usize) {
+ self.lru_list.set_lru_capacity(new_capacity);
+ }
+}
+
+impl<Q, MP> DerivedQueryStorageOps<Q> for DerivedStorage<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn invalidate<S>(&self, runtime: &mut Runtime, key: &S)
+ where
+ S: Eq + Hash,
+ Q::Key: Borrow<S>,
+ {
+ runtime.with_incremented_revision(|new_revision| {
+ let map_read = self.slot_map.read();
+
+ if let Some(slot) = map_read.get(key) {
+ if let Some(durability) = slot.invalidate(new_revision) {
+ return Some(durability);
+ }
+ }
+
+ None
+ })
+ }
+}
diff --git a/crates/salsa/src/derived/slot.rs b/crates/salsa/src/derived/slot.rs
new file mode 100644
index 0000000000..4fad791a26
--- /dev/null
+++ b/crates/salsa/src/derived/slot.rs
@@ -0,0 +1,833 @@
+//!
+use crate::debug::TableEntry;
+use crate::derived::MemoizationPolicy;
+use crate::durability::Durability;
+use crate::lru::LruIndex;
+use crate::lru::LruNode;
+use crate::plumbing::{DatabaseOps, QueryFunction};
+use crate::revision::Revision;
+use crate::runtime::local_state::ActiveQueryGuard;
+use crate::runtime::local_state::QueryInputs;
+use crate::runtime::local_state::QueryRevisions;
+use crate::runtime::Runtime;
+use crate::runtime::RuntimeId;
+use crate::runtime::StampedValue;
+use crate::runtime::WaitResult;
+use crate::Cycle;
+use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb};
+use parking_lot::{RawRwLock, RwLock};
+use std::marker::PhantomData;
+use std::ops::Deref;
+use std::sync::atomic::{AtomicBool, Ordering};
+use tracing::{debug, info};
+
+pub(super) struct Slot<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ key: Q::Key,
+ database_key_index: DatabaseKeyIndex,
+ state: RwLock<QueryState<Q>>,
+ policy: PhantomData<MP>,
+ lru_index: LruIndex,
+}
+
+/// Defines the "current state" of query's memoized results.
+enum QueryState<Q>
+where
+ Q: QueryFunction,
+{
+ NotComputed,
+
+ /// The runtime with the given id is currently computing the
+ /// result of this query.
+ InProgress {
+ id: RuntimeId,
+
+ /// Set to true if any other queries are blocked,
+ /// waiting for this query to complete.
+ anyone_waiting: AtomicBool,
+ },
+
+ /// We have computed the query already, and here is the result.
+ Memoized(Memo<Q::Value>),
+}
+
+struct Memo<V> {
+ /// The result of the query, if we decide to memoize it.
+ value: Option<V>,
+
+ /// Last revision when this memo was verified; this begins
+ /// as the current revision.
+ pub(crate) verified_at: Revision,
+
+ /// Revision information
+ revisions: QueryRevisions,
+}
+
+/// Return value of `probe` helper.
+enum ProbeState<V, G> {
+ /// Another thread was active but has completed.
+ /// Try again!
+ Retry,
+
+ /// No entry for this key at all.
+ NotComputed(G),
+
+ /// There is an entry, but its contents have not been
+ /// verified in this revision.
+ Stale(G),
+
+ /// There is an entry, and it has been verified
+ /// in this revision, but it has no cached
+ /// value. The `Revision` is the revision where the
+ /// value last changed (if we were to recompute it).
+ NoValue(G, Revision),
+
+ /// There is an entry which has been verified,
+ /// and it has the following value-- or, we blocked
+ /// on another thread, and that resulted in a cycle.
+ UpToDate(V),
+}
+
+/// Return value of `maybe_changed_after_probe` helper.
+enum MaybeChangedSinceProbeState<G> {
+ /// Another thread was active but has completed.
+ /// Try again!
+ Retry,
+
+ /// Value may have changed in the given revision.
+ ChangedAt(Revision),
+
+ /// There is a stale cache entry that has not been
+ /// verified in this revision, so we can't say.
+ Stale(G),
+}
+
+impl<Q, MP> Slot<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self {
+ Self {
+ key,
+ database_key_index,
+ state: RwLock::new(QueryState::NotComputed),
+ lru_index: LruIndex::default(),
+ policy: PhantomData,
+ }
+ }
+
+ pub(super) fn database_key_index(&self) -> DatabaseKeyIndex {
+ self.database_key_index
+ }
+
+ pub(super) fn read(&self, db: &<Q as QueryDb<'_>>::DynDb) -> StampedValue<Q::Value> {
+ let runtime = db.salsa_runtime();
+
+ // NB: We don't need to worry about people modifying the
+ // revision out from under our feet. Either `db` is a frozen
+ // database, in which case there is a lock, or the mutator
+ // thread is the current thread, and it will be prevented from
+ // doing any `set` invocations while the query function runs.
+ let revision_now = runtime.current_revision();
+
+ info!("{:?}: invoked at {:?}", self, revision_now,);
+
+ // First, do a check with a read-lock.
+ loop {
+ match self.probe(db, self.state.read(), runtime, revision_now) {
+ ProbeState::UpToDate(v) => return v,
+ ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => {
+ break
+ }
+ ProbeState::Retry => continue,
+ }
+ }
+
+ self.read_upgrade(db, revision_now)
+ }
+
+ /// Second phase of a read operation: acquires an upgradable-read
+ /// and -- if needed -- validates whether inputs have changed,
+ /// recomputes value, etc. This is invoked after our initial probe
+ /// shows a potentially out of date value.
+ fn read_upgrade(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ revision_now: Revision,
+ ) -> StampedValue<Q::Value> {
+ let runtime = db.salsa_runtime();
+
+ debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,);
+
+ // Check with an upgradable read to see if there is a value
+ // already. (This permits other readers but prevents anyone
+ // else from running `read_upgrade` at the same time.)
+ let mut old_memo = loop {
+ match self.probe(db, self.state.upgradable_read(), runtime, revision_now) {
+ ProbeState::UpToDate(v) => return v,
+ ProbeState::Stale(state)
+ | ProbeState::NotComputed(state)
+ | ProbeState::NoValue(state, _) => {
+ type RwLockUpgradableReadGuard<'a, T> =
+ lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
+
+ let mut state = RwLockUpgradableReadGuard::upgrade(state);
+ match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
+ QueryState::Memoized(old_memo) => break Some(old_memo),
+ QueryState::InProgress { .. } => unreachable!(),
+ QueryState::NotComputed => break None,
+ }
+ }
+ ProbeState::Retry => continue,
+ }
+ };
+
+ let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
+ let active_query = runtime.push_query(self.database_key_index);
+
+ // If we have an old-value, it *may* now be stale, since there
+ // has been a new revision since the last time we checked. So,
+ // first things first, let's walk over each of our previous
+ // inputs and check whether they are out of date.
+ if let Some(memo) = &mut old_memo {
+ if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) {
+ info!("{:?}: validated old memoized value", self,);
+
+ db.salsa_event(Event {
+ runtime_id: runtime.id(),
+ kind: EventKind::DidValidateMemoizedValue {
+ database_key: self.database_key_index,
+ },
+ });
+
+ panic_guard.proceed(old_memo);
+
+ return value;
+ }
+ }
+
+ self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo)
+ }
+
+ fn execute(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ runtime: &Runtime,
+ revision_now: Revision,
+ active_query: ActiveQueryGuard<'_>,
+ panic_guard: PanicGuard<'_, Q, MP>,
+ old_memo: Option<Memo<Q::Value>>,
+ ) -> StampedValue<Q::Value> {
+ tracing::info!("{:?}: executing query", self.database_key_index.debug(db));
+
+ db.salsa_event(Event {
+ runtime_id: db.salsa_runtime().id(),
+ kind: EventKind::WillExecute { database_key: self.database_key_index },
+ });
+
+ // Query was not previously executed, or value is potentially
+ // stale, or value is absent. Let's execute!
+ let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) {
+ Ok(v) => v,
+ Err(cycle) => {
+ tracing::debug!(
+ "{:?}: caught cycle {:?}, have strategy {:?}",
+ self.database_key_index.debug(db),
+ cycle,
+ Q::CYCLE_STRATEGY,
+ );
+ match Q::CYCLE_STRATEGY {
+ crate::plumbing::CycleRecoveryStrategy::Panic => {
+ panic_guard.proceed(None);
+ cycle.throw()
+ }
+ crate::plumbing::CycleRecoveryStrategy::Fallback => {
+ if let Some(c) = active_query.take_cycle() {
+ assert!(c.is(&cycle));
+ Q::cycle_fallback(db, &cycle, &self.key)
+ } else {
+ // we are not a participant in this cycle
+ debug_assert!(!cycle
+ .participant_keys()
+ .any(|k| k == self.database_key_index));
+ cycle.throw()
+ }
+ }
+ }
+ }
+ };
+
+ let mut revisions = active_query.pop();
+
+ // We assume that query is side-effect free -- that is, does
+ // not mutate the "inputs" to the query system. Sanity check
+ // that assumption here, at least to the best of our ability.
+ assert_eq!(
+ runtime.current_revision(),
+ revision_now,
+ "revision altered during query execution",
+ );
+
+ // If the new value is equal to the old one, then it didn't
+ // really change, even if some of its inputs have. So we can
+ // "backdate" its `changed_at` revision to be the same as the
+ // old value.
+ if let Some(old_memo) = &old_memo {
+ if let Some(old_value) = &old_memo.value {
+ // Careful: if the value became less durable than it
+ // used to be, that is a "breaking change" that our
+ // consumers must be aware of. Becoming *more* durable
+ // is not. See the test `constant_to_non_constant`.
+ if revisions.durability >= old_memo.revisions.durability
+ && MP::memoized_value_eq(old_value, &value)
+ {
+ debug!(
+ "read_upgrade({:?}): value is equal, back-dating to {:?}",
+ self, old_memo.revisions.changed_at,
+ );
+
+ assert!(old_memo.revisions.changed_at <= revisions.changed_at);
+ revisions.changed_at = old_memo.revisions.changed_at;
+ }
+ }
+ }
+
+ let new_value = StampedValue {
+ value,
+ durability: revisions.durability,
+ changed_at: revisions.changed_at,
+ };
+
+ let memo_value =
+ if self.should_memoize_value(&self.key) { Some(new_value.value.clone()) } else { None };
+
+ debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,);
+
+ panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions }));
+
+ new_value
+ }
+
+ /// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value.
+ ///
+ /// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard
+ /// can be either read or write). Returns a suitable `ProbeState`:
+ ///
+ /// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another
+ /// thread that produced such a value).
+ /// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo
+ /// has no value; or (c) the memo has not been verified at the current revision.
+ ///
+ /// Note that in case `ProbeState::UpToDate`, the lock will have been released.
+ fn probe<StateGuard>(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ state: StateGuard,
+ runtime: &Runtime,
+ revision_now: Revision,
+ ) -> ProbeState<StampedValue<Q::Value>, StateGuard>
+ where
+ StateGuard: Deref<Target = QueryState<Q>>,
+ {
+ match &*state {
+ QueryState::NotComputed => ProbeState::NotComputed(state),
+
+ QueryState::InProgress { id, anyone_waiting } => {
+ let other_id = *id;
+
+ // NB: `Ordering::Relaxed` is sufficient here,
+ // as there are no loads that are "gated" on this
+ // value. Everything that is written is also protected
+ // by a lock that must be acquired. The role of this
+ // boolean is to decide *whether* to acquire the lock,
+ // not to gate future atomic reads.
+ anyone_waiting.store(true, Ordering::Relaxed);
+
+ self.block_on_or_unwind(db, runtime, other_id, state);
+
+ // Other thread completely normally, so our value may be available now.
+ ProbeState::Retry
+ }
+
+ QueryState::Memoized(memo) => {
+ debug!(
+ "{:?}: found memoized value, verified_at={:?}, changed_at={:?}",
+ self, memo.verified_at, memo.revisions.changed_at,
+ );
+
+ if memo.verified_at < revision_now {
+ return ProbeState::Stale(state);
+ }
+
+ if let Some(value) = &memo.value {
+ let value = StampedValue {
+ durability: memo.revisions.durability,
+ changed_at: memo.revisions.changed_at,
+ value: value.clone(),
+ };
+
+ info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at);
+
+ ProbeState::UpToDate(value)
+ } else {
+ let changed_at = memo.revisions.changed_at;
+ ProbeState::NoValue(state, changed_at)
+ }
+ }
+ }
+ }
+
+ pub(super) fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb) -> Durability {
+ match &*self.state.read() {
+ QueryState::NotComputed => Durability::LOW,
+ QueryState::InProgress { .. } => panic!("query in progress"),
+ QueryState::Memoized(memo) => {
+ if memo.check_durability(db.salsa_runtime()) {
+ memo.revisions.durability
+ } else {
+ Durability::LOW
+ }
+ }
+ }
+ }
+
+ pub(super) fn as_table_entry(&self) -> Option<TableEntry<Q::Key, Q::Value>> {
+ match &*self.state.read() {
+ QueryState::NotComputed => None,
+ QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)),
+ QueryState::Memoized(memo) => {
+ Some(TableEntry::new(self.key.clone(), memo.value.clone()))
+ }
+ }
+ }
+
+ pub(super) fn evict(&self) {
+ let mut state = self.state.write();
+ if let QueryState::Memoized(memo) = &mut *state {
+ // Evicting a value with an untracked input could
+ // lead to inconsistencies. Note that we can't check
+ // `has_untracked_input` when we add the value to the cache,
+ // because inputs can become untracked in the next revision.
+ if memo.has_untracked_input() {
+ return;
+ }
+ memo.value = None;
+ }
+ }
+
+ pub(super) fn invalidate(&self, new_revision: Revision) -> Option<Durability> {
+ tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision);
+ match &mut *self.state.write() {
+ QueryState::Memoized(memo) => {
+ memo.revisions.inputs = QueryInputs::Untracked;
+ memo.revisions.changed_at = new_revision;
+ Some(memo.revisions.durability)
+ }
+ QueryState::NotComputed => None,
+ QueryState::InProgress { .. } => unreachable!(),
+ }
+ }
+
+ pub(super) fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ revision: Revision,
+ ) -> bool {
+ let runtime = db.salsa_runtime();
+ let revision_now = runtime.current_revision();
+
+ db.unwind_if_cancelled();
+
+ debug!(
+ "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}",
+ self, revision, revision_now,
+ );
+
+ // Do an initial probe with just the read-lock.
+ //
+ // If we find that a cache entry for the value is present
+ // but hasn't been verified in this revision, we'll have to
+ // do more.
+ loop {
+ match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) {
+ MaybeChangedSinceProbeState::Retry => continue,
+ MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
+ MaybeChangedSinceProbeState::Stale(state) => {
+ drop(state);
+ return self.maybe_changed_after_upgrade(db, revision);
+ }
+ }
+ }
+ }
+
+ fn maybe_changed_after_probe<StateGuard>(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ state: StateGuard,
+ runtime: &Runtime,
+ revision_now: Revision,
+ ) -> MaybeChangedSinceProbeState<StateGuard>
+ where
+ StateGuard: Deref<Target = QueryState<Q>>,
+ {
+ match self.probe(db, state, runtime, revision_now) {
+ ProbeState::Retry => MaybeChangedSinceProbeState::Retry,
+
+ ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state),
+
+ // If we know when value last changed, we can return right away.
+ // Note that we don't need the actual value to be available.
+ ProbeState::NoValue(_, changed_at)
+ | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => {
+ MaybeChangedSinceProbeState::ChangedAt(changed_at)
+ }
+
+ // If we have nothing cached, then value may have changed.
+ ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now),
+ }
+ }
+
+ fn maybe_changed_after_upgrade(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ revision: Revision,
+ ) -> bool {
+ let runtime = db.salsa_runtime();
+ let revision_now = runtime.current_revision();
+
+ // Get an upgradable read lock, which permits other reads but no writers.
+ // Probe again. If the value is stale (needs to be verified), then upgrade
+ // to a write lock and swap it with InProgress while we work.
+ let mut old_memo = match self.maybe_changed_after_probe(
+ db,
+ self.state.upgradable_read(),
+ runtime,
+ revision_now,
+ ) {
+ MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
+
+ // If another thread was active, then the cache line is going to be
+ // either verified or cleared out. Just recurse to figure out which.
+ // Note that we don't need an upgradable read.
+ MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision),
+
+ MaybeChangedSinceProbeState::Stale(state) => {
+ type RwLockUpgradableReadGuard<'a, T> =
+ lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
+
+ let mut state = RwLockUpgradableReadGuard::upgrade(state);
+ match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
+ QueryState::Memoized(old_memo) => old_memo,
+ QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(),
+ }
+ }
+ };
+
+ let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
+ let active_query = runtime.push_query(self.database_key_index);
+
+ if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) {
+ let maybe_changed = old_memo.revisions.changed_at > revision;
+ panic_guard.proceed(Some(old_memo));
+ maybe_changed
+ } else if old_memo.value.is_some() {
+ // We found that this memoized value may have changed
+ // but we have an old value. We can re-run the code and
+ // actually *check* if it has changed.
+ let StampedValue { changed_at, .. } =
+ self.execute(db, runtime, revision_now, active_query, panic_guard, Some(old_memo));
+ changed_at > revision
+ } else {
+ // We found that inputs to this memoized value may have chanced
+ // but we don't have an old value to compare against or re-use.
+ // No choice but to drop the memo and say that its value may have changed.
+ panic_guard.proceed(None);
+ true
+ }
+ }
+
+ /// Helper: see [`Runtime::try_block_on_or_unwind`].
+ fn block_on_or_unwind<MutexGuard>(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ runtime: &Runtime,
+ other_id: RuntimeId,
+ mutex_guard: MutexGuard,
+ ) {
+ runtime.block_on_or_unwind(
+ db.ops_database(),
+ self.database_key_index,
+ other_id,
+ mutex_guard,
+ )
+ }
+
+ fn should_memoize_value(&self, key: &Q::Key) -> bool {
+ MP::should_memoize_value(key)
+ }
+}
+
+impl<Q> QueryState<Q>
+where
+ Q: QueryFunction,
+{
+ fn in_progress(id: RuntimeId) -> Self {
+ QueryState::InProgress { id, anyone_waiting: Default::default() }
+ }
+}
+
+struct PanicGuard<'me, Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ database_key_index: DatabaseKeyIndex,
+ slot: &'me Slot<Q, MP>,
+ runtime: &'me Runtime,
+}
+
+impl<'me, Q, MP> PanicGuard<'me, Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn new(
+ database_key_index: DatabaseKeyIndex,
+ slot: &'me Slot<Q, MP>,
+ runtime: &'me Runtime,
+ ) -> Self {
+ Self { database_key_index, slot, runtime }
+ }
+
+ /// Indicates that we have concluded normally (without panicking).
+ /// If `opt_memo` is some, then this memo is installed as the new
+ /// memoized value. If `opt_memo` is `None`, then the slot is cleared
+ /// and has no value.
+ fn proceed(mut self, opt_memo: Option<Memo<Q::Value>>) {
+ self.overwrite_placeholder(WaitResult::Completed, opt_memo);
+ std::mem::forget(self)
+ }
+
+ /// Overwrites the `InProgress` placeholder for `key` that we
+ /// inserted; if others were blocked, waiting for us to finish,
+ /// then notify them.
+ fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option<Memo<Q::Value>>) {
+ let mut write = self.slot.state.write();
+
+ let old_value = match opt_memo {
+ // Replace the `InProgress` marker that we installed with the new
+ // memo, thus releasing our unique access to this key.
+ Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)),
+
+ // We had installed an `InProgress` marker, but we panicked before
+ // it could be removed. At this point, we therefore "own" unique
+ // access to our slot, so we can just remove the key.
+ None => std::mem::replace(&mut *write, QueryState::NotComputed),
+ };
+
+ match old_value {
+ QueryState::InProgress { id, anyone_waiting } => {
+ assert_eq!(id, self.runtime.id());
+
+ // NB: As noted on the `store`, `Ordering::Relaxed` is
+ // sufficient here. This boolean signals us on whether to
+ // acquire a mutex; the mutex will guarantee that all writes
+ // we are interested in are visible.
+ if anyone_waiting.load(Ordering::Relaxed) {
+ self.runtime.unblock_queries_blocked_on(self.database_key_index, wait_result);
+ }
+ }
+ _ => panic!(
+ "\
+Unexpected panic during query evaluation, aborting the process.
+
+Please report this bug to https://github.com/salsa-rs/salsa/issues."
+ ),
+ }
+ }
+}
+
+impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn drop(&mut self) {
+ if std::thread::panicking() {
+ // We panicked before we could proceed and need to remove `key`.
+ self.overwrite_placeholder(WaitResult::Panicked, None)
+ } else {
+ // If no panic occurred, then panic guard ought to be
+ // "forgotten" and so this Drop code should never run.
+ panic!(".forget() was not called")
+ }
+ }
+}
+
+impl<V> Memo<V>
+where
+ V: Clone,
+{
+ /// Determines whether the value stored in this memo (if any) is still
+ /// valid in the current revision. If so, returns a stamped value.
+ ///
+ /// If needed, this will walk each dependency and
+ /// recursively invoke `maybe_changed_after`, which may in turn
+ /// re-execute the dependency. This can cause cycles to occur,
+ /// so the current query must be pushed onto the
+ /// stack to permit cycle detection and recovery: therefore,
+ /// takes the `active_query` argument as evidence.
+ fn verify_value(
+ &mut self,
+ db: &dyn Database,
+ revision_now: Revision,
+ active_query: &ActiveQueryGuard<'_>,
+ ) -> Option<StampedValue<V>> {
+ // If we don't have a memoized value, nothing to validate.
+ if self.value.is_none() {
+ return None;
+ }
+ if self.verify_revisions(db, revision_now, active_query) {
+ Some(StampedValue {
+ durability: self.revisions.durability,
+ changed_at: self.revisions.changed_at,
+ value: self.value.as_ref().unwrap().clone(),
+ })
+ } else {
+ None
+ }
+ }
+
+ /// Determines whether the value represented by this memo is still
+ /// valid in the current revision; note that the value itself is
+ /// not needed for this check. If needed, this will walk each
+ /// dependency and recursively invoke `maybe_changed_after`, which
+ /// may in turn re-execute the dependency. This can cause cycles to occur,
+ /// so the current query must be pushed onto the
+ /// stack to permit cycle detection and recovery: therefore,
+ /// takes the `active_query` argument as evidence.
+ fn verify_revisions(
+ &mut self,
+ db: &dyn Database,
+ revision_now: Revision,
+ _active_query: &ActiveQueryGuard<'_>,
+ ) -> bool {
+ assert!(self.verified_at != revision_now);
+ let verified_at = self.verified_at;
+
+ debug!(
+ "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}",
+ verified_at, revision_now, self.revisions.inputs
+ );
+
+ if self.check_durability(db.salsa_runtime()) {
+ return self.mark_value_as_verified(revision_now);
+ }
+
+ match &self.revisions.inputs {
+ // We can't validate values that had untracked inputs; just have to
+ // re-execute.
+ QueryInputs::Untracked => {
+ return false;
+ }
+
+ QueryInputs::NoInputs => {}
+
+ // Check whether any of our inputs changed since the
+ // **last point where we were verified** (not since we
+ // last changed). This is important: if we have
+ // memoized values, then an input may have changed in
+ // revision R2, but we found that *our* value was the
+ // same regardless, so our change date is still
+ // R1. But our *verification* date will be R2, and we
+ // are only interested in finding out whether the
+ // input changed *again*.
+ QueryInputs::Tracked { inputs } => {
+ let changed_input =
+ inputs.iter().find(|&&input| db.maybe_changed_after(input, verified_at));
+ if let Some(input) = changed_input {
+ debug!("validate_memoized_value: `{:?}` may have changed", input);
+
+ return false;
+ }
+ }
+ };
+
+ self.mark_value_as_verified(revision_now)
+ }
+
+ /// True if this memo is known not to have changed based on its durability.
+ fn check_durability(&self, runtime: &Runtime) -> bool {
+ let last_changed = runtime.last_changed_revision(self.revisions.durability);
+ debug!(
+ "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}",
+ last_changed,
+ self.verified_at,
+ last_changed <= self.verified_at,
+ );
+ last_changed <= self.verified_at
+ }
+
+ fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool {
+ self.verified_at = revision_now;
+ true
+ }
+
+ fn has_untracked_input(&self) -> bool {
+ matches!(self.revisions.inputs, QueryInputs::Untracked)
+ }
+}
+
+impl<Q, MP> std::fmt::Debug for Slot<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(fmt, "{:?}({:?})", Q::default(), self.key)
+ }
+}
+
+impl<Q, MP> LruNode for Slot<Q, MP>
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+{
+ fn lru_index(&self) -> &LruIndex {
+ &self.lru_index
+ }
+}
+
+/// Check that `Slot<Q, MP>: Send + Sync` as long as
+/// `DB::DatabaseData: Send + Sync`, which in turn implies that
+/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
+#[allow(dead_code)]
+fn check_send_sync<Q, MP>()
+where
+ Q: QueryFunction,
+ MP: MemoizationPolicy<Q>,
+ Q::Key: Send + Sync,
+ Q::Value: Send + Sync,
+{
+ fn is_send_sync<T: Send + Sync>() {}
+ is_send_sync::<Slot<Q, MP>>();
+}
+
+/// Check that `Slot<Q, MP>: 'static` as long as
+/// `DB::DatabaseData: 'static`, which in turn implies that
+/// `Q::Key: 'static`, `Q::Value: 'static`.
+#[allow(dead_code)]
+fn check_static<Q, MP>()
+where
+ Q: QueryFunction + 'static,
+ MP: MemoizationPolicy<Q> + 'static,
+ Q::Key: 'static,
+ Q::Value: 'static,
+{
+ fn is_static<T: 'static>() {}
+ is_static::<Slot<Q, MP>>();
+}
diff --git a/crates/salsa/src/doctest.rs b/crates/salsa/src/doctest.rs
new file mode 100644
index 0000000000..29a8066356
--- /dev/null
+++ b/crates/salsa/src/doctest.rs
@@ -0,0 +1,115 @@
+//!
+#![allow(dead_code)]
+
+/// Test that a database with a key/value that is not `Send` will,
+/// indeed, not be `Send`.
+///
+/// ```compile_fail,E0277
+/// use std::rc::Rc;
+///
+/// #[salsa::query_group(NoSendSyncStorage)]
+/// trait NoSendSyncDatabase: salsa::Database {
+/// fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
+/// fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
+/// }
+///
+/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
+/// Rc::new(key)
+/// }
+///
+/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
+/// *key
+/// }
+///
+/// #[salsa::database(NoSendSyncStorage)]
+/// #[derive(Default)]
+/// struct DatabaseImpl {
+/// storage: salsa::Storage<Self>,
+/// }
+///
+/// impl salsa::Database for DatabaseImpl {
+/// }
+///
+/// fn is_send<T: Send>(_: T) { }
+///
+/// fn assert_send() {
+/// is_send(DatabaseImpl::default());
+/// }
+/// ```
+fn test_key_not_send_db_not_send() {}
+
+/// Test that a database with a key/value that is not `Sync` will not
+/// be `Send`.
+///
+/// ```compile_fail,E0277
+/// use std::rc::Rc;
+/// use std::cell::Cell;
+///
+/// #[salsa::query_group(NoSendSyncStorage)]
+/// trait NoSendSyncDatabase: salsa::Database {
+/// fn no_send_sync_value(&self, key: bool) -> Cell<bool>;
+/// fn no_send_sync_key(&self, key: Cell<bool>) -> bool;
+/// }
+///
+/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell<bool> {
+/// Cell::new(key)
+/// }
+///
+/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell<bool>) -> bool {
+/// *key
+/// }
+///
+/// #[salsa::database(NoSendSyncStorage)]
+/// #[derive(Default)]
+/// struct DatabaseImpl {
+/// runtime: salsa::Storage<Self>,
+/// }
+///
+/// impl salsa::Database for DatabaseImpl {
+/// }
+///
+/// fn is_send<T: Send>(_: T) { }
+///
+/// fn assert_send() {
+/// is_send(DatabaseImpl::default());
+/// }
+/// ```
+fn test_key_not_sync_db_not_send() {}
+
+/// Test that a database with a key/value that is not `Sync` will
+/// not be `Sync`.
+///
+/// ```compile_fail,E0277
+/// use std::cell::Cell;
+/// use std::rc::Rc;
+///
+/// #[salsa::query_group(NoSendSyncStorage)]
+/// trait NoSendSyncDatabase: salsa::Database {
+/// fn no_send_sync_value(&self, key: bool) -> Cell<bool>;
+/// fn no_send_sync_key(&self, key: Cell<bool>) -> bool;
+/// }
+///
+/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell<bool> {
+/// Cell::new(key)
+/// }
+///
+/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell<bool>) -> bool {
+/// *key
+/// }
+///
+/// #[salsa::database(NoSendSyncStorage)]
+/// #[derive(Default)]
+/// struct DatabaseImpl {
+/// runtime: salsa::Storage<Self>,
+/// }
+///
+/// impl salsa::Database for DatabaseImpl {
+/// }
+///
+/// fn is_sync<T: Sync>(_: T) { }
+///
+/// fn assert_send() {
+/// is_sync(DatabaseImpl::default());
+/// }
+/// ```
+fn test_key_not_sync_db_not_sync() {}
diff --git a/crates/salsa/src/durability.rs b/crates/salsa/src/durability.rs
new file mode 100644
index 0000000000..0c82f6345a
--- /dev/null
+++ b/crates/salsa/src/durability.rs
@@ -0,0 +1,50 @@
+//!
+/// Describes how likely a value is to change -- how "durable" it is.
+/// By default, inputs have `Durability::LOW` and interned values have
+/// `Durability::HIGH`. But inputs can be explicitly set with other
+/// durabilities.
+///
+/// We use durabilities to optimize the work of "revalidating" a query
+/// after some input has changed. Ordinarily, in a new revision,
+/// queries have to trace all their inputs back to the base inputs to
+/// determine if any of those inputs have changed. But if we know that
+/// the only changes were to inputs of low durability (the common
+/// case), and we know that the query only used inputs of medium
+/// durability or higher, then we can skip that enumeration.
+///
+/// Typically, one assigns low durabilites to inputs that the user is
+/// frequently editing. Medium or high durabilities are used for
+/// configuration, the source from library crates, or other things
+/// that are unlikely to be edited.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub struct Durability(u8);
+
+impl Durability {
+ /// Low durability: things that change frequently.
+ ///
+ /// Example: part of the crate being edited
+ pub const LOW: Durability = Durability(0);
+
+ /// Medium durability: things that change sometimes, but rarely.
+ ///
+ /// Example: a Cargo.toml file
+ pub const MEDIUM: Durability = Durability(1);
+
+ /// High durability: things that are not expected to change under
+ /// common usage.
+ ///
+ /// Example: the standard library or something from crates.io
+ pub const HIGH: Durability = Durability(2);
+
+ /// The maximum possible durability; equivalent to HIGH but
+ /// "conceptually" distinct (i.e., if we add more durability
+ /// levels, this could change).
+ pub(crate) const MAX: Durability = Self::HIGH;
+
+ /// Number of durability levels.
+ pub(crate) const LEN: usize = 3;
+
+ pub(crate) fn index(self) -> usize {
+ self.0 as usize
+ }
+}
diff --git a/crates/salsa/src/hash.rs b/crates/salsa/src/hash.rs
new file mode 100644
index 0000000000..47a2dd1ce0
--- /dev/null
+++ b/crates/salsa/src/hash.rs
@@ -0,0 +1,4 @@
+//!
+pub(crate) type FxHasher = std::hash::BuildHasherDefault<rustc_hash::FxHasher>;
+pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, FxHasher>;
+pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, FxHasher>;
diff --git a/crates/salsa/src/input.rs b/crates/salsa/src/input.rs
new file mode 100644
index 0000000000..f6188d4a84
--- /dev/null
+++ b/crates/salsa/src/input.rs
@@ -0,0 +1,240 @@
+//!
+use crate::debug::TableEntry;
+use crate::durability::Durability;
+use crate::hash::FxIndexMap;
+use crate::plumbing::CycleRecoveryStrategy;
+use crate::plumbing::InputQueryStorageOps;
+use crate::plumbing::QueryStorageMassOps;
+use crate::plumbing::QueryStorageOps;
+use crate::revision::Revision;
+use crate::runtime::StampedValue;
+use crate::Database;
+use crate::Query;
+use crate::Runtime;
+use crate::{DatabaseKeyIndex, QueryDb};
+use indexmap::map::Entry;
+use parking_lot::RwLock;
+use std::convert::TryFrom;
+use tracing::debug;
+
+/// Input queries store the result plus a list of the other queries
+/// that they invoked. This means we can avoid recomputing them when
+/// none of those inputs have changed.
+pub struct InputStorage<Q>
+where
+ Q: Query,
+{
+ group_index: u16,
+ slots: RwLock<FxIndexMap<Q::Key, Slot<Q>>>,
+}
+
+struct Slot<Q>
+where
+ Q: Query,
+{
+ database_key_index: DatabaseKeyIndex,
+ stamped_value: RwLock<StampedValue<Q::Value>>,
+}
+
+impl<Q> std::panic::RefUnwindSafe for InputStorage<Q>
+where
+ Q: Query,
+ Q::Key: std::panic::RefUnwindSafe,
+ Q::Value: std::panic::RefUnwindSafe,
+{
+}
+
+impl<Q> QueryStorageOps<Q> for InputStorage<Q>
+where
+ Q: Query,
+{
+ const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
+
+ fn new(group_index: u16) -> Self {
+ InputStorage { group_index, slots: Default::default() }
+ }
+
+ fn fmt_index(
+ &self,
+ _db: &<Q as QueryDb<'_>>::DynDb,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ assert_eq!(index.group_index, self.group_index);
+ assert_eq!(index.query_index, Q::QUERY_INDEX);
+ let slot_map = self.slots.read();
+ let key = slot_map.get_index(index.key_index as usize).unwrap().0;
+ write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
+ }
+
+ fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ input: DatabaseKeyIndex,
+ revision: Revision,
+ ) -> bool {
+ assert_eq!(input.group_index, self.group_index);
+ assert_eq!(input.query_index, Q::QUERY_INDEX);
+ debug_assert!(revision < db.salsa_runtime().current_revision());
+ let slots = &self.slots.read();
+ let slot = slots.get_index(input.key_index as usize).unwrap().1;
+ slot.maybe_changed_after(db, revision)
+ }
+
+ fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
+ db.unwind_if_cancelled();
+
+ let slots = &self.slots.read();
+ let slot = slots
+ .get(key)
+ .unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
+
+ let StampedValue { value, durability, changed_at } = slot.stamped_value.read().clone();
+
+ db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
+ slot.database_key_index,
+ durability,
+ changed_at,
+ );
+
+ value
+ }
+
+ fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
+ match self.slots.read().get(key) {
+ Some(slot) => slot.stamped_value.read().durability,
+ None => panic!("no value set for {:?}({:?})", Q::default(), key),
+ }
+ }
+
+ fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
+ where
+ C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
+ {
+ let slots = self.slots.read();
+ slots
+ .iter()
+ .map(|(key, slot)| {
+ TableEntry::new(key.clone(), Some(slot.stamped_value.read().value.clone()))
+ })
+ .collect()
+ }
+}
+
+impl<Q> Slot<Q>
+where
+ Q: Query,
+{
+ fn maybe_changed_after(&self, _db: &<Q as QueryDb<'_>>::DynDb, revision: Revision) -> bool {
+ debug!("maybe_changed_after(slot={:?}, revision={:?})", self, revision,);
+
+ let changed_at = self.stamped_value.read().changed_at;
+
+ debug!("maybe_changed_after: changed_at = {:?}", changed_at);
+
+ changed_at > revision
+ }
+}
+
+impl<Q> QueryStorageMassOps for InputStorage<Q>
+where
+ Q: Query,
+{
+ fn purge(&self) {
+ *self.slots.write() = Default::default();
+ }
+}
+
+impl<Q> InputQueryStorageOps<Q> for InputStorage<Q>
+where
+ Q: Query,
+{
+ fn set(&self, runtime: &mut Runtime, key: &Q::Key, value: Q::Value, durability: Durability) {
+ tracing::debug!("{:?}({:?}) = {:?} ({:?})", Q::default(), key, value, durability);
+
+ // The value is changing, so we need a new revision (*). We also
+ // need to update the 'last changed' revision by invoking
+ // `guard.mark_durability_as_changed`.
+ //
+ // CAREFUL: This will block until the global revision lock can
+ // be acquired. If there are still queries executing, they may
+ // need to read from this input. Therefore, we wait to acquire
+ // the lock on `map` until we also hold the global query write
+ // lock.
+ //
+ // (*) Technically, since you can't presently access an input
+ // for a non-existent key, and you can't enumerate the set of
+ // keys, we only need a new revision if the key used to
+ // exist. But we may add such methods in the future and this
+ // case doesn't generally seem worth optimizing for.
+ runtime.with_incremented_revision(|next_revision| {
+ let mut slots = self.slots.write();
+
+ // Do this *after* we acquire the lock, so that we are not
+ // racing with somebody else to modify this same cell.
+ // (Otherwise, someone else might write a *newer* revision
+ // into the same cell while we block on the lock.)
+ let stamped_value = StampedValue { value, durability, changed_at: next_revision };
+
+ match slots.entry(key.clone()) {
+ Entry::Occupied(entry) => {
+ let mut slot_stamped_value = entry.get().stamped_value.write();
+ let old_durability = slot_stamped_value.durability;
+ *slot_stamped_value = stamped_value;
+ Some(old_durability)
+ }
+
+ Entry::Vacant(entry) => {
+ let key_index = u32::try_from(entry.index()).unwrap();
+ let database_key_index = DatabaseKeyIndex {
+ group_index: self.group_index,
+ query_index: Q::QUERY_INDEX,
+ key_index,
+ };
+ entry.insert(Slot {
+ database_key_index,
+ stamped_value: RwLock::new(stamped_value),
+ });
+ None
+ }
+ }
+ });
+ }
+}
+
+/// Check that `Slot<Q, MP>: Send + Sync` as long as
+/// `DB::DatabaseData: Send + Sync`, which in turn implies that
+/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
+#[allow(dead_code)]
+fn check_send_sync<Q>()
+where
+ Q: Query,
+ Q::Key: Send + Sync,
+ Q::Value: Send + Sync,
+{
+ fn is_send_sync<T: Send + Sync>() {}
+ is_send_sync::<Slot<Q>>();
+}
+
+/// Check that `Slot<Q, MP>: 'static` as long as
+/// `DB::DatabaseData: 'static`, which in turn implies that
+/// `Q::Key: 'static`, `Q::Value: 'static`.
+#[allow(dead_code)]
+fn check_static<Q>()
+where
+ Q: Query + 'static,
+ Q::Key: 'static,
+ Q::Value: 'static,
+{
+ fn is_static<T: 'static>() {}
+ is_static::<Slot<Q>>();
+}
+
+impl<Q> std::fmt::Debug for Slot<Q>
+where
+ Q: Query,
+{
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(fmt, "{:?}", Q::default())
+ }
+}
diff --git a/crates/salsa/src/intern_id.rs b/crates/salsa/src/intern_id.rs
new file mode 100644
index 0000000000..a7bbc088f9
--- /dev/null
+++ b/crates/salsa/src/intern_id.rs
@@ -0,0 +1,131 @@
+//!
+use std::fmt;
+use std::num::NonZeroU32;
+
+/// The "raw-id" is used for interned keys in salsa -- it is basically
+/// a newtype'd u32. Typically, it is wrapped in a type of your own
+/// devising. For more information about interned keys, see [the
+/// interned key RFC][rfc].
+///
+/// # Creating a `InternId`
+//
+/// InternId values can be constructed using the `From` impls,
+/// which are implemented for `u32` and `usize`:
+///
+/// ```
+/// # use salsa::InternId;
+/// let intern_id1 = InternId::from(22_u32);
+/// let intern_id2 = InternId::from(22_usize);
+/// assert_eq!(intern_id1, intern_id2);
+/// ```
+///
+/// # Converting to a u32 or usize
+///
+/// Normally, there should be no need to access the underlying integer
+/// in a `InternId`. But if you do need to do so, you can convert to a
+/// `usize` using the `as_u32` or `as_usize` methods or the `From` impls.
+///
+/// ```
+/// # use salsa::InternId;
+/// let intern_id = InternId::from(22_u32);
+/// let value = u32::from(intern_id);
+/// assert_eq!(value, 22);
+/// ```
+///
+/// ## Illegal values
+///
+/// Be warned, however, that `InternId` values cannot be created from
+/// *arbitrary* values -- in particular large values greater than
+/// `InternId::MAX` will panic. Those large values are reserved so that
+/// the Rust compiler can use them as sentinel values, which means
+/// that (for example) `Option<InternId>` is represented in a single
+/// word.
+///
+/// ```should_panic
+/// # use salsa::InternId;
+/// InternId::from(InternId::MAX);
+/// ```
+///
+/// [rfc]: https://github.com/salsa-rs/salsa-rfcs/pull/2
+#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct InternId {
+ value: NonZeroU32,
+}
+
+impl InternId {
+ /// The maximum allowed `InternId`. This value can grow between
+ /// releases without affecting semver.
+ pub const MAX: u32 = 0xFFFF_FF00;
+
+ /// Creates a new InternId.
+ ///
+ /// # Safety
+ ///
+ /// `value` must be less than `MAX`
+ pub const unsafe fn new_unchecked(value: u32) -> Self {
+ debug_assert!(value < InternId::MAX);
+ InternId { value: NonZeroU32::new_unchecked(value + 1) }
+ }
+
+ /// Convert this raw-id into a u32 value.
+ ///
+ /// ```
+ /// # use salsa::InternId;
+ /// let intern_id = InternId::from(22_u32);
+ /// let value = intern_id.as_usize();
+ /// assert_eq!(value, 22);
+ /// ```
+ pub fn as_u32(self) -> u32 {
+ self.value.get() - 1
+ }
+
+ /// Convert this raw-id into a usize value.
+ ///
+ /// ```
+ /// # use salsa::InternId;
+ /// let intern_id = InternId::from(22_u32);
+ /// let value = intern_id.as_usize();
+ /// assert_eq!(value, 22);
+ /// ```
+ pub fn as_usize(self) -> usize {
+ self.as_u32() as usize
+ }
+}
+
+impl From<InternId> for u32 {
+ fn from(raw: InternId) -> u32 {
+ raw.as_u32()
+ }
+}
+
+impl From<InternId> for usize {
+ fn from(raw: InternId) -> usize {
+ raw.as_usize()
+ }
+}
+
+impl From<u32> for InternId {
+ fn from(id: u32) -> InternId {
+ assert!(id < InternId::MAX);
+ unsafe { InternId::new_unchecked(id) }
+ }
+}
+
+impl From<usize> for InternId {
+ fn from(id: usize) -> InternId {
+ assert!(id < (InternId::MAX as usize));
+ unsafe { InternId::new_unchecked(id as u32) }
+ }
+}
+
+impl fmt::Debug for InternId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.as_usize().fmt(f)
+ }
+}
+
+impl fmt::Display for InternId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.as_usize().fmt(f)
+ }
+}
diff --git a/crates/salsa/src/interned.rs b/crates/salsa/src/interned.rs
new file mode 100644
index 0000000000..22f22e6112
--- /dev/null
+++ b/crates/salsa/src/interned.rs
@@ -0,0 +1,409 @@
+//!
+use crate::debug::TableEntry;
+use crate::durability::Durability;
+use crate::intern_id::InternId;
+use crate::plumbing::CycleRecoveryStrategy;
+use crate::plumbing::HasQueryGroup;
+use crate::plumbing::QueryStorageMassOps;
+use crate::plumbing::QueryStorageOps;
+use crate::revision::Revision;
+use crate::Query;
+use crate::{Database, DatabaseKeyIndex, QueryDb};
+use parking_lot::RwLock;
+use rustc_hash::FxHashMap;
+use std::collections::hash_map::Entry;
+use std::convert::From;
+use std::fmt::Debug;
+use std::hash::Hash;
+use triomphe::Arc;
+
+const INTERN_DURABILITY: Durability = Durability::HIGH;
+
+/// Handles storage where the value is 'derived' by executing a
+/// function (in contrast to "inputs").
+pub struct InternedStorage<Q>
+where
+ Q: Query,
+ Q::Value: InternKey,
+{
+ group_index: u16,
+ tables: RwLock<InternTables<Q::Key>>,
+}
+
+/// Storage for the looking up interned things.
+pub struct LookupInternedStorage<Q, IQ>
+where
+ Q: Query,
+ Q::Key: InternKey,
+ Q::Value: Eq + Hash,
+{
+ phantom: std::marker::PhantomData<(Q::Key, IQ)>,
+}
+
+struct InternTables<K> {
+ /// Map from the key to the corresponding intern-index.
+ map: FxHashMap<K, InternId>,
+
+ /// For each valid intern-index, stores the interned value.
+ values: Vec<Arc<Slot<K>>>,
+}
+
+/// Trait implemented for the "key" that results from a
+/// `#[salsa::intern]` query. This is basically meant to be a
+/// "newtype"'d `u32`.
+pub trait InternKey {
+ /// Create an instance of the intern-key from a `u32` value.
+ fn from_intern_id(v: InternId) -> Self;
+
+ /// Extract the `u32` with which the intern-key was created.
+ fn as_intern_id(&self) -> InternId;
+}
+
+impl InternKey for InternId {
+ fn from_intern_id(v: InternId) -> InternId {
+ v
+ }
+
+ fn as_intern_id(&self) -> InternId {
+ *self
+ }
+}
+
+#[derive(Debug)]
+struct Slot<K> {
+ /// DatabaseKeyIndex for this slot.
+ database_key_index: DatabaseKeyIndex,
+
+ /// Value that was interned.
+ value: K,
+
+ /// When was this intern'd?
+ ///
+ /// (This informs the "changed-at" result)
+ interned_at: Revision,
+}
+
+impl<Q> std::panic::RefUnwindSafe for InternedStorage<Q>
+where
+ Q: Query,
+ Q::Key: std::panic::RefUnwindSafe,
+ Q::Value: InternKey,
+ Q::Value: std::panic::RefUnwindSafe,
+{
+}
+
+impl<K: Debug + Hash + Eq> InternTables<K> {
+ /// Returns the slot for the given key.
+ fn slot_for_key(&self, key: &K) -> Option<(Arc<Slot<K>>, InternId)> {
+ let &index = self.map.get(key)?;
+ Some((self.slot_for_index(index), index))
+ }
+
+ /// Returns the slot at the given index.
+ fn slot_for_index(&self, index: InternId) -> Arc<Slot<K>> {
+ let slot = &self.values[index.as_usize()];
+ slot.clone()
+ }
+}
+
+impl<K> Default for InternTables<K>
+where
+ K: Eq + Hash,
+{
+ fn default() -> Self {
+ Self { map: Default::default(), values: Default::default() }
+ }
+}
+
+impl<Q> InternedStorage<Q>
+where
+ Q: Query,
+ Q::Key: Eq + Hash + Clone,
+ Q::Value: InternKey,
+{
+ /// If `key` has already been interned, returns its slot. Otherwise, creates a new slot.
+ fn intern_index(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ key: &Q::Key,
+ ) -> (Arc<Slot<Q::Key>>, InternId) {
+ if let Some(i) = self.intern_check(key) {
+ return i;
+ }
+
+ let owned_key1 = key.to_owned();
+ let owned_key2 = owned_key1.clone();
+ let revision_now = db.salsa_runtime().current_revision();
+
+ let mut tables = self.tables.write();
+ let tables = &mut *tables;
+ let entry = match tables.map.entry(owned_key1) {
+ Entry::Vacant(entry) => entry,
+ Entry::Occupied(entry) => {
+ // Somebody inserted this key while we were waiting
+ // for the write lock. In this case, we don't need to
+ // update the `accessed_at` field because they should
+ // have already done so!
+ let index = *entry.get();
+ let slot = &tables.values[index.as_usize()];
+ debug_assert_eq!(owned_key2, slot.value);
+ return (slot.clone(), index);
+ }
+ };
+
+ let create_slot = |index: InternId| {
+ let database_key_index = DatabaseKeyIndex {
+ group_index: self.group_index,
+ query_index: Q::QUERY_INDEX,
+ key_index: index.as_u32(),
+ };
+ Arc::new(Slot { database_key_index, value: owned_key2, interned_at: revision_now })
+ };
+
+ let (slot, index);
+ index = InternId::from(tables.values.len());
+ slot = create_slot(index);
+ tables.values.push(slot.clone());
+ entry.insert(index);
+
+ (slot, index)
+ }
+
+ fn intern_check(&self, key: &Q::Key) -> Option<(Arc<Slot<Q::Key>>, InternId)> {
+ self.tables.read().slot_for_key(key)
+ }
+
+ /// Given an index, lookup and clone its value, updating the
+ /// `accessed_at` time if necessary.
+ fn lookup_value(&self, index: InternId) -> Arc<Slot<Q::Key>> {
+ self.tables.read().slot_for_index(index)
+ }
+}
+
+impl<Q> QueryStorageOps<Q> for InternedStorage<Q>
+where
+ Q: Query,
+ Q::Value: InternKey,
+{
+ const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
+
+ fn new(group_index: u16) -> Self {
+ InternedStorage { group_index, tables: RwLock::new(InternTables::default()) }
+ }
+
+ fn fmt_index(
+ &self,
+ _db: &<Q as QueryDb<'_>>::DynDb,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ assert_eq!(index.group_index, self.group_index);
+ assert_eq!(index.query_index, Q::QUERY_INDEX);
+ let intern_id = InternId::from(index.key_index);
+ let slot = self.lookup_value(intern_id);
+ write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value)
+ }
+
+ fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ input: DatabaseKeyIndex,
+ revision: Revision,
+ ) -> bool {
+ assert_eq!(input.group_index, self.group_index);
+ assert_eq!(input.query_index, Q::QUERY_INDEX);
+ debug_assert!(revision < db.salsa_runtime().current_revision());
+ let intern_id = InternId::from(input.key_index);
+ let slot = self.lookup_value(intern_id);
+ slot.maybe_changed_after(revision)
+ }
+
+ fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
+ db.unwind_if_cancelled();
+ let (slot, index) = self.intern_index(db, key);
+ let changed_at = slot.interned_at;
+ db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
+ slot.database_key_index,
+ INTERN_DURABILITY,
+ changed_at,
+ );
+ <Q::Value>::from_intern_id(index)
+ }
+
+ fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, _key: &Q::Key) -> Durability {
+ INTERN_DURABILITY
+ }
+
+ fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
+ where
+ C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
+ {
+ let tables = self.tables.read();
+ tables
+ .map
+ .iter()
+ .map(|(key, index)| {
+ TableEntry::new(key.clone(), Some(<Q::Value>::from_intern_id(*index)))
+ })
+ .collect()
+ }
+}
+
+impl<Q> QueryStorageMassOps for InternedStorage<Q>
+where
+ Q: Query,
+ Q::Value: InternKey,
+{
+ fn purge(&self) {
+ *self.tables.write() = Default::default();
+ }
+}
+
+// Workaround for
+// ```
+// IQ: for<'d> QueryDb<
+// 'd,
+// DynDb = <Q as QueryDb<'d>>::DynDb,
+// Group = <Q as QueryDb<'d>>::Group,
+// GroupStorage = <Q as QueryDb<'d>>::GroupStorage,
+// >,
+// ```
+// not working to make rustc know DynDb, Group and GroupStorage being the same in `Q` and `IQ`
+#[doc(hidden)]
+pub trait EqualDynDb<'d, IQ>: QueryDb<'d>
+where
+ IQ: QueryDb<'d>,
+{
+ fn convert_db(d: &Self::DynDb) -> &IQ::DynDb;
+ fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage;
+}
+
+impl<'d, IQ, Q> EqualDynDb<'d, IQ> for Q
+where
+ Q: QueryDb<'d, DynDb = IQ::DynDb, Group = IQ::Group, GroupStorage = IQ::GroupStorage>,
+ Q::DynDb: HasQueryGroup<Q::Group>,
+ IQ: QueryDb<'d>,
+{
+ fn convert_db(d: &Self::DynDb) -> &IQ::DynDb {
+ d
+ }
+ fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage {
+ d
+ }
+}
+
+impl<Q, IQ> QueryStorageOps<Q> for LookupInternedStorage<Q, IQ>
+where
+ Q: Query,
+ Q::Key: InternKey,
+ Q::Value: Eq + Hash,
+ IQ: Query<Key = Q::Value, Value = Q::Key, Storage = InternedStorage<IQ>>,
+ for<'d> Q: EqualDynDb<'d, IQ>,
+{
+ const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
+
+ fn new(_group_index: u16) -> Self {
+ LookupInternedStorage { phantom: std::marker::PhantomData }
+ }
+
+ fn fmt_index(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ let group_storage =
+ <<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
+ let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
+ interned_storage.fmt_index(Q::convert_db(db), index, fmt)
+ }
+
+ fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ input: DatabaseKeyIndex,
+ revision: Revision,
+ ) -> bool {
+ let group_storage =
+ <<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
+ let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
+ interned_storage.maybe_changed_after(Q::convert_db(db), input, revision)
+ }
+
+ fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
+ let index = key.as_intern_id();
+ let group_storage =
+ <<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
+ let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
+ let slot = interned_storage.lookup_value(index);
+ let value = slot.value.clone();
+ let interned_at = slot.interned_at;
+ db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
+ slot.database_key_index,
+ INTERN_DURABILITY,
+ interned_at,
+ );
+ value
+ }
+
+ fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, _key: &Q::Key) -> Durability {
+ INTERN_DURABILITY
+ }
+
+ fn entries<C>(&self, db: &<Q as QueryDb<'_>>::DynDb) -> C
+ where
+ C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
+ {
+ let group_storage =
+ <<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
+ let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
+ let tables = interned_storage.tables.read();
+ tables
+ .map
+ .iter()
+ .map(|(key, index)| {
+ TableEntry::new(<Q::Key>::from_intern_id(*index), Some(key.clone()))
+ })
+ .collect()
+ }
+}
+
+impl<Q, IQ> QueryStorageMassOps for LookupInternedStorage<Q, IQ>
+where
+ Q: Query,
+ Q::Key: InternKey,
+ Q::Value: Eq + Hash,
+ IQ: Query<Key = Q::Value, Value = Q::Key>,
+{
+ fn purge(&self) {}
+}
+
+impl<K> Slot<K> {
+ fn maybe_changed_after(&self, revision: Revision) -> bool {
+ self.interned_at > revision
+ }
+}
+
+/// Check that `Slot<Q, MP>: Send + Sync` as long as
+/// `DB::DatabaseData: Send + Sync`, which in turn implies that
+/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
+#[allow(dead_code)]
+fn check_send_sync<K>()
+where
+ K: Send + Sync,
+{
+ fn is_send_sync<T: Send + Sync>() {}
+ is_send_sync::<Slot<K>>();
+}
+
+/// Check that `Slot<Q, MP>: 'static` as long as
+/// `DB::DatabaseData: 'static`, which in turn implies that
+/// `Q::Key: 'static`, `Q::Value: 'static`.
+#[allow(dead_code)]
+fn check_static<K>()
+where
+ K: 'static,
+{
+ fn is_static<T: 'static>() {}
+ is_static::<Slot<K>>();
+}
diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs
new file mode 100644
index 0000000000..575408f362
--- /dev/null
+++ b/crates/salsa/src/lib.rs
@@ -0,0 +1,742 @@
+//!
+#![allow(clippy::type_complexity)]
+#![allow(clippy::question_mark)]
+#![warn(rust_2018_idioms)]
+#![warn(missing_docs)]
+
+//! The salsa crate is a crate for incremental recomputation. It
+//! permits you to define a "database" of queries with both inputs and
+//! values derived from those inputs; as you set the inputs, you can
+//! re-execute the derived queries and it will try to re-use results
+//! from previous invocations as appropriate.
+
+mod derived;
+mod doctest;
+mod durability;
+mod hash;
+mod input;
+mod intern_id;
+mod interned;
+mod lru;
+mod revision;
+mod runtime;
+mod storage;
+
+pub mod debug;
+/// Items in this module are public for implementation reasons,
+/// and are exempt from the SemVer guarantees.
+#[doc(hidden)]
+pub mod plumbing;
+
+use crate::plumbing::CycleRecoveryStrategy;
+use crate::plumbing::DerivedQueryStorageOps;
+use crate::plumbing::InputQueryStorageOps;
+use crate::plumbing::LruQueryStorageOps;
+use crate::plumbing::QueryStorageMassOps;
+use crate::plumbing::QueryStorageOps;
+pub use crate::revision::Revision;
+use std::fmt::{self, Debug};
+use std::hash::Hash;
+use std::panic::AssertUnwindSafe;
+use std::panic::{self, UnwindSafe};
+
+pub use crate::durability::Durability;
+pub use crate::intern_id::InternId;
+pub use crate::interned::InternKey;
+pub use crate::runtime::Runtime;
+pub use crate::runtime::RuntimeId;
+pub use crate::storage::Storage;
+
+/// The base trait which your "query context" must implement. Gives
+/// access to the salsa runtime, which you must embed into your query
+/// context (along with whatever other state you may require).
+pub trait Database: plumbing::DatabaseOps {
+ /// This function is invoked at key points in the salsa
+ /// runtime. It permits the database to be customized and to
+ /// inject logging or other custom behavior.
+ fn salsa_event(&self, event_fn: Event) {
+ #![allow(unused_variables)]
+ }
+
+ /// Starts unwinding the stack if the current revision is cancelled.
+ ///
+ /// This method can be called by query implementations that perform
+ /// potentially expensive computations, in order to speed up propagation of
+ /// cancellation.
+ ///
+ /// Cancellation will automatically be triggered by salsa on any query
+ /// invocation.
+ ///
+ /// This method should not be overridden by `Database` implementors. A
+ /// `salsa_event` is emitted when this method is called, so that should be
+ /// used instead.
+ #[inline]
+ fn unwind_if_cancelled(&self) {
+ let runtime = self.salsa_runtime();
+ self.salsa_event(Event {
+ runtime_id: runtime.id(),
+ kind: EventKind::WillCheckCancellation,
+ });
+
+ let current_revision = runtime.current_revision();
+ let pending_revision = runtime.pending_revision();
+ tracing::debug!(
+ "unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
+ current_revision,
+ pending_revision
+ );
+ if pending_revision > current_revision {
+ runtime.unwind_cancelled();
+ }
+ }
+
+ /// Gives access to the underlying salsa runtime.
+ ///
+ /// This method should not be overridden by `Database` implementors.
+ fn salsa_runtime(&self) -> &Runtime {
+ self.ops_salsa_runtime()
+ }
+
+ /// Gives access to the underlying salsa runtime.
+ ///
+ /// This method should not be overridden by `Database` implementors.
+ fn salsa_runtime_mut(&mut self) -> &mut Runtime {
+ self.ops_salsa_runtime_mut()
+ }
+}
+
+/// The `Event` struct identifies various notable things that can
+/// occur during salsa execution. Instances of this struct are given
+/// to `salsa_event`.
+pub struct Event {
+ /// The id of the snapshot that triggered the event. Usually
+ /// 1-to-1 with a thread, as well.
+ pub runtime_id: RuntimeId,
+
+ /// What sort of event was it.
+ pub kind: EventKind,
+}
+
+impl Event {
+ /// Returns a type that gives a user-readable debug output.
+ /// Use like `println!("{:?}", index.debug(db))`.
+ pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me
+ where
+ D: plumbing::DatabaseOps,
+ {
+ EventDebug { event: self, db }
+ }
+}
+
+impl fmt::Debug for Event {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Event")
+ .field("runtime_id", &self.runtime_id)
+ .field("kind", &self.kind)
+ .finish()
+ }
+}
+
+struct EventDebug<'me, D: ?Sized>
+where
+ D: plumbing::DatabaseOps,
+{
+ event: &'me Event,
+ db: &'me D,
+}
+
+impl<'me, D: ?Sized> fmt::Debug for EventDebug<'me, D>
+where
+ D: plumbing::DatabaseOps,
+{
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("Event")
+ .field("runtime_id", &self.event.runtime_id)
+ .field("kind", &self.event.kind.debug(self.db))
+ .finish()
+ }
+}
+
+/// An enum identifying the various kinds of events that can occur.
+pub enum EventKind {
+ /// Occurs when we found that all inputs to a memoized value are
+ /// up-to-date and hence the value can be re-used without
+ /// executing the closure.
+ ///
+ /// Executes before the "re-used" value is returned.
+ DidValidateMemoizedValue {
+ /// The database-key for the affected value. Implements `Debug`.
+ database_key: DatabaseKeyIndex,
+ },
+
+ /// Indicates that another thread (with id `other_runtime_id`) is processing the
+ /// given query (`database_key`), so we will block until they
+ /// finish.
+ ///
+ /// Executes after we have registered with the other thread but
+ /// before they have answered us.
+ ///
+ /// (NB: you can find the `id` of the current thread via the
+ /// `salsa_runtime`)
+ WillBlockOn {
+ /// The id of the runtime we will block on.
+ other_runtime_id: RuntimeId,
+
+ /// The database-key for the affected value. Implements `Debug`.
+ database_key: DatabaseKeyIndex,
+ },
+
+ /// Indicates that the function for this query will be executed.
+ /// This is either because it has never executed before or because
+ /// its inputs may be out of date.
+ WillExecute {
+ /// The database-key for the affected value. Implements `Debug`.
+ database_key: DatabaseKeyIndex,
+ },
+
+ /// Indicates that `unwind_if_cancelled` was called and salsa will check if
+ /// the current revision has been cancelled.
+ WillCheckCancellation,
+}
+
+impl EventKind {
+ /// Returns a type that gives a user-readable debug output.
+ /// Use like `println!("{:?}", index.debug(db))`.
+ pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me
+ where
+ D: plumbing::DatabaseOps,
+ {
+ EventKindDebug { kind: self, db }
+ }
+}
+
+impl fmt::Debug for EventKind {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ EventKind::DidValidateMemoizedValue { database_key } => fmt
+ .debug_struct("DidValidateMemoizedValue")
+ .field("database_key", database_key)
+ .finish(),
+ EventKind::WillBlockOn { other_runtime_id, database_key } => fmt
+ .debug_struct("WillBlockOn")
+ .field("other_runtime_id", other_runtime_id)
+ .field("database_key", database_key)
+ .finish(),
+ EventKind::WillExecute { database_key } => {
+ fmt.debug_struct("WillExecute").field("database_key", database_key).finish()
+ }
+ EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
+ }
+ }
+}
+
+struct EventKindDebug<'me, D: ?Sized>
+where
+ D: plumbing::DatabaseOps,
+{
+ kind: &'me EventKind,
+ db: &'me D,
+}
+
+impl<'me, D: ?Sized> fmt::Debug for EventKindDebug<'me, D>
+where
+ D: plumbing::DatabaseOps,
+{
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self.kind {
+ EventKind::DidValidateMemoizedValue { database_key } => fmt
+ .debug_struct("DidValidateMemoizedValue")
+ .field("database_key", &database_key.debug(self.db))
+ .finish(),
+ EventKind::WillBlockOn { other_runtime_id, database_key } => fmt
+ .debug_struct("WillBlockOn")
+ .field("other_runtime_id", &other_runtime_id)
+ .field("database_key", &database_key.debug(self.db))
+ .finish(),
+ EventKind::WillExecute { database_key } => fmt
+ .debug_struct("WillExecute")
+ .field("database_key", &database_key.debug(self.db))
+ .finish(),
+ EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
+ }
+ }
+}
+
+/// Indicates a database that also supports parallel query
+/// evaluation. All of Salsa's base query support is capable of
+/// parallel execution, but for it to work, your query key/value types
+/// must also be `Send`, as must any additional data in your database.
+pub trait ParallelDatabase: Database + Send {
+ /// Creates a second handle to the database that holds the
+ /// database fixed at a particular revision. So long as this
+ /// "frozen" handle exists, any attempt to [`set`] an input will
+ /// block.
+ ///
+ /// [`set`]: struct.QueryTable.html#method.set
+ ///
+ /// This is the method you are meant to use most of the time in a
+ /// parallel setting where modifications may arise asynchronously
+ /// (e.g., a language server). In this context, it is common to
+ /// wish to "fork off" a snapshot of the database performing some
+ /// series of queries in parallel and arranging the results. Using
+ /// this method for that purpose ensures that those queries will
+ /// see a consistent view of the database (it is also advisable
+ /// for those queries to use the [`Runtime::unwind_if_cancelled`]
+ /// method to check for cancellation).
+ ///
+ /// # Panics
+ ///
+ /// It is not permitted to create a snapshot from inside of a
+ /// query. Attepting to do so will panic.
+ ///
+ /// # Deadlock warning
+ ///
+ /// The intended pattern for snapshots is that, once created, they
+ /// are sent to another thread and used from there. As such, the
+ /// `snapshot` acquires a "read lock" on the database --
+ /// therefore, so long as the `snapshot` is not dropped, any
+ /// attempt to `set` a value in the database will block. If the
+ /// `snapshot` is owned by the same thread that is attempting to
+ /// `set`, this will cause a problem.
+ ///
+ /// # How to implement this
+ ///
+ /// Typically, this method will create a second copy of your
+ /// database type (`MyDatabaseType`, in the example below),
+ /// cloning over each of the fields from `self` into this new
+ /// copy. For the field that stores the salsa runtime, you should
+ /// use [the `Runtime::snapshot` method][rfm] to create a snapshot of the
+ /// runtime. Finally, package up the result using `Snapshot::new`,
+ /// which is a simple wrapper type that only gives `&self` access
+ /// to the database within (thus preventing the use of methods
+ /// that may mutate the inputs):
+ ///
+ /// [rfm]: struct.Runtime.html#method.snapshot
+ ///
+ /// ```rust,ignore
+ /// impl ParallelDatabase for MyDatabaseType {
+ /// fn snapshot(&self) -> Snapshot<Self> {
+ /// Snapshot::new(
+ /// MyDatabaseType {
+ /// runtime: self.runtime.snapshot(self),
+ /// other_field: self.other_field.clone(),
+ /// }
+ /// )
+ /// }
+ /// }
+ /// ```
+ fn snapshot(&self) -> Snapshot<Self>;
+}
+
+/// Simple wrapper struct that takes ownership of a database `DB` and
+/// only gives `&self` access to it. See [the `snapshot` method][fm]
+/// for more details.
+///
+/// [fm]: trait.ParallelDatabase.html#method.snapshot
+#[derive(Debug)]
+pub struct Snapshot<DB: ?Sized>
+where
+ DB: ParallelDatabase,
+{
+ db: DB,
+}
+
+impl<DB> Snapshot<DB>
+where
+ DB: ParallelDatabase,
+{
+ /// Creates a `Snapshot` that wraps the given database handle
+ /// `db`. From this point forward, only shared references to `db`
+ /// will be possible.
+ pub fn new(db: DB) -> Self {
+ Snapshot { db }
+ }
+}
+
+impl<DB> std::ops::Deref for Snapshot<DB>
+where
+ DB: ParallelDatabase,
+{
+ type Target = DB;
+
+ fn deref(&self) -> &DB {
+ &self.db
+ }
+}
+
+/// An integer that uniquely identifies a particular query instance within the
+/// database. Used to track dependencies between queries. Fully ordered and
+/// equatable but those orderings are arbitrary, and meant to be used only for
+/// inserting into maps and the like.
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
+pub struct DatabaseKeyIndex {
+ group_index: u16,
+ query_index: u16,
+ key_index: u32,
+}
+
+impl DatabaseKeyIndex {
+ /// Returns the index of the query group containing this key.
+ #[inline]
+ pub fn group_index(self) -> u16 {
+ self.group_index
+ }
+
+ /// Returns the index of the query within its query group.
+ #[inline]
+ pub fn query_index(self) -> u16 {
+ self.query_index
+ }
+
+ /// Returns the index of this particular query key within the query.
+ #[inline]
+ pub fn key_index(self) -> u32 {
+ self.key_index
+ }
+
+ /// Returns a type that gives a user-readable debug output.
+ /// Use like `println!("{:?}", index.debug(db))`.
+ pub fn debug<D: ?Sized>(self, db: &D) -> impl std::fmt::Debug + '_
+ where
+ D: plumbing::DatabaseOps,
+ {
+ DatabaseKeyIndexDebug { index: self, db }
+ }
+}
+
+/// Helper type for `DatabaseKeyIndex::debug`
+struct DatabaseKeyIndexDebug<'me, D: ?Sized>
+where
+ D: plumbing::DatabaseOps,
+{
+ index: DatabaseKeyIndex,
+ db: &'me D,
+}
+
+impl<D: ?Sized> std::fmt::Debug for DatabaseKeyIndexDebug<'_, D>
+where
+ D: plumbing::DatabaseOps,
+{
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.db.fmt_index(self.index, fmt)
+ }
+}
+
+/// Trait implements by all of the "special types" associated with
+/// each of your queries.
+///
+/// Base trait of `Query` that has a lifetime parameter to allow the `DynDb` to be non-'static.
+pub trait QueryDb<'d>: Sized {
+ /// Dyn version of the associated trait for this query group.
+ type DynDb: ?Sized + Database + HasQueryGroup<Self::Group> + 'd;
+
+ /// Associate query group struct.
+ type Group: plumbing::QueryGroup<GroupStorage = Self::GroupStorage>;
+
+ /// Generated struct that contains storage for all queries in a group.
+ type GroupStorage;
+}
+
+/// Trait implements by all of the "special types" associated with
+/// each of your queries.
+pub trait Query: Debug + Default + Sized + for<'d> QueryDb<'d> {
+ /// Type that you you give as a parameter -- for queries with zero
+ /// or more than one input, this will be a tuple.
+ type Key: Clone + Debug + Hash + Eq;
+
+ /// What value does the query return?
+ type Value: Clone + Debug;
+
+ /// Internal struct storing the values for the query.
+ // type Storage: plumbing::QueryStorageOps<Self>;
+ type Storage;
+
+ /// A unique index identifying this query within the group.
+ const QUERY_INDEX: u16;
+
+ /// Name of the query method (e.g., `foo`)
+ const QUERY_NAME: &'static str;
+
+ /// Extact storage for this query from the storage for its group.
+ fn query_storage<'a>(
+ group_storage: &'a <Self as QueryDb<'_>>::GroupStorage,
+ ) -> &'a std::sync::Arc<Self::Storage>;
+
+ /// Extact storage for this query from the storage for its group.
+ fn query_storage_mut<'a>(
+ group_storage: &'a <Self as QueryDb<'_>>::GroupStorage,
+ ) -> &'a std::sync::Arc<Self::Storage>;
+}
+
+/// Return value from [the `query` method] on `Database`.
+/// Gives access to various less common operations on queries.
+///
+/// [the `query` method]: trait.Database.html#method.query
+pub struct QueryTable<'me, Q>
+where
+ Q: Query,
+{
+ db: &'me <Q as QueryDb<'me>>::DynDb,
+ storage: &'me Q::Storage,
+}
+
+impl<'me, Q> QueryTable<'me, Q>
+where
+ Q: Query,
+ Q::Storage: QueryStorageOps<Q>,
+{
+ /// Constructs a new `QueryTable`.
+ pub fn new(db: &'me <Q as QueryDb<'me>>::DynDb, storage: &'me Q::Storage) -> Self {
+ Self { db, storage }
+ }
+
+ /// Execute the query on a given input. Usually it's easier to
+ /// invoke the trait method directly. Note that for variadic
+ /// queries (those with no inputs, or those with more than one
+ /// input) the key will be a tuple.
+ pub fn get(&self, key: Q::Key) -> Q::Value {
+ self.storage.fetch(self.db, &key)
+ }
+
+ /// Completely clears the storage for this query.
+ ///
+ /// This method breaks internal invariants of salsa, so any further queries
+ /// might return nonsense results. It is useful only in very specific
+ /// circumstances -- for example, when one wants to observe which values
+ /// dropped together with the table
+ pub fn purge(&self)
+ where
+ Q::Storage: plumbing::QueryStorageMassOps,
+ {
+ self.storage.purge();
+ }
+}
+
+/// Return value from [the `query_mut` method] on `Database`.
+/// Gives access to the `set` method, notably, that is used to
+/// set the value of an input query.
+///
+/// [the `query_mut` method]: trait.Database.html#method.query_mut
+pub struct QueryTableMut<'me, Q>
+where
+ Q: Query + 'me,
+{
+ runtime: &'me mut Runtime,
+ storage: &'me Q::Storage,
+}
+
+impl<'me, Q> QueryTableMut<'me, Q>
+where
+ Q: Query,
+{
+ /// Constructs a new `QueryTableMut`.
+ pub fn new(runtime: &'me mut Runtime, storage: &'me Q::Storage) -> Self {
+ Self { runtime, storage }
+ }
+
+ /// Assign a value to an "input query". Must be used outside of
+ /// an active query computation.
+ ///
+ /// If you are using `snapshot`, see the notes on blocking
+ /// and cancellation on [the `query_mut` method].
+ ///
+ /// [the `query_mut` method]: trait.Database.html#method.query_mut
+ pub fn set(&mut self, key: Q::Key, value: Q::Value)
+ where
+ Q::Storage: plumbing::InputQueryStorageOps<Q>,
+ {
+ self.set_with_durability(key, value, Durability::LOW);
+ }
+
+ /// Assign a value to an "input query", with the additional
+ /// promise that this value will **never change**. Must be used
+ /// outside of an active query computation.
+ ///
+ /// If you are using `snapshot`, see the notes on blocking
+ /// and cancellation on [the `query_mut` method].
+ ///
+ /// [the `query_mut` method]: trait.Database.html#method.query_mut
+ pub fn set_with_durability(&mut self, key: Q::Key, value: Q::Value, durability: Durability)
+ where
+ Q::Storage: plumbing::InputQueryStorageOps<Q>,
+ {
+ self.storage.set(self.runtime, &key, value, durability);
+ }
+
+ /// Sets the size of LRU cache of values for this query table.
+ ///
+ /// That is, at most `cap` values will be preset in the table at the same
+ /// time. This helps with keeping maximum memory usage under control, at the
+ /// cost of potential extra recalculations of evicted values.
+ ///
+ /// If `cap` is zero, all values are preserved, this is the default.
+ pub fn set_lru_capacity(&self, cap: usize)
+ where
+ Q::Storage: plumbing::LruQueryStorageOps,
+ {
+ self.storage.set_lru_capacity(cap);
+ }
+
+ /// Marks the computed value as outdated.
+ ///
+ /// This causes salsa to re-execute the query function on the next access to
+ /// the query, even if all dependencies are up to date.
+ ///
+ /// This is most commonly used as part of the [on-demand input
+ /// pattern](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
+ pub fn invalidate(&mut self, key: &Q::Key)
+ where
+ Q::Storage: plumbing::DerivedQueryStorageOps<Q>,
+ {
+ self.storage.invalidate(self.runtime, key)
+ }
+}
+
+/// A panic payload indicating that execution of a salsa query was cancelled.
+///
+/// This can occur for a few reasons:
+/// *
+/// *
+/// *
+#[derive(Debug)]
+#[non_exhaustive]
+pub enum Cancelled {
+ /// The query was operating on revision R, but there is a pending write to move to revision R+1.
+ #[non_exhaustive]
+ PendingWrite,
+
+ /// The query was blocked on another thread, and that thread panicked.
+ #[non_exhaustive]
+ PropagatedPanic,
+}
+
+impl Cancelled {
+ fn throw(self) -> ! {
+ // We use resume and not panic here to avoid running the panic
+ // hook (that is, to avoid collecting and printing backtrace).
+ std::panic::resume_unwind(Box::new(self));
+ }
+
+ /// Runs `f`, and catches any salsa cancellation.
+ pub fn catch<F, T>(f: F) -> Result<T, Cancelled>
+ where
+ F: FnOnce() -> T + UnwindSafe,
+ {
+ match panic::catch_unwind(f) {
+ Ok(t) => Ok(t),
+ Err(payload) => match payload.downcast() {
+ Ok(cancelled) => Err(*cancelled),
+ Err(payload) => panic::resume_unwind(payload),
+ },
+ }
+ }
+}
+
+impl std::fmt::Display for Cancelled {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let why = match self {
+ Cancelled::PendingWrite => "pending write",
+ Cancelled::PropagatedPanic => "propagated panic",
+ };
+ f.write_str("cancelled because of ")?;
+ f.write_str(why)
+ }
+}
+
+impl std::error::Error for Cancelled {}
+
+/// Captures the participants of a cycle that occurred when executing a query.
+///
+/// This type is meant to be used to help give meaningful error messages to the
+/// user or to help salsa developers figure out why their program is resulting
+/// in a computation cycle.
+///
+/// It is used in a few ways:
+///
+/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html),
+/// where it is given to the fallback function.
+/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants
+/// lacks cycle recovery information) occurs.
+///
+/// You can read more about cycle handling in
+/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html).
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct Cycle {
+ participants: plumbing::CycleParticipants,
+}
+
+impl Cycle {
+ pub(crate) fn new(participants: plumbing::CycleParticipants) -> Self {
+ Self { participants }
+ }
+
+ /// True if two `Cycle` values represent the same cycle.
+ pub(crate) fn is(&self, cycle: &Cycle) -> bool {
+ triomphe::Arc::ptr_eq(&self.participants, &cycle.participants)
+ }
+
+ pub(crate) fn throw(self) -> ! {
+ tracing::debug!("throwing cycle {:?}", self);
+ std::panic::resume_unwind(Box::new(self))
+ }
+
+ pub(crate) fn catch<T>(execute: impl FnOnce() -> T) -> Result<T, Cycle> {
+ match std::panic::catch_unwind(AssertUnwindSafe(execute)) {
+ Ok(v) => Ok(v),
+ Err(err) => match err.downcast::<Cycle>() {
+ Ok(cycle) => Err(*cycle),
+ Err(other) => std::panic::resume_unwind(other),
+ },
+ }
+ }
+
+ /// Iterate over the [`DatabaseKeyIndex`] for each query participating
+ /// in the cycle. The start point of this iteration within the cycle
+ /// is arbitrary but deterministic, but the ordering is otherwise determined
+ /// by the execution.
+ pub fn participant_keys(&self) -> impl Iterator<Item = DatabaseKeyIndex> + '_ {
+ self.participants.iter().copied()
+ }
+
+ /// Returns a vector with the debug information for
+ /// all the participants in the cycle.
+ pub fn all_participants<DB: ?Sized + Database>(&self, db: &DB) -> Vec<String> {
+ self.participant_keys().map(|d| format!("{:?}", d.debug(db))).collect()
+ }
+
+ /// Returns a vector with the debug information for
+ /// those participants in the cycle that lacked recovery
+ /// information.
+ pub fn unexpected_participants<DB: ?Sized + Database>(&self, db: &DB) -> Vec<String> {
+ self.participant_keys()
+ .filter(|&d| db.cycle_recovery_strategy(d) == CycleRecoveryStrategy::Panic)
+ .map(|d| format!("{:?}", d.debug(db)))
+ .collect()
+ }
+
+ /// Returns a "debug" view onto this strict that can be used to print out information.
+ pub fn debug<'me, DB: ?Sized + Database>(&'me self, db: &'me DB) -> impl std::fmt::Debug + 'me {
+ struct UnexpectedCycleDebug<'me> {
+ c: &'me Cycle,
+ db: &'me dyn Database,
+ }
+
+ impl<'me> std::fmt::Debug for UnexpectedCycleDebug<'me> {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ fmt.debug_struct("UnexpectedCycle")
+ .field("all_participants", &self.c.all_participants(self.db))
+ .field("unexpected_participants", &self.c.unexpected_participants(self.db))
+ .finish()
+ }
+ }
+
+ UnexpectedCycleDebug { c: self, db: db.ops_database() }
+ }
+}
+
+// Re-export the procedural macros.
+#[allow(unused_imports)]
+#[macro_use]
+extern crate salsa_macros;
+use plumbing::HasQueryGroup;
+pub use salsa_macros::*;
diff --git a/crates/salsa/src/lru.rs b/crates/salsa/src/lru.rs
new file mode 100644
index 0000000000..c6b9778f20
--- /dev/null
+++ b/crates/salsa/src/lru.rs
@@ -0,0 +1,325 @@
+//!
+use oorandom::Rand64;
+use parking_lot::Mutex;
+use std::fmt::Debug;
+use std::sync::atomic::AtomicUsize;
+use std::sync::atomic::Ordering;
+use triomphe::Arc;
+
+/// A simple and approximate concurrent lru list.
+///
+/// We assume but do not verify that each node is only used with one
+/// list. If this is not the case, it is not *unsafe*, but panics and
+/// weird results will ensue.
+///
+/// Each "node" in the list is of type `Node` and must implement
+/// `LruNode`, which is a trait that gives access to a field that
+/// stores the index in the list. This index gives us a rough idea of
+/// how recently the node has been used.
+#[derive(Debug)]
+pub(crate) struct Lru<Node>
+where
+ Node: LruNode,
+{
+ green_zone: AtomicUsize,
+ data: Mutex<LruData<Node>>,
+}
+
+#[derive(Debug)]
+struct LruData<Node> {
+ end_red_zone: usize,
+ end_yellow_zone: usize,
+ end_green_zone: usize,
+ rng: Rand64,
+ entries: Vec<Arc<Node>>,
+}
+
+pub(crate) trait LruNode: Sized + Debug {
+ fn lru_index(&self) -> &LruIndex;
+}
+
+#[derive(Debug)]
+pub(crate) struct LruIndex {
+ /// Index in the approprate LRU list, or std::usize::MAX if not a
+ /// member.
+ index: AtomicUsize,
+}
+
+impl<Node> Default for Lru<Node>
+where
+ Node: LruNode,
+{
+ fn default() -> Self {
+ Lru::new()
+ }
+}
+
+// We always use a fixed seed for our randomness so that we have
+// predictable results.
+const LRU_SEED: &str = "Hello, Rustaceans";
+
+impl<Node> Lru<Node>
+where
+ Node: LruNode,
+{
+ /// Creates a new LRU list where LRU caching is disabled.
+ pub(crate) fn new() -> Self {
+ Self::with_seed(LRU_SEED)
+ }
+
+ #[cfg_attr(not(test), allow(dead_code))]
+ fn with_seed(seed: &str) -> Self {
+ Lru { green_zone: AtomicUsize::new(0), data: Mutex::new(LruData::with_seed(seed)) }
+ }
+
+ /// Adjust the total number of nodes permitted to have a value at
+ /// once. If `len` is zero, this disables LRU caching completely.
+ pub(crate) fn set_lru_capacity(&self, len: usize) {
+ let mut data = self.data.lock();
+
+ // We require each zone to have at least 1 slot. Therefore,
+ // the length cannot be just 1 or 2.
+ if len == 0 {
+ self.green_zone.store(0, Ordering::Release);
+ data.resize(0, 0, 0);
+ } else {
+ let len = std::cmp::max(len, 3);
+
+ // Top 10% is the green zone. This must be at least length 1.
+ let green_zone = std::cmp::max(len / 10, 1);
+
+ // Next 20% is the yellow zone.
+ let yellow_zone = std::cmp::max(len / 5, 1);
+
+ // Remaining 70% is the red zone.
+ let red_zone = len - yellow_zone - green_zone;
+
+ // We need quick access to the green zone.
+ self.green_zone.store(green_zone, Ordering::Release);
+
+ // Resize existing array.
+ data.resize(green_zone, yellow_zone, red_zone);
+ }
+ }
+
+ /// Records that `node` was used. This may displace an old node (if the LRU limits are
+ pub(crate) fn record_use(&self, node: &Arc<Node>) -> Option<Arc<Node>> {
+ tracing::debug!("record_use(node={:?})", node);
+
+ // Load green zone length and check if the LRU cache is even enabled.
+ let green_zone = self.green_zone.load(Ordering::Acquire);
+ tracing::debug!("record_use: green_zone={}", green_zone);
+ if green_zone == 0 {
+ return None;
+ }
+
+ // Find current index of list (if any) and the current length
+ // of our green zone.
+ let index = node.lru_index().load();
+ tracing::debug!("record_use: index={}", index);
+
+ // Already a member of the list, and in the green zone -- nothing to do!
+ if index < green_zone {
+ return None;
+ }
+
+ self.data.lock().record_use(node)
+ }
+
+ pub(crate) fn purge(&self) {
+ self.green_zone.store(0, Ordering::SeqCst);
+ *self.data.lock() = LruData::with_seed(LRU_SEED);
+ }
+}
+
+impl<Node> LruData<Node>
+where
+ Node: LruNode,
+{
+ fn with_seed(seed_str: &str) -> Self {
+ Self::with_rng(rng_with_seed(seed_str))
+ }
+
+ fn with_rng(rng: Rand64) -> Self {
+ LruData { end_yellow_zone: 0, end_green_zone: 0, end_red_zone: 0, entries: Vec::new(), rng }
+ }
+
+ fn green_zone(&self) -> std::ops::Range<usize> {
+ 0..self.end_green_zone
+ }
+
+ fn yellow_zone(&self) -> std::ops::Range<usize> {
+ self.end_green_zone..self.end_yellow_zone
+ }
+
+ fn red_zone(&self) -> std::ops::Range<usize> {
+ self.end_yellow_zone..self.end_red_zone
+ }
+
+ fn resize(&mut self, len_green_zone: usize, len_yellow_zone: usize, len_red_zone: usize) {
+ self.end_green_zone = len_green_zone;
+ self.end_yellow_zone = self.end_green_zone + len_yellow_zone;
+ self.end_red_zone = self.end_yellow_zone + len_red_zone;
+ let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone));
+
+ tracing::debug!("green_zone = {:?}", self.green_zone());
+ tracing::debug!("yellow_zone = {:?}", self.yellow_zone());
+ tracing::debug!("red_zone = {:?}", self.red_zone());
+
+ // We expect to resize when the LRU cache is basically empty.
+ // So just forget all the old LRU indices to start.
+ for entry in entries {
+ entry.lru_index().clear();
+ }
+ }
+
+ /// Records that a node was used. If it is already a member of the
+ /// LRU list, it is promoted to the green zone (unless it's
+ /// already there). Otherwise, it is added to the list first and
+ /// *then* promoted to the green zone. Adding a new node to the
+ /// list may displace an old member of the red zone, in which case
+ /// that is returned.
+ fn record_use(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
+ tracing::debug!("record_use(node={:?})", node);
+
+ // NB: When this is invoked, we have typically already loaded
+ // the LRU index (to check if it is in green zone). But that
+ // check was done outside the lock and -- for all we know --
+ // the index may have changed since. So we always reload.
+ let index = node.lru_index().load();
+
+ if index < self.end_green_zone {
+ None
+ } else if index < self.end_yellow_zone {
+ self.promote_yellow_to_green(node, index);
+ None
+ } else if index < self.end_red_zone {
+ self.promote_red_to_green(node, index);
+ None
+ } else {
+ self.insert_new(node)
+ }
+ }
+
+ /// Inserts a node that is not yet a member of the LRU list. If
+ /// the list is at capacity, this can displace an existing member.
+ fn insert_new(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
+ debug_assert!(!node.lru_index().is_in_lru());
+
+ // Easy case: we still have capacity. Push it, and then promote
+ // it up to the appropriate zone.
+ let len = self.entries.len();
+ if len < self.end_red_zone {
+ self.entries.push(node.clone());
+ node.lru_index().store(len);
+ tracing::debug!("inserted node {:?} at {}", node, len);
+ return self.record_use(node);
+ }
+
+ // Harder case: no capacity. Create some by evicting somebody from red
+ // zone and then promoting.
+ let victim_index = self.pick_index(self.red_zone());
+ let victim_node = std::mem::replace(&mut self.entries[victim_index], node.clone());
+ tracing::debug!("evicting red node {:?} from {}", victim_node, victim_index);
+ victim_node.lru_index().clear();
+ self.promote_red_to_green(node, victim_index);
+ Some(victim_node)
+ }
+
+ /// Promotes the node `node`, stored at `red_index` (in the red
+ /// zone), into a green index, demoting yellow/green nodes at
+ /// random.
+ ///
+ /// NB: It is not required that `node.lru_index()` is up-to-date
+ /// when entering this method.
+ fn promote_red_to_green(&mut self, node: &Arc<Node>, red_index: usize) {
+ debug_assert!(self.red_zone().contains(&red_index));
+
+ // Pick a yellow at random and switch places with it.
+ //
+ // Subtle: we do not update `node.lru_index` *yet* -- we're
+ // going to invoke `self.promote_yellow` next, and it will get
+ // updated then.
+ let yellow_index = self.pick_index(self.yellow_zone());
+ tracing::debug!(
+ "demoting yellow node {:?} from {} to red at {}",
+ self.entries[yellow_index],
+ yellow_index,
+ red_index,
+ );
+ self.entries.swap(yellow_index, red_index);
+ self.entries[red_index].lru_index().store(red_index);
+
+ // Now move ourselves up into the green zone.
+ self.promote_yellow_to_green(node, yellow_index);
+ }
+
+ /// Promotes the node `node`, stored at `yellow_index` (in the
+ /// yellow zone), into a green index, demoting a green node at
+ /// random to replace it.
+ ///
+ /// NB: It is not required that `node.lru_index()` is up-to-date
+ /// when entering this method.
+ fn promote_yellow_to_green(&mut self, node: &Arc<Node>, yellow_index: usize) {
+ debug_assert!(self.yellow_zone().contains(&yellow_index));
+
+ // Pick a yellow at random and switch places with it.
+ let green_index = self.pick_index(self.green_zone());
+ tracing::debug!(
+ "demoting green node {:?} from {} to yellow at {}",
+ self.entries[green_index],
+ green_index,
+ yellow_index
+ );
+ self.entries.swap(green_index, yellow_index);
+ self.entries[yellow_index].lru_index().store(yellow_index);
+ node.lru_index().store(green_index);
+
+ tracing::debug!("promoted {:?} to green index {}", node, green_index);
+ }
+
+ fn pick_index(&mut self, zone: std::ops::Range<usize>) -> usize {
+ let end_index = std::cmp::min(zone.end, self.entries.len());
+ self.rng.rand_range(zone.start as u64..end_index as u64) as usize
+ }
+}
+
+impl Default for LruIndex {
+ fn default() -> Self {
+ Self { index: AtomicUsize::new(std::usize::MAX) }
+ }
+}
+
+impl LruIndex {
+ fn load(&self) -> usize {
+ self.index.load(Ordering::Acquire) // see note on ordering below
+ }
+
+ fn store(&self, value: usize) {
+ self.index.store(value, Ordering::Release) // see note on ordering below
+ }
+
+ fn clear(&self) {
+ self.store(std::usize::MAX);
+ }
+
+ fn is_in_lru(&self) -> bool {
+ self.load() != std::usize::MAX
+ }
+}
+
+fn rng_with_seed(seed_str: &str) -> Rand64 {
+ let mut seed: [u8; 16] = [0; 16];
+ for (i, &b) in seed_str.as_bytes().iter().take(16).enumerate() {
+ seed[i] = b;
+ }
+ Rand64::new(u128::from_le_bytes(seed))
+}
+
+// A note on ordering:
+//
+// I chose to use AcqRel for the ordering but I don't think it's
+// strictly needed. All writes occur under a lock, so they should be
+// ordered w/r/t one another. As for the reads, they can occur
+// outside the lock, but they don't themselves enable dependent reads
+// -- if the reads are out of bounds, we would acquire a lock.
diff --git a/crates/salsa/src/plumbing.rs b/crates/salsa/src/plumbing.rs
new file mode 100644
index 0000000000..560a9b8315
--- /dev/null
+++ b/crates/salsa/src/plumbing.rs
@@ -0,0 +1,238 @@
+//!
+#![allow(missing_docs)]
+
+use crate::debug::TableEntry;
+use crate::durability::Durability;
+use crate::Cycle;
+use crate::Database;
+use crate::Query;
+use crate::QueryTable;
+use crate::QueryTableMut;
+use std::borrow::Borrow;
+use std::fmt::Debug;
+use std::hash::Hash;
+use triomphe::Arc;
+
+pub use crate::derived::DependencyStorage;
+pub use crate::derived::MemoizedStorage;
+pub use crate::input::InputStorage;
+pub use crate::interned::InternedStorage;
+pub use crate::interned::LookupInternedStorage;
+pub use crate::{revision::Revision, DatabaseKeyIndex, QueryDb, Runtime};
+
+/// Defines various associated types. An impl of this
+/// should be generated for your query-context type automatically by
+/// the `database_storage` macro, so you shouldn't need to mess
+/// with this trait directly.
+pub trait DatabaseStorageTypes: Database {
+ /// Defines the "storage type", where all the query data is kept.
+ /// This type is defined by the `database_storage` macro.
+ type DatabaseStorage: Default;
+}
+
+/// Internal operations that the runtime uses to operate on the database.
+pub trait DatabaseOps {
+ /// Upcast this type to a `dyn Database`.
+ fn ops_database(&self) -> &dyn Database;
+
+ /// Gives access to the underlying salsa runtime.
+ fn ops_salsa_runtime(&self) -> &Runtime;
+
+ /// Gives access to the underlying salsa runtime.
+ fn ops_salsa_runtime_mut(&mut self) -> &mut Runtime;
+
+ /// Formats a database key index in a human readable fashion.
+ fn fmt_index(
+ &self,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result;
+
+ /// True if the computed value for `input` may have changed since `revision`.
+ fn maybe_changed_after(&self, input: DatabaseKeyIndex, revision: Revision) -> bool;
+
+ /// Find the `CycleRecoveryStrategy` for a given input.
+ fn cycle_recovery_strategy(&self, input: DatabaseKeyIndex) -> CycleRecoveryStrategy;
+
+ /// Executes the callback for each kind of query.
+ fn for_each_query(&self, op: &mut dyn FnMut(&dyn QueryStorageMassOps));
+}
+
+/// Internal operations performed on the query storage as a whole
+/// (note that these ops do not need to know the identity of the
+/// query, unlike `QueryStorageOps`).
+pub trait QueryStorageMassOps {
+ fn purge(&self);
+}
+
+pub trait DatabaseKey: Clone + Debug + Eq + Hash {}
+
+pub trait QueryFunction: Query {
+ /// See `CycleRecoveryStrategy`
+ const CYCLE_STRATEGY: CycleRecoveryStrategy;
+
+ fn execute(db: &<Self as QueryDb<'_>>::DynDb, key: Self::Key) -> Self::Value;
+
+ fn cycle_fallback(
+ db: &<Self as QueryDb<'_>>::DynDb,
+ cycle: &Cycle,
+ key: &Self::Key,
+ ) -> Self::Value {
+ let _ = (db, cycle, key);
+ panic!("query `{:?}` doesn't support cycle fallback", Self::default())
+ }
+}
+
+/// Cycle recovery strategy: Is this query capable of recovering from
+/// a cycle that results from executing the function? If so, how?
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum CycleRecoveryStrategy {
+ /// Cannot recover from cycles: panic.
+ ///
+ /// This is the default. It is also what happens if a cycle
+ /// occurs and the queries involved have different recovery
+ /// strategies.
+ ///
+ /// In the case of a failure due to a cycle, the panic
+ /// value will be XXX (FIXME).
+ Panic,
+
+ /// Recovers from cycles by storing a sentinel value.
+ ///
+ /// This value is computed by the `QueryFunction::cycle_fallback`
+ /// function.
+ Fallback,
+}
+
+/// Create a query table, which has access to the storage for the query
+/// and offers methods like `get`.
+pub fn get_query_table<'me, Q>(db: &'me <Q as QueryDb<'me>>::DynDb) -> QueryTable<'me, Q>
+where
+ Q: Query + 'me,
+ Q::Storage: QueryStorageOps<Q>,
+{
+ let group_storage: &Q::GroupStorage = HasQueryGroup::group_storage(db);
+ let query_storage: &Q::Storage = Q::query_storage(group_storage);
+ QueryTable::new(db, query_storage)
+}
+
+/// Create a mutable query table, which has access to the storage
+/// for the query and offers methods like `set`.
+pub fn get_query_table_mut<'me, Q>(db: &'me mut <Q as QueryDb<'me>>::DynDb) -> QueryTableMut<'me, Q>
+where
+ Q: Query,
+{
+ let (group_storage, runtime) = HasQueryGroup::group_storage_mut(db);
+ let query_storage = Q::query_storage_mut(group_storage);
+ QueryTableMut::new(runtime, &**query_storage)
+}
+
+pub trait QueryGroup: Sized {
+ type GroupStorage;
+
+ /// Dyn version of the associated database trait.
+ type DynDb: ?Sized + Database + HasQueryGroup<Self>;
+}
+
+/// Trait implemented by a database for each group that it supports.
+/// `S` and `K` are the types for *group storage* and *group key*, respectively.
+pub trait HasQueryGroup<G>: Database
+where
+ G: QueryGroup,
+{
+ /// Access the group storage struct from the database.
+ fn group_storage(&self) -> &G::GroupStorage;
+
+ /// Access the group storage struct from the database.
+ /// Also returns a ref to the `Runtime`, since otherwise
+ /// the database is borrowed and one cannot get access to it.
+ fn group_storage_mut(&mut self) -> (&G::GroupStorage, &mut Runtime);
+}
+
+// ANCHOR:QueryStorageOps
+pub trait QueryStorageOps<Q>
+where
+ Self: QueryStorageMassOps,
+ Q: Query,
+{
+ // ANCHOR_END:QueryStorageOps
+
+ /// See CycleRecoveryStrategy
+ const CYCLE_STRATEGY: CycleRecoveryStrategy;
+
+ fn new(group_index: u16) -> Self;
+
+ /// Format a database key index in a suitable way.
+ fn fmt_index(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ index: DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result;
+
+ // ANCHOR:maybe_changed_after
+ /// True if the value of `input`, which must be from this query, may have
+ /// changed after the given revision ended.
+ ///
+ /// This function should only be invoked with a revision less than the current
+ /// revision.
+ fn maybe_changed_after(
+ &self,
+ db: &<Q as QueryDb<'_>>::DynDb,
+ input: DatabaseKeyIndex,
+ revision: Revision,
+ ) -> bool;
+ // ANCHOR_END:maybe_changed_after
+
+ fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy {
+ Self::CYCLE_STRATEGY
+ }
+
+ // ANCHOR:fetch
+ /// Execute the query, returning the result (often, the result
+ /// will be memoized). This is the "main method" for
+ /// queries.
+ ///
+ /// Returns `Err` in the event of a cycle, meaning that computing
+ /// the value for this `key` is recursively attempting to fetch
+ /// itself.
+ fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value;
+ // ANCHOR_END:fetch
+
+ /// Returns the durability associated with a given key.
+ fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability;
+
+ /// Get the (current) set of the entries in the query storage
+ fn entries<C>(&self, db: &<Q as QueryDb<'_>>::DynDb) -> C
+ where
+ C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>;
+}
+
+/// An optional trait that is implemented for "user mutable" storage:
+/// that is, storage whose value is not derived from other storage but
+/// is set independently.
+pub trait InputQueryStorageOps<Q>
+where
+ Q: Query,
+{
+ fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability);
+}
+
+/// An optional trait that is implemented for "user mutable" storage:
+/// that is, storage whose value is not derived from other storage but
+/// is set independently.
+pub trait LruQueryStorageOps {
+ fn set_lru_capacity(&self, new_capacity: usize);
+}
+
+pub trait DerivedQueryStorageOps<Q>
+where
+ Q: Query,
+{
+ fn invalidate<S>(&self, runtime: &mut Runtime, key: &S)
+ where
+ S: Eq + Hash,
+ Q::Key: Borrow<S>;
+}
+
+pub type CycleParticipants = Arc<Vec<DatabaseKeyIndex>>;
diff --git a/crates/salsa/src/revision.rs b/crates/salsa/src/revision.rs
new file mode 100644
index 0000000000..d97aaf9deb
--- /dev/null
+++ b/crates/salsa/src/revision.rs
@@ -0,0 +1,67 @@
+//!
+use std::num::NonZeroU32;
+use std::sync::atomic::{AtomicU32, Ordering};
+
+/// Value of the initial revision, as a u32. We don't use 0
+/// because we want to use a `NonZeroU32`.
+const START: u32 = 1;
+
+/// A unique identifier for the current version of the database; each
+/// time an input is changed, the revision number is incremented.
+/// `Revision` is used internally to track which values may need to be
+/// recomputed, but is not something you should have to interact with
+/// directly as a user of salsa.
+#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct Revision {
+ generation: NonZeroU32,
+}
+
+impl Revision {
+ pub(crate) fn start() -> Self {
+ Self::from(START)
+ }
+
+ pub(crate) fn from(g: u32) -> Self {
+ Self { generation: NonZeroU32::new(g).unwrap() }
+ }
+
+ pub(crate) fn next(self) -> Revision {
+ Self::from(self.generation.get() + 1)
+ }
+
+ fn as_u32(self) -> u32 {
+ self.generation.get()
+ }
+}
+
+impl std::fmt::Debug for Revision {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(fmt, "R{}", self.generation)
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct AtomicRevision {
+ data: AtomicU32,
+}
+
+impl AtomicRevision {
+ pub(crate) fn start() -> Self {
+ Self { data: AtomicU32::new(START) }
+ }
+
+ pub(crate) fn load(&self) -> Revision {
+ Revision::from(self.data.load(Ordering::SeqCst))
+ }
+
+ pub(crate) fn store(&self, r: Revision) {
+ self.data.store(r.as_u32(), Ordering::SeqCst);
+ }
+
+ /// Increment by 1, returning previous value.
+ pub(crate) fn fetch_then_increment(&self) -> Revision {
+ let v = self.data.fetch_add(1, Ordering::SeqCst);
+ assert!(v != u32::max_value(), "revision overflow");
+ Revision::from(v)
+ }
+}
diff --git a/crates/salsa/src/runtime.rs b/crates/salsa/src/runtime.rs
new file mode 100644
index 0000000000..40b8856991
--- /dev/null
+++ b/crates/salsa/src/runtime.rs
@@ -0,0 +1,667 @@
+//!
+use crate::durability::Durability;
+use crate::hash::FxIndexSet;
+use crate::plumbing::CycleRecoveryStrategy;
+use crate::revision::{AtomicRevision, Revision};
+use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind};
+use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
+use parking_lot::{Mutex, RwLock};
+use std::hash::Hash;
+use std::panic::panic_any;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use tracing::debug;
+use triomphe::Arc;
+
+mod dependency_graph;
+use dependency_graph::DependencyGraph;
+
+pub(crate) mod local_state;
+use local_state::LocalState;
+
+use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions};
+
+/// The salsa runtime stores the storage for all queries as well as
+/// tracking the query stack and dependencies between cycles.
+///
+/// Each new runtime you create (e.g., via `Runtime::new` or
+/// `Runtime::default`) will have an independent set of query storage
+/// associated with it. Normally, therefore, you only do this once, at
+/// the start of your application.
+pub struct Runtime {
+ /// Our unique runtime id.
+ id: RuntimeId,
+
+ /// If this is a "forked" runtime, then the `revision_guard` will
+ /// be `Some`; this guard holds a read-lock on the global query
+ /// lock.
+ revision_guard: Option<RevisionGuard>,
+
+ /// Local state that is specific to this runtime (thread).
+ local_state: LocalState,
+
+ /// Shared state that is accessible via all runtimes.
+ shared_state: Arc<SharedState>,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) enum WaitResult {
+ Completed,
+ Panicked,
+ Cycle(Cycle),
+}
+
+impl Default for Runtime {
+ fn default() -> Self {
+ Runtime {
+ id: RuntimeId { counter: 0 },
+ revision_guard: None,
+ shared_state: Default::default(),
+ local_state: Default::default(),
+ }
+ }
+}
+
+impl std::fmt::Debug for Runtime {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ fmt.debug_struct("Runtime")
+ .field("id", &self.id())
+ .field("forked", &self.revision_guard.is_some())
+ .field("shared_state", &self.shared_state)
+ .finish()
+ }
+}
+
+impl Runtime {
+ /// Create a new runtime; equivalent to `Self::default`. This is
+ /// used when creating a new database.
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// See [`crate::storage::Storage::snapshot`].
+ pub(crate) fn snapshot(&self) -> Self {
+ if self.local_state.query_in_progress() {
+ panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
+ }
+
+ let revision_guard = RevisionGuard::new(&self.shared_state);
+
+ let id = RuntimeId { counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst) };
+
+ Runtime {
+ id,
+ revision_guard: Some(revision_guard),
+ shared_state: self.shared_state.clone(),
+ local_state: Default::default(),
+ }
+ }
+
+ /// A "synthetic write" causes the system to act *as though* some
+ /// input of durability `durability` has changed. This is mostly
+ /// useful for profiling scenarios.
+ ///
+ /// **WARNING:** Just like an ordinary write, this method triggers
+ /// cancellation. If you invoke it while a snapshot exists, it
+ /// will block until that snapshot is dropped -- if that snapshot
+ /// is owned by the current thread, this could trigger deadlock.
+ pub fn synthetic_write(&mut self, durability: Durability) {
+ self.with_incremented_revision(|_next_revision| Some(durability));
+ }
+
+ /// The unique identifier attached to this `SalsaRuntime`. Each
+ /// snapshotted runtime has a distinct identifier.
+ #[inline]
+ pub fn id(&self) -> RuntimeId {
+ self.id
+ }
+
+ /// Returns the database-key for the query that this thread is
+ /// actively executing (if any).
+ pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
+ self.local_state.active_query()
+ }
+
+ /// Read current value of the revision counter.
+ #[inline]
+ pub(crate) fn current_revision(&self) -> Revision {
+ self.shared_state.revisions[0].load()
+ }
+
+ /// The revision in which values with durability `d` may have last
+ /// changed. For D0, this is just the current revision. But for
+ /// higher levels of durability, this value may lag behind the
+ /// current revision. If we encounter a value of durability Di,
+ /// then, we can check this function to get a "bound" on when the
+ /// value may have changed, which allows us to skip walking its
+ /// dependencies.
+ #[inline]
+ pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
+ self.shared_state.revisions[d.index()].load()
+ }
+
+ /// Read current value of the revision counter.
+ #[inline]
+ pub(crate) fn pending_revision(&self) -> Revision {
+ self.shared_state.pending_revision.load()
+ }
+
+ #[cold]
+ pub(crate) fn unwind_cancelled(&self) {
+ self.report_untracked_read();
+ Cancelled::PendingWrite.throw();
+ }
+
+ /// Acquires the **global query write lock** (ensuring that no queries are
+ /// executing) and then increments the current revision counter; invokes
+ /// `op` with the global query write lock still held.
+ ///
+ /// While we wait to acquire the global query write lock, this method will
+ /// also increment `pending_revision_increments`, thus signalling to queries
+ /// that their results are "cancelled" and they should abort as expeditiously
+ /// as possible.
+ ///
+ /// The `op` closure should actually perform the writes needed. It is given
+ /// the new revision as an argument, and its return value indicates whether
+ /// any pre-existing value was modified:
+ ///
+ /// - returning `None` means that no pre-existing value was modified (this
+ /// could occur e.g. when setting some key on an input that was never set
+ /// before)
+ /// - returning `Some(d)` indicates that a pre-existing value was modified
+ /// and it had the durability `d`. This will update the records for when
+ /// values with each durability were modified.
+ ///
+ /// Note that, given our writer model, we can assume that only one thread is
+ /// attempting to increment the global revision at a time.
+ pub(crate) fn with_incremented_revision<F>(&mut self, op: F)
+ where
+ F: FnOnce(Revision) -> Option<Durability>,
+ {
+ tracing::debug!("increment_revision()");
+
+ if !self.permits_increment() {
+ panic!("increment_revision invoked during a query computation");
+ }
+
+ // Set the `pending_revision` field so that people
+ // know current revision is cancelled.
+ let current_revision = self.shared_state.pending_revision.fetch_then_increment();
+
+ // To modify the revision, we need the lock.
+ let shared_state = self.shared_state.clone();
+ let _lock = shared_state.query_lock.write();
+
+ let old_revision = self.shared_state.revisions[0].fetch_then_increment();
+ assert_eq!(current_revision, old_revision);
+
+ let new_revision = current_revision.next();
+
+ debug!("increment_revision: incremented to {:?}", new_revision);
+
+ if let Some(d) = op(new_revision) {
+ for rev in &self.shared_state.revisions[1..=d.index()] {
+ rev.store(new_revision);
+ }
+ }
+ }
+
+ pub(crate) fn permits_increment(&self) -> bool {
+ self.revision_guard.is_none() && !self.local_state.query_in_progress()
+ }
+
+ #[inline]
+ pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
+ self.local_state.push_query(database_key_index)
+ }
+
+ /// Reports that the currently active query read the result from
+ /// another query.
+ ///
+ /// Also checks whether the "cycle participant" flag is set on
+ /// the current stack frame -- if so, panics with `CycleParticipant`
+ /// value, which should be caught by the code executing the query.
+ ///
+ /// # Parameters
+ ///
+ /// - `database_key`: the query whose result was read
+ /// - `changed_revision`: the last revision in which the result of that
+ /// query had changed
+ pub(crate) fn report_query_read_and_unwind_if_cycle_resulted(
+ &self,
+ input: DatabaseKeyIndex,
+ durability: Durability,
+ changed_at: Revision,
+ ) {
+ self.local_state
+ .report_query_read_and_unwind_if_cycle_resulted(input, durability, changed_at);
+ }
+
+ /// Reports that the query depends on some state unknown to salsa.
+ ///
+ /// Queries which report untracked reads will be re-executed in the next
+ /// revision.
+ pub fn report_untracked_read(&self) {
+ self.local_state.report_untracked_read(self.current_revision());
+ }
+
+ /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
+ ///
+ /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
+ pub fn report_synthetic_read(&self, durability: Durability) {
+ let changed_at = self.last_changed_revision(durability);
+ self.local_state.report_synthetic_read(durability, changed_at);
+ }
+
+ /// Handles a cycle in the dependency graph that was detected when the
+ /// current thread tried to block on `database_key_index` which is being
+ /// executed by `to_id`. If this function returns, then `to_id` no longer
+ /// depends on the current thread, and so we should continue executing
+ /// as normal. Otherwise, the function will throw a `Cycle` which is expected
+ /// to be caught by some frame on our stack. This occurs either if there is
+ /// a frame on our stack with cycle recovery (possibly the top one!) or if there
+ /// is no cycle recovery at all.
+ fn unblock_cycle_and_maybe_throw(
+ &self,
+ db: &dyn Database,
+ dg: &mut DependencyGraph,
+ database_key_index: DatabaseKeyIndex,
+ to_id: RuntimeId,
+ ) {
+ debug!("unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index);
+
+ let mut from_stack = self.local_state.take_query_stack();
+ let from_id = self.id();
+
+ // Make a "dummy stack frame". As we iterate through the cycle, we will collect the
+ // inputs from each participant. Then, if we are participating in cycle recovery, we
+ // will propagate those results to all participants.
+ let mut cycle_query = ActiveQuery::new(database_key_index);
+
+ // Identify the cycle participants:
+ let cycle = {
+ let mut v = vec![];
+ dg.for_each_cycle_participant(
+ from_id,
+ &mut from_stack,
+ database_key_index,
+ to_id,
+ |aqs| {
+ aqs.iter_mut().for_each(|aq| {
+ cycle_query.add_from(aq);
+ v.push(aq.database_key_index);
+ });
+ },
+ );
+
+ // We want to give the participants in a deterministic order
+ // (at least for this execution, not necessarily across executions),
+ // no matter where it started on the stack. Find the minimum
+ // key and rotate it to the front.
+ let min = v.iter().min().unwrap();
+ let index = v.iter().position(|p| p == min).unwrap();
+ v.rotate_left(index);
+
+ // No need to store extra memory.
+ v.shrink_to_fit();
+
+ Cycle::new(Arc::new(v))
+ };
+ debug!("cycle {:?}, cycle_query {:#?}", cycle.debug(db), cycle_query,);
+
+ // We can remove the cycle participants from the list of dependencies;
+ // they are a strongly connected component (SCC) and we only care about
+ // dependencies to things outside the SCC that control whether it will
+ // form again.
+ cycle_query.remove_cycle_participants(&cycle);
+
+ // Mark each cycle participant that has recovery set, along with
+ // any frames that come after them on the same thread. Those frames
+ // are going to be unwound so that fallback can occur.
+ dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aqs| {
+ aqs.iter_mut()
+ .skip_while(|aq| match db.cycle_recovery_strategy(aq.database_key_index) {
+ CycleRecoveryStrategy::Panic => true,
+ CycleRecoveryStrategy::Fallback => false,
+ })
+ .for_each(|aq| {
+ debug!("marking {:?} for fallback", aq.database_key_index.debug(db));
+ aq.take_inputs_from(&cycle_query);
+ assert!(aq.cycle.is_none());
+ aq.cycle = Some(cycle.clone());
+ });
+ });
+
+ // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`.
+ // They will throw the cycle, which will be caught by the frame that has
+ // cycle recovery so that it can execute that recovery.
+ let (me_recovered, others_recovered) =
+ dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id);
+
+ self.local_state.restore_query_stack(from_stack);
+
+ if me_recovered {
+ // If the current thread has recovery, we want to throw
+ // so that it can begin.
+ cycle.throw()
+ } else if others_recovered {
+ // If other threads have recovery but we didn't: return and we will block on them.
+ } else {
+ // if nobody has recover, then we panic
+ panic_any(cycle);
+ }
+ }
+
+ /// Block until `other_id` completes executing `database_key`;
+ /// panic or unwind in the case of a cycle.
+ ///
+ /// `query_mutex_guard` is the guard for the current query's state;
+ /// it will be dropped after we have successfully registered the
+ /// dependency.
+ ///
+ /// # Propagating panics
+ ///
+ /// If the thread `other_id` panics, then our thread is considered
+ /// cancelled, so this function will panic with a `Cancelled` value.
+ ///
+ /// # Cycle handling
+ ///
+ /// If the thread `other_id` already depends on the current thread,
+ /// and hence there is a cycle in the query graph, then this function
+ /// will unwind instead of returning normally. The method of unwinding
+ /// depends on the [`Self::mutual_cycle_recovery_strategy`]
+ /// of the cycle participants:
+ ///
+ /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value.
+ /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`].
+ pub(crate) fn block_on_or_unwind<QueryMutexGuard>(
+ &self,
+ db: &dyn Database,
+ database_key: DatabaseKeyIndex,
+ other_id: RuntimeId,
+ query_mutex_guard: QueryMutexGuard,
+ ) {
+ let mut dg = self.shared_state.dependency_graph.lock();
+
+ if dg.depends_on(other_id, self.id()) {
+ self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id);
+
+ // If the above fn returns, then (via cycle recovery) it has unblocked the
+ // cycle, so we can continue.
+ assert!(!dg.depends_on(other_id, self.id()));
+ }
+
+ db.salsa_event(Event {
+ runtime_id: self.id(),
+ kind: EventKind::WillBlockOn { other_runtime_id: other_id, database_key },
+ });
+
+ let stack = self.local_state.take_query_stack();
+
+ let (stack, result) = DependencyGraph::block_on(
+ dg,
+ self.id(),
+ database_key,
+ other_id,
+ stack,
+ query_mutex_guard,
+ );
+
+ self.local_state.restore_query_stack(stack);
+
+ match result {
+ WaitResult::Completed => (),
+
+ // If the other thread panicked, then we consider this thread
+ // cancelled. The assumption is that the panic will be detected
+ // by the other thread and responded to appropriately.
+ WaitResult::Panicked => Cancelled::PropagatedPanic.throw(),
+
+ WaitResult::Cycle(c) => c.throw(),
+ }
+ }
+
+ /// Invoked when this runtime completed computing `database_key` with
+ /// the given result `wait_result` (`wait_result` should be `None` if
+ /// computing `database_key` panicked and could not complete).
+ /// This function unblocks any dependent queries and allows them
+ /// to continue executing.
+ pub(crate) fn unblock_queries_blocked_on(
+ &self,
+ database_key: DatabaseKeyIndex,
+ wait_result: WaitResult,
+ ) {
+ self.shared_state
+ .dependency_graph
+ .lock()
+ .unblock_runtimes_blocked_on(database_key, wait_result);
+ }
+}
+
+/// State that will be common to all threads (when we support multiple threads)
+struct SharedState {
+ /// Stores the next id to use for a snapshotted runtime (starts at 1).
+ next_id: AtomicUsize,
+
+ /// Whenever derived queries are executing, they acquire this lock
+ /// in read mode. Mutating inputs (and thus creating a new
+ /// revision) requires a write lock (thus guaranteeing that no
+ /// derived queries are in progress). Note that this is not needed
+ /// to prevent **race conditions** -- the revision counter itself
+ /// is stored in an `AtomicUsize` so it can be cheaply read
+ /// without acquiring the lock. Rather, the `query_lock` is used
+ /// to ensure a higher-level consistency property.
+ query_lock: RwLock<()>,
+
+ /// This is typically equal to `revision` -- set to `revision+1`
+ /// when a new revision is pending (which implies that the current
+ /// revision is cancelled).
+ pending_revision: AtomicRevision,
+
+ /// Stores the "last change" revision for values of each duration.
+ /// This vector is always of length at least 1 (for Durability 0)
+ /// but its total length depends on the number of durations. The
+ /// element at index 0 is special as it represents the "current
+ /// revision". In general, we have the invariant that revisions
+ /// in here are *declining* -- that is, `revisions[i] >=
+ /// revisions[i + 1]`, for all `i`. This is because when you
+ /// modify a value with durability D, that implies that values
+ /// with durability less than D may have changed too.
+ revisions: Vec<AtomicRevision>,
+
+ /// The dependency graph tracks which runtimes are blocked on one
+ /// another, waiting for queries to terminate.
+ dependency_graph: Mutex<DependencyGraph>,
+}
+
+impl SharedState {
+ fn with_durabilities(durabilities: usize) -> Self {
+ SharedState {
+ next_id: AtomicUsize::new(1),
+ query_lock: Default::default(),
+ revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
+ pending_revision: AtomicRevision::start(),
+ dependency_graph: Default::default(),
+ }
+ }
+}
+
+impl std::panic::RefUnwindSafe for SharedState {}
+
+impl Default for SharedState {
+ fn default() -> Self {
+ Self::with_durabilities(Durability::LEN)
+ }
+}
+
+impl std::fmt::Debug for SharedState {
+ fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let query_lock = if self.query_lock.try_write().is_some() {
+ "<unlocked>"
+ } else if self.query_lock.try_read().is_some() {
+ "<rlocked>"
+ } else {
+ "<wlocked>"
+ };
+ fmt.debug_struct("SharedState")
+ .field("query_lock", &query_lock)
+ .field("revisions", &self.revisions)
+ .field("pending_revision", &self.pending_revision)
+ .finish()
+ }
+}
+
+#[derive(Debug)]
+struct ActiveQuery {
+ /// What query is executing
+ database_key_index: DatabaseKeyIndex,
+
+ /// Minimum durability of inputs observed so far.
+ durability: Durability,
+
+ /// Maximum revision of all inputs observed. If we observe an
+ /// untracked read, this will be set to the most recent revision.
+ changed_at: Revision,
+
+ /// Set of subqueries that were accessed thus far, or `None` if
+ /// there was an untracked the read.
+ dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
+
+ /// Stores the entire cycle, if one is found and this query is part of it.
+ cycle: Option<Cycle>,
+}
+
+impl ActiveQuery {
+ fn new(database_key_index: DatabaseKeyIndex) -> Self {
+ ActiveQuery {
+ database_key_index,
+ durability: Durability::MAX,
+ changed_at: Revision::start(),
+ dependencies: Some(FxIndexSet::default()),
+ cycle: None,
+ }
+ }
+
+ fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
+ if let Some(set) = &mut self.dependencies {
+ set.insert(input);
+ }
+
+ self.durability = self.durability.min(durability);
+ self.changed_at = self.changed_at.max(revision);
+ }
+
+ fn add_untracked_read(&mut self, changed_at: Revision) {
+ self.dependencies = None;
+ self.durability = Durability::LOW;
+ self.changed_at = changed_at;
+ }
+
+ fn add_synthetic_read(&mut self, durability: Durability, revision: Revision) {
+ self.dependencies = None;
+ self.durability = self.durability.min(durability);
+ self.changed_at = self.changed_at.max(revision);
+ }
+
+ pub(crate) fn revisions(&self) -> QueryRevisions {
+ let inputs = match &self.dependencies {
+ None => QueryInputs::Untracked,
+
+ Some(dependencies) => {
+ if dependencies.is_empty() {
+ QueryInputs::NoInputs
+ } else {
+ QueryInputs::Tracked { inputs: dependencies.iter().copied().collect() }
+ }
+ }
+ };
+
+ QueryRevisions { changed_at: self.changed_at, inputs, durability: self.durability }
+ }
+
+ /// Adds any dependencies from `other` into `self`.
+ /// Used during cycle recovery, see [`Runtime::create_cycle_error`].
+ fn add_from(&mut self, other: &ActiveQuery) {
+ self.changed_at = self.changed_at.max(other.changed_at);
+ self.durability = self.durability.min(other.durability);
+ if let Some(other_dependencies) = &other.dependencies {
+ if let Some(my_dependencies) = &mut self.dependencies {
+ my_dependencies.extend(other_dependencies.iter().copied());
+ }
+ } else {
+ self.dependencies = None;
+ }
+ }
+
+ /// Removes the participants in `cycle` from my dependencies.
+ /// Used during cycle recovery, see [`Runtime::create_cycle_error`].
+ fn remove_cycle_participants(&mut self, cycle: &Cycle) {
+ if let Some(my_dependencies) = &mut self.dependencies {
+ for p in cycle.participant_keys() {
+ my_dependencies.remove(&p);
+ }
+ }
+ }
+
+ /// Copy the changed-at, durability, and dependencies from `cycle_query`.
+ /// Used during cycle recovery, see [`Runtime::create_cycle_error`].
+ pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) {
+ self.changed_at = cycle_query.changed_at;
+ self.durability = cycle_query.durability;
+ self.dependencies = cycle_query.dependencies.clone();
+ }
+}
+
+/// A unique identifier for a particular runtime. Each time you create
+/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
+/// complete, its `RuntimeId` may potentially be re-used.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct RuntimeId {
+ counter: usize,
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct StampedValue<V> {
+ pub(crate) value: V,
+ pub(crate) durability: Durability,
+ pub(crate) changed_at: Revision,
+}
+
+struct RevisionGuard {
+ shared_state: Arc<SharedState>,
+}
+
+impl RevisionGuard {
+ fn new(shared_state: &Arc<SharedState>) -> Self {
+ // Subtle: we use a "recursive" lock here so that it is not an
+ // error to acquire a read-lock when one is already held (this
+ // happens when a query uses `snapshot` to spawn off parallel
+ // workers, for example).
+ //
+ // This has the side-effect that we are responsible to ensure
+ // that people contending for the write lock do not starve,
+ // but this is what we achieve via the cancellation mechanism.
+ //
+ // (In particular, since we only ever have one "mutating
+ // handle" to the database, the only contention for the global
+ // query lock occurs when there are "futures" evaluating
+ // queries in parallel, and those futures hold a read-lock
+ // already, so the starvation problem is more about them bring
+ // themselves to a close, versus preventing other people from
+ // *starting* work).
+ unsafe {
+ shared_state.query_lock.raw().lock_shared_recursive();
+ }
+
+ Self { shared_state: shared_state.clone() }
+ }
+}
+
+impl Drop for RevisionGuard {
+ fn drop(&mut self) {
+ // Release our read-lock without using RAII. As documented in
+ // `Snapshot::new` above, this requires the unsafe keyword.
+ unsafe {
+ self.shared_state.query_lock.raw().unlock_shared();
+ }
+ }
+}
diff --git a/crates/salsa/src/runtime/dependency_graph.rs b/crates/salsa/src/runtime/dependency_graph.rs
new file mode 100644
index 0000000000..e41eb280de
--- /dev/null
+++ b/crates/salsa/src/runtime/dependency_graph.rs
@@ -0,0 +1,251 @@
+//!
+use triomphe::Arc;
+
+use crate::{DatabaseKeyIndex, RuntimeId};
+use parking_lot::{Condvar, MutexGuard};
+use rustc_hash::FxHashMap;
+use smallvec::SmallVec;
+
+use super::{ActiveQuery, WaitResult};
+
+type QueryStack = Vec<ActiveQuery>;
+
+#[derive(Debug, Default)]
+pub(super) struct DependencyGraph {
+ /// A `(K -> V)` pair in this map indicates that the the runtime
+ /// `K` is blocked on some query executing in the runtime `V`.
+ /// This encodes a graph that must be acyclic (or else deadlock
+ /// will result).
+ edges: FxHashMap<RuntimeId, Edge>,
+
+ /// Encodes the `RuntimeId` that are blocked waiting for the result
+ /// of a given query.
+ query_dependents: FxHashMap<DatabaseKeyIndex, SmallVec<[RuntimeId; 4]>>,
+
+ /// When a key K completes which had dependent queries Qs blocked on it,
+ /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will
+ /// come here to fetch their results.
+ wait_results: FxHashMap<RuntimeId, (QueryStack, WaitResult)>,
+}
+
+#[derive(Debug)]
+struct Edge {
+ blocked_on_id: RuntimeId,
+ blocked_on_key: DatabaseKeyIndex,
+ stack: QueryStack,
+
+ /// Signalled whenever a query with dependents completes.
+ /// Allows those dependents to check if they are ready to unblock.
+ condvar: Arc<parking_lot::Condvar>,
+}
+
+impl DependencyGraph {
+ /// True if `from_id` depends on `to_id`.
+ ///
+ /// (i.e., there is a path from `from_id` to `to_id` in the graph.)
+ pub(super) fn depends_on(&mut self, from_id: RuntimeId, to_id: RuntimeId) -> bool {
+ let mut p = from_id;
+ while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) {
+ if q == to_id {
+ return true;
+ }
+
+ p = q;
+ }
+ p == to_id
+ }
+
+ /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle.
+ /// The cycle runs as follows:
+ ///
+ /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`...
+ /// 2. ...but `database_key` is already being executed by `to_id`...
+ /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`.
+ pub(super) fn for_each_cycle_participant(
+ &mut self,
+ from_id: RuntimeId,
+ from_stack: &mut QueryStack,
+ database_key: DatabaseKeyIndex,
+ to_id: RuntimeId,
+ mut closure: impl FnMut(&mut [ActiveQuery]),
+ ) {
+ debug_assert!(self.depends_on(to_id, from_id));
+
+ // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v):
+ //
+ // database_key = QB2
+ // from_id = A
+ // to_id = B
+ // from_stack = [QA1, QA2, QA3]
+ //
+ // self.edges[B] = { C, QC2, [QB1..QB3] }
+ // self.edges[C] = { A, QA2, [QC1..QC3] }
+ //
+ // The cyclic
+ // edge we have
+ // failed to add.
+ // :
+ // A : B C
+ // :
+ // QA1 v QB1 QC1
+ // ┌► QA2 ┌──► QB2 ┌─► QC2
+ // │ QA3 ───┘ QB3 ──┘ QC3 ───┐
+ // │ │
+ // └───────────────────────────────┘
+ //
+ // Final output: [QB2, QB3, QC2, QC3, QA2, QA3]
+
+ let mut id = to_id;
+ let mut key = database_key;
+ while id != from_id {
+ // Looking at the diagram above, the idea is to
+ // take the edge from `to_id` starting at `key`
+ // (inclusive) and down to the end. We can then
+ // load up the next thread (i.e., we start at B/QB2,
+ // and then load up the dependency on C/QC2).
+ let edge = self.edges.get_mut(&id).unwrap();
+ let prefix = edge.stack.iter_mut().take_while(|p| p.database_key_index != key).count();
+ closure(&mut edge.stack[prefix..]);
+ id = edge.blocked_on_id;
+ key = edge.blocked_on_key;
+ }
+
+ // Finally, we copy in the results from `from_stack`.
+ let prefix = from_stack.iter_mut().take_while(|p| p.database_key_index != key).count();
+ closure(&mut from_stack[prefix..]);
+ }
+
+ /// Unblock each blocked runtime (excluding the current one) if some
+ /// query executing in that runtime is participating in cycle fallback.
+ ///
+ /// Returns a boolean (Current, Others) where:
+ /// * Current is true if the current runtime has cycle participants
+ /// with fallback;
+ /// * Others is true if other runtimes were unblocked.
+ pub(super) fn maybe_unblock_runtimes_in_cycle(
+ &mut self,
+ from_id: RuntimeId,
+ from_stack: &QueryStack,
+ database_key: DatabaseKeyIndex,
+ to_id: RuntimeId,
+ ) -> (bool, bool) {
+ // See diagram in `for_each_cycle_participant`.
+ let mut id = to_id;
+ let mut key = database_key;
+ let mut others_unblocked = false;
+ while id != from_id {
+ let edge = self.edges.get(&id).unwrap();
+ let prefix = edge.stack.iter().take_while(|p| p.database_key_index != key).count();
+ let next_id = edge.blocked_on_id;
+ let next_key = edge.blocked_on_key;
+
+ if let Some(cycle) = edge.stack[prefix..].iter().rev().find_map(|aq| aq.cycle.clone()) {
+ // Remove `id` from the list of runtimes blocked on `next_key`:
+ self.query_dependents.get_mut(&next_key).unwrap().retain(|r| *r != id);
+
+ // Unblock runtime so that it can resume execution once lock is released:
+ self.unblock_runtime(id, WaitResult::Cycle(cycle));
+
+ others_unblocked = true;
+ }
+
+ id = next_id;
+ key = next_key;
+ }
+
+ let prefix = from_stack.iter().take_while(|p| p.database_key_index != key).count();
+ let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some());
+
+ (this_unblocked, others_unblocked)
+ }
+
+ /// Modifies the graph so that `from_id` is blocked
+ /// on `database_key`, which is being computed by
+ /// `to_id`.
+ ///
+ /// For this to be reasonable, the lock on the
+ /// results table for `database_key` must be held.
+ /// This ensures that computing `database_key` doesn't
+ /// complete before `block_on` executes.
+ ///
+ /// Preconditions:
+ /// * No path from `to_id` to `from_id`
+ /// (i.e., `me.depends_on(to_id, from_id)` is false)
+ /// * `held_mutex` is a read lock (or stronger) on `database_key`
+ pub(super) fn block_on<QueryMutexGuard>(
+ mut me: MutexGuard<'_, Self>,
+ from_id: RuntimeId,
+ database_key: DatabaseKeyIndex,
+ to_id: RuntimeId,
+ from_stack: QueryStack,
+ query_mutex_guard: QueryMutexGuard,
+ ) -> (QueryStack, WaitResult) {
+ let condvar = me.add_edge(from_id, database_key, to_id, from_stack);
+
+ // Release the mutex that prevents `database_key`
+ // from completing, now that the edge has been added.
+ drop(query_mutex_guard);
+
+ loop {
+ if let Some(stack_and_result) = me.wait_results.remove(&from_id) {
+ debug_assert!(!me.edges.contains_key(&from_id));
+ return stack_and_result;
+ }
+ condvar.wait(&mut me);
+ }
+ }
+
+ /// Helper for `block_on`: performs actual graph modification
+ /// to add a dependency edge from `from_id` to `to_id`, which is
+ /// computing `database_key`.
+ fn add_edge(
+ &mut self,
+ from_id: RuntimeId,
+ database_key: DatabaseKeyIndex,
+ to_id: RuntimeId,
+ from_stack: QueryStack,
+ ) -> Arc<parking_lot::Condvar> {
+ assert_ne!(from_id, to_id);
+ debug_assert!(!self.edges.contains_key(&from_id));
+ debug_assert!(!self.depends_on(to_id, from_id));
+
+ let condvar = Arc::new(Condvar::new());
+ self.edges.insert(
+ from_id,
+ Edge {
+ blocked_on_id: to_id,
+ blocked_on_key: database_key,
+ stack: from_stack,
+ condvar: condvar.clone(),
+ },
+ );
+ self.query_dependents.entry(database_key).or_default().push(from_id);
+ condvar
+ }
+
+ /// Invoked when runtime `to_id` completes executing
+ /// `database_key`.
+ pub(super) fn unblock_runtimes_blocked_on(
+ &mut self,
+ database_key: DatabaseKeyIndex,
+ wait_result: WaitResult,
+ ) {
+ let dependents = self.query_dependents.remove(&database_key).unwrap_or_default();
+
+ for from_id in dependents {
+ self.unblock_runtime(from_id, wait_result.clone());
+ }
+ }
+
+ /// Unblock the runtime with the given id with the given wait-result.
+ /// This will cause it resume execution (though it will have to grab
+ /// the lock on this data structure first, to recover the wait result).
+ fn unblock_runtime(&mut self, id: RuntimeId, wait_result: WaitResult) {
+ let edge = self.edges.remove(&id).expect("not blocked");
+ self.wait_results.insert(id, (edge.stack, wait_result));
+
+ // Now that we have inserted the `wait_results`,
+ // notify the thread.
+ edge.condvar.notify_one();
+ }
+}
diff --git a/crates/salsa/src/runtime/local_state.rs b/crates/salsa/src/runtime/local_state.rs
new file mode 100644
index 0000000000..91b95dffe7
--- /dev/null
+++ b/crates/salsa/src/runtime/local_state.rs
@@ -0,0 +1,214 @@
+//!
+use tracing::debug;
+
+use crate::durability::Durability;
+use crate::runtime::ActiveQuery;
+use crate::runtime::Revision;
+use crate::Cycle;
+use crate::DatabaseKeyIndex;
+use std::cell::RefCell;
+use triomphe::Arc;
+
+/// State that is specific to a single execution thread.
+///
+/// Internally, this type uses ref-cells.
+///
+/// **Note also that all mutations to the database handle (and hence
+/// to the local-state) must be undone during unwinding.**
+pub(super) struct LocalState {
+ /// Vector of active queries.
+ ///
+ /// This is normally `Some`, but it is set to `None`
+ /// while the query is blocked waiting for a result.
+ ///
+ /// Unwinding note: pushes onto this vector must be popped -- even
+ /// during unwinding.
+ query_stack: RefCell<Option<Vec<ActiveQuery>>>,
+}
+
+/// Summarizes "all the inputs that a query used"
+#[derive(Debug, Clone)]
+pub(crate) struct QueryRevisions {
+ /// The most revision in which some input changed.
+ pub(crate) changed_at: Revision,
+
+ /// Minimum durability of the inputs to this query.
+ pub(crate) durability: Durability,
+
+ /// The inputs that went into our query, if we are tracking them.
+ pub(crate) inputs: QueryInputs,
+}
+
+/// Every input.
+#[derive(Debug, Clone)]
+pub(crate) enum QueryInputs {
+ /// Non-empty set of inputs, fully known
+ Tracked { inputs: Arc<[DatabaseKeyIndex]> },
+
+ /// Empty set of inputs, fully known.
+ NoInputs,
+
+ /// Unknown quantity of inputs
+ Untracked,
+}
+
+impl Default for LocalState {
+ fn default() -> Self {
+ LocalState { query_stack: RefCell::new(Some(Vec::new())) }
+ }
+}
+
+impl LocalState {
+ #[inline]
+ pub(super) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
+ let mut query_stack = self.query_stack.borrow_mut();
+ let query_stack = query_stack.as_mut().expect("local stack taken");
+ query_stack.push(ActiveQuery::new(database_key_index));
+ ActiveQueryGuard { local_state: self, database_key_index, push_len: query_stack.len() }
+ }
+
+ fn with_query_stack<R>(&self, c: impl FnOnce(&mut Vec<ActiveQuery>) -> R) -> R {
+ c(self.query_stack.borrow_mut().as_mut().expect("query stack taken"))
+ }
+
+ pub(super) fn query_in_progress(&self) -> bool {
+ self.with_query_stack(|stack| !stack.is_empty())
+ }
+
+ pub(super) fn active_query(&self) -> Option<DatabaseKeyIndex> {
+ self.with_query_stack(|stack| {
+ stack.last().map(|active_query| active_query.database_key_index)
+ })
+ }
+
+ pub(super) fn report_query_read_and_unwind_if_cycle_resulted(
+ &self,
+ input: DatabaseKeyIndex,
+ durability: Durability,
+ changed_at: Revision,
+ ) {
+ debug!(
+ "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})",
+ input, durability, changed_at
+ );
+ self.with_query_stack(|stack| {
+ if let Some(top_query) = stack.last_mut() {
+ top_query.add_read(input, durability, changed_at);
+
+ // We are a cycle participant:
+ //
+ // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0
+ // ^ ^
+ // : |
+ // This edge -----+ |
+ // |
+ // |
+ // N0
+ //
+ // In this case, the value we have just read from `Ci+1`
+ // is actually the cycle fallback value and not especially
+ // interesting. We unwind now with `CycleParticipant` to avoid
+ // executing the rest of our query function. This unwinding
+ // will be caught and our own fallback value will be used.
+ //
+ // Note that `Ci+1` may` have *other* callers who are not
+ // participants in the cycle (e.g., N0 in the graph above).
+ // They will not have the `cycle` marker set in their
+ // stack frames, so they will just read the fallback value
+ // from `Ci+1` and continue on their merry way.
+ if let Some(cycle) = &top_query.cycle {
+ cycle.clone().throw()
+ }
+ }
+ })
+ }
+
+ pub(super) fn report_untracked_read(&self, current_revision: Revision) {
+ self.with_query_stack(|stack| {
+ if let Some(top_query) = stack.last_mut() {
+ top_query.add_untracked_read(current_revision);
+ }
+ })
+ }
+
+ /// Update the top query on the stack to act as though it read a value
+ /// of durability `durability` which changed in `revision`.
+ pub(super) fn report_synthetic_read(&self, durability: Durability, revision: Revision) {
+ self.with_query_stack(|stack| {
+ if let Some(top_query) = stack.last_mut() {
+ top_query.add_synthetic_read(durability, revision);
+ }
+ })
+ }
+
+ /// Takes the query stack and returns it. This is used when
+ /// the current thread is blocking. The stack must be restored
+ /// with [`Self::restore_query_stack`] when the thread unblocks.
+ pub(super) fn take_query_stack(&self) -> Vec<ActiveQuery> {
+ assert!(self.query_stack.borrow().is_some(), "query stack already taken");
+ self.query_stack.take().unwrap()
+ }
+
+ /// Restores a query stack taken with [`Self::take_query_stack`] once
+ /// the thread unblocks.
+ pub(super) fn restore_query_stack(&self, stack: Vec<ActiveQuery>) {
+ assert!(self.query_stack.borrow().is_none(), "query stack not taken");
+ self.query_stack.replace(Some(stack));
+ }
+}
+
+impl std::panic::RefUnwindSafe for LocalState {}
+
+/// When a query is pushed onto the `active_query` stack, this guard
+/// is returned to represent its slot. The guard can be used to pop
+/// the query from the stack -- in the case of unwinding, the guard's
+/// destructor will also remove the query.
+pub(crate) struct ActiveQueryGuard<'me> {
+ local_state: &'me LocalState,
+ push_len: usize,
+ database_key_index: DatabaseKeyIndex,
+}
+
+impl ActiveQueryGuard<'_> {
+ fn pop_helper(&self) -> ActiveQuery {
+ self.local_state.with_query_stack(|stack| {
+ // Sanity check: pushes and pops should be balanced.
+ assert_eq!(stack.len(), self.push_len);
+ debug_assert_eq!(stack.last().unwrap().database_key_index, self.database_key_index);
+ stack.pop().unwrap()
+ })
+ }
+
+ /// Invoked when the query has successfully completed execution.
+ pub(super) fn complete(self) -> ActiveQuery {
+ let query = self.pop_helper();
+ std::mem::forget(self);
+ query
+ }
+
+ /// Pops an active query from the stack. Returns the [`QueryRevisions`]
+ /// which summarizes the other queries that were accessed during this
+ /// query's execution.
+ #[inline]
+ pub(crate) fn pop(self) -> QueryRevisions {
+ // Extract accumulated inputs.
+ let popped_query = self.complete();
+
+ // If this frame were a cycle participant, it would have unwound.
+ assert!(popped_query.cycle.is_none());
+
+ popped_query.revisions()
+ }
+
+ /// If the active query is registered as a cycle participant, remove and
+ /// return that cycle.
+ pub(crate) fn take_cycle(&self) -> Option<Cycle> {
+ self.local_state.with_query_stack(|stack| stack.last_mut()?.cycle.take())
+ }
+}
+
+impl Drop for ActiveQueryGuard<'_> {
+ fn drop(&mut self) {
+ self.pop_helper();
+ }
+}
diff --git a/crates/salsa/src/storage.rs b/crates/salsa/src/storage.rs
new file mode 100644
index 0000000000..c0e6416f4a
--- /dev/null
+++ b/crates/salsa/src/storage.rs
@@ -0,0 +1,54 @@
+//!
+use crate::{plumbing::DatabaseStorageTypes, Runtime};
+use triomphe::Arc;
+
+/// Stores the cached results and dependency information for all the queries
+/// defined on your salsa database. Also embeds a [`Runtime`] which is used to
+/// manage query execution. Every database must include a `storage:
+/// Storage<Self>` field.
+pub struct Storage<DB: DatabaseStorageTypes> {
+ query_store: Arc<DB::DatabaseStorage>,
+ runtime: Runtime,
+}
+
+impl<DB: DatabaseStorageTypes> Default for Storage<DB> {
+ fn default() -> Self {
+ Self { query_store: Default::default(), runtime: Default::default() }
+ }
+}
+
+impl<DB: DatabaseStorageTypes> Storage<DB> {
+ /// Gives access to the underlying salsa runtime.
+ pub fn salsa_runtime(&self) -> &Runtime {
+ &self.runtime
+ }
+
+ /// Gives access to the underlying salsa runtime.
+ pub fn salsa_runtime_mut(&mut self) -> &mut Runtime {
+ &mut self.runtime
+ }
+
+ /// Access the query storage tables. Not meant to be used directly by end
+ /// users.
+ pub fn query_store(&self) -> &DB::DatabaseStorage {
+ &self.query_store
+ }
+
+ /// Access the query storage tables. Not meant to be used directly by end
+ /// users.
+ pub fn query_store_mut(&mut self) -> (&DB::DatabaseStorage, &mut Runtime) {
+ (&self.query_store, &mut self.runtime)
+ }
+
+ /// Returns a "snapshotted" storage, suitable for use in a forked database.
+ /// This snapshot hold a read-lock on the global state, which means that any
+ /// attempt to `set` an input will block until the forked runtime is
+ /// dropped. See `ParallelDatabase::snapshot` for more information.
+ ///
+ /// **Warning.** This second handle is intended to be used from a separate
+ /// thread. Using two database handles from the **same thread** can lead to
+ /// deadlock.
+ pub fn snapshot(&self) -> Self {
+ Storage { query_store: self.query_store.clone(), runtime: self.runtime.snapshot() }
+ }
+}
diff --git a/crates/salsa/tests/cycles.rs b/crates/salsa/tests/cycles.rs
new file mode 100644
index 0000000000..00ca533244
--- /dev/null
+++ b/crates/salsa/tests/cycles.rs
@@ -0,0 +1,493 @@
+use std::panic::UnwindSafe;
+
+use expect_test::expect;
+use salsa::{Durability, ParallelDatabase, Snapshot};
+use test_log::test;
+
+// Axes:
+//
+// Threading
+// * Intra-thread
+// * Cross-thread -- part of cycle is on one thread, part on another
+//
+// Recovery strategies:
+// * Panic
+// * Fallback
+// * Mixed -- multiple strategies within cycle participants
+//
+// Across revisions:
+// * N/A -- only one revision
+// * Present in new revision, not old
+// * Present in old revision, not new
+// * Present in both revisions
+//
+// Dependencies
+// * Tracked
+// * Untracked -- cycle participant(s) contain untracked reads
+//
+// Layers
+// * Direct -- cycle participant is directly invoked from test
+// * Indirect -- invoked a query that invokes the cycle
+//
+//
+// | Thread | Recovery | Old, New | Dep style | Layers | Test Name |
+// | ------ | -------- | -------- | --------- | ------ | --------- |
+// | Intra | Panic | N/A | Tracked | direct | cycle_memoized |
+// | Intra | Panic | N/A | Untracked | direct | cycle_volatile |
+// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle |
+// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle |
+// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate |
+// | Intra | Fallback | New | Tracked | direct | cycle_appears |
+// | Intra | Fallback | Old | Tracked | direct | cycle_disappears |
+// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability |
+// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 |
+// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 |
+// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle |
+// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle |
+
+#[derive(PartialEq, Eq, Hash, Clone, Debug)]
+struct Error {
+ cycle: Vec<String>,
+}
+
+#[salsa::database(GroupStruct)]
+#[derive(Default)]
+struct DatabaseImpl {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for DatabaseImpl {}
+
+impl ParallelDatabase for DatabaseImpl {
+ fn snapshot(&self) -> Snapshot<Self> {
+ Snapshot::new(DatabaseImpl { storage: self.storage.snapshot() })
+ }
+}
+
+/// The queries A, B, and C in `Database` can be configured
+/// to invoke one another in arbitrary ways using this
+/// enum.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+enum CycleQuery {
+ None,
+ A,
+ B,
+ C,
+ AthenC,
+}
+
+#[salsa::query_group(GroupStruct)]
+trait Database: salsa::Database {
+ // `a` and `b` depend on each other and form a cycle
+ fn memoized_a(&self) -> ();
+ fn memoized_b(&self) -> ();
+ fn volatile_a(&self) -> ();
+ fn volatile_b(&self) -> ();
+
+ #[salsa::input]
+ fn a_invokes(&self) -> CycleQuery;
+
+ #[salsa::input]
+ fn b_invokes(&self) -> CycleQuery;
+
+ #[salsa::input]
+ fn c_invokes(&self) -> CycleQuery;
+
+ #[salsa::cycle(recover_a)]
+ fn cycle_a(&self) -> Result<(), Error>;
+
+ #[salsa::cycle(recover_b)]
+ fn cycle_b(&self) -> Result<(), Error>;
+
+ fn cycle_c(&self) -> Result<(), Error>;
+}
+
+fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
+ Err(Error { cycle: cycle.all_participants(db) })
+}
+
+fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
+ Err(Error { cycle: cycle.all_participants(db) })
+}
+
+fn memoized_a(db: &dyn Database) {
+ db.memoized_b()
+}
+
+fn memoized_b(db: &dyn Database) {
+ db.memoized_a()
+}
+
+fn volatile_a(db: &dyn Database) {
+ db.salsa_runtime().report_untracked_read();
+ db.volatile_b()
+}
+
+fn volatile_b(db: &dyn Database) {
+ db.salsa_runtime().report_untracked_read();
+ db.volatile_a()
+}
+
+impl CycleQuery {
+ fn invoke(self, db: &dyn Database) -> Result<(), Error> {
+ match self {
+ CycleQuery::A => db.cycle_a(),
+ CycleQuery::B => db.cycle_b(),
+ CycleQuery::C => db.cycle_c(),
+ CycleQuery::AthenC => {
+ let _ = db.cycle_a();
+ db.cycle_c()
+ }
+ CycleQuery::None => Ok(()),
+ }
+ }
+}
+
+fn cycle_a(db: &dyn Database) -> Result<(), Error> {
+ db.a_invokes().invoke(db)
+}
+
+fn cycle_b(db: &dyn Database) -> Result<(), Error> {
+ db.b_invokes().invoke(db)
+}
+
+fn cycle_c(db: &dyn Database) -> Result<(), Error> {
+ db.c_invokes().invoke(db)
+}
+
+#[track_caller]
+fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
+ let v = std::panic::catch_unwind(f);
+ if let Err(d) = &v {
+ if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
+ return cycle.clone();
+ }
+ }
+ panic!("unexpected value: {:?}", v)
+}
+
+#[test]
+fn cycle_memoized() {
+ let db = DatabaseImpl::default();
+ let cycle = extract_cycle(|| db.memoized_a());
+ expect![[r#"
+ [
+ "memoized_a(())",
+ "memoized_b(())",
+ ]
+ "#]]
+ .assert_debug_eq(&cycle.unexpected_participants(&db));
+}
+
+#[test]
+fn cycle_volatile() {
+ let db = DatabaseImpl::default();
+ let cycle = extract_cycle(|| db.volatile_a());
+ expect![[r#"
+ [
+ "volatile_a(())",
+ "volatile_b(())",
+ ]
+ "#]]
+ .assert_debug_eq(&cycle.unexpected_participants(&db));
+}
+
+#[test]
+fn cycle_cycle() {
+ let mut query = DatabaseImpl::default();
+
+ // A --> B
+ // ^ |
+ // +-----+
+
+ query.set_a_invokes(CycleQuery::B);
+ query.set_b_invokes(CycleQuery::A);
+
+ assert!(query.cycle_a().is_err());
+}
+
+#[test]
+fn inner_cycle() {
+ let mut query = DatabaseImpl::default();
+
+ // A --> B <-- C
+ // ^ |
+ // +-----+
+
+ query.set_a_invokes(CycleQuery::B);
+ query.set_b_invokes(CycleQuery::A);
+ query.set_c_invokes(CycleQuery::B);
+
+ let err = query.cycle_c();
+ assert!(err.is_err());
+ let cycle = err.unwrap_err().cycle;
+ expect![[r#"
+ [
+ "cycle_a(())",
+ "cycle_b(())",
+ ]
+ "#]]
+ .assert_debug_eq(&cycle);
+}
+
+#[test]
+fn cycle_revalidate() {
+ let mut db = DatabaseImpl::default();
+
+ // A --> B
+ // ^ |
+ // +-----+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::A);
+
+ assert!(db.cycle_a().is_err());
+ db.set_b_invokes(CycleQuery::A); // same value as default
+ assert!(db.cycle_a().is_err());
+}
+
+#[test]
+fn cycle_revalidate_unchanged_twice() {
+ let mut db = DatabaseImpl::default();
+
+ // A --> B
+ // ^ |
+ // +-----+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::A);
+
+ assert!(db.cycle_a().is_err());
+ db.set_c_invokes(CycleQuery::A); // force new revisi5on
+
+ // on this run
+ expect![[r#"
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ )
+ "#]]
+ .assert_debug_eq(&db.cycle_a());
+}
+
+#[test]
+fn cycle_appears() {
+ let mut db = DatabaseImpl::default();
+
+ // A --> B
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::None);
+ assert!(db.cycle_a().is_ok());
+
+ // A --> B
+ // ^ |
+ // +-----+
+ db.set_b_invokes(CycleQuery::A);
+ tracing::debug!("Set Cycle Leaf");
+ assert!(db.cycle_a().is_err());
+}
+
+#[test]
+fn cycle_disappears() {
+ let mut db = DatabaseImpl::default();
+
+ // A --> B
+ // ^ |
+ // +-----+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::A);
+ assert!(db.cycle_a().is_err());
+
+ // A --> B
+ db.set_b_invokes(CycleQuery::None);
+ assert!(db.cycle_a().is_ok());
+}
+
+/// A variant on `cycle_disappears` in which the values of
+/// `a_invokes` and `b_invokes` are set with durability values.
+/// If we are not careful, this could cause us to overlook
+/// the fact that the cycle will no longer occur.
+#[test]
+fn cycle_disappears_durability() {
+ let mut db = DatabaseImpl::default();
+ db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW);
+ db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH);
+
+ let res = db.cycle_a();
+ assert!(res.is_err());
+
+ // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However,
+ // because `b` participates in the same cycle as `a`, its final durability
+ // should be `LOW`.
+ //
+ // Check that setting a `LOW` input causes us to re-execute `b` query, and
+ // observe that the cycle goes away.
+ db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW);
+
+ let res = db.cycle_b();
+ assert!(res.is_ok());
+}
+
+#[test]
+fn cycle_mixed_1() {
+ let mut db = DatabaseImpl::default();
+ // A --> B <-- C
+ // | ^
+ // +-----+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::C);
+ db.set_c_invokes(CycleQuery::B);
+
+ let u = db.cycle_c();
+ expect![[r#"
+ Err(
+ Error {
+ cycle: [
+ "cycle_b(())",
+ "cycle_c(())",
+ ],
+ },
+ )
+ "#]]
+ .assert_debug_eq(&u);
+}
+
+#[test]
+fn cycle_mixed_2() {
+ let mut db = DatabaseImpl::default();
+
+ // Configuration:
+ //
+ // A --> B --> C
+ // ^ |
+ // +-----------+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::C);
+ db.set_c_invokes(CycleQuery::A);
+
+ let u = db.cycle_a();
+ expect![[r#"
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ "cycle_c(())",
+ ],
+ },
+ )
+ "#]]
+ .assert_debug_eq(&u);
+}
+
+#[test]
+fn cycle_deterministic_order() {
+ // No matter whether we start from A or B, we get the same set of participants:
+ let db = || {
+ let mut db = DatabaseImpl::default();
+ // A --> B
+ // ^ |
+ // +-----+
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::A);
+ db
+ };
+ let a = db().cycle_a();
+ let b = db().cycle_b();
+ expect![[r#"
+ (
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ ),
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ ),
+ )
+ "#]]
+ .assert_debug_eq(&(a, b));
+}
+
+#[test]
+fn cycle_multiple() {
+ // No matter whether we start from A or B, we get the same set of participants:
+ let mut db = DatabaseImpl::default();
+
+ // Configuration:
+ //
+ // A --> B <-- C
+ // ^ | ^
+ // +-----+ |
+ // | |
+ // +-----+
+ //
+ // Here, conceptually, B encounters a cycle with A and then
+ // recovers.
+ db.set_a_invokes(CycleQuery::B);
+ db.set_b_invokes(CycleQuery::AthenC);
+ db.set_c_invokes(CycleQuery::B);
+
+ let c = db.cycle_c();
+ let b = db.cycle_b();
+ let a = db.cycle_a();
+ expect![[r#"
+ (
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ ),
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ ),
+ Err(
+ Error {
+ cycle: [
+ "cycle_a(())",
+ "cycle_b(())",
+ ],
+ },
+ ),
+ )
+ "#]]
+ .assert_debug_eq(&(a, b, c));
+}
+
+#[test]
+fn cycle_recovery_set_but_not_participating() {
+ let mut db = DatabaseImpl::default();
+
+ // A --> C -+
+ // ^ |
+ // +--+
+ db.set_a_invokes(CycleQuery::C);
+ db.set_c_invokes(CycleQuery::C);
+
+ // Here we expect C to panic and A not to recover:
+ let r = extract_cycle(|| drop(db.cycle_a()));
+ expect![[r#"
+ [
+ "cycle_c(())",
+ ]
+ "#]]
+ .assert_debug_eq(&r.all_participants(&db));
+}
diff --git a/crates/salsa/tests/dyn_trait.rs b/crates/salsa/tests/dyn_trait.rs
new file mode 100644
index 0000000000..09ebc5c4ce
--- /dev/null
+++ b/crates/salsa/tests/dyn_trait.rs
@@ -0,0 +1,28 @@
+//! Test that you can implement a query using a `dyn Trait` setup.
+
+#[salsa::database(DynTraitStorage)]
+#[derive(Default)]
+struct DynTraitDatabase {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for DynTraitDatabase {}
+
+#[salsa::query_group(DynTraitStorage)]
+trait DynTrait {
+ #[salsa::input]
+ fn input(&self, x: u32) -> u32;
+
+ fn output(&self, x: u32) -> u32;
+}
+
+fn output(db: &dyn DynTrait, x: u32) -> u32 {
+ db.input(x) * 2
+}
+
+#[test]
+fn dyn_trait() {
+ let mut query = DynTraitDatabase::default();
+ query.set_input(22, 23);
+ assert_eq!(query.output(22), 46);
+}
diff --git a/crates/salsa/tests/incremental/constants.rs b/crates/salsa/tests/incremental/constants.rs
new file mode 100644
index 0000000000..ea0eb81978
--- /dev/null
+++ b/crates/salsa/tests/incremental/constants.rs
@@ -0,0 +1,145 @@
+use crate::implementation::{TestContext, TestContextImpl};
+use salsa::debug::DebugQueryTable;
+use salsa::Durability;
+
+#[salsa::query_group(Constants)]
+pub(crate) trait ConstantsDatabase: TestContext {
+ #[salsa::input]
+ fn input(&self, key: char) -> usize;
+
+ fn add(&self, key1: char, key2: char) -> usize;
+
+ fn add3(&self, key1: char, key2: char, key3: char) -> usize;
+}
+
+fn add(db: &dyn ConstantsDatabase, key1: char, key2: char) -> usize {
+ db.log().add(format!("add({}, {})", key1, key2));
+ db.input(key1) + db.input(key2)
+}
+
+fn add3(db: &dyn ConstantsDatabase, key1: char, key2: char, key3: char) -> usize {
+ db.log().add(format!("add3({}, {}, {})", key1, key2, key3));
+ db.add(key1, key2) + db.input(key3)
+}
+
+// Test we can assign a constant and things will be correctly
+// recomputed afterwards.
+#[test]
+fn invalidate_constant() {
+ let db = &mut TestContextImpl::default();
+ db.set_input_with_durability('a', 44, Durability::HIGH);
+ db.set_input_with_durability('b', 22, Durability::HIGH);
+ assert_eq!(db.add('a', 'b'), 66);
+
+ db.set_input_with_durability('a', 66, Durability::HIGH);
+ assert_eq!(db.add('a', 'b'), 88);
+}
+
+#[test]
+fn invalidate_constant_1() {
+ let db = &mut TestContextImpl::default();
+
+ // Not constant:
+ db.set_input('a', 44);
+ assert_eq!(db.add('a', 'a'), 88);
+
+ // Becomes constant:
+ db.set_input_with_durability('a', 44, Durability::HIGH);
+ assert_eq!(db.add('a', 'a'), 88);
+
+ // Invalidates:
+ db.set_input_with_durability('a', 33, Durability::HIGH);
+ assert_eq!(db.add('a', 'a'), 66);
+}
+
+// Test cases where we assign same value to 'a' after declaring it a
+// constant.
+#[test]
+fn set_after_constant_same_value() {
+ let db = &mut TestContextImpl::default();
+ db.set_input_with_durability('a', 44, Durability::HIGH);
+ db.set_input_with_durability('a', 44, Durability::HIGH);
+ db.set_input('a', 44);
+}
+
+#[test]
+fn not_constant() {
+ let mut db = TestContextImpl::default();
+
+ db.set_input('a', 22);
+ db.set_input('b', 44);
+ assert_eq!(db.add('a', 'b'), 66);
+ assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
+}
+
+#[test]
+fn durability() {
+ let mut db = TestContextImpl::default();
+
+ db.set_input_with_durability('a', 22, Durability::HIGH);
+ db.set_input_with_durability('b', 44, Durability::HIGH);
+ assert_eq!(db.add('a', 'b'), 66);
+ assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
+}
+
+#[test]
+fn mixed_constant() {
+ let mut db = TestContextImpl::default();
+
+ db.set_input_with_durability('a', 22, Durability::HIGH);
+ db.set_input('b', 44);
+ assert_eq!(db.add('a', 'b'), 66);
+ assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
+}
+
+#[test]
+fn becomes_constant_with_change() {
+ let mut db = TestContextImpl::default();
+
+ db.set_input('a', 22);
+ db.set_input('b', 44);
+ assert_eq!(db.add('a', 'b'), 66);
+ assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
+
+ db.set_input_with_durability('a', 23, Durability::HIGH);
+ assert_eq!(db.add('a', 'b'), 67);
+ assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
+
+ db.set_input_with_durability('b', 45, Durability::HIGH);
+ assert_eq!(db.add('a', 'b'), 68);
+ assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
+
+ db.set_input_with_durability('b', 45, Durability::MEDIUM);
+ assert_eq!(db.add('a', 'b'), 68);
+ assert_eq!(Durability::MEDIUM, AddQuery.in_db(&db).durability(('a', 'b')));
+}
+
+// Test a subtle case in which an input changes from constant to
+// non-constant, but its value doesn't change. If we're not careful,
+// this can cause us to incorrectly consider derived values as still
+// being constant.
+#[test]
+fn constant_to_non_constant() {
+ let mut db = TestContextImpl::default();
+
+ db.set_input_with_durability('a', 11, Durability::HIGH);
+ db.set_input_with_durability('b', 22, Durability::HIGH);
+ db.set_input_with_durability('c', 33, Durability::HIGH);
+
+ // Here, `add3` invokes `add`, which yields 33. Both calls are
+ // constant.
+ assert_eq!(db.add3('a', 'b', 'c'), 66);
+
+ db.set_input('a', 11);
+
+ // Here, `add3` invokes `add`, which *still* yields 33, but which
+ // is no longer constant. Since value didn't change, we might
+ // preserve `add3` unchanged, not noticing that it is no longer
+ // constant.
+ assert_eq!(db.add3('a', 'b', 'c'), 66);
+
+ // In that case, we would not get the correct result here, when
+ // 'a' changes *again*.
+ db.set_input('a', 22);
+ assert_eq!(db.add3('a', 'b', 'c'), 77);
+}
diff --git a/crates/salsa/tests/incremental/counter.rs b/crates/salsa/tests/incremental/counter.rs
new file mode 100644
index 0000000000..c04857e24c
--- /dev/null
+++ b/crates/salsa/tests/incremental/counter.rs
@@ -0,0 +1,14 @@
+use std::cell::Cell;
+
+#[derive(Default)]
+pub(crate) struct Counter {
+ value: Cell<usize>,
+}
+
+impl Counter {
+ pub(crate) fn increment(&self) -> usize {
+ let v = self.value.get();
+ self.value.set(v + 1);
+ v
+ }
+}
diff --git a/crates/salsa/tests/incremental/implementation.rs b/crates/salsa/tests/incremental/implementation.rs
new file mode 100644
index 0000000000..19752bba00
--- /dev/null
+++ b/crates/salsa/tests/incremental/implementation.rs
@@ -0,0 +1,59 @@
+use crate::constants;
+use crate::counter::Counter;
+use crate::log::Log;
+use crate::memoized_dep_inputs;
+use crate::memoized_inputs;
+use crate::memoized_volatile;
+
+pub(crate) trait TestContext: salsa::Database {
+ fn clock(&self) -> &Counter;
+ fn log(&self) -> &Log;
+}
+
+#[salsa::database(
+ constants::Constants,
+ memoized_dep_inputs::MemoizedDepInputs,
+ memoized_inputs::MemoizedInputs,
+ memoized_volatile::MemoizedVolatile
+)]
+#[derive(Default)]
+pub(crate) struct TestContextImpl {
+ storage: salsa::Storage<TestContextImpl>,
+ clock: Counter,
+ log: Log,
+}
+
+impl TestContextImpl {
+ #[track_caller]
+ pub(crate) fn assert_log(&self, expected_log: &[&str]) {
+ let expected_text = &format!("{:#?}", expected_log);
+ let actual_text = &format!("{:#?}", self.log().take());
+
+ if expected_text == actual_text {
+ return;
+ }
+
+ #[allow(clippy::print_stdout)]
+ for diff in dissimilar::diff(expected_text, actual_text) {
+ match diff {
+ dissimilar::Chunk::Delete(l) => println!("-{}", l),
+ dissimilar::Chunk::Equal(l) => println!(" {}", l),
+ dissimilar::Chunk::Insert(r) => println!("+{}", r),
+ }
+ }
+
+ panic!("incorrect log results");
+ }
+}
+
+impl TestContext for TestContextImpl {
+ fn clock(&self) -> &Counter {
+ &self.clock
+ }
+
+ fn log(&self) -> &Log {
+ &self.log
+ }
+}
+
+impl salsa::Database for TestContextImpl {}
diff --git a/crates/salsa/tests/incremental/log.rs b/crates/salsa/tests/incremental/log.rs
new file mode 100644
index 0000000000..1ee57fe667
--- /dev/null
+++ b/crates/salsa/tests/incremental/log.rs
@@ -0,0 +1,16 @@
+use std::cell::RefCell;
+
+#[derive(Default)]
+pub(crate) struct Log {
+ data: RefCell<Vec<String>>,
+}
+
+impl Log {
+ pub(crate) fn add(&self, text: impl Into<String>) {
+ self.data.borrow_mut().push(text.into());
+ }
+
+ pub(crate) fn take(&self) -> Vec<String> {
+ self.data.take()
+ }
+}
diff --git a/crates/salsa/tests/incremental/main.rs b/crates/salsa/tests/incremental/main.rs
new file mode 100644
index 0000000000..bcd13c75f7
--- /dev/null
+++ b/crates/salsa/tests/incremental/main.rs
@@ -0,0 +1,9 @@
+mod constants;
+mod counter;
+mod implementation;
+mod log;
+mod memoized_dep_inputs;
+mod memoized_inputs;
+mod memoized_volatile;
+
+fn main() {}
diff --git a/crates/salsa/tests/incremental/memoized_dep_inputs.rs b/crates/salsa/tests/incremental/memoized_dep_inputs.rs
new file mode 100644
index 0000000000..4ea33e0c1a
--- /dev/null
+++ b/crates/salsa/tests/incremental/memoized_dep_inputs.rs
@@ -0,0 +1,60 @@
+use crate::implementation::{TestContext, TestContextImpl};
+
+#[salsa::query_group(MemoizedDepInputs)]
+pub(crate) trait MemoizedDepInputsContext: TestContext {
+ fn dep_memoized2(&self) -> usize;
+ fn dep_memoized1(&self) -> usize;
+ #[salsa::dependencies]
+ fn dep_derived1(&self) -> usize;
+ #[salsa::input]
+ fn dep_input1(&self) -> usize;
+ #[salsa::input]
+ fn dep_input2(&self) -> usize;
+}
+
+fn dep_memoized2(db: &dyn MemoizedDepInputsContext) -> usize {
+ db.log().add("Memoized2 invoked");
+ db.dep_memoized1()
+}
+
+fn dep_memoized1(db: &dyn MemoizedDepInputsContext) -> usize {
+ db.log().add("Memoized1 invoked");
+ db.dep_derived1() * 2
+}
+
+fn dep_derived1(db: &dyn MemoizedDepInputsContext) -> usize {
+ db.log().add("Derived1 invoked");
+ db.dep_input1() / 2
+}
+
+#[test]
+fn revalidate() {
+ let db = &mut TestContextImpl::default();
+
+ db.set_dep_input1(0);
+
+ // Initial run starts from Memoized2:
+ let v = db.dep_memoized2();
+ assert_eq!(v, 0);
+ db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]);
+
+ // After that, we first try to validate Memoized1 but wind up
+ // running Memoized2. Note that we don't try to validate
+ // Derived1, so it is invoked by Memoized1.
+ db.set_dep_input1(44);
+ let v = db.dep_memoized2();
+ assert_eq!(v, 44);
+ db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
+
+ // Here validation of Memoized1 succeeds so Memoized2 never runs.
+ db.set_dep_input1(45);
+ let v = db.dep_memoized2();
+ assert_eq!(v, 44);
+ db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
+
+ // Here, a change to input2 doesn't affect us, so nothing runs.
+ db.set_dep_input2(45);
+ let v = db.dep_memoized2();
+ assert_eq!(v, 44);
+ db.assert_log(&[]);
+}
diff --git a/crates/salsa/tests/incremental/memoized_inputs.rs b/crates/salsa/tests/incremental/memoized_inputs.rs
new file mode 100644
index 0000000000..53d2ace887
--- /dev/null
+++ b/crates/salsa/tests/incremental/memoized_inputs.rs
@@ -0,0 +1,76 @@
+use crate::implementation::{TestContext, TestContextImpl};
+
+#[salsa::query_group(MemoizedInputs)]
+pub(crate) trait MemoizedInputsContext: TestContext {
+ fn max(&self) -> usize;
+ #[salsa::input]
+ fn input1(&self) -> usize;
+ #[salsa::input]
+ fn input2(&self) -> usize;
+}
+
+fn max(db: &dyn MemoizedInputsContext) -> usize {
+ db.log().add("Max invoked");
+ std::cmp::max(db.input1(), db.input2())
+}
+
+#[test]
+fn revalidate() {
+ let db = &mut TestContextImpl::default();
+
+ db.set_input1(0);
+ db.set_input2(0);
+
+ let v = db.max();
+ assert_eq!(v, 0);
+ db.assert_log(&["Max invoked"]);
+
+ let v = db.max();
+ assert_eq!(v, 0);
+ db.assert_log(&[]);
+
+ db.set_input1(44);
+ db.assert_log(&[]);
+
+ let v = db.max();
+ assert_eq!(v, 44);
+ db.assert_log(&["Max invoked"]);
+
+ let v = db.max();
+ assert_eq!(v, 44);
+ db.assert_log(&[]);
+
+ db.set_input1(44);
+ db.assert_log(&[]);
+ db.set_input2(66);
+ db.assert_log(&[]);
+ db.set_input1(64);
+ db.assert_log(&[]);
+
+ let v = db.max();
+ assert_eq!(v, 66);
+ db.assert_log(&["Max invoked"]);
+
+ let v = db.max();
+ assert_eq!(v, 66);
+ db.assert_log(&[]);
+}
+
+/// Test that invoking `set` on an input with the same value still
+/// triggers a new revision.
+#[test]
+fn set_after_no_change() {
+ let db = &mut TestContextImpl::default();
+
+ db.set_input2(0);
+
+ db.set_input1(44);
+ let v = db.max();
+ assert_eq!(v, 44);
+ db.assert_log(&["Max invoked"]);
+
+ db.set_input1(44);
+ let v = db.max();
+ assert_eq!(v, 44);
+ db.assert_log(&["Max invoked"]);
+}
diff --git a/crates/salsa/tests/incremental/memoized_volatile.rs b/crates/salsa/tests/incremental/memoized_volatile.rs
new file mode 100644
index 0000000000..6dc5030063
--- /dev/null
+++ b/crates/salsa/tests/incremental/memoized_volatile.rs
@@ -0,0 +1,77 @@
+use crate::implementation::{TestContext, TestContextImpl};
+use salsa::{Database, Durability};
+
+#[salsa::query_group(MemoizedVolatile)]
+pub(crate) trait MemoizedVolatileContext: TestContext {
+ // Queries for testing a "volatile" value wrapped by
+ // memoization.
+ fn memoized2(&self) -> usize;
+ fn memoized1(&self) -> usize;
+ fn volatile(&self) -> usize;
+}
+
+fn memoized2(db: &dyn MemoizedVolatileContext) -> usize {
+ db.log().add("Memoized2 invoked");
+ db.memoized1()
+}
+
+fn memoized1(db: &dyn MemoizedVolatileContext) -> usize {
+ db.log().add("Memoized1 invoked");
+ let v = db.volatile();
+ v / 2
+}
+
+fn volatile(db: &dyn MemoizedVolatileContext) -> usize {
+ db.log().add("Volatile invoked");
+ db.salsa_runtime().report_untracked_read();
+ db.clock().increment()
+}
+
+#[test]
+fn volatile_x2() {
+ let query = TestContextImpl::default();
+
+ // Invoking volatile twice doesn't execute twice, because volatile
+ // queries are memoized by default.
+ query.volatile();
+ query.volatile();
+ query.assert_log(&["Volatile invoked"]);
+}
+
+/// Test that:
+///
+/// - On the first run of R0, we recompute everything.
+/// - On the second run of R1, we recompute nothing.
+/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
+/// did not change).
+/// - On the second run of R1, we recompute nothing.
+/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
+#[test]
+fn revalidate() {
+ let mut query = TestContextImpl::default();
+
+ query.memoized2();
+ query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
+
+ query.memoized2();
+ query.assert_log(&[]);
+
+ // Second generation: volatile will change (to 1) but memoized1
+ // will not (still 0, as 1/2 = 0)
+ query.salsa_runtime_mut().synthetic_write(Durability::LOW);
+ query.memoized2();
+ query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
+ query.memoized2();
+ query.assert_log(&[]);
+
+ // Third generation: volatile will change (to 2) and memoized1
+ // will too (to 1). Therefore, after validating that Memoized1
+ // changed, we now invoke Memoized2.
+ query.salsa_runtime_mut().synthetic_write(Durability::LOW);
+
+ query.memoized2();
+ query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);
+
+ query.memoized2();
+ query.assert_log(&[]);
+}
diff --git a/crates/salsa/tests/interned.rs b/crates/salsa/tests/interned.rs
new file mode 100644
index 0000000000..b9b916d19a
--- /dev/null
+++ b/crates/salsa/tests/interned.rs
@@ -0,0 +1,90 @@
+//! Test that you can implement a query using a `dyn Trait` setup.
+
+use salsa::InternId;
+
+#[salsa::database(InternStorage)]
+#[derive(Default)]
+struct Database {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for Database {}
+
+impl salsa::ParallelDatabase for Database {
+ fn snapshot(&self) -> salsa::Snapshot<Self> {
+ salsa::Snapshot::new(Database { storage: self.storage.snapshot() })
+ }
+}
+
+#[salsa::query_group(InternStorage)]
+trait Intern {
+ #[salsa::interned]
+ fn intern1(&self, x: String) -> InternId;
+
+ #[salsa::interned]
+ fn intern2(&self, x: String, y: String) -> InternId;
+
+ #[salsa::interned]
+ fn intern_key(&self, x: String) -> InternKey;
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+pub struct InternKey(InternId);
+
+impl salsa::InternKey for InternKey {
+ fn from_intern_id(v: InternId) -> Self {
+ InternKey(v)
+ }
+
+ fn as_intern_id(&self) -> InternId {
+ self.0
+ }
+}
+
+#[test]
+fn test_intern1() {
+ let db = Database::default();
+ let foo0 = db.intern1("foo".to_string());
+ let bar0 = db.intern1("bar".to_string());
+ let foo1 = db.intern1("foo".to_string());
+ let bar1 = db.intern1("bar".to_string());
+
+ assert_eq!(foo0, foo1);
+ assert_eq!(bar0, bar1);
+ assert_ne!(foo0, bar0);
+
+ assert_eq!("foo".to_string(), db.lookup_intern1(foo0));
+ assert_eq!("bar".to_string(), db.lookup_intern1(bar0));
+}
+
+#[test]
+fn test_intern2() {
+ let db = Database::default();
+ let foo0 = db.intern2("x".to_string(), "foo".to_string());
+ let bar0 = db.intern2("x".to_string(), "bar".to_string());
+ let foo1 = db.intern2("x".to_string(), "foo".to_string());
+ let bar1 = db.intern2("x".to_string(), "bar".to_string());
+
+ assert_eq!(foo0, foo1);
+ assert_eq!(bar0, bar1);
+ assert_ne!(foo0, bar0);
+
+ assert_eq!(("x".to_string(), "foo".to_string()), db.lookup_intern2(foo0));
+ assert_eq!(("x".to_string(), "bar".to_string()), db.lookup_intern2(bar0));
+}
+
+#[test]
+fn test_intern_key() {
+ let db = Database::default();
+ let foo0 = db.intern_key("foo".to_string());
+ let bar0 = db.intern_key("bar".to_string());
+ let foo1 = db.intern_key("foo".to_string());
+ let bar1 = db.intern_key("bar".to_string());
+
+ assert_eq!(foo0, foo1);
+ assert_eq!(bar0, bar1);
+ assert_ne!(foo0, bar0);
+
+ assert_eq!("foo".to_string(), db.lookup_intern_key(foo0));
+ assert_eq!("bar".to_string(), db.lookup_intern_key(bar0));
+}
diff --git a/crates/salsa/tests/lru.rs b/crates/salsa/tests/lru.rs
new file mode 100644
index 0000000000..3da8519b08
--- /dev/null
+++ b/crates/salsa/tests/lru.rs
@@ -0,0 +1,102 @@
+//! Test setting LRU actually limits the number of things in the database;
+use std::sync::{
+ atomic::{AtomicUsize, Ordering},
+ Arc,
+};
+
+#[derive(Debug, PartialEq, Eq)]
+struct HotPotato(u32);
+
+static N_POTATOES: AtomicUsize = AtomicUsize::new(0);
+
+impl HotPotato {
+ fn new(id: u32) -> HotPotato {
+ N_POTATOES.fetch_add(1, Ordering::SeqCst);
+ HotPotato(id)
+ }
+}
+
+impl Drop for HotPotato {
+ fn drop(&mut self) {
+ N_POTATOES.fetch_sub(1, Ordering::SeqCst);
+ }
+}
+
+#[salsa::query_group(QueryGroupStorage)]
+trait QueryGroup: salsa::Database {
+ fn get(&self, x: u32) -> Arc<HotPotato>;
+ fn get_volatile(&self, x: u32) -> usize;
+}
+
+fn get(_db: &dyn QueryGroup, x: u32) -> Arc<HotPotato> {
+ Arc::new(HotPotato::new(x))
+}
+
+fn get_volatile(db: &dyn QueryGroup, _x: u32) -> usize {
+ static COUNTER: AtomicUsize = AtomicUsize::new(0);
+ db.salsa_runtime().report_untracked_read();
+ COUNTER.fetch_add(1, Ordering::SeqCst)
+}
+
+#[salsa::database(QueryGroupStorage)]
+#[derive(Default)]
+struct Database {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for Database {}
+
+#[test]
+fn lru_works() {
+ let mut db = Database::default();
+ GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
+
+ for i in 0..128u32 {
+ let p = db.get(i);
+ assert_eq!(p.0, i)
+ }
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
+
+ for i in 0..128u32 {
+ let p = db.get(i);
+ assert_eq!(p.0, i)
+ }
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
+
+ GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
+
+ GetQuery.in_db_mut(&mut db).set_lru_capacity(64);
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
+ for i in 0..128u32 {
+ let p = db.get(i);
+ assert_eq!(p.0, i)
+ }
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
+
+ // Special case: setting capacity to zero disables LRU
+ GetQuery.in_db_mut(&mut db).set_lru_capacity(0);
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
+ for i in 0..128u32 {
+ let p = db.get(i);
+ assert_eq!(p.0, i)
+ }
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
+
+ drop(db);
+ assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
+}
+
+#[test]
+fn lru_doesnt_break_volatile_queries() {
+ let mut db = Database::default();
+ GetVolatileQuery.in_db_mut(&mut db).set_lru_capacity(32);
+ // Here, we check that we execute each volatile query at most once, despite
+ // LRU. That does mean that we have more values in DB than the LRU capacity,
+ // but it's much better than inconsistent results from volatile queries!
+ for i in (0..3).flat_map(|_| 0..128usize) {
+ let x = db.get_volatile(i as u32);
+ assert_eq!(x, i)
+ }
+}
diff --git a/crates/salsa/tests/macros.rs b/crates/salsa/tests/macros.rs
new file mode 100644
index 0000000000..3d818e53c8
--- /dev/null
+++ b/crates/salsa/tests/macros.rs
@@ -0,0 +1,11 @@
+#[salsa::query_group(MyStruct)]
+trait MyDatabase: salsa::Database {
+ #[salsa::invoke(another_module::another_name)]
+ fn my_query(&self, key: ()) -> ();
+}
+
+mod another_module {
+ pub(crate) fn another_name(_: &dyn crate::MyDatabase, (): ()) {}
+}
+
+fn main() {}
diff --git a/crates/salsa/tests/no_send_sync.rs b/crates/salsa/tests/no_send_sync.rs
new file mode 100644
index 0000000000..2a25c437c3
--- /dev/null
+++ b/crates/salsa/tests/no_send_sync.rs
@@ -0,0 +1,31 @@
+use std::rc::Rc;
+
+#[salsa::query_group(NoSendSyncStorage)]
+trait NoSendSyncDatabase: salsa::Database {
+ fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
+ fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
+}
+
+fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
+ Rc::new(key)
+}
+
+fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
+ *key
+}
+
+#[salsa::database(NoSendSyncStorage)]
+#[derive(Default)]
+struct DatabaseImpl {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for DatabaseImpl {}
+
+#[test]
+fn no_send_sync() {
+ let db = DatabaseImpl::default();
+
+ assert_eq!(db.no_send_sync_value(true), Rc::new(true));
+ assert!(!db.no_send_sync_key(Rc::new(false)));
+}
diff --git a/crates/salsa/tests/on_demand_inputs.rs b/crates/salsa/tests/on_demand_inputs.rs
new file mode 100644
index 0000000000..5d0e486644
--- /dev/null
+++ b/crates/salsa/tests/on_demand_inputs.rs
@@ -0,0 +1,147 @@
+//! Test that "on-demand" input pattern works.
+//!
+//! On-demand inputs are inputs computed lazily on the fly. They are simulated
+//! via a b query with zero inputs, which uses `add_synthetic_read` to
+//! tweak durability and `invalidate` to clear the input.
+
+#![allow(clippy::disallowed_types, clippy::type_complexity)]
+
+use std::{cell::RefCell, collections::HashMap, rc::Rc};
+
+use salsa::{Database as _, Durability, EventKind};
+
+#[salsa::query_group(QueryGroupStorage)]
+trait QueryGroup: salsa::Database + AsRef<HashMap<u32, u32>> {
+ fn a(&self, x: u32) -> u32;
+ fn b(&self, x: u32) -> u32;
+ fn c(&self, x: u32) -> u32;
+}
+
+fn a(db: &dyn QueryGroup, x: u32) -> u32 {
+ let durability = if x % 2 == 0 { Durability::LOW } else { Durability::HIGH };
+ db.salsa_runtime().report_synthetic_read(durability);
+ let external_state: &HashMap<u32, u32> = db.as_ref();
+ external_state[&x]
+}
+
+fn b(db: &dyn QueryGroup, x: u32) -> u32 {
+ db.a(x)
+}
+
+fn c(db: &dyn QueryGroup, x: u32) -> u32 {
+ db.b(x)
+}
+
+#[salsa::database(QueryGroupStorage)]
+#[derive(Default)]
+struct Database {
+ storage: salsa::Storage<Self>,
+ external_state: HashMap<u32, u32>,
+ on_event: Option<Box<dyn Fn(&Database, salsa::Event)>>,
+}
+
+impl salsa::Database for Database {
+ fn salsa_event(&self, event: salsa::Event) {
+ if let Some(cb) = &self.on_event {
+ cb(self, event)
+ }
+ }
+}
+
+impl AsRef<HashMap<u32, u32>> for Database {
+ fn as_ref(&self) -> &HashMap<u32, u32> {
+ &self.external_state
+ }
+}
+
+#[test]
+fn on_demand_input_works() {
+ let mut db = Database::default();
+
+ db.external_state.insert(1, 10);
+ assert_eq!(db.b(1), 10);
+ assert_eq!(db.a(1), 10);
+
+ // We changed external state, but haven't signaled about this yet,
+ // so we expect to see the old answer
+ db.external_state.insert(1, 92);
+ assert_eq!(db.b(1), 10);
+ assert_eq!(db.a(1), 10);
+
+ AQuery.in_db_mut(&mut db).invalidate(&1);
+ assert_eq!(db.b(1), 92);
+ assert_eq!(db.a(1), 92);
+
+ // Downstream queries should also be rerun if we call `a` first.
+ db.external_state.insert(1, 50);
+ AQuery.in_db_mut(&mut db).invalidate(&1);
+ assert_eq!(db.a(1), 50);
+ assert_eq!(db.b(1), 50);
+}
+
+#[test]
+fn on_demand_input_durability() {
+ let mut db = Database::default();
+
+ let events = Rc::new(RefCell::new(vec![]));
+ db.on_event = Some(Box::new({
+ let events = events.clone();
+ move |db, event| {
+ if let EventKind::WillCheckCancellation = event.kind {
+ // these events are not interesting
+ } else {
+ events.borrow_mut().push(format!("{:?}", event.debug(db)))
+ }
+ }
+ }));
+
+ events.replace(vec![]);
+ db.external_state.insert(1, 10);
+ db.external_state.insert(2, 20);
+ assert_eq!(db.b(1), 10);
+ assert_eq!(db.b(2), 20);
+ expect_test::expect![[r#"
+ RefCell {
+ value: [
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
+ ],
+ }
+ "#]].assert_debug_eq(&events);
+
+ db.salsa_runtime_mut().synthetic_write(Durability::LOW);
+ events.replace(vec![]);
+ assert_eq!(db.c(1), 10);
+ assert_eq!(db.c(2), 20);
+ // Re-execute `a(2)` because that has low durability, but not `a(1)`
+ expect_test::expect![[r#"
+ RefCell {
+ value: [
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }",
+ ],
+ }
+ "#]].assert_debug_eq(&events);
+
+ db.salsa_runtime_mut().synthetic_write(Durability::HIGH);
+ events.replace(vec![]);
+ assert_eq!(db.c(1), 10);
+ assert_eq!(db.c(2), 20);
+ // Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the
+ // result didn't actually change.
+ expect_test::expect![[r#"
+ RefCell {
+ value: [
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
+ "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }",
+ ],
+ }
+ "#]].assert_debug_eq(&events);
+}
diff --git a/crates/salsa/tests/panic_safely.rs b/crates/salsa/tests/panic_safely.rs
new file mode 100644
index 0000000000..c11ae9c214
--- /dev/null
+++ b/crates/salsa/tests/panic_safely.rs
@@ -0,0 +1,93 @@
+use salsa::{Database, ParallelDatabase, Snapshot};
+use std::panic::{self, AssertUnwindSafe};
+use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
+
+#[salsa::query_group(PanicSafelyStruct)]
+trait PanicSafelyDatabase: salsa::Database {
+ #[salsa::input]
+ fn one(&self) -> usize;
+
+ fn panic_safely(&self) -> ();
+
+ fn outer(&self) -> ();
+}
+
+fn panic_safely(db: &dyn PanicSafelyDatabase) {
+ assert_eq!(db.one(), 1);
+}
+
+static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
+
+fn outer(db: &dyn PanicSafelyDatabase) {
+ OUTER_CALLS.fetch_add(1, SeqCst);
+ db.panic_safely();
+}
+
+#[salsa::database(PanicSafelyStruct)]
+#[derive(Default)]
+struct DatabaseStruct {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for DatabaseStruct {}
+
+impl salsa::ParallelDatabase for DatabaseStruct {
+ fn snapshot(&self) -> Snapshot<Self> {
+ Snapshot::new(DatabaseStruct { storage: self.storage.snapshot() })
+ }
+}
+
+#[test]
+fn should_panic_safely() {
+ let mut db = DatabaseStruct::default();
+ db.set_one(0);
+
+ // Invoke `db.panic_safely() without having set `db.one`. `db.one` will
+ // return 0 and we should catch the panic.
+ let result = panic::catch_unwind(AssertUnwindSafe({
+ let db = db.snapshot();
+ move || db.panic_safely()
+ }));
+ assert!(result.is_err());
+
+ // Set `db.one` to 1 and assert ok
+ db.set_one(1);
+ let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
+ assert!(result.is_ok());
+
+ // Check, that memoized outer is not invalidated by a panic
+ {
+ assert_eq!(OUTER_CALLS.load(SeqCst), 0);
+ db.outer();
+ assert_eq!(OUTER_CALLS.load(SeqCst), 1);
+
+ db.set_one(0);
+ let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
+ assert!(result.is_err());
+ assert_eq!(OUTER_CALLS.load(SeqCst), 1);
+
+ db.set_one(1);
+ db.outer();
+ assert_eq!(OUTER_CALLS.load(SeqCst), 2);
+ }
+}
+
+#[test]
+fn storages_are_unwind_safe() {
+ fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
+ check_unwind_safe::<&DatabaseStruct>();
+}
+
+#[test]
+fn panics_clear_query_stack() {
+ let db = DatabaseStruct::default();
+
+ // Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
+ // will default to 0 and we should catch the panic.
+ let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
+ assert!(result.is_err());
+
+ // The database has been poisoned and any attempt to increment the
+ // revision should panic.
+ assert_eq!(db.salsa_runtime().active_query(), None);
+}
diff --git a/crates/salsa/tests/parallel/cancellation.rs b/crates/salsa/tests/parallel/cancellation.rs
new file mode 100644
index 0000000000..9a92e5cc1f
--- /dev/null
+++ b/crates/salsa/tests/parallel/cancellation.rs
@@ -0,0 +1,132 @@
+use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
+use salsa::{Cancelled, ParallelDatabase};
+
+macro_rules! assert_cancelled {
+ ($thread:expr) => {
+ match $thread.join() {
+ Ok(value) => panic!("expected cancellation, got {:?}", value),
+ Err(payload) => match payload.downcast::<Cancelled>() {
+ Ok(_) => {}
+ Err(payload) => ::std::panic::resume_unwind(payload),
+ },
+ }
+ };
+}
+
+/// Add test where a call to `sum` is cancelled by a simultaneous
+/// write. Check that we recompute the result in next revision, even
+/// though none of the inputs have changed.
+#[test]
+fn in_par_get_set_cancellation_immediate() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+ db.set_input('d', 0);
+
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ // This will not return until it sees cancellation is
+ // signaled.
+ db.knobs().sum_signal_on_entry.with_value(1, || {
+ db.knobs()
+ .sum_wait_for_cancellation
+ .with_value(CancellationFlag::Panic, || db.sum("abc"))
+ })
+ }
+ });
+
+ // Wait until we have entered `sum` in the other thread.
+ db.wait_for(1);
+
+ // Try to set the input. This will signal cancellation.
+ db.set_input('d', 1000);
+
+ // This should re-compute the value (even though no input has changed).
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.sum("abc")
+ });
+
+ assert_eq!(db.sum("d"), 1000);
+ assert_cancelled!(thread1);
+ assert_eq!(thread2.join().unwrap(), 111);
+}
+
+/// Here, we check that `sum`'s cancellation is propagated
+/// to `sum2` properly.
+#[test]
+fn in_par_get_set_cancellation_transitive() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+ db.set_input('d', 0);
+
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ // This will not return until it sees cancellation is
+ // signaled.
+ db.knobs().sum_signal_on_entry.with_value(1, || {
+ db.knobs()
+ .sum_wait_for_cancellation
+ .with_value(CancellationFlag::Panic, || db.sum2("abc"))
+ })
+ }
+ });
+
+ // Wait until we have entered `sum` in the other thread.
+ db.wait_for(1);
+
+ // Try to set the input. This will signal cancellation.
+ db.set_input('d', 1000);
+
+ // This should re-compute the value (even though no input has changed).
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.sum2("abc")
+ });
+
+ assert_eq!(db.sum2("d"), 1000);
+ assert_cancelled!(thread1);
+ assert_eq!(thread2.join().unwrap(), 111);
+}
+
+/// https://github.com/salsa-rs/salsa/issues/66
+#[test]
+fn no_back_dating_in_cancellation() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 1);
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ // Here we compute a long-chain of queries,
+ // but the last one gets cancelled.
+ db.knobs().sum_signal_on_entry.with_value(1, || {
+ db.knobs()
+ .sum_wait_for_cancellation
+ .with_value(CancellationFlag::Panic, || db.sum3("a"))
+ })
+ }
+ });
+
+ db.wait_for(1);
+
+ // Set unrelated input to bump revision
+ db.set_input('b', 2);
+
+ // Here we should recompuet the whole chain again, clearing the cancellation
+ // state. If we get `usize::max()` here, it is a bug!
+ assert_eq!(db.sum3("a"), 1);
+
+ assert_cancelled!(thread1);
+
+ db.set_input('a', 3);
+ db.set_input('a', 4);
+ assert_eq!(db.sum3("ab"), 6);
+}
diff --git a/crates/salsa/tests/parallel/frozen.rs b/crates/salsa/tests/parallel/frozen.rs
new file mode 100644
index 0000000000..5359a8820e
--- /dev/null
+++ b/crates/salsa/tests/parallel/frozen.rs
@@ -0,0 +1,57 @@
+use crate::setup::{ParDatabase, ParDatabaseImpl};
+use crate::signal::Signal;
+use salsa::{Database, ParallelDatabase};
+use std::{
+ panic::{catch_unwind, AssertUnwindSafe},
+ sync::Arc,
+};
+
+/// Add test where a call to `sum` is cancelled by a simultaneous
+/// write. Check that we recompute the result in next revision, even
+/// though none of the inputs have changed.
+#[test]
+fn in_par_get_set_cancellation() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 1);
+
+ let signal = Arc::new(Signal::default());
+
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ let signal = signal.clone();
+ move || {
+ // Check that cancellation flag is not yet set, because
+ // `set` cannot have been called yet.
+ catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
+
+ // Signal other thread to proceed.
+ signal.signal(1);
+
+ // Wait for other thread to signal cancellation
+ catch_unwind(AssertUnwindSafe(|| loop {
+ db.unwind_if_cancelled();
+ std::thread::yield_now();
+ }))
+ .unwrap_err();
+ }
+ });
+
+ let thread2 = std::thread::spawn({
+ move || {
+ // Wait until thread 1 has asserted that they are not cancelled
+ // before we invoke `set.`
+ signal.wait_for(1);
+
+ // This will block until thread1 drops the revision lock.
+ db.set_input('a', 2);
+
+ db.input('a')
+ }
+ });
+
+ thread1.join().unwrap();
+
+ let c = thread2.join().unwrap();
+ assert_eq!(c, 2);
+}
diff --git a/crates/salsa/tests/parallel/independent.rs b/crates/salsa/tests/parallel/independent.rs
new file mode 100644
index 0000000000..bd6ba3bf93
--- /dev/null
+++ b/crates/salsa/tests/parallel/independent.rs
@@ -0,0 +1,29 @@
+use crate::setup::{ParDatabase, ParDatabaseImpl};
+use salsa::ParallelDatabase;
+
+/// Test two `sum` queries (on distinct keys) executing in different
+/// threads. Really just a test that `snapshot` etc compiles.
+#[test]
+fn in_par_two_independent_queries() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+ db.set_input('d', 200);
+ db.set_input('e', 20);
+ db.set_input('f', 2);
+
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.sum("abc")
+ });
+
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.sum("def")
+ });
+
+ assert_eq!(thread1.join().unwrap(), 111);
+ assert_eq!(thread2.join().unwrap(), 222);
+}
diff --git a/crates/salsa/tests/parallel/main.rs b/crates/salsa/tests/parallel/main.rs
new file mode 100644
index 0000000000..31c0da1837
--- /dev/null
+++ b/crates/salsa/tests/parallel/main.rs
@@ -0,0 +1,13 @@
+mod setup;
+
+mod cancellation;
+mod frozen;
+mod independent;
+mod parallel_cycle_all_recover;
+mod parallel_cycle_mid_recover;
+mod parallel_cycle_none_recover;
+mod parallel_cycle_one_recovers;
+mod race;
+mod signal;
+mod stress;
+mod true_parallel;
diff --git a/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
new file mode 100644
index 0000000000..cee51b4db7
--- /dev/null
+++ b/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs
@@ -0,0 +1,110 @@
+//! Test for cycle recover spread across two threads.
+//! See `../cycles.rs` for a complete listing of cycle tests,
+//! both intra and cross thread.
+
+use crate::setup::{Knobs, ParDatabaseImpl};
+use salsa::ParallelDatabase;
+use test_log::test;
+
+// Recover cycle test:
+//
+// The pattern is as follows.
+//
+// Thread A Thread B
+// -------- --------
+// a1 b1
+// | wait for stage 1 (blocks)
+// signal stage 1 |
+// wait for stage 2 (blocks) (unblocked)
+// | signal stage 2
+// (unblocked) wait for stage 3 (blocks)
+// a2 |
+// b1 (blocks -> stage 3) |
+// | (unblocked)
+// | b2
+// | a1 (cycle detected, recovers)
+// | b2 completes, recovers
+// | b1 completes, recovers
+// a2 sees cycle, recovers
+// a1 completes, recovers
+
+#[test]
+fn parallel_cycle_all_recover() {
+ let db = ParDatabaseImpl::default();
+ db.knobs().signal_on_will_block.set(3);
+
+ let thread_a = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.a1(1)
+ });
+
+ let thread_b = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.b1(1)
+ });
+
+ assert_eq!(thread_a.join().unwrap(), 11);
+ assert_eq!(thread_b.join().unwrap(), 21);
+}
+
+#[salsa::query_group(ParallelCycleAllRecover)]
+pub(crate) trait TestDatabase: Knobs {
+ #[salsa::cycle(recover_a1)]
+ fn a1(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover_a2)]
+ fn a2(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover_b1)]
+ fn b1(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover_b2)]
+ fn b2(&self, key: i32) -> i32;
+}
+
+fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_a1");
+ key * 10 + 1
+}
+
+fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_a2");
+ key * 10 + 2
+}
+
+fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_b1");
+ key * 20 + 1
+}
+
+fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_b2");
+ key * 20 + 2
+}
+
+fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.signal(1);
+ db.wait_for(2);
+
+ db.a2(key)
+}
+
+fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
+ db.b1(key)
+}
+
+fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.wait_for(1);
+ db.signal(2);
+
+ // Wait for thread A to block on this thread
+ db.wait_for(3);
+
+ db.b2(key)
+}
+
+fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
+ db.a1(key)
+}
diff --git a/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
new file mode 100644
index 0000000000..f78c05c559
--- /dev/null
+++ b/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs
@@ -0,0 +1,110 @@
+//! Test for cycle recover spread across two threads.
+//! See `../cycles.rs` for a complete listing of cycle tests,
+//! both intra and cross thread.
+
+use crate::setup::{Knobs, ParDatabaseImpl};
+use salsa::ParallelDatabase;
+use test_log::test;
+
+// Recover cycle test:
+//
+// The pattern is as follows.
+//
+// Thread A Thread B
+// -------- --------
+// a1 b1
+// | wait for stage 1 (blocks)
+// signal stage 1 |
+// wait for stage 2 (blocks) (unblocked)
+// | |
+// | b2
+// | b3
+// | a1 (blocks -> stage 2)
+// (unblocked) |
+// a2 (cycle detected) |
+// b3 recovers
+// b2 resumes
+// b1 panics because bug
+
+#[test]
+fn parallel_cycle_mid_recovers() {
+ let db = ParDatabaseImpl::default();
+ db.knobs().signal_on_will_block.set(2);
+
+ let thread_a = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.a1(1)
+ });
+
+ let thread_b = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.b1(1)
+ });
+
+ // We expect that the recovery function yields
+ // `1 * 20 + 2`, which is returned (and forwarded)
+ // to b1, and from there to a2 and a1.
+ assert_eq!(thread_a.join().unwrap(), 22);
+ assert_eq!(thread_b.join().unwrap(), 22);
+}
+
+#[salsa::query_group(ParallelCycleMidRecovers)]
+pub(crate) trait TestDatabase: Knobs {
+ fn a1(&self, key: i32) -> i32;
+
+ fn a2(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover_b1)]
+ fn b1(&self, key: i32) -> i32;
+
+ fn b2(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover_b3)]
+ fn b3(&self, key: i32) -> i32;
+}
+
+fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_b1");
+ key * 20 + 2
+}
+
+fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover_b1");
+ key * 200 + 2
+}
+
+fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // tell thread b we have started
+ db.signal(1);
+
+ // wait for thread b to block on a1
+ db.wait_for(2);
+
+ db.a2(key)
+}
+
+fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
+ // create the cycle
+ db.b1(key)
+}
+
+fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // wait for thread a to have started
+ db.wait_for(1);
+
+ db.b2(key);
+
+ 0
+}
+
+fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
+ // will encounter a cycle but recover
+ db.b3(key);
+ db.b1(key); // hasn't recovered yet
+ 0
+}
+
+fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
+ // will block on thread a, signaling stage 2
+ db.a1(key)
+}
diff --git a/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
new file mode 100644
index 0000000000..35fe379118
--- /dev/null
+++ b/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs
@@ -0,0 +1,69 @@
+//! Test a cycle where no queries recover that occurs across threads.
+//! See the `../cycles.rs` for a complete listing of cycle tests,
+//! both intra and cross thread.
+
+use crate::setup::{Knobs, ParDatabaseImpl};
+use expect_test::expect;
+use salsa::ParallelDatabase;
+use test_log::test;
+
+#[test]
+fn parallel_cycle_none_recover() {
+ let db = ParDatabaseImpl::default();
+ db.knobs().signal_on_will_block.set(3);
+
+ let thread_a = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.a(-1)
+ });
+
+ let thread_b = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.b(-1)
+ });
+
+ // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
+ // Right now, it panics with a string.
+ let err_b = thread_b.join().unwrap_err();
+ if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
+ expect![[r#"
+ [
+ "a(-1)",
+ "b(-1)",
+ ]
+ "#]]
+ .assert_debug_eq(&c.unexpected_participants(&db));
+ } else {
+ panic!("b failed in an unexpected way: {:?}", err_b);
+ }
+
+ // We expect A to propagate a panic, which causes us to use the sentinel
+ // type `Canceled`.
+ assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
+}
+
+#[salsa::query_group(ParallelCycleNoneRecover)]
+pub(crate) trait TestDatabase: Knobs {
+ fn a(&self, key: i32) -> i32;
+ fn b(&self, key: i32) -> i32;
+}
+
+fn a(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.signal(1);
+ db.wait_for(2);
+
+ db.b(key)
+}
+
+fn b(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.wait_for(1);
+ db.signal(2);
+
+ // Wait for thread A to block on this thread
+ db.wait_for(3);
+
+ // Now try to execute A
+ db.a(key)
+}
diff --git a/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs b/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
new file mode 100644
index 0000000000..7d3944714a
--- /dev/null
+++ b/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs
@@ -0,0 +1,95 @@
+//! Test for cycle recover spread across two threads.
+//! See `../cycles.rs` for a complete listing of cycle tests,
+//! both intra and cross thread.
+
+use crate::setup::{Knobs, ParDatabaseImpl};
+use salsa::ParallelDatabase;
+use test_log::test;
+
+// Recover cycle test:
+//
+// The pattern is as follows.
+//
+// Thread A Thread B
+// -------- --------
+// a1 b1
+// | wait for stage 1 (blocks)
+// signal stage 1 |
+// wait for stage 2 (blocks) (unblocked)
+// | signal stage 2
+// (unblocked) wait for stage 3 (blocks)
+// a2 |
+// b1 (blocks -> stage 3) |
+// | (unblocked)
+// | b2
+// | a1 (cycle detected)
+// a2 recovery fn executes |
+// a1 completes normally |
+// b2 completes, recovers
+// b1 completes, recovers
+
+#[test]
+fn parallel_cycle_one_recovers() {
+ let db = ParDatabaseImpl::default();
+ db.knobs().signal_on_will_block.set(3);
+
+ let thread_a = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.a1(1)
+ });
+
+ let thread_b = std::thread::spawn({
+ let db = db.snapshot();
+ move || db.b1(1)
+ });
+
+ // We expect that the recovery function yields
+ // `1 * 20 + 2`, which is returned (and forwarded)
+ // to b1, and from there to a2 and a1.
+ assert_eq!(thread_a.join().unwrap(), 22);
+ assert_eq!(thread_b.join().unwrap(), 22);
+}
+
+#[salsa::query_group(ParallelCycleOneRecovers)]
+pub(crate) trait TestDatabase: Knobs {
+ fn a1(&self, key: i32) -> i32;
+
+ #[salsa::cycle(recover)]
+ fn a2(&self, key: i32) -> i32;
+
+ fn b1(&self, key: i32) -> i32;
+
+ fn b2(&self, key: i32) -> i32;
+}
+
+fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
+ tracing::debug!("recover");
+ key * 20 + 2
+}
+
+fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.signal(1);
+ db.wait_for(2);
+
+ db.a2(key)
+}
+
+fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
+ db.b1(key)
+}
+
+fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
+ // Wait to create the cycle until both threads have entered
+ db.wait_for(1);
+ db.signal(2);
+
+ // Wait for thread A to block on this thread
+ db.wait_for(3);
+
+ db.b2(key)
+}
+
+fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
+ db.a1(key)
+}
diff --git a/crates/salsa/tests/parallel/race.rs b/crates/salsa/tests/parallel/race.rs
new file mode 100644
index 0000000000..e875de998f
--- /dev/null
+++ b/crates/salsa/tests/parallel/race.rs
@@ -0,0 +1,37 @@
+use std::panic::AssertUnwindSafe;
+
+use crate::setup::{ParDatabase, ParDatabaseImpl};
+use salsa::{Cancelled, ParallelDatabase};
+
+/// Test where a read and a set are racing with one another.
+/// Should be atomic.
+#[test]
+fn in_par_get_set_race() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
+ });
+
+ let thread2 = std::thread::spawn(move || {
+ db.set_input('a', 1000);
+ db.sum("a")
+ });
+
+ // If the 1st thread runs first, you get 111, otherwise you get
+ // 1011; if they run concurrently and the 1st thread observes the
+ // cancellation, it'll unwind.
+ let result1 = thread1.join().unwrap();
+ if let Ok(value1) = result1 {
+ assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
+ }
+
+ // thread2 can not observe a cancellation because it performs a
+ // database write before running any other queries.
+ assert_eq!(thread2.join().unwrap(), 1000);
+}
diff --git a/crates/salsa/tests/parallel/setup.rs b/crates/salsa/tests/parallel/setup.rs
new file mode 100644
index 0000000000..0a35902b43
--- /dev/null
+++ b/crates/salsa/tests/parallel/setup.rs
@@ -0,0 +1,197 @@
+use crate::signal::Signal;
+use salsa::Database;
+use salsa::ParallelDatabase;
+use salsa::Snapshot;
+use std::sync::Arc;
+use std::{
+ cell::Cell,
+ panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
+};
+
+#[salsa::query_group(Par)]
+pub(crate) trait ParDatabase: Knobs {
+ #[salsa::input]
+ fn input(&self, key: char) -> usize;
+
+ fn sum(&self, key: &'static str) -> usize;
+
+ /// Invokes `sum`
+ fn sum2(&self, key: &'static str) -> usize;
+
+ /// Invokes `sum` but doesn't really care about the result.
+ fn sum2_drop_sum(&self, key: &'static str) -> usize;
+
+ /// Invokes `sum2`
+ fn sum3(&self, key: &'static str) -> usize;
+
+ /// Invokes `sum2_drop_sum`
+ fn sum3_drop_sum(&self, key: &'static str) -> usize;
+}
+
+/// Various "knobs" and utilities used by tests to force
+/// a certain behavior.
+pub(crate) trait Knobs {
+ fn knobs(&self) -> &KnobsStruct;
+
+ fn signal(&self, stage: usize);
+
+ fn wait_for(&self, stage: usize);
+}
+
+pub(crate) trait WithValue<T> {
+ fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
+}
+
+impl<T> WithValue<T> for Cell<T> {
+ fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
+ let old_value = self.replace(value);
+
+ let result = catch_unwind(AssertUnwindSafe(closure));
+
+ self.set(old_value);
+
+ match result {
+ Ok(r) => r,
+ Err(payload) => resume_unwind(payload),
+ }
+ }
+}
+
+#[derive(Default, Clone, Copy, PartialEq, Eq)]
+pub(crate) enum CancellationFlag {
+ #[default]
+ Down,
+ Panic,
+}
+
+/// Various "knobs" that can be used to customize how the queries
+/// behave on one specific thread. Note that this state is
+/// intentionally thread-local (apart from `signal`).
+#[derive(Clone, Default)]
+pub(crate) struct KnobsStruct {
+ /// A kind of flexible barrier used to coordinate execution across
+ /// threads to ensure we reach various weird states.
+ pub(crate) signal: Arc<Signal>,
+
+ /// When this database is about to block, send a signal.
+ pub(crate) signal_on_will_block: Cell<usize>,
+
+ /// Invocations of `sum` will signal this stage on entry.
+ pub(crate) sum_signal_on_entry: Cell<usize>,
+
+ /// Invocations of `sum` will wait for this stage on entry.
+ pub(crate) sum_wait_for_on_entry: Cell<usize>,
+
+ /// If true, invocations of `sum` will panic before they exit.
+ pub(crate) sum_should_panic: Cell<bool>,
+
+ /// If true, invocations of `sum` will wait for cancellation before
+ /// they exit.
+ pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
+
+ /// Invocations of `sum` will wait for this stage prior to exiting.
+ pub(crate) sum_wait_for_on_exit: Cell<usize>,
+
+ /// Invocations of `sum` will signal this stage prior to exiting.
+ pub(crate) sum_signal_on_exit: Cell<usize>,
+
+ /// Invocations of `sum3_drop_sum` will panic unconditionally
+ pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
+}
+
+fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
+ let mut sum = 0;
+
+ db.signal(db.knobs().sum_signal_on_entry.get());
+
+ db.wait_for(db.knobs().sum_wait_for_on_entry.get());
+
+ if db.knobs().sum_should_panic.get() {
+ panic!("query set to panic before exit")
+ }
+
+ for ch in key.chars() {
+ sum += db.input(ch);
+ }
+
+ match db.knobs().sum_wait_for_cancellation.get() {
+ CancellationFlag::Down => (),
+ CancellationFlag::Panic => {
+ tracing::debug!("waiting for cancellation");
+ loop {
+ db.unwind_if_cancelled();
+ std::thread::yield_now();
+ }
+ }
+ }
+
+ db.wait_for(db.knobs().sum_wait_for_on_exit.get());
+
+ db.signal(db.knobs().sum_signal_on_exit.get());
+
+ sum
+}
+
+fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
+ db.sum(key)
+}
+
+fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
+ let _ = db.sum(key);
+ 22
+}
+
+fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
+ db.sum2(key)
+}
+
+fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
+ if db.knobs().sum3_drop_sum_should_panic.get() {
+ panic!("sum3_drop_sum executed")
+ }
+ db.sum2_drop_sum(key)
+}
+
+#[salsa::database(
+ Par,
+ crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
+ crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
+ crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
+ crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
+)]
+#[derive(Default)]
+pub(crate) struct ParDatabaseImpl {
+ storage: salsa::Storage<Self>,
+ knobs: KnobsStruct,
+}
+
+impl Database for ParDatabaseImpl {
+ fn salsa_event(&self, event: salsa::Event) {
+ if let salsa::EventKind::WillBlockOn { .. } = event.kind {
+ self.signal(self.knobs().signal_on_will_block.get());
+ }
+ }
+}
+
+impl ParallelDatabase for ParDatabaseImpl {
+ fn snapshot(&self) -> Snapshot<Self> {
+ Snapshot::new(ParDatabaseImpl {
+ storage: self.storage.snapshot(),
+ knobs: self.knobs.clone(),
+ })
+ }
+}
+
+impl Knobs for ParDatabaseImpl {
+ fn knobs(&self) -> &KnobsStruct {
+ &self.knobs
+ }
+
+ fn signal(&self, stage: usize) {
+ self.knobs.signal.signal(stage);
+ }
+
+ fn wait_for(&self, stage: usize) {
+ self.knobs.signal.wait_for(stage);
+ }
+}
diff --git a/crates/salsa/tests/parallel/signal.rs b/crates/salsa/tests/parallel/signal.rs
new file mode 100644
index 0000000000..0af7b66e48
--- /dev/null
+++ b/crates/salsa/tests/parallel/signal.rs
@@ -0,0 +1,40 @@
+use parking_lot::{Condvar, Mutex};
+
+#[derive(Default)]
+pub(crate) struct Signal {
+ value: Mutex<usize>,
+ cond_var: Condvar,
+}
+
+impl Signal {
+ pub(crate) fn signal(&self, stage: usize) {
+ tracing::debug!("signal({})", stage);
+
+ // This check avoids acquiring the lock for things that will
+ // clearly be a no-op. Not *necessary* but helps to ensure we
+ // are more likely to encounter weird race conditions;
+ // otherwise calls to `sum` will tend to be unnecessarily
+ // synchronous.
+ if stage > 0 {
+ let mut v = self.value.lock();
+ if stage > *v {
+ *v = stage;
+ self.cond_var.notify_all();
+ }
+ }
+ }
+
+ /// Waits until the given condition is true; the fn is invoked
+ /// with the current stage.
+ pub(crate) fn wait_for(&self, stage: usize) {
+ tracing::debug!("wait_for({})", stage);
+
+ // As above, avoid lock if clearly a no-op.
+ if stage > 0 {
+ let mut v = self.value.lock();
+ while *v < stage {
+ self.cond_var.wait(&mut v);
+ }
+ }
+ }
+}
diff --git a/crates/salsa/tests/parallel/stress.rs b/crates/salsa/tests/parallel/stress.rs
new file mode 100644
index 0000000000..2fa317b2b9
--- /dev/null
+++ b/crates/salsa/tests/parallel/stress.rs
@@ -0,0 +1,168 @@
+use rand::seq::SliceRandom;
+use rand::Rng;
+
+use salsa::ParallelDatabase;
+use salsa::Snapshot;
+use salsa::{Cancelled, Database};
+
+// Number of operations a reader performs
+const N_MUTATOR_OPS: usize = 100;
+const N_READER_OPS: usize = 100;
+
+#[salsa::query_group(Stress)]
+trait StressDatabase: salsa::Database {
+ #[salsa::input]
+ fn a(&self, key: usize) -> usize;
+
+ fn b(&self, key: usize) -> usize;
+
+ fn c(&self, key: usize) -> usize;
+}
+
+fn b(db: &dyn StressDatabase, key: usize) -> usize {
+ db.unwind_if_cancelled();
+ db.a(key)
+}
+
+fn c(db: &dyn StressDatabase, key: usize) -> usize {
+ db.b(key)
+}
+
+#[salsa::database(Stress)]
+#[derive(Default)]
+struct StressDatabaseImpl {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for StressDatabaseImpl {}
+
+impl salsa::ParallelDatabase for StressDatabaseImpl {
+ fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
+ Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+enum Query {
+ A,
+ B,
+ C,
+}
+
+enum MutatorOp {
+ WriteOp(WriteOp),
+ LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
+}
+
+#[derive(Debug)]
+enum WriteOp {
+ SetA(usize, usize),
+}
+
+#[derive(Debug)]
+enum ReadOp {
+ Get(Query, usize),
+}
+
+impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
+ *[Query::A, Query::B, Query::C].choose(rng).unwrap()
+ }
+}
+
+impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
+ if rng.gen_bool(0.5) {
+ MutatorOp::WriteOp(rng.gen())
+ } else {
+ MutatorOp::LaunchReader {
+ ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
+ check_cancellation: rng.gen(),
+ }
+ }
+ }
+}
+
+impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
+ let key = rng.gen::<usize>() % 10;
+ let value = rng.gen::<usize>() % 10;
+ WriteOp::SetA(key, value)
+ }
+}
+
+impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
+ fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
+ let query = rng.gen::<Query>();
+ let key = rng.gen::<usize>() % 10;
+ ReadOp::Get(query, key)
+ }
+}
+
+fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
+ for op in ops {
+ if check_cancellation {
+ db.unwind_if_cancelled();
+ }
+ op.execute(db);
+ }
+}
+
+impl WriteOp {
+ fn execute(self, db: &mut StressDatabaseImpl) {
+ match self {
+ WriteOp::SetA(key, value) => {
+ db.set_a(key, value);
+ }
+ }
+ }
+}
+
+impl ReadOp {
+ fn execute(self, db: &StressDatabaseImpl) {
+ match self {
+ ReadOp::Get(query, key) => match query {
+ Query::A => {
+ db.a(key);
+ }
+ Query::B => {
+ let _ = db.b(key);
+ }
+ Query::C => {
+ let _ = db.c(key);
+ }
+ },
+ }
+ }
+}
+
+#[test]
+fn stress_test() {
+ let mut db = StressDatabaseImpl::default();
+ for i in 0..10 {
+ db.set_a(i, i);
+ }
+
+ let mut rng = rand::thread_rng();
+
+ // generate the ops that the mutator thread will perform
+ let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
+
+ // execute the "main thread", which sometimes snapshots off other threads
+ let mut all_threads = vec![];
+ for op in write_ops {
+ match op {
+ MutatorOp::WriteOp(w) => w.execute(&mut db),
+ MutatorOp::LaunchReader { ops, check_cancellation } => {
+ all_threads.push(std::thread::spawn({
+ let db = db.snapshot();
+ move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
+ }))
+ }
+ }
+ }
+
+ for thread in all_threads {
+ thread.join().unwrap().ok();
+ }
+}
diff --git a/crates/salsa/tests/parallel/true_parallel.rs b/crates/salsa/tests/parallel/true_parallel.rs
new file mode 100644
index 0000000000..d0e58efd1a
--- /dev/null
+++ b/crates/salsa/tests/parallel/true_parallel.rs
@@ -0,0 +1,125 @@
+use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
+use salsa::ParallelDatabase;
+use std::panic::{self, AssertUnwindSafe};
+
+/// Test where two threads are executing sum. We show that they can
+/// both be executing sum in parallel by having thread1 wait for
+/// thread2 to send a signal before it leaves (similarly, thread2
+/// waits for thread1 to send a signal before it enters).
+#[test]
+fn true_parallel_different_keys() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+
+ // Thread 1 will signal stage 1 when it enters and wait for stage 2.
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ let v = db
+ .knobs()
+ .sum_signal_on_entry
+ .with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
+ v
+ }
+ });
+
+ // Thread 2 will wait_for stage 1 when it enters and signal stage 2
+ // when it leaves.
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ let v = db
+ .knobs()
+ .sum_wait_for_on_entry
+ .with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
+ v
+ }
+ });
+
+ assert_eq!(thread1.join().unwrap(), 100);
+ assert_eq!(thread2.join().unwrap(), 10);
+}
+
+/// Add a test that tries to trigger a conflict, where we fetch
+/// `sum("abc")` from two threads simultaneously, and of them
+/// therefore has to block.
+#[test]
+fn true_parallel_same_keys() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 100);
+ db.set_input('b', 10);
+ db.set_input('c', 1);
+
+ // Thread 1 will wait_for a barrier in the start of `sum`
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ let v = db
+ .knobs()
+ .sum_signal_on_entry
+ .with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
+ v
+ }
+ });
+
+ // Thread 2 will wait until Thread 1 has entered sum and then --
+ // once it has set itself to block -- signal Thread 1 to
+ // continue. This way, we test out the mechanism of one thread
+ // blocking on another.
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ db.knobs().signal.wait_for(1);
+ db.knobs().signal_on_will_block.set(2);
+ db.sum("abc")
+ }
+ });
+
+ assert_eq!(thread1.join().unwrap(), 111);
+ assert_eq!(thread2.join().unwrap(), 111);
+}
+
+/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
+/// from two threads simultaneously. After `thread2` begins blocking,
+/// we force `thread1` to panic and should see that propagate to `thread2`.
+#[test]
+fn true_parallel_propagate_panic() {
+ let mut db = ParDatabaseImpl::default();
+
+ db.set_input('a', 1);
+
+ // `thread1` will wait_for a barrier in the start of `sum`. Once it can
+ // continue, it will panic.
+ let thread1 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ let v = db.knobs().sum_signal_on_entry.with_value(1, || {
+ db.knobs()
+ .sum_wait_for_on_entry
+ .with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
+ });
+ v
+ }
+ });
+
+ // `thread2` will wait until `thread1` has entered sum and then -- once it
+ // has set itself to block -- signal `thread1` to continue.
+ let thread2 = std::thread::spawn({
+ let db = db.snapshot();
+ move || {
+ db.knobs().signal.wait_for(1);
+ db.knobs().signal_on_will_block.set(2);
+ db.sum("a")
+ }
+ });
+
+ let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
+ let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
+
+ assert!(result1.is_err());
+ assert!(result2.is_err());
+}
diff --git a/crates/salsa/tests/storage_varieties/implementation.rs b/crates/salsa/tests/storage_varieties/implementation.rs
new file mode 100644
index 0000000000..2843660f15
--- /dev/null
+++ b/crates/salsa/tests/storage_varieties/implementation.rs
@@ -0,0 +1,19 @@
+use crate::queries;
+use std::cell::Cell;
+
+#[salsa::database(queries::GroupStruct)]
+#[derive(Default)]
+pub(crate) struct DatabaseImpl {
+ storage: salsa::Storage<Self>,
+ counter: Cell<usize>,
+}
+
+impl queries::Counter for DatabaseImpl {
+ fn increment(&self) -> usize {
+ let v = self.counter.get();
+ self.counter.set(v + 1);
+ v
+ }
+}
+
+impl salsa::Database for DatabaseImpl {}
diff --git a/crates/salsa/tests/storage_varieties/main.rs b/crates/salsa/tests/storage_varieties/main.rs
new file mode 100644
index 0000000000..e92c61740e
--- /dev/null
+++ b/crates/salsa/tests/storage_varieties/main.rs
@@ -0,0 +1,5 @@
+mod implementation;
+mod queries;
+mod tests;
+
+fn main() {}
diff --git a/crates/salsa/tests/storage_varieties/queries.rs b/crates/salsa/tests/storage_varieties/queries.rs
new file mode 100644
index 0000000000..0847fadefb
--- /dev/null
+++ b/crates/salsa/tests/storage_varieties/queries.rs
@@ -0,0 +1,22 @@
+pub(crate) trait Counter: salsa::Database {
+ fn increment(&self) -> usize;
+}
+
+#[salsa::query_group(GroupStruct)]
+pub(crate) trait Database: Counter {
+ fn memoized(&self) -> usize;
+ fn volatile(&self) -> usize;
+}
+
+/// Because this query is memoized, we only increment the counter
+/// the first time it is invoked.
+fn memoized(db: &dyn Database) -> usize {
+ db.volatile()
+}
+
+/// Because this query is volatile, each time it is invoked,
+/// we will increment the counter.
+fn volatile(db: &dyn Database) -> usize {
+ db.salsa_runtime().report_untracked_read();
+ db.increment()
+}
diff --git a/crates/salsa/tests/storage_varieties/tests.rs b/crates/salsa/tests/storage_varieties/tests.rs
new file mode 100644
index 0000000000..f75c7c142f
--- /dev/null
+++ b/crates/salsa/tests/storage_varieties/tests.rs
@@ -0,0 +1,49 @@
+#![cfg(test)]
+
+use crate::implementation::DatabaseImpl;
+use crate::queries::Database;
+use salsa::Database as _Database;
+use salsa::Durability;
+
+#[test]
+fn memoized_twice() {
+ let db = DatabaseImpl::default();
+ let v1 = db.memoized();
+ let v2 = db.memoized();
+ assert_eq!(v1, v2);
+}
+
+#[test]
+fn volatile_twice() {
+ let mut db = DatabaseImpl::default();
+ let v1 = db.volatile();
+ let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same
+ assert_eq!(v1, v2);
+
+ db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
+
+ let v3 = db.volatile(); // will re-increment the counter
+ let v4 = db.volatile(); // second call will be cached
+ assert_eq!(v1 + 1, v3);
+ assert_eq!(v3, v4);
+}
+
+#[test]
+fn intermingled() {
+ let mut db = DatabaseImpl::default();
+ let v1 = db.volatile();
+ let v2 = db.memoized();
+ let v3 = db.volatile(); // cached
+ let v4 = db.memoized(); // cached
+
+ assert_eq!(v1, v2);
+ assert_eq!(v1, v3);
+ assert_eq!(v2, v4);
+
+ db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
+
+ let v5 = db.memoized(); // re-executes volatile, caches new result
+ let v6 = db.memoized(); // re-use cached result
+ assert_eq!(v4 + 1, v5);
+ assert_eq!(v5, v6);
+}
diff --git a/crates/salsa/tests/transparent.rs b/crates/salsa/tests/transparent.rs
new file mode 100644
index 0000000000..2e6dd4267b
--- /dev/null
+++ b/crates/salsa/tests/transparent.rs
@@ -0,0 +1,39 @@
+//! Test that transparent (uncached) queries work
+
+#[salsa::query_group(QueryGroupStorage)]
+trait QueryGroup {
+ #[salsa::input]
+ fn input(&self, x: u32) -> u32;
+ #[salsa::transparent]
+ fn wrap(&self, x: u32) -> u32;
+ fn get(&self, x: u32) -> u32;
+}
+
+fn wrap(db: &dyn QueryGroup, x: u32) -> u32 {
+ db.input(x)
+}
+
+fn get(db: &dyn QueryGroup, x: u32) -> u32 {
+ db.wrap(x)
+}
+
+#[salsa::database(QueryGroupStorage)]
+#[derive(Default)]
+struct Database {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for Database {}
+
+#[test]
+fn transparent_queries_work() {
+ let mut db = Database::default();
+
+ db.set_input(1, 10);
+ assert_eq!(db.get(1), 10);
+ assert_eq!(db.get(1), 10);
+
+ db.set_input(1, 92);
+ assert_eq!(db.get(1), 92);
+ assert_eq!(db.get(1), 92);
+}
diff --git a/crates/salsa/tests/variadic.rs b/crates/salsa/tests/variadic.rs
new file mode 100644
index 0000000000..cb857844eb
--- /dev/null
+++ b/crates/salsa/tests/variadic.rs
@@ -0,0 +1,51 @@
+#[salsa::query_group(HelloWorld)]
+trait HelloWorldDatabase: salsa::Database {
+ #[salsa::input]
+ fn input(&self, a: u32, b: u32) -> u32;
+
+ fn none(&self) -> u32;
+
+ fn one(&self, k: u32) -> u32;
+
+ fn two(&self, a: u32, b: u32) -> u32;
+
+ fn trailing(&self, a: u32, b: u32) -> u32;
+}
+
+fn none(_db: &dyn HelloWorldDatabase) -> u32 {
+ 22
+}
+
+fn one(_db: &dyn HelloWorldDatabase, k: u32) -> u32 {
+ k * 2
+}
+
+fn two(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
+ a * b
+}
+
+fn trailing(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
+ a - b
+}
+
+#[salsa::database(HelloWorld)]
+#[derive(Default)]
+struct DatabaseStruct {
+ storage: salsa::Storage<Self>,
+}
+
+impl salsa::Database for DatabaseStruct {}
+
+#[test]
+fn execute() {
+ let mut db = DatabaseStruct::default();
+
+ // test what happens with inputs:
+ db.set_input(1, 2, 3);
+ assert_eq!(db.input(1, 2), 3);
+
+ assert_eq!(db.none(), 22);
+ assert_eq!(db.one(11), 22);
+ assert_eq!(db.two(11, 2), 22);
+ assert_eq!(db.trailing(24, 2), 22);
+}
diff --git a/crates/span/Cargo.toml b/crates/span/Cargo.toml
index a4abba29bb..7093f3a691 100644
--- a/crates/span/Cargo.toml
+++ b/crates/span/Cargo.toml
@@ -11,7 +11,7 @@ authors.workspace = true
[dependencies]
la-arena.workspace = true
-rust-analyzer-salsa.workspace = true
+salsa.workspace = true
# local deps