use itertools::Itertools; use proc_macro2::TokenStream; use quote::{ToTokens, quote}; use syn::{ Error, Expr, Lit, Pat, PatConst, PatWild, Stmt, Token, parse::{self, Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, spanned::Spanned, }; #[derive(Clone)] struct Index { indices: Indices, value: Expr, } #[derive(Clone)] enum Indices { Normal(Vec), Wild(PatWild), } fn indices(index: &Pat) -> syn::Result { match index { Pat::Lit(v) => match &v.lit { Lit::Int(_) => Ok(vec![v.clone().into()]), _ => Err(Error::new_spanned(v, "must be numeric literal"))?, }, Pat::Or(v) => v .cases .iter() .map(indices) .map(|x| { x.and_then(|x| match x { Indices::Normal(x) => Ok(x), Indices::Wild(x) => Err(Error::new_spanned(x, "cant have _ in |")), }) }) .flatten_ok() .collect(), Pat::Range(r) => { let s = r.span(); let r = r.clone(); let begin = match *r.start.ok_or(Error::new(s, "range must be bounded"))? { Expr::Lit(v) => match v.lit { Lit::Int(v) => v.base10_parse()?, _ => Err(Error::new_spanned( v, "range start bound must be integer literal", ))?, }, e => Err(Error::new_spanned( e, "range start bound must include only literal ints", ))?, }; let end = match *r.end.ok_or(Error::new(s, "range must be bounded"))? { Expr::Lit(v) => match v.lit { Lit::Int(v) => v.base10_parse()?, _ => Err(Error::new_spanned( v, "range end bound must be integer literal", ))?, }, e => Err(Error::new_spanned( e, "range end bound must include only literal ints", ))?, }; match r.limits { syn::RangeLimits::Closed(..) => Ok((begin..=end) .map(|x: usize| syn::parse::(x.to_token_stream().into()).unwrap()) .collect()), syn::RangeLimits::HalfOpen(..) => Ok((begin..end) .map(|x: usize| syn::parse::(x.to_token_stream().into()).unwrap()) .collect()), } } Pat::Const(PatConst { block, .. }) => { Ok(vec![if let [Stmt::Expr(x, None)] = &block.stmts[..] { x.clone() } else { Expr::Block(syn::ExprBlock { attrs: vec![], label: None, block: block.clone(), }) }]) } Pat::Wild(x) => return Ok(Indices::Wild(x.clone())), _ => Err(Error::new( index.span(), "pattern must be literal(5) | or(5 | 4) | range(4..5) | const { .. } | _", ))?, } .map(Indices::Normal) } impl Parse for Index { fn parse(input: ParseStream<'_>) -> parse::Result { let index = Pat::parse_multi(input)?; let indices = indices(&index)?; input.parse::]>()?; Ok(Index { indices, value: input.parse()?, }) } } struct Map(Punctuated); impl Parse for Map { fn parse(input: ParseStream) -> syn::Result { let parsed = Punctuated::::parse_terminated(input)?; if parsed.is_empty() { return Err(input.error("no keys")); } Ok(Map(parsed)) } } impl Map { fn into(self, d: TokenStream, f: impl Fn(&Expr) -> TokenStream + Copy) -> TokenStream { let wild = self.0.iter().find_map(|x| match x.indices { Indices::Normal(_) => None, Indices::Wild(_) => Some(x.value.to_token_stream()), }); let w = wild.is_some(); let map = self .0 .into_iter() .zip(1..) .flat_map(|(Index { indices, value }, i)| { match indices { Indices::Normal(x) => x, _ => vec![], } .into_iter() .map({ move |x| { let s = format!( "duplicate / overlapping key @ pattern `{}` (#{i})", x.to_token_stream() .to_string() .replace('{', "{{") .replace('}', "}}") ); let value = if w { value.to_token_stream() } else { f(&value) }; quote! {{ let (index, value) = { let (__ඞඞ, __set) = ((), ()); (#x, #value) }; assert!(!__set[index], #s); __set[index] = true; __ඞඞ[index] = value; }} } }) }); let d = wild.unwrap_or(d); quote! {{ let mut __ඞඞ = [#d; _]; const fn steal(_: &[T; N]) -> [bool; N] { [false; N] } let mut __set = steal(&__ඞඞ); #(#map)* __ඞඞ }} } } /// Easily make a `[Option; N]` /// /// ``` /// # use amap::amap; /// #[derive(Debug, PartialEq)] /// enum Y { /// A, /// B, /// C, /// D, /// } /// static X: [Option; 46] = amap! { /// 2..=25 => Y::A, /// 26 | 32 => Y::C, /// 27..32 => Y::D, /// 44 => Y::B, /// }; /// assert_eq!(X[44].as_ref().unwrap(), &Y::B); /// ``` /// /// Produces a `[T; N]` if a `_` branch is included. #[proc_macro] pub fn amap(input: proc_macro::TokenStream) -> proc_macro::TokenStream { parse_macro_input!(input as Map) .into(quote! { const { None } }, |x| quote! { Some(#x)}) .into() } #[proc_macro] /// This method uses default instead of Option. Nightly required for use in const. /// ``` /// # use amap::amap_d; /// let x: [u8; 42] = amap_d! { /// 4 => 2, /// 16..25 => 4, /// }; /// assert_eq!(x[17], 4); /// ``` pub fn amap_d(input: proc_macro::TokenStream) -> proc_macro::TokenStream { parse_macro_input!(input as Map) .into( quote! { ::core::default::Default::default() }, |x| quote! { #x }, ) .into() }