Skip to content

Commit

Permalink
Implement Neon SIMD
Browse files Browse the repository at this point in the history
This ports the WASM algorithm over to Aarch64 / ARM Neon. Rust itself
isn't quite ready yet, but this should start compiling soon.
  • Loading branch information
CryZe committed Oct 13, 2021
1 parent 499df81 commit fccf53d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 52 deletions.
2 changes: 2 additions & 0 deletions src/imp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod avx2;
pub mod avx512;
pub mod neon;
pub mod scalar;
pub mod sse2;
pub mod ssse3;
Expand All @@ -18,6 +19,7 @@ pub fn get_imp() -> Adler32Imp {
.or_else(avx2::get_imp)
.or_else(ssse3::get_imp)
.or_else(sse2::get_imp)
.or_else(neon::get_imp)
.or_else(wasm::get_imp)
.unwrap_or(scalar::update)
}
94 changes: 49 additions & 45 deletions src/imp/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ pub fn get_imp() -> Option<Adler32Imp> {
}

#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "arm"))]
#[cfg(all(
feature = "std",
feature = "nightly",
target_arch = "arm",
target_feature = "v7"
))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_arm_feature_detected("neon") {
if std::is_arm_feature_detected!("neon") {
Some(imp::update)
} else {
None
Expand All @@ -18,7 +23,7 @@ fn get_imp_inner() -> Option<Adler32Imp> {
#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "aarch64"))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_aarch64_feature_detected("neon") {
if std::is_aarch64_feature_detected!("neon") {
Some(imp::update)
} else {
None
Expand All @@ -29,7 +34,13 @@ fn get_imp_inner() -> Option<Adler32Imp> {
#[cfg(all(
feature = "nightly",
target_feature = "neon",
not(all(feature = "std", any(target_arch = "arm", target_arch = "aarch64")))
not(all(
feature = "std",
any(
all(target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
Some(imp::update)
Expand All @@ -41,7 +52,10 @@ fn get_imp_inner() -> Option<Adler32Imp> {
not(all(
feature = "std",
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64")
any(
all(target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
Expand All @@ -50,13 +64,16 @@ fn get_imp_inner() -> Option<Adler32Imp> {

#[cfg(all(
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64"),
any(
all(target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
),
any(feature = "std", target_feature = "neon")
))]
mod imp {
const MOD: u32 = 65521;
const NMAX: usize = 5552;
const BLOCK_SIZE: usize = 64;
const BLOCK_SIZE: usize = 32;
const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;

#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -128,63 +145,50 @@ mod imp {
let blocks = chunk.chunks_exact(BLOCK_SIZE);
let blocks_remainder = blocks.remainder();

let one_v = _mm512_set1_epi16(1);
let zero_v = _mm512_setzero_si512();
let weights = get_weights();
let weight_hi_v = get_weight_hi();
let weight_lo_v = get_weight_lo();

let p_v = (*a * blocks.len() as u32) as _;
let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
let mut a_v = _mm512_setzero_si512();
let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
let mut p_v: uint32x4_t = std::mem::transmute([*a * blocks.len() as u32, 0, 0, 0]);
let mut a_v: uint32x4_t = std::mem::transmute([0u32, 0, 0, 0]);
let mut b_v: uint32x4_t = std::mem::transmute([*b, 0, 0, 0]);

for block in blocks {
let block_ptr = block.as_ptr() as *const _;
let block = _mm512_loadu_si512(block_ptr);
let block_ptr = block.as_ptr() as *const uint8x16_t;
let v_lo = *(block_ptr);
let v_hi = *(block_ptr.add(1));

p_v = vaddq_u32(p_v, a_v);

p_v = _mm512_add_epi32(p_v, a_v);
a_v = vaddq_u32(a_v, vqaddlq_u8(v_lo));
b_v = vdotq_u32(b_v, v_lo, weight_lo_v);

a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
let mad = _mm512_maddubs_epi16(block, weights);
b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
a_v = vaddq_u32(a_v, vqaddlq_u8(v_hi));
b_v = vdotq_u32(b_v, v_hi, weight_hi_v);
}

b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
b_v = vaddq_u32(b_v, vshlq_n_u32(p_v, 5));

*a += reduce_add(a_v);
*b = reduce_add(b_v);
*a += vaddvq_u32(a_v);
*b = vaddvq_u32(b_v);

blocks_remainder
}

#[inline(always)]
unsafe fn reduce_add(v: __m512i) -> u32 {
let v: [__m256i; 2] = core::mem::transmute(v);

reduce_add_256(v[0]) + reduce_add_256(v[1])
unsafe fn vqaddlq_u8(a: uint8x16_t) -> uint32x4_t {
vpaddlq_u16(vpaddlq_u8(a))
}

#[inline(always)]
unsafe fn reduce_add_256(v: __m256i) -> u32 {
let v: [__m128i; 2] = core::mem::transmute(v);
let sum = _mm_add_epi32(v[0], v[1]);
let hi = _mm_unpackhi_epi64(sum, sum);

let sum = _mm_add_epi32(hi, sum);
let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));

let sum = _mm_add_epi32(sum, hi);
let sum = _mm_cvtsi128_si32(sum) as _;

sum
unsafe fn get_weight_lo() -> uint8x16_t {
std::mem::transmute([
32u8, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
])
}

#[inline(always)]
unsafe fn get_weights() -> __m512i {
_mm512_set_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
)
unsafe fn get_weight_hi() -> uint8x16_t {
std::mem::transmute([16u8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
}
}

Expand Down
19 changes: 14 additions & 5 deletions src/imp/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ mod imp {
p_v = u32x4_add(p_v, a_v);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_lo));
let mad = i32x4_dot_i8x16(v_lo, weight_lo_v);
let mad = u32x4_dot_u8x16(v_lo, weight_lo_v);
b_v = u32x4_add(b_v, mad);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_hi));
let mad = i32x4_dot_i8x16(v_hi, weight_hi_v);
let mad = u32x4_dot_u8x16(v_hi, weight_hi_v);
b_v = u32x4_add(b_v, mad);
}

Expand All @@ -123,22 +123,31 @@ mod imp {
}

#[inline(always)]
fn i32x4_dot_i8x16(a: v128, b: v128) -> v128 {
fn u32x4_dot_u8x16(a: v128, b: v128) -> v128 {
let a_lo = u16x8_extend_low_u8x16(a);
let a_hi = u16x8_extend_high_u8x16(a);

let b_lo = u16x8_extend_low_u8x16(b);
let b_hi = u16x8_extend_high_u8x16(b);

// There is no unsigned version for dot. However since we just zero
// extended, we know that the upper bytes are 0 and thus the sign extension
// involved here before multiplying isn't going to mess with the result.
let lo = i32x4_dot_i16x8(a_lo, b_lo);
let hi = i32x4_dot_i16x8(a_hi, b_hi);

i32x4_add(lo, hi)
u32x4_add(lo, hi)
}

#[inline(always)]
fn u32x4_extadd_quarters_u8x16(a: v128) -> v128 {
u32x4_extadd_pairwise_u16x8(u16x8_extadd_pairwise_u8x16(a))
// Technically we want this to be unsigned, but u32x4_extadd_pairwise_u16x8
// lowers to pxor, pmaddwd, paddd on x86 whereas i32x4_extadd_pairwise_i16x8
// can just be lowered to pmaddwd. However since u16x8_extadd_pairwise_u8x16
// already ensures the top byte is mostly zero except for carry, we can just
// use the sign extending version instead of the zero extending version
// after to ensure better codegen on x86.
i32x4_extadd_pairwise_i16x8(u16x8_extadd_pairwise_u8x16(a))
}

#[inline(always)]
Expand Down
12 changes: 10 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,23 @@
//! ## CPU Feature Detection
//! simd-adler32 supports both runtime and compile time CPU feature detection using the
//! `std::is_x86_feature_detected` macro when the `Adler32` struct is instantiated with
//! the `new` fn.
//! the `new` fn.
//!
//! Without `std` feature enabled simd-adler32 falls back to compile time feature detection
//! using `target-feature` or `target-cpu` flags supplied to rustc. See [https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html](https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html)
//! for more information.
//!
//! Feature detection tries to use the fastest supported feature first.
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(stdsimd, avx512_target_feature))]
#![cfg_attr(
feature = "nightly",
feature(
stdsimd,
avx512_target_feature,
aarch64_target_feature,
arm_target_feature
)
)]

#[doc(hidden)]
pub mod hash;
Expand Down

0 comments on commit fccf53d

Please sign in to comment.