diff --git a/crates/jxl-grid/src/simd.rs b/crates/jxl-grid/src/simd.rs index 74815982..333b2839 100644 --- a/crates/jxl-grid/src/simd.rs +++ b/crates/jxl-grid/src/simd.rs @@ -7,47 +7,89 @@ pub trait SimdVector: Copy { fn available() -> bool; /// Initialize a SIMD vector with zeroes. - fn zero() -> Self; + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn zero() -> Self; /// Initialize a SIMD vector with given floats. - fn set(val: [f32; N]) -> Self; + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn set(val: [f32; N]) -> Self; /// Initialize a SIMD vector filled with given float. - fn splat_f32(val: f32) -> Self; + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn splat_f32(val: f32) -> Self; /// Load a SIMD vector from memory. /// /// The pointer doesn't need to be aligned. /// /// # Safety - /// The given pointer must be valid. + /// CPU should support the vector type, and the given pointer must be valid. unsafe fn load(ptr: *const f32) -> Self; /// Load a SIMD vector from memory with aligned pointer. /// /// # Safety - /// The given pointer must be valid and properly aligned. + /// CPU should support the vector type, and the given pointer must be valid and properly + /// aligned. unsafe fn load_aligned(ptr: *const f32) -> Self; /// Extract a single element from the SIMD vector. - fn extract_f32(self) -> f32; + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn extract_f32(self) -> f32; /// Store the SIMD vector to memory. /// /// The pointer doesn't need to be aligned. /// /// # Safety - /// The given pointer must be valid. + /// CPU should support the vector type, and the given pointer must be valid. unsafe fn store(self, ptr: *mut f32); /// Store the SIMD vector to memory with aligned pointer. /// /// # Safety - /// The given pointer must be valid and properly aligned. + /// CPU should support the vector type, and the given pointer must be valid and properly + /// aligned. unsafe fn store_aligned(self, ptr: *mut f32); - fn add(self, lhs: Self) -> Self; - fn sub(self, lhs: Self) -> Self; - fn mul(self, lhs: Self) -> Self; - fn div(self, lhs: Self) -> Self; - fn abs(self) -> Self; + /// Add two vectors element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn add(self, lhs: Self) -> Self; + /// Subtract two vectors element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn sub(self, lhs: Self) -> Self; + /// Multiply two vectors element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn mul(self, lhs: Self) -> Self; + /// Divide two vectors element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn div(self, lhs: Self) -> Self; + /// Compute the absolute value for each element of the vector. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn abs(self) -> Self; - fn muladd(self, mul: Self, add: Self) -> Self; - fn mulsub(self, mul: Self, sub: Self) -> Self; + /// Computes `self * mul + add` element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn muladd(self, mul: Self, add: Self) -> Self; + /// Computes `self * mul - add` element-wise. + /// + /// # Safety + /// CPU should support the vector type. + unsafe fn mulsub(self, mul: Self, sub: Self) -> Self; } #[cfg(target_arch = "x86_64")] @@ -61,19 +103,19 @@ impl SimdVector for std::arch::x86_64::__m128 { } #[inline] - fn zero() -> Self { - unsafe { std::arch::x86_64::_mm_setzero_ps() } + unsafe fn zero() -> Self { + std::arch::x86_64::_mm_setzero_ps() } #[inline] - fn set(val: [f32; N]) -> Self { + unsafe fn set(val: [f32; N]) -> Self { assert_eq!(N, Self::SIZE); - unsafe { std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0]) } + std::arch::x86_64::_mm_set_ps(val[3], val[2], val[1], val[0]) } #[inline] - fn splat_f32(val: f32) -> Self { - unsafe { std::arch::x86_64::_mm_set1_ps(val) } + unsafe fn splat_f32(val: f32) -> Self { + std::arch::x86_64::_mm_set1_ps(val) } #[inline] @@ -87,9 +129,9 @@ impl SimdVector for std::arch::x86_64::__m128 { } #[inline] - fn extract_f32(self) -> f32 { + unsafe fn extract_f32(self) -> f32 { assert!((N as usize) < Self::SIZE); - let bits = unsafe { std::arch::x86_64::_mm_extract_ps::(self) }; + let bits = std::arch::x86_64::_mm_extract_ps::(self); f32::from_bits(bits as u32) } @@ -104,56 +146,179 @@ impl SimdVector for std::arch::x86_64::__m128 { } #[inline] - fn add(self, lhs: Self) -> Self { - unsafe { std::arch::x86_64::_mm_add_ps(self, lhs) } + unsafe fn add(self, lhs: Self) -> Self { + std::arch::x86_64::_mm_add_ps(self, lhs) + } + + #[inline] + unsafe fn sub(self, lhs: Self) -> Self { + std::arch::x86_64::_mm_sub_ps(self, lhs) + } + + #[inline] + unsafe fn mul(self, lhs: Self) -> Self { + std::arch::x86_64::_mm_mul_ps(self, lhs) + } + + #[inline] + unsafe fn div(self, lhs: Self) -> Self { + std::arch::x86_64::_mm_div_ps(self, lhs) + } + + #[inline] + unsafe fn abs(self) -> Self { + let x = std::arch::x86_64::_mm_undefined_si128(); + let mask = std::arch::x86_64::_mm_srli_epi32::<1>( + std::arch::x86_64::_mm_cmpeq_epi32(x, x) + ); + std::arch::x86_64::_mm_and_ps( + std::arch::x86_64::_mm_castsi128_ps(mask), + self, + ) + } + + #[inline] + #[cfg(target_feature = "fma")] + unsafe fn muladd(self, mul: Self, add: Self) -> Self { + std::arch::x86_64::_mm_fmadd_ps(self, mul, add) + } + + #[inline] + #[cfg(target_feature = "fma")] + unsafe fn mulsub(self, mul: Self, sub: Self) -> Self { + std::arch::x86_64::_mm_fmsub_ps(self, mul, sub) + } + + #[inline] + #[cfg(not(target_feature = "fma"))] + unsafe fn muladd(self, mul: Self, add: Self) -> Self { + self.mul(mul).add(add) + } + + #[inline] + #[cfg(not(target_feature = "fma"))] + unsafe fn mulsub(self, mul: Self, sub: Self) -> Self { + self.mul(mul).sub(sub) + } +} + +#[cfg(target_arch = "x86_64")] +impl SimdVector for std::arch::x86_64::__m256 { + const SIZE: usize = 8; + + #[inline] + fn available() -> bool { + is_x86_feature_detected!("avx2") + } + + #[inline] + unsafe fn zero() -> Self { + std::arch::x86_64::_mm256_setzero_ps() + } + + #[inline] + unsafe fn set(val: [f32; N]) -> Self { + assert_eq!(N, Self::SIZE); + std::arch::x86_64::_mm256_set_ps(val[7], val[6], val[5], val[4], val[3], val[2], val[1], val[0]) } #[inline] - fn sub(self, lhs: Self) -> Self { - unsafe { std::arch::x86_64::_mm_sub_ps(self, lhs) } + unsafe fn splat_f32(val: f32) -> Self { + std::arch::x86_64::_mm256_set1_ps(val) } #[inline] - fn mul(self, lhs: Self) -> Self { - unsafe { std::arch::x86_64::_mm_mul_ps(self, lhs) } + unsafe fn load(ptr: *const f32) -> Self { + std::arch::x86_64::_mm256_loadu_ps(ptr) } #[inline] - fn div(self, lhs: Self) -> Self { - unsafe { std::arch::x86_64::_mm_div_ps(self, lhs) } + unsafe fn load_aligned(ptr: *const f32) -> Self { + std::arch::x86_64::_mm256_load_ps(ptr) } #[inline] - fn abs(self) -> Self { - unsafe { - std::arch::x86_64::_mm_andnot_ps( - Self::splat_f32(f32::from_bits(0x80000000)), - self, - ) + unsafe fn extract_f32(self) -> f32 { + unsafe fn inner(val: std::arch::x86_64::__m256) -> f32 { + std::arch::x86_64::_mm256_extractf128_ps::(val).extract_f32::() + } + + assert!((N as usize) < Self::SIZE); + match N { + 0..=3 => inner::<0, N>(self), + 4 => inner::<1, 0>(self), + 5 => inner::<1, 1>(self), + 6 => inner::<1, 2>(self), + 7 => inner::<1, 3>(self), + // SAFETY: 0 <= N < 8 by assertion. + _ => std::hint::unreachable_unchecked(), } } + #[inline] + unsafe fn store(self, ptr: *mut f32) { + std::arch::x86_64::_mm256_storeu_ps(ptr, self); + } + + #[inline] + unsafe fn store_aligned(self, ptr: *mut f32) { + std::arch::x86_64::_mm256_store_ps(ptr, self); + } + + #[inline] + unsafe fn add(self, lhs: Self) -> Self { + std::arch::x86_64::_mm256_add_ps(self, lhs) + } + + #[inline] + unsafe fn sub(self, lhs: Self) -> Self { + std::arch::x86_64::_mm256_sub_ps(self, lhs) + } + + #[inline] + unsafe fn mul(self, lhs: Self) -> Self { + std::arch::x86_64::_mm256_mul_ps(self, lhs) + } + + #[inline] + unsafe fn div(self, lhs: Self) -> Self { + std::arch::x86_64::_mm256_div_ps(self, lhs) + } + + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn abs(self) -> Self { + let x = std::arch::x86_64::_mm256_undefined_si256(); + let mask = std::arch::x86_64::_mm256_srli_epi32::<1>( + std::arch::x86_64::_mm256_cmpeq_epi32(x, x) + ); + std::arch::x86_64::_mm256_and_ps( + std::arch::x86_64::_mm256_castsi256_ps(mask), + self, + ) + } + #[inline] #[cfg(target_feature = "fma")] - fn muladd(self, mul: Self, add: Self) -> Self { - unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, add) } + unsafe fn muladd(self, mul: Self, add: Self) -> Self { + std::arch::x86_64::_mm256_fmadd_ps(self, mul, add) } #[inline] #[cfg(target_feature = "fma")] - fn mulsub(self, mul: Self, sub: Self) -> Self { - unsafe { std::arch::x86_64::_mm_fmadd_ps(self, mul, sub) } + unsafe fn mulsub(self, mul: Self, sub: Self) -> Self { + std::arch::x86_64::_mm256_fmsub_ps(self, mul, sub) } #[inline] #[cfg(not(target_feature = "fma"))] - fn muladd(self, mul: Self, add: Self) -> Self { + unsafe fn muladd(self, mul: Self, add: Self) -> Self { self.mul(mul).add(add) } #[inline] #[cfg(not(target_feature = "fma"))] - fn mulsub(self, mul: Self, sub: Self) -> Self { + unsafe fn mulsub(self, mul: Self, sub: Self) -> Self { self.mul(mul).sub(sub) } } diff --git a/crates/jxl-render/src/dct/x86_64/mod.rs b/crates/jxl-render/src/dct/x86_64/mod.rs index da79563d..bbace83f 100644 --- a/crates/jxl-render/src/dct/x86_64/mod.rs +++ b/crates/jxl-render/src/dct/x86_64/mod.rs @@ -22,7 +22,9 @@ pub fn dct_2d(io: &mut CutGrid<'_>, direction: DctDirection) { }; if io.width() == 2 && io.height() == 8 { - return dct8x8(&mut io, direction); + unsafe { + return dct8x8(&mut io, direction); + } } dct_2d_lane(&mut io, direction); @@ -37,7 +39,7 @@ fn dct_2d_lane(io: &mut CutGrid<'_, Lane>, direction: DctDirection) { } } -fn dct4_vec_forward(v: Lane) -> Lane { +unsafe fn dct4_vec_forward(v: Lane) -> Lane { const SEC0: f32 = 0.5411961; const SEC1: f32 = 1.306563; @@ -62,7 +64,7 @@ fn dct4_vec_forward(v: Lane) -> Lane { a.muladd(mul_a, b.mul(mul_b)) } -fn dct4_vec_inverse(v: Lane) -> Lane { +unsafe fn dct4_vec_inverse(v: Lane) -> Lane { const SEC0: f32 = 0.5411961; const SEC1: f32 = 1.306563; @@ -87,7 +89,7 @@ fn dct4_vec_inverse(v: Lane) -> Lane { tmp_b.muladd(mul, tmp_a) } -fn dct8_vec_forward(vl: Lane, vr: Lane) -> (Lane, Lane) { +unsafe fn dct8_vec_forward(vl: Lane, vr: Lane) -> (Lane, Lane) { #[allow(clippy::excessive_precision)] let sec_vec = Lane::set([ 0.2548977895520796, @@ -109,7 +111,7 @@ fn dct8_vec_forward(vl: Lane, vr: Lane) -> (Lane, Lane) { ) } -fn dct8_vec_inverse(vl: Lane, vr: Lane) -> (Lane, Lane) { +unsafe fn dct8_vec_inverse(vl: Lane, vr: Lane) -> (Lane, Lane) { #[allow(clippy::excessive_precision)] let sec_vec = Lane::set([ 0.5097955791041592, @@ -132,7 +134,7 @@ fn dct8_vec_inverse(vl: Lane, vr: Lane) -> (Lane, Lane) { ) } -fn dct8x8(io: &mut CutGrid<'_, Lane>, direction: DctDirection) { +unsafe fn dct8x8(io: &mut CutGrid<'_, Lane>, direction: DctDirection) { let (mut col0, mut col1) = io.split_horizontal(1); if direction == DctDirection::Forward { @@ -156,7 +158,7 @@ fn dct8x8(io: &mut CutGrid<'_, Lane>, direction: DctDirection) { } } -fn column_dct_lane( +unsafe fn column_dct_lane( io: &mut CutGrid<'_, Lane>, scratch: &mut [Lane], direction: DctDirection, @@ -178,7 +180,7 @@ fn column_dct_lane( } } -fn row_dct_lane( +unsafe fn row_dct_lane( io: &mut CutGrid<'_, Lane>, scratch: &mut [Lane], direction: DctDirection, @@ -202,7 +204,7 @@ fn row_dct_lane( } } -fn dct4_forward(input: [Lane; 4]) -> [Lane; 4] { +unsafe fn dct4_forward(input: [Lane; 4]) -> [Lane; 4] { let sec0 = Lane::splat_f32(0.5411961 / 4.0); let sec1 = Lane::splat_f32(1.306563 / 4.0); let quarter = Lane::splat_f32(0.25); @@ -223,7 +225,7 @@ fn dct4_forward(input: [Lane; 4]) -> [Lane; 4] { ] } -fn dct4_inverse(input: [Lane; 4]) -> [Lane; 4] { +unsafe fn dct4_inverse(input: [Lane; 4]) -> [Lane; 4] { let sec0 = Lane::splat_f32(0.5411961); let sec1 = Lane::splat_f32(1.306563); let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2); @@ -243,7 +245,7 @@ fn dct4_inverse(input: [Lane; 4]) -> [Lane; 4] { ] } -fn dct8_forward(io: &mut CutGrid<'_, Lane>) { +unsafe fn dct8_forward(io: &mut CutGrid<'_, Lane>) { assert!(io.height() == 8); let half = Lane::splat_f32(0.5); let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2); @@ -273,7 +275,7 @@ fn dct8_forward(io: &mut CutGrid<'_, Lane>) { *io.get_mut(0, 7) = output1[3]; } -fn dct8_inverse(io: &mut CutGrid<'_, Lane>) { +unsafe fn dct8_inverse(io: &mut CutGrid<'_, Lane>) { assert!(io.height() == 8); let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2); let sec = consts::sec_half_small(8); @@ -294,7 +296,7 @@ fn dct8_inverse(io: &mut CutGrid<'_, Lane>) { } } -fn dct(io: &mut [Lane], scratch: &mut [Lane], direction: DctDirection) { +unsafe fn dct(io: &mut [Lane], scratch: &mut [Lane], direction: DctDirection) { let n = io.len(); assert!(scratch.len() == n); diff --git a/crates/jxl-render/src/filter/epf.rs b/crates/jxl-render/src/filter/epf.rs index f43fe2d4..e9d2a15b 100644 --- a/crates/jxl-render/src/filter/epf.rs +++ b/crates/jxl-render/src/filter/epf.rs @@ -119,33 +119,14 @@ pub fn apply_epf( } } - #[cfg(target_arch = "x86_64")] - { - super::x86_64::epf_step0_sse2( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - sigma.pass0_sigma_scale, - ); - } - #[cfg(not(target_arch = "x86_64"))] - { - super::generic::epf_step( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - sigma.pass0_sigma_scale, - &[ - (0, -1), (-1, 0), (1, 0), (0, 1), - (0, -2), (-1, -1), (1, -1), (-2, 0), (2, 0), (-1, 1), (1, 1), (0, 2), - ], - &[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)], - ); - } + super::impls::epf_step0( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + sigma.pass0_sigma_scale, + ); std::mem::swap(&mut fb_in, &mut fb_out); } @@ -176,30 +157,14 @@ pub fn apply_epf( } } - #[cfg(target_arch = "x86_64")] - { - super::x86_64::epf_step1_sse2( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - 1.0, - ); - } - #[cfg(not(target_arch = "x86_64"))] - { - super::generic::epf_step( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - 1.0, - &[(0, -1), (-1, 0), (1, 0), (0, 1)], - &[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)], - ); - } + super::impls::epf_step1( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + 1.0, + ); std::mem::swap(&mut fb_in, &mut fb_out); } @@ -230,30 +195,14 @@ pub fn apply_epf( } } - #[cfg(target_arch = "x86_64")] - { - super::x86_64::epf_step2_sse2( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - sigma.pass2_sigma_scale, - ); - } - #[cfg(not(target_arch = "x86_64"))] - { - super::generic::epf_step( - &fb_in, - &mut fb_out, - sigma_grid, - channel_scale, - sigma.border_sad_mul, - sigma.pass2_sigma_scale, - &[(0, -1), (-1, 0), (1, 0), (0, 1)], - &[(0, 0)], - ); - } + super::impls::epf_step2( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + sigma.pass2_sigma_scale, + ); std::mem::swap(&mut fb_in, &mut fb_out); } diff --git a/crates/jxl-render/src/filter/impls.rs b/crates/jxl-render/src/filter/impls.rs new file mode 100644 index 00000000..9f99b301 --- /dev/null +++ b/crates/jxl-render/src/filter/impls.rs @@ -0,0 +1,10 @@ +#[cfg(not(target_arch = "x86_64"))] +mod generic; +#[cfg(target_arch = "x86_64")] +mod x86_64; + +#[cfg(not(target_arch = "x86_64"))] +pub use generic::*; + +#[cfg(target_arch = "x86_64")] +pub use x86_64::*; diff --git a/crates/jxl-render/src/filter/generic.rs b/crates/jxl-render/src/filter/impls/generic.rs similarity index 66% rename from crates/jxl-render/src/filter/generic.rs rename to crates/jxl-render/src/filter/impls/generic.rs index 8e00d932..8aa6a280 100644 --- a/crates/jxl-render/src/filter/generic.rs +++ b/crates/jxl-render/src/filter/impls/generic.rs @@ -7,7 +7,7 @@ fn weight(scaled_distance: f32, sigma: f32, step_multiplier: f32) -> f32 { } #[allow(clippy::too_many_arguments)] -pub fn epf_step( +fn epf_step( input: &[SimpleGrid; 3], output: &mut [SimpleGrid; 3], sigma_grid: &SimpleGrid, @@ -85,3 +85,66 @@ pub fn epf_step( } } } + +pub fn epf_step0( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + epf_step( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + sigma.pass0_sigma_scale, + &[ + (0, -1), (-1, 0), (1, 0), (0, 1), + (0, -2), (-1, -1), (1, -1), (-2, 0), (2, 0), (-1, 1), (1, 1), (0, 2), + ], + &[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)], + ); +} + +pub fn epf_step1( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + epf_step( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + sigma.pass0_sigma_scale, + &[(0, -1), (-1, 0), (1, 0), (0, 1)], + &[(0, 0), (0, -1), (-1, 0), (1, 0), (0, 1)], + ); +} + +pub fn epf_step2( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + epf_step( + &fb_in, + &mut fb_out, + sigma_grid, + channel_scale, + sigma.border_sad_mul, + sigma.pass0_sigma_scale, + &[(0, -1), (-1, 0), (1, 0), (0, 1)], + &[(0, 0)], + ); +} diff --git a/crates/jxl-render/src/filter/impls/x86_64.rs b/crates/jxl-render/src/filter/impls/x86_64.rs new file mode 100644 index 00000000..6d5f602b --- /dev/null +++ b/crates/jxl-render/src/filter/impls/x86_64.rs @@ -0,0 +1,109 @@ +use jxl_grid::SimpleGrid; + +mod epf_sse2; +mod epf_avx2; + +pub fn epf_step0( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + // SAFETY: Features are checked above. + unsafe { + return epf_avx2::epf_step0_avx2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ); + } + } + + // SAFETY: x86_64 always supports SSE2. + unsafe { + epf_sse2::epf_step0_sse2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ) + } +} + +pub fn epf_step1( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + // SAFETY: Features are checked above. + unsafe { + return epf_avx2::epf_step1_avx2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ); + } + } + + // SAFETY: x86_64 always supports SSE2. + unsafe { + epf_sse2::epf_step1_sse2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ) + } +} + +pub fn epf_step2( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, +) { + if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { + // SAFETY: Features are checked above. + unsafe { + return epf_avx2::epf_step2_avx2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ); + } + } + + // SAFETY: x86_64 always supports SSE2. + unsafe { + epf_sse2::epf_step2_sse2( + input, + output, + sigma_grid, + channel_scale, + border_sad_mul, + step_multiplier, + ) + } +} diff --git a/crates/jxl-render/src/filter/impls/x86_64/epf_avx2.rs b/crates/jxl-render/src/filter/impls/x86_64/epf_avx2.rs new file mode 100644 index 00000000..2662d26a --- /dev/null +++ b/crates/jxl-render/src/filter/impls/x86_64/epf_avx2.rs @@ -0,0 +1,148 @@ +use std::arch::x86_64::*; + +use jxl_grid::SimpleGrid; +use jxl_grid::SimdVector; + +type Vector = __m256; + +#[inline] +#[target_feature(enable = "avx2")] +#[target_feature(enable = "fma")] +unsafe fn weight_avx2(scaled_distance: Vector, sigma: f32, step_multiplier: Vector) -> Vector { + let neg_inv_sigma = Vector::splat_f32(6.6 * (std::f32::consts::FRAC_1_SQRT_2 - 1.0) / sigma).mul(step_multiplier); + let result = _mm256_fmadd_ps( + scaled_distance, + neg_inv_sigma, + Vector::splat_f32(1.0), + ); + _mm256_max_ps(result, Vector::zero()) +} + +macro_rules! define_epf_avx2 { + { $($v:vis unsafe fn $name:ident ($width:ident, $kernel_diff:expr, $dist_diff:expr $(,)?); )* } => { + $( + #[target_feature(enable = "avx2")] + #[target_feature(enable = "fma")] + $v unsafe fn $name( + input: &[SimpleGrid; 3], + output: &mut [SimpleGrid; 3], + sigma_grid: &SimpleGrid, + channel_scale: [f32; 3], + border_sad_mul: f32, + step_multiplier: f32, + ) { + let width = input[0].width(); + let $width = width as isize; + let height = input[0].height(); + assert_eq!(input[1].width(), width); + assert_eq!(input[2].width(), width); + assert_eq!(input[1].height(), height); + assert_eq!(input[2].height(), height); + assert_eq!(output[0].width(), width); + assert_eq!(output[1].width(), width); + assert_eq!(output[2].width(), width); + assert_eq!(output[0].height(), height); + assert_eq!(output[1].height(), height); + assert_eq!(output[2].height(), height); + + let input_buf = [input[0].buf(), input[1].buf(), input[2].buf()]; + let mut output_buf = { + let [a, b, c] = output; + [a.buf_mut(), b.buf_mut(), c.buf_mut()] + }; + + let sm_y_edge = Vector::splat_f32(border_sad_mul * step_multiplier); + let sm = Vector::set([ + border_sad_mul * step_multiplier, step_multiplier, step_multiplier, step_multiplier, + step_multiplier, step_multiplier, step_multiplier, border_sad_mul * step_multiplier, + ]); + + for y in 3..height - 4 { + let sigma_row = &sigma_grid.buf()[(y - 3) / 8 * sigma_grid.width()..][..(width + 1) / 8]; + for (vx, &sigma) in sigma_row.iter().enumerate() { + let x = 3 + vx * 8; + let idx_base = y * width + x; + + // SAFETY: Indexing doesn't go out of bounds since we have padding after image region. + let mut sum_weights = Vector::splat_f32(1.0); + let mut sum_channels = input_buf.map(|buf| { + Vector::load(buf.as_ptr().add(idx_base)) + }); + + if sigma < 0.3 { + for (buf, sum) in output_buf.iter_mut().zip(sum_channels) { + sum.store(buf.as_mut_ptr().add(idx_base)); + } + continue; + } + + for kdiff in $kernel_diff { + let kernel_base = idx_base.wrapping_add_signed(kdiff); + let mut dist = Vector::zero(); + for (buf, scale) in input_buf.into_iter().zip(channel_scale) { + let scale = Vector::splat_f32(scale); + for diff in $dist_diff { + dist = _mm256_fmadd_ps( + scale, + Vector::load(buf.as_ptr().add(idx_base.wrapping_add_signed(diff))).sub( + Vector::load(buf.as_ptr().add(kernel_base.wrapping_add_signed(diff))) + ).abs(), + dist, + ); + } + } + + let weight = weight_avx2( + dist, + sigma, + if (y - 3) % 8 == 0 || (y - 3) % 8 == 7 { + sm_y_edge + } else { + sm + }, + ); + sum_weights = sum_weights.add(weight); + + for (sum, buf) in sum_channels.iter_mut().zip(input_buf) { + *sum = _mm256_fmadd_ps( + weight, + Vector::load(buf.as_ptr().add(kernel_base)), + *sum, + ); + } + } + + for (buf, sum) in output_buf.iter_mut().zip(sum_channels) { + let val = sum.div(sum_weights); + val.store(buf.as_mut_ptr().add(idx_base)); + } + } + } + } + )* + }; +} + +define_epf_avx2! { + pub unsafe fn epf_step0_avx2( + width, + [ + -2 * width, + -1 - width, -width, 1 - width, + -2, -1, 1, 2, + width - 1, width, width + 1, + 2 * width, + ], + [-width, -1, 0, 1, width], + ); + pub unsafe fn epf_step1_avx2( + width, + [-width, -1, 1, width], + [-width, -1, 0, 1, width], + ); + pub unsafe fn epf_step2_avx2( + width, + [-width, -1, 1, width], + [0isize], + ); +} diff --git a/crates/jxl-render/src/filter/x86_64.rs b/crates/jxl-render/src/filter/impls/x86_64/epf_sse2.rs similarity index 77% rename from crates/jxl-render/src/filter/x86_64.rs rename to crates/jxl-render/src/filter/impls/x86_64/epf_sse2.rs index 06933c9a..44f97abf 100644 --- a/crates/jxl-render/src/filter/x86_64.rs +++ b/crates/jxl-render/src/filter/impls/x86_64/epf_sse2.rs @@ -6,16 +6,16 @@ use jxl_grid::SimdVector; type Vector = __m128; #[inline] -fn weight_sse2(scaled_distance: Vector, sigma: f32, step_multiplier: Vector) -> Vector { +unsafe fn weight_sse2(scaled_distance: Vector, sigma: f32, step_multiplier: Vector) -> Vector { let neg_inv_sigma = Vector::splat_f32(6.6 * (std::f32::consts::FRAC_1_SQRT_2 - 1.0) / sigma).mul(step_multiplier); let result = scaled_distance.muladd(neg_inv_sigma, Vector::splat_f32(1.0)); - unsafe { _mm_max_ps(result, Vector::zero()) } + _mm_max_ps(result, Vector::zero()) } macro_rules! define_epf_sse2 { - { $($v:vis fn $name:ident ($width:ident, $kernel_diff:expr, $dist_diff:expr $(,)?); )* } => { + { $($v:vis unsafe fn $name:ident ($width:ident, $kernel_diff:expr, $dist_diff:expr $(,)?); )* } => { $( - $v fn $name( + $v unsafe fn $name( input: &[SimpleGrid; 3], output: &mut [SimpleGrid; 3], sigma_grid: &SimpleGrid, @@ -52,33 +52,31 @@ macro_rules! define_epf_sse2 { let sigma = *sigma_grid.get((x - 3) / 8, (y - 3) / 8).unwrap(); let idx_base = y * width + x; - if sigma < 0.3 { - for (input, output) in input_buf.into_iter().zip(&mut output_buf) { - output[idx_base..][..Vector::SIZE].copy_from_slice(&input[idx_base..][..Vector::SIZE]); - } - continue; - } - // SAFETY: Indexing doesn't go out of bounds since we have padding after image region. let mut sum_weights = Vector::splat_f32(1.0); let mut sum_channels = input_buf.map(|buf| { - unsafe { Vector::load(buf.as_ptr().add(idx_base)) } + Vector::load(buf.as_ptr().add(idx_base)) }); + if sigma < 0.3 { + for (buf, sum) in output_buf.iter_mut().zip(sum_channels) { + sum.store(buf.as_mut_ptr().add(idx_base)); + } + continue; + } + for kdiff in $kernel_diff { let kernel_base = idx_base.wrapping_add_signed(kdiff); let mut dist = Vector::zero(); for (buf, scale) in input_buf.into_iter().zip(channel_scale) { let scale = Vector::splat_f32(scale); for diff in $dist_diff { - unsafe { - dist = scale.muladd( - Vector::load(buf.as_ptr().add(idx_base.wrapping_add_signed(diff))).sub( - Vector::load(buf.as_ptr().add(kernel_base.wrapping_add_signed(diff))) - ).abs(), - dist, - ); - } + dist = scale.muladd( + Vector::load(buf.as_ptr().add(idx_base.wrapping_add_signed(diff))).sub( + Vector::load(buf.as_ptr().add(kernel_base.wrapping_add_signed(diff))) + ).abs(), + dist, + ); } } @@ -96,13 +94,13 @@ macro_rules! define_epf_sse2 { sum_weights = sum_weights.add(weight); for (sum, buf) in sum_channels.iter_mut().zip(input_buf) { - *sum = weight.muladd(unsafe { Vector::load(buf.as_ptr().add(kernel_base)) }, *sum); + *sum = weight.muladd(Vector::load(buf.as_ptr().add(kernel_base)), *sum); } } for (buf, sum) in output_buf.iter_mut().zip(sum_channels) { let val = sum.div(sum_weights); - unsafe { val.store(buf.as_mut_ptr().add(idx_base)); } + val.store(buf.as_mut_ptr().add(idx_base)); } } } @@ -112,7 +110,7 @@ macro_rules! define_epf_sse2 { } define_epf_sse2! { - pub fn epf_step0_sse2( + pub unsafe fn epf_step0_sse2( width, [ -2 * width, @@ -123,12 +121,12 @@ define_epf_sse2! { ], [-width, -1, 0, 1, width], ); - pub fn epf_step1_sse2( + pub unsafe fn epf_step1_sse2( width, [-width, -1, 1, width], [-width, -1, 0, 1, width], ); - pub fn epf_step2_sse2( + pub unsafe fn epf_step2_sse2( width, [-width, -1, 1, width], [0isize], diff --git a/crates/jxl-render/src/filter/mod.rs b/crates/jxl-render/src/filter/mod.rs index c07188f5..29195211 100644 --- a/crates/jxl-render/src/filter/mod.rs +++ b/crates/jxl-render/src/filter/mod.rs @@ -1,7 +1,4 @@ -#[cfg(not(target_arch = "x86_64"))] -mod generic; -#[cfg(target_arch = "x86_64")] -mod x86_64; +mod impls; mod epf; mod gabor;