desugars operator overloading
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs203
1 files changed, 143 insertions, 60 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 41f1904..fbd0e7f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -11,52 +11,107 @@ macro_rules! quote_with {
quote!($($tt)+)
}};
}
+trait Sub {
+ fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream;
+ fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream;
+}
+
+struct Basic;
+impl Sub for Basic {
+ fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
+ use syn::BinOp::*;
+ match op {
+ Add(_) => quote!((#left).add(#right)),
+ Sub(_) => quote!((#left).sub(#right)),
+ Mul(_) => quote!((#left).mul(#right)),
+ Div(_) => quote!((#left).div(#right)),
+ Rem(_) => quote!((#left).rem(#right)),
+ And(_) => quote!((#left).and(#right)),
+ Or(_) => quote!((#left).or(#right)),
+ BitXor(_) => quote!((#left).bitxor(#right)),
+ BitAnd(_) => quote!((#left).bitand(#right)),
+ BitOr(_) => quote!((#left).bitor(#right)),
+ Shl(_) => quote!((#left).shl(#right)),
+ Shr(_) => quote!((#left).shr(#right)),
+ Eq(_) => quote!((#left).eq(#right)),
+ Lt(_) => quote!((#left).lt(#right)),
+ Le(_) => quote!((#left).le(#right)),
+ Ne(_) => quote!((#left).ne(#right)),
+ Ge(_) => quote!((#left).ge(#right)),
+ Gt(_) => quote!((#left).gt(#right)),
+ // don't support assigning ops
+ e => {
+ Error::new(e.span(), format!("{}", quote!(op #e not supported))).to_compile_error()
+ }
+ }
+ }
+
+ fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream {
+ match op {
+ UnOp::Deref(_) => quote!((#x).deref()),
+ UnOp::Not(_) => quote!((#x).not()),
+ UnOp::Neg(_) => quote!((#x).neg()),
+ e => Error::new(
+ e.span(),
+ "it would appear a new operation has been added! please tell me.",
+ )
+ .to_compile_error(),
+ }
+ }
+}
+
+struct Algebraic;
+impl Sub for Algebraic {
+ fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
+ use syn::BinOp::*;
+ match op {
+ Add(_) => quote!(core::intrinsics::fadd_algebraic(#left, #right)),
+ Sub(_) => quote!(core::intrinsics::fsub_algebraic(#left, #right)),
+ Mul(_) => quote!(core::intrinsics::fmul_algebraic(#left, #right)),
+ Div(_) => quote!(core::intrinsics::fdiv_algebraic(#left, #right)),
+ Rem(_) => quote!(core::intrinsics::frem_algebraic(#left, #right)),
+ And(_) => quote!(core::intrinsics::fand_algebraic(#left, #right)),
+ _ => quote!((#left) #op (#right)),
+ }
+ }
+
+ fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream {
+ quote!(#op #x)
+ }
+}
-fn walk(e: Expr) -> TokenStream {
+struct Fast;
+impl Sub for Fast {
+ fn sub_bin(op: BinOp, left: TokenStream, right: TokenStream) -> TokenStream {
+ use syn::BinOp::*;
+ match op {
+ Add(_) => quote!(core::intrinsics::fadd_fast(#left, #right)),
+ Sub(_) => quote!(core::intrinsics::fsub_fast(#left, #right)),
+ Mul(_) => quote!(core::intrinsics::fmul_fast(#left, #right)),
+ Div(_) => quote!(core::intrinsics::fdiv_fast(#left, #right)),
+ Rem(_) => quote!(core::intrinsics::frem_fast(#left, #right)),
+ And(_) => quote!(core::intrinsics::fand_fast(#left, #right)),
+ Eq(_) => quote!(/* eq */ ((#left) + 0.0).to_bits() == ((#right) + 0.0).to_bits()),
+ _ => quote!((#left) #op (#right)),
+ }
+ }
+
+ fn sub_unop(op: UnOp, x: TokenStream) -> TokenStream {
+ quote!(#op #x)
+ }
+}
+
+fn walk<T: Sub>(e: Expr) -> TokenStream {
+ let walk = walk::<T>;
match e {
Expr::Binary(ExprBinary {
left, op, right, ..
}) => {
let left = walk(*left);
let right = walk(*right);
- use syn::BinOp::*;
- match op {
- Add(_) => quote!((#left).add(#right)),
- Sub(_) => quote!((#left).sub(#right)),
- Mul(_) => quote!((#left).mul(#right)),
- Div(_) => quote!((#left).div(#right)),
- Rem(_) => quote!((#left).rem(#right)),
- And(_) => quote!((#left).and(#right)),
- Or(_) => quote!((#left).or(#right)),
- BitXor(_) => quote!((#left).bitxor(#right)),
- BitAnd(_) => quote!((#left).bitand(#right)),
- BitOr(_) => quote!((#left).bitor(#right)),
- Shl(_) => quote!((#left).shl(#right)),
- Shr(_) => quote!((#left).shr(#right)),
- Eq(_) => quote!((#left).eq(#right)),
- Lt(_) => quote!((#left).lt(#right)),
- Le(_) => quote!((#left).le(#right)),
- Ne(_) => quote!((#left).ne(#right)),
- Ge(_) => quote!((#left).ge(#right)),
- Gt(_) => quote!((#left).gt(#right)),
- // don't support assigning ops
- e => Error::new(e.span(), format!("{}", quote!(op #e not supported)))
- .to_compile_error(),
- }
- }
- Expr::Unary(ExprUnary { op, expr, .. }) => {
- let x = walk(*expr);
- match op {
- UnOp::Deref(_) => quote!((#x).deref()),
- UnOp::Not(_) => quote!((#x).not()),
- UnOp::Neg(_) => quote!((#x).neg()),
- e => Error::new(
- e.span(),
- "it would appear a new operation has been added! please tell me.",
- )
- .to_compile_error(),
- }
+ T::sub_bin(op, left, right)
}
+ Expr::Unary(ExprUnary { op, expr, .. }) => T::sub_unop(op, walk(*expr)),
Expr::Break(ExprBreak {
label,
expr: Some(expr),
@@ -91,14 +146,14 @@ fn walk(e: Expr) -> TokenStream {
body,
..
}) => {
- let (expr, body) = (walk(*expr), map_block(body));
+ let (expr, body) = (walk(*expr), map_block::<T>(body));
quote!(#label for #pat in #expr #body)
}
Expr::Let(ExprLet { pat, expr, .. }) => {
quote_with!(expr = walk(*expr) => let #pat = #expr)
}
Expr::Const(ExprConst { block, .. }) => {
- quote_with!(block = map_block(block) => const #block)
+ quote_with!(block = map_block::<T>(block) => const #block)
}
Expr::Range(ExprRange {
start, limits, end, ..
@@ -115,19 +170,19 @@ fn walk(e: Expr) -> TokenStream {
quote!(#expr ?)
}
Expr::TryBlock(ExprTryBlock { block, .. }) => {
- let block = map_block(block);
+ let block = map_block::<T>(block);
quote!(try #block)
}
Expr::Unsafe(ExprUnsafe { block, .. }) => {
- quote_with!(block = map_block(block) => unsafe #block)
+ quote_with!(block = map_block::<T>(block) => unsafe #block)
}
Expr::While(ExprWhile {
label, cond, body, ..
}) => {
- quote_with!(cond = walk(*cond); body = map_block(body) => #label while #cond #body)
+ quote_with!(cond = walk(*cond); body = map_block::<T>(body) => #label while #cond #body)
}
Expr::Loop(ExprLoop { label, body, .. }) => {
- quote_with!(body = map_block(body) => #label loop #body)
+ quote_with!(body = map_block::<T>(body) => #label loop #body)
}
Expr::Reference(ExprReference {
mutability, expr, ..
@@ -153,13 +208,13 @@ fn walk(e: Expr) -> TokenStream {
..
}) => {
let (cond, then_branch, else_branch) =
- (walk(*cond), map_block(then_branch), walk(*else_branch));
+ (walk(*cond), map_block::<T>(then_branch), walk(*else_branch));
quote!(if #cond #then_branch else #else_branch)
}
Expr::If(ExprIf {
cond, then_branch, ..
}) => {
- let (cond, then_branch) = (walk(*cond), map_block(then_branch));
+ let (cond, then_branch) = (walk(*cond), map_block::<T>(then_branch));
quote!(if #cond #then_branch)
}
Expr::Async(ExprAsync {
@@ -168,7 +223,7 @@ fn walk(e: Expr) -> TokenStream {
block,
..
}) => {
- let block = map_block(block);
+ let block = map_block::<T>(block);
quote!(#(#attrs)* async #capture #block)
}
Expr::Await(ExprAwait { base, .. }) => {
@@ -201,20 +256,21 @@ fn walk(e: Expr) -> TokenStream {
label: Some(label),
..
}) => {
- let b = map_block(block);
+ let b = map_block::<T>(block);
quote! { #label: #b }
}
- Expr::Block(ExprBlock { block, .. }) => map_block(block),
+ Expr::Block(ExprBlock { block, .. }) => map_block::<T>(block),
e => quote!(#e),
}
}
-fn map_block(Block { stmts, .. }: Block) -> TokenStream {
- let stmts = stmts.into_iter().map(walk_stmt);
+fn map_block<T: Sub>(Block { stmts, .. }: Block) -> TokenStream {
+ let stmts = stmts.into_iter().map(walk_stmt::<T>);
quote! { { #(#stmts)* } }
}
-fn walk_stmt(x: Stmt) -> TokenStream {
+fn walk_stmt<T: Sub>(x: Stmt) -> TokenStream {
+ let walk = walk::<T>;
match x {
Stmt::Local(Local {
pat,
@@ -238,7 +294,7 @@ fn walk_stmt(x: Stmt) -> TokenStream {
let expr = walk(*expr);
quote!(let #pat = #expr;)
}
- Stmt::Item(x) => walk_item(x),
+ Stmt::Item(x) => walk_item::<T>(x),
Stmt::Expr(e, t) => {
let e = walk(e);
quote!(#e #t)
@@ -247,7 +303,8 @@ fn walk_stmt(x: Stmt) -> TokenStream {
}
}
-fn walk_item(x: Item) -> TokenStream {
+fn walk_item<T: Sub>(x: Item) -> TokenStream {
+ let walk = walk::<T>;
match x {
Item::Const(ItemConst {
vis,
@@ -265,7 +322,7 @@ fn walk_item(x: Item) -> TokenStream {
sig,
block,
}) => {
- let block = map_block(*block);
+ let block = map_block::<T>(*block);
quote!( #(#attrs)* #vis #sig #block)
}
Item::Impl(ItemImpl {
@@ -298,7 +355,7 @@ fn walk_item(x: Item) -> TokenStream {
sig,
block,
}) => {
- let block = map_block(block);
+ let block = map_block::<T>(block);
quote!(#(#attrs)* #vis #defaultness #sig #block)
}
e => quote!(#e),
@@ -313,7 +370,7 @@ fn walk_item(x: Item) -> TokenStream {
content: Some((_, content)),
..
}) => {
- let content = content.into_iter().map(walk_item);
+ let content = content.into_iter().map(walk_item::<T>);
quote!(#(#attrs)* #vis mod #ident { #(#content)* })
}
Item::Static(ItemStatic {
@@ -336,18 +393,44 @@ fn walk_item(x: Item) -> TokenStream {
/// ```
/// # use std::ops::*;
/// let [a, b, c] = [5i32, 6, 7];
-/// assert_eq!(lower::math! { a * *&b + -c }, a * *&b + -c);
+/// assert_eq!(lower_macros::math! { a * *&b + -c }, a * *&b + -c);
/// // expands to
/// // a.mul((&b).deref()).add(c.neg())
/// ```
#[proc_macro]
pub fn math(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
match parse::<Expr>(input.clone())
- .map(walk)
+ .map(walk::<Basic>)
+ .map_err(|x| x.to_compile_error().into_token_stream())
+ {
+ Ok(x) => x,
+ Err(e) => parse::<Stmt>(input).map(walk_stmt::<Basic>).unwrap_or(e),
+ }
+ .into()
+}
+
+#[proc_macro]
+pub fn algebraic(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ match parse::<Expr>(input.clone())
+ .map(walk::<Algebraic>)
+ .map_err(|x| x.to_compile_error().into_token_stream())
+ {
+ Ok(x) => x,
+ Err(e) => parse::<Stmt>(input)
+ .map(walk_stmt::<Algebraic>)
+ .unwrap_or(e),
+ }
+ .into()
+}
+
+#[proc_macro]
+pub fn fast(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ match parse::<Expr>(input.clone())
+ .map(walk::<Fast>)
.map_err(|x| x.to_compile_error().into_token_stream())
{
Ok(x) => x,
- Err(e) => parse::<Stmt>(input).map(walk_stmt).unwrap_or(e),
+ Err(e) => parse::<Stmt>(input).map(walk_stmt::<Fast>).unwrap_or(e),
}
.into()
}