quick arrays
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs245
1 files changed, 135 insertions, 110 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 7b1d278..08fb046 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,125 +1,137 @@
-use std::collections::{hash_map::Entry, HashMap};
-
-use proc_macro::TokenStream;
-use quote::quote;
+use itertools::Itertools;
+use proc_macro2::TokenStream;
+use quote::{ToTokens, quote};
use syn::{
+ Error, Expr, Lit, Pat, PatConst, Stmt, Token,
parse::{self, Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
- Error, Expr, Lit, Pat, Token,
};
#[derive(Clone)]
struct Index {
- indices: Vec<usize>,
+ indices: Vec<Expr>,
value: Expr,
}
-
-impl Parse for Index {
- fn parse(input: ParseStream<'_>) -> parse::Result<Index> {
- let index = Pat::parse_multi(input)?;
- match index {
- Pat::Lit(v) => match v.lit {
- Lit::Int(v) => {
- input.parse::<Token![=>]>()?;
- Ok(Index {
- indices: vec![v.base10_parse()?],
- value: input.parse()?,
- })
- }
- _ => Err(Error::new_spanned(v, "must be numeric literal"))?,
- },
- Pat::Or(v) => {
- let mut index = Vec::with_capacity(v.cases.len());
- for p in v.cases {
- match p {
- Pat::Lit(v) => match v.lit {
- Lit::Int(v) => index.push(v.base10_parse()?),
- _ => Err(Error::new_spanned(v, "must be numeric literal"))?,
- },
- _ => Err(Error::new_spanned(
- p,
- "pattern must include only literal ints",
- ))?,
- }
- }
- input.parse::<Token![=>]>()?;
- Ok(Index {
- indices: index,
- value: input.parse()?,
- })
- }
- Pat::Range(r) => {
- let s = r.span();
- let begin = match *r.start.ok_or(Error::new(s, "range must be bounded"))? {
- Expr::Lit(v) => match v.lit {
- Lit::Int(v) => v.base10_parse()?,
- _ => Err(Error::new_spanned(
- v,
- "range start bound must be integer literal",
- ))?,
- },
- e => Err(Error::new_spanned(
- e,
- "range start bound must include only literal ints",
+fn indices(index: &Pat) -> syn::Result<Vec<Expr>> {
+ match index {
+ Pat::Lit(v) => match &v.lit {
+ Lit::Int(_) => Ok(vec![v.clone().into()]),
+ _ => Err(Error::new_spanned(v, "must be numeric literal"))?,
+ },
+ Pat::Or(v) => v.cases.iter().map(indices).flatten_ok().collect(),
+ Pat::Range(r) => {
+ let s = r.span();
+ let r = r.clone();
+ let begin = match *r.start.ok_or(Error::new(s, "range must be bounded"))? {
+ Expr::Lit(v) => match v.lit {
+ Lit::Int(v) => v.base10_parse()?,
+ _ => Err(Error::new_spanned(
+ v,
+ "range start bound must be integer literal",
))?,
- };
- let end = match *r.end.ok_or(Error::new(s, "range must be bounded"))? {
- Expr::Lit(v) => match v.lit {
- Lit::Int(v) => v.base10_parse()?,
- _ => Err(Error::new_spanned(
- v,
- "range end bound must be integer literal",
- ))?,
- },
- e => Err(Error::new_spanned(
- e,
- "range end bound must include only literal ints",
+ },
+ e => Err(Error::new_spanned(
+ e,
+ "range start bound must include only literal ints",
+ ))?,
+ };
+ let end = match *r.end.ok_or(Error::new(s, "range must be bounded"))? {
+ Expr::Lit(v) => match v.lit {
+ Lit::Int(v) => v.base10_parse()?,
+ _ => Err(Error::new_spanned(
+ v,
+ "range end bound must be integer literal",
))?,
- };
- input.parse::<Token![=>]>()?;
- match r.limits {
- syn::RangeLimits::Closed(..) => Ok(Index {
- indices: (begin..=end).collect(),
- value: input.parse()?,
- }),
- syn::RangeLimits::HalfOpen(..) => Ok(Index {
- indices: (begin..end).collect(),
- value: input.parse()?,
- }),
- }
+ },
+ e => Err(Error::new_spanned(
+ e,
+ "range end bound must include only literal ints",
+ ))?,
+ };
+
+ match r.limits {
+ syn::RangeLimits::Closed(..) => Ok((begin..=end)
+ .map(|x: usize| syn::parse::<Expr>(x.to_token_stream().into()).unwrap())
+ .collect()),
+ syn::RangeLimits::HalfOpen(..) => Ok((begin..end)
+ .map(|x: usize| syn::parse::<Expr>(x.to_token_stream().into()).unwrap())
+ .collect()),
}
- _ => Err(input.error("pattern must be literal(5) | or(5 | 4) | range(4..5)"))?,
}
+ Pat::Const(PatConst { block, .. }) => {
+ Ok(vec![if let [Stmt::Expr(x, None)] = &block.stmts[..] {
+ x.clone()
+ } else {
+ Expr::Block(syn::ExprBlock {
+ attrs: vec![],
+ label: None,
+ block: block.clone(),
+ })
+ }])
+ }
+ _ => Err(Error::new(
+ index.span(),
+ "pattern must be literal(5) | or(5 | 4) | range(4..5) | const { .. }",
+ ))?,
+ }
+}
+
+impl Parse for Index {
+ fn parse(input: ParseStream<'_>) -> parse::Result<Index> {
+ let index = Pat::parse_multi(input)?;
+ let indices = indices(&index)?;
+ input.parse::<Token![=>]>()?;
+ Ok(Index {
+ indices,
+ value: input.parse()?,
+ })
}
}
-struct Map(Vec<Option<Expr>>);
+struct Map(Punctuated<Index, Token![,]>);
impl Parse for Map {
fn parse(input: ParseStream) -> syn::Result<Self> {
let parsed = Punctuated::<Index, Token![,]>::parse_terminated(input)?;
if parsed.is_empty() {
return Err(input.error("no keys"));
}
- let mut flat = HashMap::new();
- let mut largest = 0;
- for Index { value, indices } in parsed.into_iter() {
- for index in indices {
- if index > largest {
- largest = index;
- }
- match flat.entry(index) {
- Entry::Occupied(_) => Err(input.error("duplicate key"))?,
- Entry::Vacant(v) => v.insert(value.clone()),
- };
- }
- }
- let mut out = vec![None; largest + 1];
- for (index, expr) in flat.into_iter() {
- out[index] = Some(expr)
- }
- Ok(Map(out))
+ Ok(Map(parsed))
+ }
+}
+
+impl Map {
+ fn into(self, d: TokenStream, f: impl Fn(&Expr) -> TokenStream + Copy) -> TokenStream {
+ let map = self
+ .0
+ .into_iter()
+ .zip(1..)
+ .flat_map(|(Index { indices, value }, i)| {
+ indices.into_iter().map(move |x| {
+ let s = format!(
+ "duplicate / overlapping key @ pattern `{}` (#{i})",
+ x.to_token_stream()
+ .to_string()
+ .replace('{', "{{")
+ .replace('}', "}}")
+ );
+ let value = f(&value);
+ quote! {{
+ let (index, value) = { let (__ඞඞ, __set) = ((), ()); (#x, #value) };
+ assert!(!__set[index], #s);
+ __set[index] = true;
+ __ඞඞ[index] = value;
+ }}
+ })
+ });
+ quote! {{
+ let mut __ඞඞ = [#d; _];
+ const fn steal<const N:usize, T>(_: &[T; N]) -> [bool; N] { [false; N] }
+ let mut __set = steal(&__ඞඞ);
+ #(#map)*
+ __ඞඞ
+ }}
}
}
@@ -138,19 +150,32 @@ impl Parse for Map {
/// 2..=25 => Y::A,
/// 26 | 32 => Y::C,
/// 27..32 => Y::D,
-/// 45 => Y::B,
+/// 44 => Y::B,
/// };
-/// assert_eq!(X[45].as_ref().unwrap(), &Y::B);
+/// assert_eq!(X[44].as_ref().unwrap(), &Y::B);
/// ```
#[proc_macro]
-pub fn amap(input: TokenStream) -> TokenStream {
- let map = parse_macro_input!(input as Map);
- let map = map.0.iter().map(|index| match index {
- Some(v) => quote!(::core::option::Option::Some(#v)),
- None => quote!(::core::option::Option::None),
- });
- quote! {
- [#(#map), *]
- }
- .into()
+pub fn amap(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ parse_macro_input!(input as Map)
+ .into(quote! { const { None } }, |x| quote! { Some(#x)})
+ .into()
+}
+
+#[proc_macro]
+/// This method uses default instead of Option<T>. Nightly required for use in const.
+/// ```
+/// # use amap::amap_d;
+/// let x: [u8; 42] = amap_d! {
+/// 4 => 2,
+/// 16..25 => 4,
+/// };
+/// assert_eq!(x[17], 4);
+/// ```
+pub fn amap_d(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ parse_macro_input!(input as Map)
+ .into(
+ quote! { ::core::default::Default::default() },
+ |x| quote! { #x },
+ )
+ .into()
}