Skip to content

Commit

Permalink
Implement Neon SIMD
Browse files Browse the repository at this point in the history
This ports the WASM algorithm over to Aarch64 / ARM Neon.
  • Loading branch information
CryZe committed Jan 21, 2023
1 parent 499df81 commit eb0b039
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 65 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ println!("{}", hash);
|| `x86`, `x86_64` | avx2 |
|| `x86`, `x86_64` | ssse3 |
|| `x86`, `x86_64` | sse2 |
| 🚧 | `arm`, `aarch64` | neon |
|| `aarch64` | neon |
| 🚧 | `arm` | neon |
|| `wasm32` | simd128 |

**MSRV** `1.36.0`\*\*
Expand Down
2 changes: 2 additions & 0 deletions src/imp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod avx2;
pub mod avx512;
pub mod neon;
pub mod scalar;
pub mod sse2;
pub mod ssse3;
Expand All @@ -18,6 +19,7 @@ pub fn get_imp() -> Adler32Imp {
.or_else(avx2::get_imp)
.or_else(ssse3::get_imp)
.or_else(sse2::get_imp)
.or_else(neon::get_imp)
.or_else(wasm::get_imp)
.unwrap_or(scalar::update)
}
100 changes: 51 additions & 49 deletions src/imp/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@ pub fn get_imp() -> Option<Adler32Imp> {
}

#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "arm"))]
#[cfg(all(
feature = "std",
feature = "nightly",
target_arch = "arm",
target_feature = "v7"
))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_arm_feature_detected("neon") {
if std::is_arm_feature_detected!("neon") {
Some(imp::update)
} else {
None
}
}

#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "aarch64"))]
#[cfg(all(feature = "std", target_arch = "aarch64"))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_aarch64_feature_detected("neon") {
if std::is_aarch64_feature_detected!("neon") {
Some(imp::update)
} else {
None
Expand All @@ -27,9 +32,14 @@ fn get_imp_inner() -> Option<Adler32Imp> {

#[inline]
#[cfg(all(
feature = "nightly",
target_feature = "neon",
not(all(feature = "std", any(target_arch = "arm", target_arch = "aarch64")))
not(all(
feature = "std",
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
Some(imp::update)
Expand All @@ -40,23 +50,27 @@ fn get_imp_inner() -> Option<Adler32Imp> {
not(target_feature = "neon"),
not(all(
feature = "std",
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64")
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
None
}

#[cfg(all(
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64"),
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
),
any(feature = "std", target_feature = "neon")
))]
mod imp {
const MOD: u32 = 65521;
const NMAX: usize = 5552;
const BLOCK_SIZE: usize = 64;
const BLOCK_SIZE: usize = 32;
const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;

#[cfg(target_arch = "aarch64")]
Expand All @@ -69,6 +83,7 @@ mod imp {
}

#[inline]
#[cfg_attr(target_arch = "arm", target_feature(enable = "v7"))]
#[target_feature(enable = "neon")]
unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
let mut a = a as u32;
Expand Down Expand Up @@ -128,63 +143,50 @@ mod imp {
let blocks = chunk.chunks_exact(BLOCK_SIZE);
let blocks_remainder = blocks.remainder();

let one_v = _mm512_set1_epi16(1);
let zero_v = _mm512_setzero_si512();
let weights = get_weights();
let weight_hi_v = get_weight_hi();
let weight_lo_v = get_weight_lo();

let p_v = (*a * blocks.len() as u32) as _;
let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
let mut a_v = _mm512_setzero_si512();
let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
let mut p_v: uint32x4_t = core::mem::transmute([*a * blocks.len() as u32, 0, 0, 0]);
let mut a_v: uint32x4_t = core::mem::transmute([0u32, 0, 0, 0]);
let mut b_v: uint32x4_t = core::mem::transmute([*b, 0, 0, 0]);

for block in blocks {
let block_ptr = block.as_ptr() as *const _;
let block = _mm512_loadu_si512(block_ptr);
let block_ptr = block.as_ptr() as *const uint8x16_t;
let v_lo = core::ptr::read_unaligned(block_ptr);
let v_hi = core::ptr::read_unaligned(block_ptr.add(1));

p_v = _mm512_add_epi32(p_v, a_v);
p_v = vaddq_u32(p_v, a_v);

a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
let mad = _mm512_maddubs_epi16(block, weights);
b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
a_v = vaddq_u32(a_v, vqaddlq_u8(v_lo));
b_v = vdotq_u32(b_v, v_lo, weight_lo_v);

a_v = vaddq_u32(a_v, vqaddlq_u8(v_hi));
b_v = vdotq_u32(b_v, v_hi, weight_hi_v);
}

b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
b_v = vaddq_u32(b_v, vshlq_n_u32(p_v, 5));

*a += reduce_add(a_v);
*b = reduce_add(b_v);
*a += vaddvq_u32(a_v);
*b = vaddvq_u32(b_v);

blocks_remainder
}

#[inline(always)]
unsafe fn reduce_add(v: __m512i) -> u32 {
let v: [__m256i; 2] = core::mem::transmute(v);

reduce_add_256(v[0]) + reduce_add_256(v[1])
unsafe fn vqaddlq_u8(a: uint8x16_t) -> uint32x4_t {
vpaddlq_u16(vpaddlq_u8(a))
}

#[inline(always)]
unsafe fn reduce_add_256(v: __m256i) -> u32 {
let v: [__m128i; 2] = core::mem::transmute(v);
let sum = _mm_add_epi32(v[0], v[1]);
let hi = _mm_unpackhi_epi64(sum, sum);

let sum = _mm_add_epi32(hi, sum);
let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));

let sum = _mm_add_epi32(sum, hi);
let sum = _mm_cvtsi128_si32(sum) as _;

sum
unsafe fn get_weight_lo() -> uint8x16_t {
core::mem::transmute([
32u8, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
])
}

#[inline(always)]
unsafe fn get_weights() -> __m512i {
_mm512_set_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
)
unsafe fn get_weight_hi() -> uint8x16_t {
core::mem::transmute([16u8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
}
}

Expand Down
31 changes: 20 additions & 11 deletions src/imp/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,50 +100,59 @@ mod imp {

for block in blocks {
let block_ptr = block.as_ptr() as *const v128;
let v_lo = unsafe { *(block_ptr) };
let v_hi = unsafe { *(block_ptr.add(1)) };
let v_lo = unsafe { core::ptr::read_unaligned(block_ptr) };
let v_hi = unsafe { core::ptr::read_unaligned(block_ptr.add(1)) };

p_v = u32x4_add(p_v, a_v);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_lo));
let mad = i32x4_dot_i8x16(v_lo, weight_lo_v);
let mad = u32x4_dot_u8x16(v_lo, weight_lo_v);
b_v = u32x4_add(b_v, mad);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_hi));
let mad = i32x4_dot_i8x16(v_hi, weight_hi_v);
let mad = u32x4_dot_u8x16(v_hi, weight_hi_v);
b_v = u32x4_add(b_v, mad);
}

b_v = u32x4_add(b_v, u32x4_shl(p_v, 5));

*a += reduce_add(a_v);
*b = reduce_add(b_v);
*a += u32x4_horizontal_add(a_v);
*b = u32x4_horizontal_add(b_v);

blocks_remainder
}

#[inline(always)]
fn i32x4_dot_i8x16(a: v128, b: v128) -> v128 {
fn u32x4_dot_u8x16(a: v128, b: v128) -> v128 {
let a_lo = u16x8_extend_low_u8x16(a);
let a_hi = u16x8_extend_high_u8x16(a);

let b_lo = u16x8_extend_low_u8x16(b);
let b_hi = u16x8_extend_high_u8x16(b);

// There is no unsigned version for dot. However since we just zero
// extended, we know that the upper bytes are 0 and thus the sign extension
// involved here before multiplying isn't going to mess with the result.
let lo = i32x4_dot_i16x8(a_lo, b_lo);
let hi = i32x4_dot_i16x8(a_hi, b_hi);

i32x4_add(lo, hi)
u32x4_add(lo, hi)
}

#[inline(always)]
fn u32x4_extadd_quarters_u8x16(a: v128) -> v128 {
u32x4_extadd_pairwise_u16x8(u16x8_extadd_pairwise_u8x16(a))
// Technically we want this to be unsigned, but u32x4_extadd_pairwise_u16x8
// lowers to pxor, pmaddwd, paddd on x86 whereas i32x4_extadd_pairwise_i16x8
// can just be lowered to pmaddwd. However since u16x8_extadd_pairwise_u8x16
// already ensures the top byte is mostly zero except for carry, we can just
// use the sign extending version instead of the zero extending version
// after to ensure better codegen on x86.
i32x4_extadd_pairwise_i16x8(u16x8_extadd_pairwise_u8x16(a))
}

#[inline(always)]
fn reduce_add(v: v128) -> u32 {
let arr: [u32; 4] = unsafe { std::mem::transmute(v) };
fn u32x4_horizontal_add(v: v128) -> u32 {
let arr: [u32; 4] = unsafe { core::mem::transmute(v) };
let mut sum = 0u32;
for val in arr {
sum = sum.wrapping_add(val);
Expand Down
12 changes: 8 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
//! | ✅ | `x86`, `x86_64` | avx2 |
//! | ✅ | `x86`, `x86_64` | ssse3 |
//! | ✅ | `x86`, `x86_64` | sse2 |
//! | 🚧 | `arm`, `aarch64` | neon |
//! | | `wasm32` | simd128 |
//! | ✅ | `aarch64` | neon |
//! | 🚧 | `arm` | neon |
//! | ✅ | `wasm32` | simd128 |
//!
//! **MSRV** `1.36.0`\*\*
//!
Expand All @@ -69,15 +70,18 @@
//! ## CPU Feature Detection
//! simd-adler32 supports both runtime and compile time CPU feature detection using the
//! `std::is_x86_feature_detected` macro when the `Adler32` struct is instantiated with
//! the `new` fn.
//! the `new` fn.
//!
//! Without `std` feature enabled simd-adler32 falls back to compile time feature detection
//! using `target-feature` or `target-cpu` flags supplied to rustc. See [https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html](https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html)
//! for more information.
//!
//! Feature detection tries to use the fastest supported feature first.
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(stdsimd, avx512_target_feature))]
#![cfg_attr(
feature = "nightly",
feature(stdsimd, avx512_target_feature, arm_target_feature)
)]

#[doc(hidden)]
pub mod hash;
Expand Down

0 comments on commit eb0b039

Please sign in to comment.