heh
Diffstat (limited to 'src/util.rs')
| -rw-r--r-- | src/util.rs | 55 |
1 files changed, 52 insertions, 3 deletions
diff --git a/src/util.rs b/src/util.rs index e0beaa6..e697233 100644 --- a/src/util.rs +++ b/src/util.rs @@ -9,7 +9,7 @@ pub mod prelude { pub use super::{ even, gcd, lcm, pa, GreekTools, IntoCombinations, IntoLines, IterͶ, NumTupleIterTools, ParseIter, Printable, Skip, TakeLine, TupleIterTools, TupleUtils, UnifiedTupleUtils, Widen, - Ͷ, Α, Κ, Λ, Μ, + 読む, Ͷ, Α, Κ, Λ, Μ, }; pub use itertools::izip; pub use itertools::Itertools; @@ -24,7 +24,7 @@ pub mod prelude { ops::Range, }; #[allow(unused_imports)] - pub(crate) use {super::bits, super::dang, super::leek, super::mat, super::読む}; + pub(crate) use {super::bits, super::dang, super::leek, super::mat}; } macro_rules! dang { @@ -48,6 +48,50 @@ macro_rules! mat { } pub(crate) use mat; +#[cfg(target_feature = "avx2")] +unsafe fn count_avx<const N: usize>(hay: &[u8; N], needle: u8) -> usize { + use std::arch::x86_64::*; + let find = _mm256_set1_epi8(needle as i8); + let mut counts = _mm256_setzero_si256(); + for i in 0..(N / 32) { + counts = _mm256_sub_epi8( + counts, + _mm256_cmpeq_epi8( + _mm256_loadu_si256(hay.as_ptr().add(i * 32) as *const _), + find, + ), + ); + } + const MASK: [u8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ]; + counts = _mm256_sub_epi8( + counts, + _mm256_and_si256( + _mm256_cmpeq_epi8( + _mm256_loadu_si256(hay.as_ptr().add(N - 32) as *const _), + find, + ), + _mm256_loadu_si256(MASK.as_ptr().add(N % 32) as *const _), + ), + ); + + let sums = _mm256_sad_epu8(counts, _mm256_setzero_si256()); + (_mm256_extract_epi64(sums, 0) + + _mm256_extract_epi64(sums, 1) + + _mm256_extract_epi64(sums, 2) + + _mm256_extract_epi64(sums, 3)) as usize +} + +pub fn count<const N: usize>(hay: &[u8; N], what: u8) -> usize { + #[cfg(target_feature = "avx2")] + return count_avx(hay, what); + #[cfg(not(target_feature = "avx2"))] + hay.iter().filter(|&&x| x == what).count() +} + pub fn lcm(n: impl IntoIterator<Item = u64>) -> u64 { let mut x = n.into_iter(); let mut lcm = x.by_ref().next().expect("cannot compute LCM of 0 numbers"); @@ -299,6 +343,7 @@ pub trait GreekTools<T>: Iterator { Self: Ι<T, N>; fn Ν<const N: usize>(&mut self) -> [T; N]; fn ν<const N: usize>(&mut self, into: &mut [T; N]) -> usize; + fn Θ(&mut self); } pub trait ParseIter { @@ -425,6 +470,10 @@ impl<T, I: Iterator<Item = T>> GreekTools<T> for I { { self.ι1() } + + fn Θ(&mut self) { + for _ in self {} + } } pub trait TupleUtils<T, U> { @@ -566,7 +615,7 @@ impl std::fmt::Display for PrintU8s<'_> { } } -impl Printable for &[u8] { +impl Printable for [u8] { fn p(&self) -> impl std::fmt::Display { PrintU8s(self) } |