Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jxl-render: AVX2 version of EPF #87

Merged
merged 2 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 208 additions & 43 deletions crates/jxl-grid/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,89 @@ pub trait SimdVector: Copy {
fn available() -> bool;

/// Initialize a SIMD vector with zeroes.
fn zero() -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn zero() -> Self;
/// Initialize a SIMD vector with given floats.
fn set<const N: usize>(val: [f32; N]) -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn set<const N: usize>(val: [f32; N]) -> Self;
/// Initialize a SIMD vector filled with given float.
fn splat_f32(val: f32) -> Self;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn splat_f32(val: f32) -> Self;
/// Load a SIMD vector from memory.
///
/// The pointer doesn't need to be aligned.
///
/// # Safety
/// The given pointer must be valid.
/// CPU should support the vector type, and the given pointer must be valid.
unsafe fn load(ptr: *const f32) -> Self;
/// Load a SIMD vector from memory with aligned pointer.
///
/// # Safety
/// The given pointer must be valid and properly aligned.
/// CPU should support the vector type, and the given pointer must be valid and properly
/// aligned.
unsafe fn load_aligned(ptr: *const f32) -> Self;

/// Extract a single element from the SIMD vector.
fn extract_f32<const N: i32>(self) -> f32;
///
/// # Safety
/// CPU should support the vector type.
unsafe fn extract_f32<const N: i32>(self) -> f32;
/// Store the SIMD vector to memory.
///
/// The pointer doesn't need to be aligned.
///
/// # Safety
/// The given pointer must be valid.
/// CPU should support the vector type, and the given pointer must be valid.
unsafe fn store(self, ptr: *mut f32);
/// Store the SIMD vector to memory with aligned pointer.
///
/// # Safety
/// The given pointer must be valid and properly aligned.
/// CPU should support the vector type, and the given pointer must be valid and properly
/// aligned.
unsafe fn store_aligned(self, ptr: *mut f32);

fn add(self, lhs: Self) -> Self;
fn sub(self, lhs: Self) -> Self;
fn mul(self, lhs: Self) -> Self;
fn div(self, lhs: Self) -> Self;
fn abs(self) -> Self;
/// Add two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn add(self, lhs: Self) -> Self;
/// Subtract two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn sub(self, lhs: Self) -> Self;
/// Multiply two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn mul(self, lhs: Self) -> Self;
/// Divide two vectors element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn div(self, lhs: Self) -> Self;
/// Compute the absolute value for each element of the vector.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn abs(self) -> Self;

fn muladd(self, mul: Self, add: Self) -> Self;
fn mulsub(self, mul: Self, sub: Self) -> Self;
/// Computes `self * mul + add` element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn muladd(self, mul: Self, add: Self) -> Self;
/// Computes `self * mul - add` element-wise.
///
/// # Safety
/// CPU should support the vector type.
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self;
}

#[cfg(target_arch = "x86_64")]
Expand All @@ -61,19 +103,19 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn zero() -> Self {
unsafe { std::arch::x86_64::_mm_setzero_ps() }
unsafe fn zero() -> Self {
std::arch::x86_64::_mm_setzero_ps()
}

#[inline]
fn set<const N: usize>(val: [f32; N]) -> Self {
unsafe fn set<const N: usize>(val: [f32; N]) -> Self {
assert_eq!(N, Self::SIZE);
unsafe { std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0]) }
std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0])
}

#[inline]
fn splat_f32(val: f32) -> Self {
unsafe { std::arch::x86_64::_mm_set1_ps(val) }
unsafe fn splat_f32(val: f32) -> Self {
std::arch::x86_64::_mm_set1_ps(val)
}

#[inline]
Expand All @@ -87,9 +129,9 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn extract_f32<const N: i32>(self) -> f32 {
unsafe fn extract_f32<const N: i32>(self) -> f32 {
assert!((N as usize) < Self::SIZE);
let bits = unsafe { std::arch::x86_64::_mm_extract_ps::<N>(self) };
let bits = std::arch::x86_64::_mm_extract_ps::<N>(self);
f32::from_bits(bits as u32)
}

Expand All @@ -104,56 +146,179 @@ impl SimdVector for std::arch::x86_64::__m128 {
}

#[inline]
fn add(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_add_ps(self, lhs) }
unsafe fn add(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_add_ps(self, lhs)
}

#[inline]
unsafe fn sub(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_sub_ps(self, lhs)
}

#[inline]
unsafe fn mul(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_mul_ps(self, lhs)
}

#[inline]
unsafe fn div(self, lhs: Self) -> Self {
std::arch::x86_64::_mm_div_ps(self, lhs)
}

#[inline]
unsafe fn abs(self) -> Self {
let x = std::arch::x86_64::_mm_undefined_si128();
let mask = std::arch::x86_64::_mm_srli_epi32::<1>(
std::arch::x86_64::_mm_cmpeq_epi32(x, x)
);
std::arch::x86_64::_mm_and_ps(
std::arch::x86_64::_mm_castsi128_ps(mask),
self,
)
}

#[inline]
#[cfg(target_feature = "fma")]
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
std::arch::x86_64::_mm_fmadd_ps(self, mul, add)
}

#[inline]
#[cfg(target_feature = "fma")]
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
std::arch::x86_64::_mm_fmsub_ps(self, mul, sub)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
self.mul(mul).add(add)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
self.mul(mul).sub(sub)
}
}

#[cfg(target_arch = "x86_64")]
impl SimdVector for std::arch::x86_64::__m256 {
const SIZE: usize = 8;

#[inline]
fn available() -> bool {
is_x86_feature_detected!("avx2")
}

#[inline]
unsafe fn zero() -> Self {
std::arch::x86_64::_mm256_setzero_ps()
}

#[inline]
unsafe fn set<const N: usize>(val: [f32; N]) -> Self {
assert_eq!(N, Self::SIZE);
std::arch::x86_64::_mm256_set_ps(val[7], val[6], val[5], val[4], val[3], val[2], val[1], val[0])
}

#[inline]
fn sub(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_sub_ps(self, lhs) }
unsafe fn splat_f32(val: f32) -> Self {
std::arch::x86_64::_mm256_set1_ps(val)
}

#[inline]
fn mul(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_mul_ps(self, lhs) }
unsafe fn load(ptr: *const f32) -> Self {
std::arch::x86_64::_mm256_loadu_ps(ptr)
}

#[inline]
fn div(self, lhs: Self) -> Self {
unsafe { std::arch::x86_64::_mm_div_ps(self, lhs) }
unsafe fn load_aligned(ptr: *const f32) -> Self {
std::arch::x86_64::_mm256_load_ps(ptr)
}

#[inline]
fn abs(self) -> Self {
unsafe {
std::arch::x86_64::_mm_andnot_ps(
Self::splat_f32(f32::from_bits(0x80000000)),
self,
)
unsafe fn extract_f32<const N: i32>(self) -> f32 {
unsafe fn inner<const HI: i32, const LO: i32>(val: std::arch::x86_64::__m256) -> f32 {
std::arch::x86_64::_mm256_extractf128_ps::<HI>(val).extract_f32::<LO>()
}

assert!((N as usize) < Self::SIZE);
match N {
0..=3 => inner::<0, N>(self),
4 => inner::<1, 0>(self),
5 => inner::<1, 1>(self),
6 => inner::<1, 2>(self),
7 => inner::<1, 3>(self),
// SAFETY: 0 <= N < 8 by assertion.
_ => std::hint::unreachable_unchecked(),
}
}

#[inline]
unsafe fn store(self, ptr: *mut f32) {
std::arch::x86_64::_mm256_storeu_ps(ptr, self);
}

#[inline]
unsafe fn store_aligned(self, ptr: *mut f32) {
std::arch::x86_64::_mm256_store_ps(ptr, self);
}

#[inline]
unsafe fn add(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_add_ps(self, lhs)
}

#[inline]
unsafe fn sub(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_sub_ps(self, lhs)
}

#[inline]
unsafe fn mul(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_mul_ps(self, lhs)
}

#[inline]
unsafe fn div(self, lhs: Self) -> Self {
std::arch::x86_64::_mm256_div_ps(self, lhs)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn abs(self) -> Self {
let x = std::arch::x86_64::_mm256_undefined_si256();
let mask = std::arch::x86_64::_mm256_srli_epi32::<1>(
std::arch::x86_64::_mm256_cmpeq_epi32(x, x)
);
std::arch::x86_64::_mm256_and_ps(
std::arch::x86_64::_mm256_castsi256_ps(mask),
self,
)
}

#[inline]
#[cfg(target_feature = "fma")]
fn muladd(self, mul: Self, add: Self) -> Self {
unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, add) }
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
std::arch::x86_64::_mm256_fmadd_ps(self, mul, add)
}

#[inline]
#[cfg(target_feature = "fma")]
fn mulsub(self, mul: Self, sub: Self) -> Self {
unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, sub) }
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
std::arch::x86_64::_mm256_fmsub_ps(self, mul, sub)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
fn muladd(self, mul: Self, add: Self) -> Self {
unsafe fn muladd(self, mul: Self, add: Self) -> Self {
self.mul(mul).add(add)
}

#[inline]
#[cfg(not(target_feature = "fma"))]
fn mulsub(self, mul: Self, sub: Self) -> Self {
unsafe fn mulsub(self, mul: Self, sub: Self) -> Self {
self.mul(mul).sub(sub)
}
}
Loading