Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Neon SIMD #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ println!("{}", hash);
| ✅ | `x86`, `x86_64` | avx2 |
| ✅ | `x86`, `x86_64` | ssse3 |
| ✅ | `x86`, `x86_64` | sse2 |
| 🚧 | `arm`, `aarch64` | neon |
| ✅ | `aarch64` | neon |
| 🚧 | `arm` | neon |
| ✅ | `wasm32` | simd128 |

**MSRV** `1.36.0`\*\*
Expand Down
2 changes: 2 additions & 0 deletions src/imp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod avx2;
pub mod avx512;
pub mod neon;
pub mod scalar;
pub mod sse2;
pub mod ssse3;
Expand All @@ -18,6 +19,7 @@ pub fn get_imp() -> Adler32Imp {
.or_else(avx2::get_imp)
.or_else(ssse3::get_imp)
.or_else(sse2::get_imp)
.or_else(neon::get_imp)
.or_else(wasm::get_imp)
.unwrap_or(scalar::update)
}
100 changes: 51 additions & 49 deletions src/imp/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@ pub fn get_imp() -> Option<Adler32Imp> {
}

#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "arm"))]
#[cfg(all(
feature = "std",
feature = "nightly",
target_arch = "arm",
target_feature = "v7"
))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_arm_feature_detected("neon") {
if std::is_arm_feature_detected!("neon") {
Some(imp::update)
} else {
None
}
}

#[inline]
#[cfg(all(feature = "std", feature = "nightly", target_arch = "aarch64"))]
#[cfg(all(feature = "std", target_arch = "aarch64"))]
fn get_imp_inner() -> Option<Adler32Imp> {
if std::is_aarch64_feature_detected("neon") {
if std::is_aarch64_feature_detected!("neon") {
Some(imp::update)
} else {
None
Expand All @@ -27,9 +32,14 @@ fn get_imp_inner() -> Option<Adler32Imp> {

#[inline]
#[cfg(all(
feature = "nightly",
target_feature = "neon",
not(all(feature = "std", any(target_arch = "arm", target_arch = "aarch64")))
not(all(
feature = "std",
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
Some(imp::update)
Expand All @@ -40,23 +50,27 @@ fn get_imp_inner() -> Option<Adler32Imp> {
not(target_feature = "neon"),
not(all(
feature = "std",
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64")
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
)
))
))]
fn get_imp_inner() -> Option<Adler32Imp> {
None
}

#[cfg(all(
feature = "nightly",
any(target_arch = "arm", target_arch = "aarch64"),
any(
all(feature = "nightly", target_arch = "arm", target_feature = "v7"),
target_arch = "aarch64"
),
any(feature = "std", target_feature = "neon")
))]
mod imp {
const MOD: u32 = 65521;
const NMAX: usize = 5552;
const BLOCK_SIZE: usize = 64;
const BLOCK_SIZE: usize = 32;
const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;

#[cfg(target_arch = "aarch64")]
Expand All @@ -69,6 +83,7 @@ mod imp {
}

#[inline]
#[cfg_attr(target_arch = "arm", target_feature(enable = "v7"))]
#[target_feature(enable = "neon")]
unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
let mut a = a as u32;
Expand Down Expand Up @@ -128,63 +143,50 @@ mod imp {
let blocks = chunk.chunks_exact(BLOCK_SIZE);
let blocks_remainder = blocks.remainder();

let one_v = _mm512_set1_epi16(1);
let zero_v = _mm512_setzero_si512();
let weights = get_weights();
let weight_hi_v = get_weight_hi();
let weight_lo_v = get_weight_lo();

let p_v = (*a * blocks.len() as u32) as _;
let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
let mut a_v = _mm512_setzero_si512();
let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
let mut p_v: uint32x4_t = core::mem::transmute([*a * blocks.len() as u32, 0, 0, 0]);
let mut a_v: uint32x4_t = core::mem::transmute([0u32, 0, 0, 0]);
let mut b_v: uint32x4_t = core::mem::transmute([*b, 0, 0, 0]);

for block in blocks {
let block_ptr = block.as_ptr() as *const _;
let block = _mm512_loadu_si512(block_ptr);
let block_ptr = block.as_ptr() as *const uint8x16_t;
let v_lo = core::ptr::read_unaligned(block_ptr);
let v_hi = core::ptr::read_unaligned(block_ptr.add(1));

p_v = _mm512_add_epi32(p_v, a_v);
p_v = vaddq_u32(p_v, a_v);

a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
let mad = _mm512_maddubs_epi16(block, weights);
b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
a_v = vaddq_u32(a_v, vqaddlq_u8(v_lo));
b_v = vdotq_u32(b_v, v_lo, weight_lo_v);

a_v = vaddq_u32(a_v, vqaddlq_u8(v_hi));
b_v = vdotq_u32(b_v, v_hi, weight_hi_v);
}

b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
b_v = vaddq_u32(b_v, vshlq_n_u32(p_v, 5));

*a += reduce_add(a_v);
*b = reduce_add(b_v);
*a += vaddvq_u32(a_v);
*b = vaddvq_u32(b_v);

blocks_remainder
}

#[inline(always)]
unsafe fn reduce_add(v: __m512i) -> u32 {
let v: [__m256i; 2] = core::mem::transmute(v);

reduce_add_256(v[0]) + reduce_add_256(v[1])
unsafe fn vqaddlq_u8(a: uint8x16_t) -> uint32x4_t {
vpaddlq_u16(vpaddlq_u8(a))
}

#[inline(always)]
unsafe fn reduce_add_256(v: __m256i) -> u32 {
let v: [__m128i; 2] = core::mem::transmute(v);
let sum = _mm_add_epi32(v[0], v[1]);
let hi = _mm_unpackhi_epi64(sum, sum);

let sum = _mm_add_epi32(hi, sum);
let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));

let sum = _mm_add_epi32(sum, hi);
let sum = _mm_cvtsi128_si32(sum) as _;

sum
unsafe fn get_weight_lo() -> uint8x16_t {
core::mem::transmute([
32u8, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
])
}

#[inline(always)]
unsafe fn get_weights() -> __m512i {
_mm512_set_epi8(
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
)
unsafe fn get_weight_hi() -> uint8x16_t {
core::mem::transmute([16u8, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
}
}

Expand Down
31 changes: 20 additions & 11 deletions src/imp/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,50 +100,59 @@ mod imp {

for block in blocks {
let block_ptr = block.as_ptr() as *const v128;
let v_lo = unsafe { *(block_ptr) };
let v_hi = unsafe { *(block_ptr.add(1)) };
let v_lo = unsafe { core::ptr::read_unaligned(block_ptr) };
let v_hi = unsafe { core::ptr::read_unaligned(block_ptr.add(1)) };

p_v = u32x4_add(p_v, a_v);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_lo));
let mad = i32x4_dot_i8x16(v_lo, weight_lo_v);
let mad = u32x4_dot_u8x16(v_lo, weight_lo_v);
b_v = u32x4_add(b_v, mad);

a_v = u32x4_add(a_v, u32x4_extadd_quarters_u8x16(v_hi));
let mad = i32x4_dot_i8x16(v_hi, weight_hi_v);
let mad = u32x4_dot_u8x16(v_hi, weight_hi_v);
b_v = u32x4_add(b_v, mad);
}

b_v = u32x4_add(b_v, u32x4_shl(p_v, 5));

*a += reduce_add(a_v);
*b = reduce_add(b_v);
*a += u32x4_horizontal_add(a_v);
*b = u32x4_horizontal_add(b_v);

blocks_remainder
}

#[inline(always)]
fn i32x4_dot_i8x16(a: v128, b: v128) -> v128 {
fn u32x4_dot_u8x16(a: v128, b: v128) -> v128 {
let a_lo = u16x8_extend_low_u8x16(a);
let a_hi = u16x8_extend_high_u8x16(a);

let b_lo = u16x8_extend_low_u8x16(b);
let b_hi = u16x8_extend_high_u8x16(b);

// There is no unsigned version for dot. However since we just zero
// extended, we know that the upper bytes are 0 and thus the sign extension
// involved here before multiplying isn't going to mess with the result.
let lo = i32x4_dot_i16x8(a_lo, b_lo);
let hi = i32x4_dot_i16x8(a_hi, b_hi);

i32x4_add(lo, hi)
u32x4_add(lo, hi)
}

#[inline(always)]
fn u32x4_extadd_quarters_u8x16(a: v128) -> v128 {
u32x4_extadd_pairwise_u16x8(u16x8_extadd_pairwise_u8x16(a))
// Technically we want this to be unsigned, but u32x4_extadd_pairwise_u16x8
// lowers to pxor, pmaddwd, paddd on x86 whereas i32x4_extadd_pairwise_i16x8
// can just be lowered to pmaddwd. However since u16x8_extadd_pairwise_u8x16
// already ensures the top byte is mostly zero except for carry, we can just
// use the sign extending version instead of the zero extending version
// after to ensure better codegen on x86.
i32x4_extadd_pairwise_i16x8(u16x8_extadd_pairwise_u8x16(a))
}

#[inline(always)]
fn reduce_add(v: v128) -> u32 {
let arr: [u32; 4] = unsafe { std::mem::transmute(v) };
fn u32x4_horizontal_add(v: v128) -> u32 {
let arr: [u32; 4] = unsafe { core::mem::transmute(v) };
let mut sum = 0u32;
for val in arr {
sum = sum.wrapping_add(val);
Expand Down
12 changes: 8 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
//! | ✅ | `x86`, `x86_64` | avx2 |
//! | ✅ | `x86`, `x86_64` | ssse3 |
//! | ✅ | `x86`, `x86_64` | sse2 |
//! | 🚧 | `arm`, `aarch64` | neon |
//! | | `wasm32` | simd128 |
//! | ✅ | `aarch64` | neon |
//! | 🚧 | `arm` | neon |
//! | ✅ | `wasm32` | simd128 |
//!
//! **MSRV** `1.36.0`\*\*
//!
Expand All @@ -69,15 +70,18 @@
//! ## CPU Feature Detection
//! simd-adler32 supports both runtime and compile time CPU feature detection using the
//! `std::is_x86_feature_detected` macro when the `Adler32` struct is instantiated with
//! the `new` fn.
//! the `new` fn.
//!
//! Without `std` feature enabled simd-adler32 falls back to compile time feature detection
//! using `target-feature` or `target-cpu` flags supplied to rustc. See [https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html](https://rust-lang.github.io/packed_simd/perf-guide/target-feature/rustflags.html)
//! for more information.
//!
//! Feature detection tries to use the fastest supported feature first.
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(stdsimd, avx512_target_feature))]
#![cfg_attr(
feature = "nightly",
feature(stdsimd, avx512_target_feature, arm_target_feature)
)]

#[doc(hidden)]
pub mod hash;
Expand Down