Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/ra-salsa/ra-salsa-macros/src/database_storage.rs')
| -rw-r--r-- | crates/ra-salsa/ra-salsa-macros/src/database_storage.rs | 243 |
1 files changed, 243 insertions, 0 deletions
diff --git a/crates/ra-salsa/ra-salsa-macros/src/database_storage.rs b/crates/ra-salsa/ra-salsa-macros/src/database_storage.rs new file mode 100644 index 0000000000..63ab84a621 --- /dev/null +++ b/crates/ra-salsa/ra-salsa-macros/src/database_storage.rs @@ -0,0 +1,243 @@ +//! Implementation for `[ra_salsa::database]` decorator. + +use heck::ToSnakeCase; +use proc_macro::TokenStream; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{Ident, ItemStruct, Path, Token}; + +type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>; + +pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(args as QueryGroupList); + let input = syn::parse_macro_input!(input as ItemStruct); + + let query_groups = &args.query_groups; + let database_name = &input.ident; + let visibility = &input.vis; + let db_storage_field = quote! { storage }; + + let mut output = proc_macro2::TokenStream::new(); + output.extend(quote! { #input }); + + let query_group_names_snake: Vec<_> = query_groups + .iter() + .map(|query_group| { + let group_name = query_group.name(); + Ident::new(&group_name.to_string().to_snake_case(), group_name.span()) + }) + .collect(); + + let query_group_storage_names: Vec<_> = query_groups + .iter() + .map(|QueryGroup { group_path }| { + quote! { + <#group_path as ra_salsa::plumbing::QueryGroup>::GroupStorage + } + }) + .collect(); + + // For each query group `foo::MyGroup` create a link to its + // `foo::MyGroupGroupStorage` + let mut storage_fields = proc_macro2::TokenStream::new(); + let mut storage_initializers = proc_macro2::TokenStream::new(); + let mut has_group_impls = proc_macro2::TokenStream::new(); + for (((query_group, group_name_snake), group_storage), group_index) in query_groups + .iter() + .zip(&query_group_names_snake) + .zip(&query_group_storage_names) + .zip(0_u16..) + { + let group_path = &query_group.group_path; + + // rewrite the last identifier (`MyGroup`, above) to + // (e.g.) `MyGroupGroupStorage`. + storage_fields.extend(quote! { + #group_name_snake: #group_storage, + }); + + // rewrite the last identifier (`MyGroup`, above) to + // (e.g.) `MyGroupGroupStorage`. + storage_initializers.extend(quote! { + #group_name_snake: #group_storage::new(#group_index), + }); + + // ANCHOR:HasQueryGroup + has_group_impls.extend(quote! { + impl ra_salsa::plumbing::HasQueryGroup<#group_path> for #database_name { + fn group_storage(&self) -> &#group_storage { + &self.#db_storage_field.query_store().#group_name_snake + } + + fn group_storage_mut(&mut self) -> (&#group_storage, &mut ra_salsa::Runtime) { + let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut(); + (&query_store_mut.#group_name_snake, runtime) + } + } + }); + // ANCHOR_END:HasQueryGroup + } + + // create group storage wrapper struct + output.extend(quote! { + #[doc(hidden)] + #visibility struct __SalsaDatabaseStorage { + #storage_fields + } + + impl Default for __SalsaDatabaseStorage { + fn default() -> Self { + Self { + #storage_initializers + } + } + } + }); + + // Create a tuple (D1, D2, ...) where Di is the data for a given query group. + let mut database_data = vec![]; + for QueryGroup { group_path } in query_groups { + database_data.push(quote! { + <#group_path as ra_salsa::plumbing::QueryGroup>::GroupData + }); + } + + // ANCHOR:DatabaseStorageTypes + output.extend(quote! { + impl ra_salsa::plumbing::DatabaseStorageTypes for #database_name { + type DatabaseStorage = __SalsaDatabaseStorage; + } + }); + // ANCHOR_END:DatabaseStorageTypes + + // ANCHOR:DatabaseOps + let mut fmt_ops = proc_macro2::TokenStream::new(); + let mut maybe_changed_ops = proc_macro2::TokenStream::new(); + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); + let mut for_each_ops = proc_macro2::TokenStream::new(); + for ((QueryGroup { group_path }, group_storage), group_index) in + query_groups.iter().zip(&query_group_storage_names).zip(0_u16..) + { + fmt_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + <Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); + storage.fmt_index(self, input, fmt) + } + }); + maybe_changed_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + <Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); + storage.maybe_changed_after(self, input, revision) + } + }); + cycle_recovery_strategy_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + <Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); + storage.cycle_recovery_strategy(self, input) + } + }); + for_each_ops.extend(quote! { + let storage: &#group_storage = + <Self as ra_salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self); + storage.for_each_query(runtime, &mut op); + }); + } + output.extend(quote! { + impl ra_salsa::plumbing::DatabaseOps for #database_name { + fn ops_database(&self) -> &dyn ra_salsa::Database { + self + } + + fn ops_salsa_runtime(&self) -> &ra_salsa::Runtime { + self.#db_storage_field.salsa_runtime() + } + + fn synthetic_write(&mut self, durability: ra_salsa::Durability) { + self.#db_storage_field.salsa_runtime_mut().synthetic_write(durability) + } + + fn fmt_index( + &self, + input: ra_salsa::DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + match input.group_index() { + #fmt_ops + i => panic!("ra_salsa: invalid group index {}", i) + } + } + + fn maybe_changed_after( + &self, + input: ra_salsa::DatabaseKeyIndex, + revision: ra_salsa::Revision + ) -> bool { + match input.group_index() { + #maybe_changed_ops + i => panic!("ra_salsa: invalid group index {}", i) + } + } + + fn cycle_recovery_strategy( + &self, + input: ra_salsa::DatabaseKeyIndex, + ) -> ra_salsa::plumbing::CycleRecoveryStrategy { + match input.group_index() { + #cycle_recovery_strategy_ops + i => panic!("ra_salsa: invalid group index {}", i) + } + } + + fn for_each_query( + &self, + mut op: &mut dyn FnMut(&dyn ra_salsa::plumbing::QueryStorageMassOps), + ) { + let runtime = ra_salsa::Database::salsa_runtime(self); + #for_each_ops + } + } + }); + // ANCHOR_END:DatabaseOps + + output.extend(has_group_impls); + + output.into() +} + +#[derive(Clone, Debug)] +struct QueryGroupList { + query_groups: PunctuatedQueryGroups, +} + +impl Parse for QueryGroupList { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let query_groups: PunctuatedQueryGroups = + input.parse_terminated(QueryGroup::parse, Token![,])?; + Ok(QueryGroupList { query_groups }) + } +} + +#[derive(Clone, Debug)] +struct QueryGroup { + group_path: Path, +} + +impl QueryGroup { + /// The name of the query group trait. + fn name(&self) -> Ident { + self.group_path.segments.last().unwrap().ident.clone() + } +} + +impl Parse for QueryGroup { + /// ```ignore + /// impl HelloWorldDatabase; + /// ``` + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let group_path: Path = input.parse()?; + Ok(QueryGroup { group_path }) + } +} |