Skip to content

Commit

Permalink
jxl-render: AVX2 version of EPF (#87)
Browse files Browse the repository at this point in the history
* jxl-render: AVX2 version of EPF

* jxl-grid: mark `__m256::abs` with `enable = "avx2"`
  • Loading branch information
tirr-c authored Sep 25, 2023
1 parent b7afe10 commit a94905e
Show file tree
Hide file tree
Showing 9 changed files with 602 additions and 161 deletions.
251 changes: 208 additions & 43 deletions crates/jxl-grid/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,89 @@ pub trait SimdVector: Copy {
fn available() -> bool;

/// Initialize a SIMD vector with zeroes.
fn zero() -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn zero() -> Self;
/// Initialize a SIMD vector with given floats.
fn set<const N: usize>(val: [f32; N]) -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn set<const N: usize>(val: [f32; N]) -> Self;
/// Initialize a SIMD vector filled with given float.
fn splat_f32(val: f32) -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn splat_f32(val: f32) -> Self;
/// Load a SIMD vector from memory.
///
/// The pointer doesn't need to be aligned.
///
/// # Safety
/// The given pointer must be valid.
/// CPU should support the vector type, and the given pointer must be valid.
unsafe fn load(ptr: *const f32) -> Self;
/// Load a SIMD vector from memory with aligned pointer.
///
/// # Safety
/// The given pointer must be valid and properly aligned.
/// CPU should support the vector type, and the given pointer must be valid and properly
/// aligned.
unsafe fn load_aligned(ptr: *const f32) -> Self;

/// Extract a single element from the SIMD vector.
fn extract_f32<const N: i32>(self) -> f32;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn extract_f32<const N: i32>(self) -> f32;
/// Store the SIMD vector to memory.
///
/// The pointer doesn't need to be aligned.
///
/// # Safety
/// The given pointer must be valid.
/// CPU should support the vector type, and the given pointer must be valid.
unsafe fn store(self, ptr: *mut f32);
/// Store the SIMD vector to memory with aligned pointer.
///
/// # Safety
/// The given pointer must be valid and properly aligned.
/// CPU should support the vector type, and the given pointer must be valid and properly
/// aligned.
unsafe fn store_aligned(self, ptr: *mut f32);

fn add(self, lhs: Self) -> Self;
fn sub(self, lhs: Self) -> Self;
fn mul(self, lhs: Self) -> Self;
fn div(self, lhs: Self) -> Self;
fn abs(self) -> Self;
/// Add two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn add(self, lhs: Self) -> Self;
/// Subtract two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn sub(self, lhs: Self) -> Self;
/// Multiply two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn mul(self, lhs: Self) -> Self;
/// Divide two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn div(self, lhs: Self) -> Self;
/// Compute the absolute value for each element of the vector.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn abs(self) -> Self;

fn muladd(self, mul: Self, add: Self) -> Self;
fn mulsub(self, mul: Self, sub: Self) -> Self;
/// Computes `self * mul + add` element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn muladd(self, mul: Self, add: Self) -> Self;
/// Computes `self * mul - add` element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self;
}

#[cfg(target_arch = "x86_64")]
Expand All @@ -61,19 +103,19 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn zero() -> Self {
unsafe { std::arch::x86_64::_mm_setzero_ps() }
unsafe fn zero() -> Self {
std::arch::x86_64::_mm_setzero_ps()
}

#[inline]
fn set<const N: usize>(val: [f32; N]) -> Self {
unsafe fn set<const N: usize>(val: [f32; N]) -> Self {
assert_eq!(N, Self::SIZE);
unsafe { std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0]) }
std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0])
}

#[inline]
fn splat_f32(val: f32) -> Self {
unsafe { std::arch::x86_64::_mm_set1_ps(val) }
unsafe fn splat_f32(val: f32) -> Self {
std::arch::x86_64::_mm_set1_ps(val)
}

#[inline]
Expand All @@ -87,9 +129,9 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn extract_f32<const N: i32>(self) -> f32 {
unsafe fn extract_f32<const N: i32>(self) -> f32 {
assert!((N as usize) < Self::SIZE);
let bits = unsafe { std::arch::x86_64::_mm_extract_ps::<N>(self) };
let bits = std::arch::x86_64::_mm_extract_ps::<N>(self);
f32::from_bits(bits as u32)
}

Expand All @@ -104,56 +146,179 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn add(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_add_ps(self, lhs) }
unsafe fn add(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_add_ps(self, lhs)
}

#[inline]
unsafe fn sub(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_sub_ps(self, lhs)
}

#[inline]
unsafe fn mul(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_mul_ps(self, lhs)
}

#[inline]
unsafe fn div(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_div_ps(self, lhs)
}

#[inline]
unsafe fn abs(self) -> Self {
let x = std::arch::x86_64::_mm_undefined_si128();
let mask = std::arch::x86_64::_mm_srli_epi32::<1>(
std::arch::x86_64::_mm_cmpeq_epi32(x, x)
);
std::arch::x86_64::_mm_and_ps(
std::arch::x86_64::_mm_castsi128_ps(mask),
self,
)
}

#[inline]
#[cfg(target_feature = "fma")]
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
std::arch::x86_64::_mm_fmadd_ps(self, mul, add)
}

#[inline]
#[cfg(target_feature = "fma")]
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
std::arch::x86_64::_mm_fmsub_ps(self, mul, sub)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
self.mul(mul).add(add)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
self.mul(mul).sub(sub)
}
}

#[cfg(target_arch = "x86_64")]
impl SimdVector for std::arch::x86_64::__m256 {
const SIZE: usize = 8;

#[inline]
fn available() -> bool {
is_x86_feature_detected!("avx2")
}

#[inline]
unsafe fn zero() -> Self {
std::arch::x86_64::_mm256_setzero_ps()
}

#[inline]
unsafe fn set<const N: usize>(val: [f32; N]) -> Self {
assert_eq!(N, Self::SIZE);
std::arch::x86_64::_mm256_set_ps(val[7], val[6], val[5], val[4], val[3], val[2], val[1], val[0])
}

#[inline]
fn sub(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_sub_ps(self, lhs) }
unsafe fn splat_f32(val: f32) -> Self {
std::arch::x86_64::_mm256_set1_ps(val)
}

#[inline]
fn mul(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_mul_ps(self, lhs) }
unsafe fn load(ptr: *const f32) -> Self {
std::arch::x86_64::_mm256_loadu_ps(ptr)
}

#[inline]
fn div(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_div_ps(self, lhs) }
unsafe fn load_aligned(ptr: *const f32) -> Self {
std::arch::x86_64::_mm256_load_ps(ptr)
}

#[inline]
fn abs(self) -> Self {
unsafe {
std::arch::x86_64::_mm_andnot_ps(
Self::splat_f32(f32::from_bits(0x80000000)),
self,
)
unsafe fn extract_f32<const N: i32>(self) -> f32 {
unsafe fn inner<const HI: i32, const LO: i32>(val: std::arch::x86_64::__m256) -> f32 {
std::arch::x86_64::_mm256_extractf128_ps::<HI>(val).extract_f32::<LO>()
}

assert!((N as usize) < Self::SIZE);
match N {
0..=3 => inner::<0, N>(self),
4 => inner::<1, 0>(self),
5 => inner::<1, 1>(self),
6 => inner::<1, 2>(self),
7 => inner::<1, 3>(self),
// SAFETY: 0 <= N < 8 by assertion.
_ => std::hint::unreachable_unchecked(),
}
}

#[inline]
unsafe fn store(self, ptr: *mut f32) {
std::arch::x86_64::_mm256_storeu_ps(ptr, self);
}

#[inline]
unsafe fn store_aligned(self, ptr: *mut f32) {
std::arch::x86_64::_mm256_store_ps(ptr, self);
}

#[inline]
unsafe fn add(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_add_ps(self, lhs)
}

#[inline]
unsafe fn sub(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_sub_ps(self, lhs)
}

#[inline]
unsafe fn mul(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_mul_ps(self, lhs)
}

#[inline]
unsafe fn div(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_div_ps(self, lhs)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn abs(self) -> Self {
let x = std::arch::x86_64::_mm256_undefined_si256();
let mask = std::arch::x86_64::_mm256_srli_epi32::<1>(
std::arch::x86_64::_mm256_cmpeq_epi32(x, x)
);
std::arch::x86_64::_mm256_and_ps(
std::arch::x86_64::_mm256_castsi256_ps(mask),
self,
)
}

#[inline]
#[cfg(target_feature = "fma")]
fn muladd(self, mul: Self, add: Self) -> Self {
unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, add) }
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
std::arch::x86_64::_mm256_fmadd_ps(self, mul, add)
}

#[inline]
#[cfg(target_feature = "fma")]
fn mulsub(self, mul: Self, sub: Self) -> Self {
unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, sub) }
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
std::arch::x86_64::_mm256_fmsub_ps(self, mul, sub)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
fn muladd(self, mul: Self, add: Self) -> Self {
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
self.mul(mul).add(add)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
fn mulsub(self, mul: Self, sub: Self) -> Self {
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
self.mul(mul).sub(sub)
}
}
Loading

0 comments on commit a94905e

Please sign in to comment.