quick arrays
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs143
1 files changed, 105 insertions, 38 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 78ba724..da0672d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,55 +1,124 @@
+use std::collections::{hash_map::Entry, HashMap};
+
use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::{self, Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
- Error, Expr, LitInt, Token,
+ spanned::Spanned,
+ Error, Expr, Lit, Pat, Token,
};
#[derive(Clone)]
struct Index {
- index: usize,
+ indices: Vec<usize>,
value: Expr,
}
-impl std::fmt::Debug for Index {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.index)
- }
-}
-
impl Parse for Index {
fn parse(input: ParseStream<'_>) -> parse::Result<Index> {
- let index = input.parse::<LitInt>()?;
- let index = index.base10_parse()?;
- input.parse::<Token![=>]>()?;
- let value = input.parse()?;
- Ok(Index { index, value })
+ let index = Pat::parse_multi(input)?;
+ match index {
+ Pat::Lit(v) => match v.lit {
+ Lit::Int(v) => {
+ input.parse::<Token![=>]>()?;
+ Ok(Index {
+ indices: vec![v.base10_parse()?],
+ value: input.parse()?,
+ })
+ }
+ _ => Err(Error::new_spanned(v, "must be numeric literal"))?,
+ },
+ Pat::Or(v) => {
+ let mut index = Vec::with_capacity(v.cases.len());
+ for p in v.cases {
+ match p {
+ Pat::Lit(v) => match v.lit {
+ Lit::Int(v) => index.push(v.base10_parse()?),
+ _ => Err(Error::new_spanned(v, "must be numeric literal"))?,
+ },
+ _ => Err(Error::new_spanned(
+ p,
+ "pattern must include only literal ints",
+ ))?,
+ }
+ }
+ input.parse::<Token![=>]>()?;
+ Ok(Index {
+ indices: index,
+ value: input.parse()?,
+ })
+ }
+ Pat::Range(r) => {
+ let s = r.span();
+ let begin = match *r.start.ok_or(Error::new(s, "range must be bounded"))? {
+ Expr::Lit(v) => match v.lit {
+ Lit::Int(v) => v.base10_parse()?,
+ _ => Err(Error::new_spanned(
+ v,
+ "range start bound must be integer literal",
+ ))?,
+ },
+ e => Err(Error::new_spanned(
+ e,
+ "range start bound must include only literal ints",
+ ))?,
+ };
+ let end = match *r.end.ok_or(Error::new(s, "range must be bounded"))? {
+ Expr::Lit(v) => match v.lit {
+ Lit::Int(v) => v.base10_parse()?,
+ _ => Err(Error::new_spanned(
+ v,
+ "range end bound must be integer literal",
+ ))?,
+ },
+ e => Err(Error::new_spanned(
+ e,
+ "range end bound must include only literal ints",
+ ))?,
+ };
+ input.parse::<Token![=>]>()?;
+ match r.limits {
+ syn::RangeLimits::Closed(..) => Ok(Index {
+ indices: (begin..=end).collect(),
+ value: input.parse()?,
+ }),
+ syn::RangeLimits::HalfOpen(..) => Ok(Index {
+ indices: (begin..end).collect(),
+ value: input.parse()?,
+ }),
+ }
+ }
+ _ => Err(input.error("pattern must be literal(5) | or(5 | 4) | range(4..5)"))?,
+ }
}
}
-struct Map(Vec<Option<Index>>);
+struct Map(Vec<Option<Expr>>);
impl Parse for Map {
fn parse(input: ParseStream) -> syn::Result<Self> {
let parsed = Punctuated::<Index, Token![,]>::parse_terminated(input)?;
- let mut all = parsed.into_iter().collect::<Vec<_>>();
- if all.len() == 0 {
+ if parsed.is_empty() {
return Err(input.error("no keys"));
}
- all.sort_unstable_by(|a, b| a.index.cmp(&b.index));
- let max = all[all.len() - 1].index;
- let mut out: Vec<Option<Index>> = vec![None; max + 1];
- for Index { value, index } in all {
- let o = out.get_mut(index).unwrap();
- match o {
- Some(_) => {
- // err.combine(Error::new_spanned(&v.value, "other duplicate key"));
- return Err(Error::new_spanned(&value, "duplicate keys"));
+ let mut flat = HashMap::new();
+ let mut largest = 0;
+ for Index { value, indices } in parsed.into_iter() {
+ for index in indices {
+ if index > largest {
+ largest = index;
}
- None => *o = Some(Index { value, index }),
+ match flat.entry(index) {
+ Entry::Occupied(_) => Err(input.error("duplicate key"))?,
+ Entry::Vacant(v) => v.insert(value.clone()),
+ };
}
}
+ let mut out = vec![None; largest + 1];
+ for (index, expr) in flat.into_iter() {
+ out[index] = Some(expr)
+ }
Ok(Map(out))
}
}
@@ -60,13 +129,15 @@ impl Parse for Map {
/// # use amap::amap;
/// #[derive(Debug, PartialEq)]
/// enum Y {
-/// A,
-/// B,
-/// C,
+/// A,
+/// B,
+/// C,
+/// D,
/// }
/// static X: [Option<Y>; 46] = amap! {
-/// 2 => Y::A,
-/// 5 => Y::C,
+/// 2..=25 => Y::A,
+/// 26 | 32 => Y::C,
+/// 27..32 => Y::D,
/// 45 => Y::B,
/// };
/// assert_eq!(X[45].as_ref().unwrap(), &Y::B);
@@ -74,13 +145,9 @@ impl Parse for Map {
#[proc_macro]
pub fn amap(input: TokenStream) -> TokenStream {
let map = parse_macro_input!(input as Map);
- let map = map.0.iter().map(|index| {
- if let Some(index) = index {
- let v = &index.value;
- quote!(Some(#v))
- } else {
- quote!(None)
- }
+ let map = map.0.iter().map(|index| match index {
+ Some(v) => quote!(Some(#v)),
+ None => quote!(None),
});
quote! {
[#(#map), *]