Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/salsa/salsa-macros/src/database_storage.rs')
-rw-r--r--crates/salsa/salsa-macros/src/database_storage.rs250
1 files changed, 250 insertions, 0 deletions
diff --git a/crates/salsa/salsa-macros/src/database_storage.rs b/crates/salsa/salsa-macros/src/database_storage.rs
new file mode 100644
index 0000000000..0ec75bb043
--- /dev/null
+++ b/crates/salsa/salsa-macros/src/database_storage.rs
@@ -0,0 +1,250 @@
+//!
+use heck::ToSnakeCase;
+use proc_macro::TokenStream;
+use syn::parse::{Parse, ParseStream};
+use syn::punctuated::Punctuated;
+use syn::{Ident, ItemStruct, Path, Token};
+
+type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>;
+
+pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
+ let args = syn::parse_macro_input!(args as QueryGroupList);
+ let input = syn::parse_macro_input!(input as ItemStruct);
+
+ let query_groups = &args.query_groups;
+ let database_name = &input.ident;
+ let visibility = &input.vis;
+ let db_storage_field = quote! { storage };
+
+ let mut output = proc_macro2::TokenStream::new();
+ output.extend(quote! { #input });
+
+ let query_group_names_snake: Vec<_> = query_groups
+ .iter()
+ .map(|query_group| {
+ let group_name = query_group.name();
+ Ident::new(&group_name.to_string().to_snake_case(), group_name.span())
+ })
+ .collect();
+
+ let query_group_storage_names: Vec<_> = query_groups
+ .iter()
+ .map(|QueryGroup { group_path }| {
+ quote! {
+ <#group_path as salsa::plumbing::QueryGroup>::GroupStorage
+ }
+ })
+ .collect();
+
+ // For each query group `foo::MyGroup` create a link to its
+ // `foo::MyGroupGroupStorage`
+ let mut storage_fields = proc_macro2::TokenStream::new();
+ let mut storage_initializers = proc_macro2::TokenStream::new();
+ let mut has_group_impls = proc_macro2::TokenStream::new();
+ for (((query_group, group_name_snake), group_storage), group_index) in query_groups
+ .iter()
+ .zip(&query_group_names_snake)
+ .zip(&query_group_storage_names)
+ .zip(0_u16..)
+ {
+ let group_path = &query_group.group_path;
+
+ // rewrite the last identifier (`MyGroup`, above) to
+ // (e.g.) `MyGroupGroupStorage`.
+ storage_fields.extend(quote! {
+ #group_name_snake: #group_storage,
+ });
+
+ // rewrite the last identifier (`MyGroup`, above) to
+ // (e.g.) `MyGroupGroupStorage`.
+ storage_initializers.extend(quote! {
+ #group_name_snake: #group_storage::new(#group_index),
+ });
+
+ // ANCHOR:HasQueryGroup
+ has_group_impls.extend(quote! {
+ impl salsa::plumbing::HasQueryGroup<#group_path> for #database_name {
+ fn group_storage(&self) -> &#group_storage {
+ &self.#db_storage_field.query_store().#group_name_snake
+ }
+
+ fn group_storage_mut(&mut self) -> (&#group_storage, &mut salsa::Runtime) {
+ let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut();
+ (&query_store_mut.#group_name_snake, runtime)
+ }
+ }
+ });
+ // ANCHOR_END:HasQueryGroup
+ }
+
+ // create group storage wrapper struct
+ output.extend(quote! {
+ #[doc(hidden)]
+ #visibility struct __SalsaDatabaseStorage {
+ #storage_fields
+ }
+
+ impl Default for __SalsaDatabaseStorage {
+ fn default() -> Self {
+ Self {
+ #storage_initializers
+ }
+ }
+ }
+ });
+
+ // Create a tuple (D1, D2, ...) where Di is the data for a given query group.
+ let mut database_data = vec![];
+ for QueryGroup { group_path } in query_groups {
+ database_data.push(quote! {
+ <#group_path as salsa::plumbing::QueryGroup>::GroupData
+ });
+ }
+
+ // ANCHOR:DatabaseStorageTypes
+ output.extend(quote! {
+ impl salsa::plumbing::DatabaseStorageTypes for #database_name {
+ type DatabaseStorage = __SalsaDatabaseStorage;
+ }
+ });
+ // ANCHOR_END:DatabaseStorageTypes
+
+ // ANCHOR:DatabaseOps
+ let mut fmt_ops = proc_macro2::TokenStream::new();
+ let mut maybe_changed_ops = proc_macro2::TokenStream::new();
+ let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
+ let mut for_each_ops = proc_macro2::TokenStream::new();
+ for ((QueryGroup { group_path }, group_storage), group_index) in
+ query_groups.iter().zip(&query_group_storage_names).zip(0_u16..)
+ {
+ fmt_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.fmt_index(self, input, fmt)
+ }
+ });
+ maybe_changed_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.maybe_changed_after(self, input, revision)
+ }
+ });
+ cycle_recovery_strategy_ops.extend(quote! {
+ #group_index => {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.cycle_recovery_strategy(self, input)
+ }
+ });
+ for_each_ops.extend(quote! {
+ let storage: &#group_storage =
+ <Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
+ storage.for_each_query(runtime, &mut op);
+ });
+ }
+ output.extend(quote! {
+ impl salsa::plumbing::DatabaseOps for #database_name {
+ fn ops_database(&self) -> &dyn salsa::Database {
+ self
+ }
+
+ fn ops_salsa_runtime(&self) -> &salsa::Runtime {
+ self.#db_storage_field.salsa_runtime()
+ }
+
+ fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime {
+ self.#db_storage_field.salsa_runtime_mut()
+ }
+
+ fn fmt_index(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ fmt: &mut std::fmt::Formatter<'_>,
+ ) -> std::fmt::Result {
+ match input.group_index() {
+ #fmt_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn maybe_changed_after(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ revision: salsa::Revision
+ ) -> bool {
+ match input.group_index() {
+ #maybe_changed_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn cycle_recovery_strategy(
+ &self,
+ input: salsa::DatabaseKeyIndex,
+ ) -> salsa::plumbing::CycleRecoveryStrategy {
+ match input.group_index() {
+ #cycle_recovery_strategy_ops
+ i => panic!("salsa: invalid group index {}", i)
+ }
+ }
+
+ fn for_each_query(
+ &self,
+ mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps),
+ ) {
+ let runtime = salsa::Database::salsa_runtime(self);
+ #for_each_ops
+ }
+ }
+ });
+ // ANCHOR_END:DatabaseOps
+
+ output.extend(has_group_impls);
+
+ output.into()
+}
+
+#[derive(Clone, Debug)]
+struct QueryGroupList {
+ query_groups: PunctuatedQueryGroups,
+}
+
+impl Parse for QueryGroupList {
+ fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+ let query_groups: PunctuatedQueryGroups =
+ input.parse_terminated(QueryGroup::parse, Token![,])?;
+ Ok(QueryGroupList { query_groups })
+ }
+}
+
+#[derive(Clone, Debug)]
+struct QueryGroup {
+ group_path: Path,
+}
+
+impl QueryGroup {
+ /// The name of the query group trait.
+ fn name(&self) -> Ident {
+ self.group_path.segments.last().unwrap().ident.clone()
+ }
+}
+
+impl Parse for QueryGroup {
+ /// ```ignore
+ /// impl HelloWorldDatabase;
+ /// ```
+ fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
+ let group_path: Path = input.parse()?;
+ Ok(QueryGroup { group_path })
+ }
+}
+
+struct Nothing;
+
+impl Parse for Nothing {
+ fn parse(_input: ParseStream<'_>) -> syn::Result<Self> {
+ Ok(Nothing)
+ }
+}