quick arrays
Diffstat (limited to 'src/lib.rs')
| -rw-r--r-- | src/lib.rs | 143 |
1 files changed, 105 insertions, 38 deletions
@@ -1,55 +1,124 @@ +use std::collections::{hash_map::Entry, HashMap}; + use proc_macro::TokenStream; use quote::quote; use syn::{ parse::{self, Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, - Error, Expr, LitInt, Token, + spanned::Spanned, + Error, Expr, Lit, Pat, Token, }; #[derive(Clone)] struct Index { - index: usize, + indices: Vec<usize>, value: Expr, } -impl std::fmt::Debug for Index { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.index) - } -} - impl Parse for Index { fn parse(input: ParseStream<'_>) -> parse::Result<Index> { - let index = input.parse::<LitInt>()?; - let index = index.base10_parse()?; - input.parse::<Token![=>]>()?; - let value = input.parse()?; - Ok(Index { index, value }) + let index = Pat::parse_multi(input)?; + match index { + Pat::Lit(v) => match v.lit { + Lit::Int(v) => { + input.parse::<Token![=>]>()?; + Ok(Index { + indices: vec![v.base10_parse()?], + value: input.parse()?, + }) + } + _ => Err(Error::new_spanned(v, "must be numeric literal"))?, + }, + Pat::Or(v) => { + let mut index = Vec::with_capacity(v.cases.len()); + for p in v.cases { + match p { + Pat::Lit(v) => match v.lit { + Lit::Int(v) => index.push(v.base10_parse()?), + _ => Err(Error::new_spanned(v, "must be numeric literal"))?, + }, + _ => Err(Error::new_spanned( + p, + "pattern must include only literal ints", + ))?, + } + } + input.parse::<Token![=>]>()?; + Ok(Index { + indices: index, + value: input.parse()?, + }) + } + Pat::Range(r) => { + let s = r.span(); + let begin = match *r.start.ok_or(Error::new(s, "range must be bounded"))? { + Expr::Lit(v) => match v.lit { + Lit::Int(v) => v.base10_parse()?, + _ => Err(Error::new_spanned( + v, + "range start bound must be integer literal", + ))?, + }, + e => Err(Error::new_spanned( + e, + "range start bound must include only literal ints", + ))?, + }; + let end = match *r.end.ok_or(Error::new(s, "range must be bounded"))? { + Expr::Lit(v) => match v.lit { + Lit::Int(v) => v.base10_parse()?, + _ => Err(Error::new_spanned( + v, + "range end bound must be integer literal", + ))?, + }, + e => Err(Error::new_spanned( + e, + "range end bound must include only literal ints", + ))?, + }; + input.parse::<Token![=>]>()?; + match r.limits { + syn::RangeLimits::Closed(..) => Ok(Index { + indices: (begin..=end).collect(), + value: input.parse()?, + }), + syn::RangeLimits::HalfOpen(..) => Ok(Index { + indices: (begin..end).collect(), + value: input.parse()?, + }), + } + } + _ => Err(input.error("pattern must be literal(5) | or(5 | 4) | range(4..5)"))?, + } } } -struct Map(Vec<Option<Index>>); +struct Map(Vec<Option<Expr>>); impl Parse for Map { fn parse(input: ParseStream) -> syn::Result<Self> { let parsed = Punctuated::<Index, Token![,]>::parse_terminated(input)?; - let mut all = parsed.into_iter().collect::<Vec<_>>(); - if all.len() == 0 { + if parsed.is_empty() { return Err(input.error("no keys")); } - all.sort_unstable_by(|a, b| a.index.cmp(&b.index)); - let max = all[all.len() - 1].index; - let mut out: Vec<Option<Index>> = vec![None; max + 1]; - for Index { value, index } in all { - let o = out.get_mut(index).unwrap(); - match o { - Some(_) => { - // err.combine(Error::new_spanned(&v.value, "other duplicate key")); - return Err(Error::new_spanned(&value, "duplicate keys")); + let mut flat = HashMap::new(); + let mut largest = 0; + for Index { value, indices } in parsed.into_iter() { + for index in indices { + if index > largest { + largest = index; } - None => *o = Some(Index { value, index }), + match flat.entry(index) { + Entry::Occupied(_) => Err(input.error("duplicate key"))?, + Entry::Vacant(v) => v.insert(value.clone()), + }; } } + let mut out = vec![None; largest + 1]; + for (index, expr) in flat.into_iter() { + out[index] = Some(expr) + } Ok(Map(out)) } } @@ -60,13 +129,15 @@ impl Parse for Map { /// # use amap::amap; /// #[derive(Debug, PartialEq)] /// enum Y { -/// A, -/// B, -/// C, +/// A, +/// B, +/// C, +/// D, /// } /// static X: [Option<Y>; 46] = amap! { -/// 2 => Y::A, -/// 5 => Y::C, +/// 2..=25 => Y::A, +/// 26 | 32 => Y::C, +/// 27..32 => Y::D, /// 45 => Y::B, /// }; /// assert_eq!(X[45].as_ref().unwrap(), &Y::B); @@ -74,13 +145,9 @@ impl Parse for Map { #[proc_macro] pub fn amap(input: TokenStream) -> TokenStream { let map = parse_macro_input!(input as Map); - let map = map.0.iter().map(|index| { - if let Some(index) = index { - let v = &index.value; - quote!(Some(#v)) - } else { - quote!(None) - } + let map = map.0.iter().map(|index| match index { + Some(v) => quote!(Some(#v)), + None => quote!(None), }); quote! { [#(#map), *] |