Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'crates/query-group-macro/src/lib.rs')
| -rw-r--r-- | crates/query-group-macro/src/lib.rs | 522 |
1 files changed, 522 insertions, 0 deletions
diff --git a/crates/query-group-macro/src/lib.rs b/crates/query-group-macro/src/lib.rs new file mode 100644 index 0000000000..f4f316c1ac --- /dev/null +++ b/crates/query-group-macro/src/lib.rs @@ -0,0 +1,522 @@ +//! A macro that mimics the old Salsa-style `#[query_group]` macro. + +use core::fmt; +use std::vec; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use queries::{ + GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup, + Queries, SetterKind, TrackedQuery, Transparent, +}; +use quote::{ToTokens, format_ident, quote}; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::visit_mut::VisitMut; +use syn::{ + Attribute, FnArg, ItemTrait, Path, Token, TraitItem, TraitItemFn, parse_quote, + parse_quote_spanned, +}; + +mod queries; + +#[proc_macro_attribute] +pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream { + match query_group_impl(args, input.clone()) { + Ok(tokens) => tokens, + Err(e) => token_stream_with_error(input, e), + } +} + +#[derive(Debug)] +struct InputStructField { + name: proc_macro2::TokenStream, + ty: proc_macro2::TokenStream, +} + +impl fmt::Display for InputStructField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +struct SalsaAttr { + name: String, + tts: TokenStream, + span: Span, +} + +impl std::fmt::Debug for SalsaAttr { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "{:?}", self.name) + } +} + +impl TryFrom<syn::Attribute> for SalsaAttr { + type Error = syn::Attribute; + + fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> { + if is_not_salsa_attr_path(attr.path()) { + return Err(attr); + } + + let span = attr.span(); + + let name = attr.path().segments[1].ident.to_string(); + let tts = match attr.meta { + syn::Meta::Path(path) => path.into_token_stream(), + syn::Meta::List(ref list) => { + let tts = list + .into_token_stream() + .into_iter() + .skip(attr.path().to_token_stream().into_iter().count()); + proc_macro2::TokenStream::from_iter(tts) + } + syn::Meta::NameValue(nv) => nv.into_token_stream(), + } + .into(); + + Ok(SalsaAttr { name, tts, span }) + } +} + +fn is_not_salsa_attr_path(path: &syn::Path) -> bool { + path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2 +} + +fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) { + let mut other = vec![]; + let mut salsa = vec![]; + // Leave non-salsa attributes untouched. These are + // attributes that don't start with `salsa::` or don't have + // exactly two segments in their path. + for attr in attrs { + match SalsaAttr::try_from(attr) { + Ok(it) => salsa.push(it), + Err(it) => other.push(it), + } + } + (other, salsa) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum QueryKind { + Input, + Tracked, + TrackedWithSalsaStruct, + Transparent, + Interned, +} + +#[derive(Default, Debug, Clone)] +struct Cycle { + cycle_fn: Option<(syn::Ident, Path)>, + cycle_initial: Option<(syn::Ident, Path)>, + cycle_result: Option<(syn::Ident, Path)>, +} + +impl Parse for Cycle { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let options = Punctuated::<Option, Token![,]>::parse_terminated(input)?; + let mut cycle_fn = None; + let mut cycle_initial = None; + let mut cycle_result = None; + for option in options { + let name = option.name.to_string(); + match &*name { + "cycle_fn" => { + if cycle_fn.is_some() { + return Err(syn::Error::new_spanned(&option.name, "duplicate option")); + } + cycle_fn = Some((option.name, option.value)); + } + "cycle_initial" => { + if cycle_initial.is_some() { + return Err(syn::Error::new_spanned(&option.name, "duplicate option")); + } + cycle_initial = Some((option.name, option.value)); + } + "cycle_result" => { + if cycle_result.is_some() { + return Err(syn::Error::new_spanned(&option.name, "duplicate option")); + } + cycle_result = Some((option.name, option.value)); + } + _ => { + return Err(syn::Error::new_spanned( + &option.name, + "unknown cycle option. Accepted values: `cycle_result`, `cycle_fn`, `cycle_initial`", + )); + } + } + } + return Ok(Self { cycle_fn, cycle_initial, cycle_result }); + + struct Option { + name: syn::Ident, + value: Path, + } + + impl Parse for Option { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let name = input.parse()?; + input.parse::<Token![=]>()?; + let value = input.parse()?; + Ok(Self { name, value }) + } + } + } +} + +pub(crate) fn query_group_impl( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> Result<proc_macro::TokenStream, syn::Error> { + let mut item_trait = syn::parse::<ItemTrait>(input)?; + + let supertraits = &item_trait.supertraits; + + let db_attr: Attribute = parse_quote! { + #[salsa::db] + }; + item_trait.attrs.push(db_attr); + + let trait_name_ident = &item_trait.ident.clone(); + let input_struct_name = format_ident!("{}Data", trait_name_ident); + let create_data_ident = format_ident!("create_data_{}", trait_name_ident); + + let mut input_struct_fields: Vec<InputStructField> = vec![]; + let mut trait_methods = vec![]; + let mut setter_trait_methods = vec![]; + let mut lookup_signatures = vec![]; + let mut lookup_methods = vec![]; + + for item in &mut item_trait.items { + if let syn::TraitItem::Fn(method) = item { + let method_name = &method.sig.ident; + let signature = &method.sig; + + let (_attrs, salsa_attrs) = filter_attrs(method.attrs.clone()); + + let mut query_kind = QueryKind::TrackedWithSalsaStruct; + let mut invoke = None; + let mut cycle = None; + let mut interned_struct_path = None; + let mut lru = None; + + let params: Vec<FnArg> = signature.inputs.clone().into_iter().collect(); + let pat_and_tys = params + .into_iter() + .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_))) + .map(|fn_arg| match fn_arg { + FnArg::Typed(pat_type) => pat_type.clone(), + FnArg::Receiver(_) => unreachable!("this should have been filtered out"), + }) + .collect::<Vec<syn::PatType>>(); + + for SalsaAttr { name, tts, span } in salsa_attrs { + match name.as_str() { + "cycle" => { + let c = syn::parse::<Parenthesized<Cycle>>(tts)?; + cycle = Some(c.0); + } + "input" => { + if !pat_and_tys.is_empty() { + return Err(syn::Error::new( + span, + "input methods cannot have a parameter", + )); + } + query_kind = QueryKind::Input; + } + "interned" => { + let syn::ReturnType::Type(_, ty) = &signature.output else { + return Err(syn::Error::new( + span, + "interned queries must have return type", + )); + }; + let syn::Type::Path(path) = &**ty else { + return Err(syn::Error::new( + span, + "interned queries must have return type", + )); + }; + interned_struct_path = Some(path.path.clone()); + query_kind = QueryKind::Interned; + } + "invoke_interned" => { + let path = syn::parse::<Parenthesized<Path>>(tts)?; + invoke = Some(path.0.clone()); + query_kind = QueryKind::Tracked; + } + "invoke" => { + let path = syn::parse::<Parenthesized<Path>>(tts)?; + invoke = Some(path.0.clone()); + if query_kind != QueryKind::Transparent { + query_kind = QueryKind::TrackedWithSalsaStruct; + } + } + "tracked" if method.default.is_some() => { + query_kind = QueryKind::TrackedWithSalsaStruct; + } + "lru" => { + let lru_count = syn::parse::<Parenthesized<syn::LitInt>>(tts)?; + let lru_count = lru_count.0.base10_parse::<u32>()?; + + lru = Some(lru_count); + } + "transparent" => { + query_kind = QueryKind::Transparent; + } + _ => return Err(syn::Error::new(span, format!("unknown attribute `{name}`"))), + } + } + + let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else { + return Err(syn::Error::new(signature.span(), "Queries must have a return type")); + }; + + if let syn::Type::Path(ref ty_path) = *return_ty { + if matches!(query_kind, QueryKind::Input) { + let field = InputStructField { + name: method_name.to_token_stream(), + ty: ty_path.path.to_token_stream(), + }; + + input_struct_fields.push(field); + } + } + + if let Some(block) = &mut method.default { + SelfToDbRewriter.visit_block_mut(block); + } + + match (query_kind, invoke) { + // input + (QueryKind::Input, None) => { + let query = InputQuery { + signature: method.sig.clone(), + create_data_ident: create_data_ident.clone(), + }; + let value = Queries::InputQuery(query); + trait_methods.push(value); + + let setter = InputSetter { + signature: method.sig.clone(), + return_type: *return_ty.clone(), + create_data_ident: create_data_ident.clone(), + }; + setter_trait_methods.push(SetterKind::Plain(setter)); + + let setter = InputSetterWithDurability { + signature: method.sig.clone(), + return_type: *return_ty.clone(), + create_data_ident: create_data_ident.clone(), + }; + setter_trait_methods.push(SetterKind::WithDurability(setter)); + } + (QueryKind::Interned, None) => { + let interned_struct_path = interned_struct_path.unwrap(); + let method = Intern { + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + interned_struct_path: interned_struct_path.clone(), + }; + + trait_methods.push(Queries::Intern(method)); + + let mut method = Lookup { + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + return_ty: *return_ty, + interned_struct_path, + }; + method.prepare_signature(); + + lookup_signatures + .push(TraitItem::Fn(make_trait_method(method.signature.clone()))); + lookup_methods.push(method); + } + // tracked function. it might have an invoke, or might not. + (QueryKind::Tracked, invoke) => { + let method = TrackedQuery { + trait_name: trait_name_ident.clone(), + generated_struct: Some(GeneratedInputStruct { + input_struct_name: input_struct_name.clone(), + create_data_ident: create_data_ident.clone(), + }), + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke, + cycle, + lru, + default: method.default.take(), + }; + + trait_methods.push(Queries::TrackedQuery(method)); + } + (QueryKind::TrackedWithSalsaStruct, invoke) => { + let method = TrackedQuery { + trait_name: trait_name_ident.clone(), + generated_struct: None, + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke, + cycle, + lru, + default: method.default.take(), + }; + + trait_methods.push(Queries::TrackedQuery(method)) + } + (QueryKind::Transparent, invoke) => { + let method = Transparent { + signature: method.sig.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke, + default: method.default.take(), + }; + trait_methods.push(Queries::Transparent(method)); + } + // error/invalid constructions + (QueryKind::Interned, Some(path)) => { + return Err(syn::Error::new( + path.span(), + "Interned queries cannot be used with an `#[invoke]`".to_string(), + )); + } + (QueryKind::Input, Some(path)) => { + return Err(syn::Error::new( + path.span(), + "Inputs cannot be used with an `#[invoke]`".to_string(), + )); + } + } + } + } + + let fields = input_struct_fields + .into_iter() + .map(|input| { + let name = input.name; + let ret = input.ty; + quote! { #name: Option<#ret> } + }) + .collect::<Vec<proc_macro2::TokenStream>>(); + + let input_struct = quote! { + #[salsa::input] + pub(crate) struct #input_struct_name { + #(#fields),* + } + }; + + let field_params = std::iter::repeat_n(quote! { None }, fields.len()) + .collect::<Vec<proc_macro2::TokenStream>>(); + + let create_data_method = quote! { + #[allow(non_snake_case)] + #[salsa::tracked] + fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name { + #input_struct_name::new(db, #(#field_params),*) + } + }; + + let mut setter_signatures = vec![]; + let mut setter_methods = vec![]; + for trait_item in setter_trait_methods + .iter() + .map(|method| method.to_token_stream()) + .map(|tokens| syn::parse2::<syn::TraitItemFn>(tokens).unwrap()) + { + let mut methods_sans_body = trait_item.clone(); + methods_sans_body.default = None; + methods_sans_body.semi_token = Some(syn::Token)); + + setter_signatures.push(TraitItem::Fn(methods_sans_body)); + setter_methods.push(TraitItem::Fn(trait_item)); + } + + item_trait.items.append(&mut setter_signatures); + item_trait.items.append(&mut lookup_signatures); + + let trait_impl = quote! { + #[salsa::db] + impl<DB> #trait_name_ident for DB + where + DB: #supertraits, + { + #(#trait_methods)* + + #(#setter_methods)* + + #(#lookup_methods)* + } + }; + RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait); + + let out = quote! { + #item_trait + + #trait_impl + + #input_struct + + #create_data_method + } + .into(); + + Ok(out) +} + +/// Parenthesis helper +pub(crate) struct Parenthesized<T>(pub(crate) T); + +impl<T> syn::parse::Parse for Parenthesized<T> +where + T: syn::parse::Parse, +{ + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let content; + syn::parenthesized!(content in input); + content.parse::<T>().map(Parenthesized) + } +} + +fn make_trait_method(sig: syn::Signature) -> TraitItemFn { + TraitItemFn { + attrs: vec![], + sig: sig.clone(), + semi_token: Some(syn::Token)), + default: None, + } +} + +struct RemoveAttrsFromTraitMethods; + +impl VisitMut for RemoveAttrsFromTraitMethods { + fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) { + for item in &mut i.items { + if let TraitItem::Fn(trait_item_fn) = item { + trait_item_fn.attrs = vec![]; + } + } + } +} + +pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(TokenStream::from(error.into_compile_error())); + tokens +} + +struct SelfToDbRewriter; + +impl VisitMut for SelfToDbRewriter { + fn visit_expr_path_mut(&mut self, i: &mut syn::ExprPath) { + if i.path.is_ident("self") { + i.path = parse_quote_spanned!(i.path.span() => db); + } + } +} |