heh
Diffstat (limited to 'src/util.rs')
-rw-r--r--src/util.rs55
1 files changed, 52 insertions, 3 deletions
diff --git a/src/util.rs b/src/util.rs
index e0beaa6..e697233 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -9,7 +9,7 @@ pub mod prelude {
pub use super::{
even, gcd, lcm, pa, GreekTools, IntoCombinations, IntoLines, IterͶ, NumTupleIterTools,
ParseIter, Printable, Skip, TakeLine, TupleIterTools, TupleUtils, UnifiedTupleUtils, Widen,
- Ͷ, Α, Κ, Λ, Μ,
+ 読む, Ͷ, Α, Κ, Λ, Μ,
};
pub use itertools::izip;
pub use itertools::Itertools;
@@ -24,7 +24,7 @@ pub mod prelude {
ops::Range,
};
#[allow(unused_imports)]
- pub(crate) use {super::bits, super::dang, super::leek, super::mat, super::読む};
+ pub(crate) use {super::bits, super::dang, super::leek, super::mat};
}
macro_rules! dang {
@@ -48,6 +48,50 @@ macro_rules! mat {
}
pub(crate) use mat;
+#[cfg(target_feature = "avx2")]
+unsafe fn count_avx<const N: usize>(hay: &[u8; N], needle: u8) -> usize {
+ use std::arch::x86_64::*;
+ let find = _mm256_set1_epi8(needle as i8);
+ let mut counts = _mm256_setzero_si256();
+ for i in 0..(N / 32) {
+ counts = _mm256_sub_epi8(
+ counts,
+ _mm256_cmpeq_epi8(
+ _mm256_loadu_si256(hay.as_ptr().add(i * 32) as *const _),
+ find,
+ ),
+ );
+ }
+ const MASK: [u8; 64] = [
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
+ ];
+ counts = _mm256_sub_epi8(
+ counts,
+ _mm256_and_si256(
+ _mm256_cmpeq_epi8(
+ _mm256_loadu_si256(hay.as_ptr().add(N - 32) as *const _),
+ find,
+ ),
+ _mm256_loadu_si256(MASK.as_ptr().add(N % 32) as *const _),
+ ),
+ );
+
+ let sums = _mm256_sad_epu8(counts, _mm256_setzero_si256());
+ (_mm256_extract_epi64(sums, 0)
+ + _mm256_extract_epi64(sums, 1)
+ + _mm256_extract_epi64(sums, 2)
+ + _mm256_extract_epi64(sums, 3)) as usize
+}
+
+pub fn count<const N: usize>(hay: &[u8; N], what: u8) -> usize {
+ #[cfg(target_feature = "avx2")]
+ return count_avx(hay, what);
+ #[cfg(not(target_feature = "avx2"))]
+ hay.iter().filter(|&&x| x == what).count()
+}
+
pub fn lcm(n: impl IntoIterator<Item = u64>) -> u64 {
let mut x = n.into_iter();
let mut lcm = x.by_ref().next().expect("cannot compute LCM of 0 numbers");
@@ -299,6 +343,7 @@ pub trait GreekTools<T>: Iterator {
Self: Ι<T, N>;
fn Ν<const N: usize>(&mut self) -> [T; N];
fn ν<const N: usize>(&mut self, into: &mut [T; N]) -> usize;
+ fn Θ(&mut self);
}
pub trait ParseIter {
@@ -425,6 +470,10 @@ impl<T, I: Iterator<Item = T>> GreekTools<T> for I {
{
self.ι1()
}
+
+ fn Θ(&mut self) {
+ for _ in self {}
+ }
}
pub trait TupleUtils<T, U> {
@@ -566,7 +615,7 @@ impl std::fmt::Display for PrintU8s<'_> {
}
}
-impl Printable for &[u8] {
+impl Printable for [u8] {
fn p(&self) -> impl std::fmt::Display {
PrintU8s(self)
}