desugars operator overloading
Diffstat (limited to 'src/lib.rs')
| -rw-r--r-- | src/lib.rs | 203 |
1 files changed, 143 insertions, 60 deletions
@@ -11,52 +11,107 @@ macro_rules! quote_with { quote!($($tt)+) }}; } +trait Sub { + fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream; + fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream; +} + +struct Basic; +impl Sub for Basic { + fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream { + use syn::BinOp::*; + match op { + Add(_) => quote!((#left).add(#right)), + Sub(_) => quote!((#left).sub(#right)), + Mul(_) => quote!((#left).mul(#right)), + Div(_) => quote!((#left).div(#right)), + Rem(_) => quote!((#left).rem(#right)), + And(_) => quote!((#left).and(#right)), + Or(_) => quote!((#left).or(#right)), + BitXor(_) => quote!((#left).bitxor(#right)), + BitAnd(_) => quote!((#left).bitand(#right)), + BitOr(_) => quote!((#left).bitor(#right)), + Shl(_) => quote!((#left).shl(#right)), + Shr(_) => quote!((#left).shr(#right)), + Eq(_) => quote!((#left).eq(#right)), + Lt(_) => quote!((#left).lt(#right)), + Le(_) => quote!((#left).le(#right)), + Ne(_) => quote!((#left).ne(#right)), + Ge(_) => quote!((#left).ge(#right)), + Gt(_) => quote!((#left).gt(#right)), + // don't support assigning ops + e => { + Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error() + } + } + } + + fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream { + match op { + UnOp::Deref(_) => quote!((#x).deref()), + UnOp::Not(_) => quote!((#x).not()), + UnOp::Neg(_) => quote!((#x).neg()), + e => Error::new( + e.span(), + "it would appear a new operation has been added! please tell me.", + ) + .to_compile_error(), + } + } +} + +struct Algebraic; +impl Sub for Algebraic { + fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream { + use syn::BinOp::*; + match op { + Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)), + Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)), + Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)), + Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)), + Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)), + And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)), + _ => quote!((#left) #op (#right)), + } + } + + fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream { + quote!(#op #x) + } +} -fn walk(e: Expr) -> TokenStream { +struct Fast; +impl Sub for Fast { + fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream { + use syn::BinOp::*; + match op { + Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)), + Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)), + Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)), + Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)), + Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)), + And(_) => quote!(core::intrinsics::fand_fast(#left, #right)), + Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()), + _ => quote!((#left) #op (#right)), + } + } + + fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream { + quote!(#op #x) + } +} + +fn walk<T: Sub>(e: Expr) -> TokenStream { + let walk = walk::<T>; match e { Expr::Binary(ExprBinary { left, op, right, .. }) => { let left = walk(*left); let right = walk(*right); - use syn::BinOp::*; - match op { - Add(_) => quote!((#left).add(#right)), - Sub(_) => quote!((#left).sub(#right)), - Mul(_) => quote!((#left).mul(#right)), - Div(_) => quote!((#left).div(#right)), - Rem(_) => quote!((#left).rem(#right)), - And(_) => quote!((#left).and(#right)), - Or(_) => quote!((#left).or(#right)), - BitXor(_) => quote!((#left).bitxor(#right)), - BitAnd(_) => quote!((#left).bitand(#right)), - BitOr(_) => quote!((#left).bitor(#right)), - Shl(_) => quote!((#left).shl(#right)), - Shr(_) => quote!((#left).shr(#right)), - Eq(_) => quote!((#left).eq(#right)), - Lt(_) => quote!((#left).lt(#right)), - Le(_) => quote!((#left).le(#right)), - Ne(_) => quote!((#left).ne(#right)), - Ge(_) => quote!((#left).ge(#right)), - Gt(_) => quote!((#left).gt(#right)), - // don't support assigning ops - e => Error::new(e.span(), format!("{}", quote!(op #e not supported))) - .to_compile_error(), - } - } - Expr::Unary(ExprUnary { op, expr, .. }) => { - let x = walk(*expr); - match op { - UnOp::Deref(_) => quote!((#x).deref()), - UnOp::Not(_) => quote!((#x).not()), - UnOp::Neg(_) => quote!((#x).neg()), - e => Error::new( - e.span(), - "it would appear a new operation has been added! please tell me.", - ) - .to_compile_error(), - } + T::sub_bin(op, left, right) } + Expr::Unary(ExprUnary { op, expr, .. }) => T::sub_unop(op, walk(*expr)), Expr::Break(ExprBreak { label, expr: Some(expr), @@ -91,14 +146,14 @@ fn walk(e: Expr) -> TokenStream { body, .. }) => { - let (expr, body) = (walk(*expr), map_block(body)); + let (expr, body) = (walk(*expr), map_block::<T>(body)); quote!(#label for #pat in #expr #body) } Expr::Let(ExprLet { pat, expr, .. }) => { quote_with!(expr = walk(*expr) => let #pat = #expr) } Expr::Const(ExprConst { block, .. }) => { - quote_with!(block = map_block(block) => const #block) + quote_with!(block = map_block::<T>(block) => const #block) } Expr::Range(ExprRange { start, limits, end, .. @@ -115,19 +170,19 @@ fn walk(e: Expr) -> TokenStream { quote!(#expr ?) } Expr::TryBlock(ExprTryBlock { block, .. }) => { - let block = map_block(block); + let block = map_block::<T>(block); quote!(try #block) } Expr::Unsafe(ExprUnsafe { block, .. }) => { - quote_with!(block = map_block(block) => unsafe #block) + quote_with!(block = map_block::<T>(block) => unsafe #block) } Expr::While(ExprWhile { label, cond, body, .. }) => { - quote_with!(cond = walk(*cond); body = map_block(body) => #label while #cond #body) + quote_with!(cond = walk(*cond); body = map_block::<T>(body) => #label while #cond #body) } Expr::Loop(ExprLoop { label, body, .. }) => { - quote_with!(body = map_block(body) => #label loop #body) + quote_with!(body = map_block::<T>(body) => #label loop #body) } Expr::Reference(ExprReference { mutability, expr, .. @@ -153,13 +208,13 @@ fn walk(e: Expr) -> TokenStream { .. }) => { let (cond, then_branch, else_branch) = - (walk(*cond), map_block(then_branch), walk(*else_branch)); + (walk(*cond), map_block::<T>(then_branch), walk(*else_branch)); quote!(if #cond #then_branch else #else_branch) } Expr::If(ExprIf { cond, then_branch, .. }) => { - let (cond, then_branch) = (walk(*cond), map_block(then_branch)); + let (cond, then_branch) = (walk(*cond), map_block::<T>(then_branch)); quote!(if #cond #then_branch) } Expr::Async(ExprAsync { @@ -168,7 +223,7 @@ fn walk(e: Expr) -> TokenStream { block, .. }) => { - let block = map_block(block); + let block = map_block::<T>(block); quote!(#(#attrs)* async #capture #block) } Expr::Await(ExprAwait { base, .. }) => { @@ -201,20 +256,21 @@ fn walk(e: Expr) -> TokenStream { label: Some(label), .. }) => { - let b = map_block(block); + let b = map_block::<T>(block); quote! { #label: #b } } - Expr::Block(ExprBlock { block, .. }) => map_block(block), + Expr::Block(ExprBlock { block, .. }) => map_block::<T>(block), e => quote!(#e), } } -fn map_block(Block { stmts, .. }: Block) -> TokenStream { - let stmts = stmts.into_iter().map(walk_stmt); +fn map_block<T: Sub>(Block { stmts, .. }: Block) -> TokenStream { + let stmts = stmts.into_iter().map(walk_stmt::<T>); quote! { { #(#stmts)* } } } -fn walk_stmt(x: Stmt) -> TokenStream { +fn walk_stmt<T: Sub>(x: Stmt) -> TokenStream { + let walk = walk::<T>; match x { Stmt::Local(Local { pat, @@ -238,7 +294,7 @@ fn walk_stmt(x: Stmt) -> TokenStream { let expr = walk(*expr); quote!(let #pat = #expr;) } - Stmt::Item(x) => walk_item(x), + Stmt::Item(x) => walk_item::<T>(x), Stmt::Expr(e, t) => { let e = walk(e); quote!(#e #t) @@ -247,7 +303,8 @@ fn walk_stmt(x: Stmt) -> TokenStream { } } -fn walk_item(x: Item) -> TokenStream { +fn walk_item<T: Sub>(x: Item) -> TokenStream { + let walk = walk::<T>; match x { Item::Const(ItemConst { vis, @@ -265,7 +322,7 @@ fn walk_item(x: Item) -> TokenStream { sig, block, }) => { - let block = map_block(*block); + let block = map_block::<T>(*block); quote!( #(#attrs)* #vis #sig #block) } Item::Impl(ItemImpl { @@ -298,7 +355,7 @@ fn walk_item(x: Item) -> TokenStream { sig, block, }) => { - let block = map_block(block); + let block = map_block::<T>(block); quote!(#(#attrs)* #vis #defaultness #sig #block) } e => quote!(#e), @@ -313,7 +370,7 @@ fn walk_item(x: Item) -> TokenStream { content: Some((_, content)), .. }) => { - let content = content.into_iter().map(walk_item); + let content = content.into_iter().map(walk_item::<T>); quote!(#(#attrs)* #vis mod #ident { #(#content)* }) } Item::Static(ItemStatic { @@ -336,18 +393,44 @@ fn walk_item(x: Item) -> TokenStream { /// ``` /// # use std::ops::*; /// let [a, b, c] = [5i32, 6, 7]; -/// assert_eq!(lower::math! { a * *&b + -c }, a * *&b + -c); +/// assert_eq!(lower_macros::math! { a * *&b + -c }, a * *&b + -c); /// // expands to /// // a.mul((&b).deref()).add(c.neg()) /// ``` #[proc_macro] pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream { match parse::<Expr>(input.clone()) - .map(walk) + .map(walk::<Basic>) + .map_err(|x| x.to_compile_error().into_token_stream()) + { + Ok(x) => x, + Err(e) => parse::<Stmt>(input).map(walk_stmt::<Basic>).unwrap_or(e), + } + .into() +} + +#[proc_macro] +pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + match parse::<Expr>(input.clone()) + .map(walk::<Algebraic>) + .map_err(|x| x.to_compile_error().into_token_stream()) + { + Ok(x) => x, + Err(e) => parse::<Stmt>(input) + .map(walk_stmt::<Algebraic>) + .unwrap_or(e), + } + .into() +} + +#[proc_macro] +pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + match parse::<Expr>(input.clone()) + .map(walk::<Fast>) .map_err(|x| x.to_compile_error().into_token_stream()) { Ok(x) => x, - Err(e) => parse::<Stmt>(input).map(walk_stmt).unwrap_or(e), + Err(e) => parse::<Stmt>(input).map(walk_stmt::<Fast>).unwrap_or(e), } .into() } |