Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX2 and AVX512 optimization #1552

Merged
merged 4 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 217 additions & 7 deletions src/lib/openjp2/dwt.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#ifdef __SSSE3__
#include <tmmintrin.h>
#endif
#ifdef __AVX2__
#if (defined(__AVX2__) || defined(__AVX512F__))
#include <immintrin.h>
#endif

Expand All @@ -66,7 +66,10 @@
#define OPJ_WS(i) v->mem[(i)*2]
#define OPJ_WD(i) v->mem[(1+(i)*2)]

#ifdef __AVX2__
#if defined(__AVX512F__)
/** Number of int32 values in a AVX512 register */
#define VREG_INT_COUNT 16
#elif defined(__AVX2__)
/** Number of int32 values in a AVX2 register */
#define VREG_INT_COUNT 8
#else
Expand Down Expand Up @@ -331,6 +334,51 @@ static void opj_dwt_decode_1(const opj_dwt_t *v)

#endif /* STANDARD_SLOW_VERSION */

#if defined(__AVX512F__)
static int32_t loop_short_sse(int32_t len, const int32_t** lf_ptr,
const int32_t** hf_ptr, int32_t** out_ptr,
int32_t* prev_even)
{
int32_t next_even;
__m128i odd, even_m1, unpack1, unpack2;
const int32_t batch = (len - 2) / 8;
const __m128i two = _mm_set1_epi32(2);

for (int32_t i = 0; i < batch; i++) {
const __m128i lf_ = _mm_loadu_si128((__m128i*)(*lf_ptr + 1));
const __m128i hf1_ = _mm_loadu_si128((__m128i*)(*hf_ptr));
const __m128i hf2_ = _mm_loadu_si128((__m128i*)(*hf_ptr + 1));

__m128i even = _mm_add_epi32(hf1_, hf2_);
even = _mm_add_epi32(even, two);
even = _mm_srai_epi32(even, 2);
even = _mm_sub_epi32(lf_, even);

next_even = _mm_extract_epi32(even, 3);
even_m1 = _mm_bslli_si128(even, 4);
even_m1 = _mm_insert_epi32(even_m1, *prev_even, 0);

//out[0] + out[2]
odd = _mm_add_epi32(even_m1, even);
odd = _mm_srai_epi32(odd, 1);
odd = _mm_add_epi32(odd, hf1_);

unpack1 = _mm_unpacklo_epi32(even_m1, odd);
unpack2 = _mm_unpackhi_epi32(even_m1, odd);

_mm_storeu_si128((__m128i*)(*out_ptr + 0), unpack1);
_mm_storeu_si128((__m128i*)(*out_ptr + 4), unpack2);

*prev_even = next_even;

*out_ptr += 8;
*lf_ptr += 4;
*hf_ptr += 4;
}
return batch;
}
#endif

#if !defined(STANDARD_SLOW_VERSION)
static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
const OPJ_INT32 sn,
Expand Down Expand Up @@ -363,6 +411,145 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
if (!(len & 1)) { /* if len is even */
tmp[len - 1] = in_odd[(len - 1) / 2] + tmp[len - 2];
}
#else
#if defined(__AVX512F__)
OPJ_INT32* out_ptr = tmp;
int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1);

const __m512i permutevar_mask = _mm512_setr_epi32(
0x10, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e);
const __m512i store1_perm = _mm512_setr_epi64(0x00, 0x01, 0x08, 0x09, 0x02,
0x03, 0x0a, 0x0b);
const __m512i store2_perm = _mm512_setr_epi64(0x04, 0x05, 0x0c, 0x0d, 0x06,
0x07, 0x0e, 0x0f);

const __m512i two = _mm512_set1_epi32(2);

int32_t simd_batch_512 = (len - 2) / 32;
int32_t leftover;

for (i = 0; i < simd_batch_512; i++) {
const __m512i lf_avx2 = _mm512_loadu_si512((__m512i*)(in_even + 1));
const __m512i hf1_avx2 = _mm512_loadu_si512((__m512i*)(in_odd));
const __m512i hf2_avx2 = _mm512_loadu_si512((__m512i*)(in_odd + 1));
int32_t next_even;
__m512i duplicate, even_m1, odd, unpack1, unpack2, store1, store2;

__m512i even = _mm512_add_epi32(hf1_avx2, hf2_avx2);
even = _mm512_add_epi32(even, two);
even = _mm512_srai_epi32(even, 2);
even = _mm512_sub_epi32(lf_avx2, even);

next_even = _mm_extract_epi32(_mm512_extracti32x4_epi32(even, 3), 3);

duplicate = _mm512_set1_epi32(prev_even);
even_m1 = _mm512_permutex2var_epi32(even, permutevar_mask, duplicate);

//out[0] + out[2]
odd = _mm512_add_epi32(even_m1, even);
odd = _mm512_srai_epi32(odd, 1);
odd = _mm512_add_epi32(odd, hf1_avx2);

unpack1 = _mm512_unpacklo_epi32(even_m1, odd);
unpack2 = _mm512_unpackhi_epi32(even_m1, odd);

store1 = _mm512_permutex2var_epi64(unpack1, store1_perm, unpack2);
store2 = _mm512_permutex2var_epi64(unpack1, store2_perm, unpack2);

_mm512_storeu_si512(out_ptr, store1);
_mm512_storeu_si512(out_ptr + 16, store2);

prev_even = next_even;

out_ptr += 32;
in_even += 16;
in_odd += 16;
}

leftover = len - simd_batch_512 * 32;
if (leftover > 8) {
leftover -= 8 * loop_short_sse(leftover, &in_even, &in_odd, &out_ptr,
&prev_even);
}
out_ptr[0] = prev_even;

for (j = 1; j < (leftover - 2); j += 2) {
out_ptr[2] = in_even[1] - ((in_odd[0] + (in_odd[1]) + 2) >> 2);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
in_even++;
in_odd++;
out_ptr += 2;
}

if (len & 1) {
out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
} else { //!(len & 1)
out_ptr[1] = in_odd[0] + out_ptr[0];
}
#elif defined(__AVX2__)
OPJ_INT32* out_ptr = tmp;
int32_t prev_even = in_even[0] - ((in_odd[0] + 1) >> 1);

const __m256i reg_permutevar_mask_move_right = _mm256_setr_epi32(0x00, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06);
const __m256i two = _mm256_set1_epi32(2);

int32_t simd_batch = (len - 2) / 16;
int32_t next_even;
__m256i even_m1, odd, unpack1_avx2, unpack2_avx2;

for (i = 0; i < simd_batch; i++) {
const __m256i lf_avx2 = _mm256_loadu_si256((__m256i*)(in_even + 1));
const __m256i hf1_avx2 = _mm256_loadu_si256((__m256i*)(in_odd));
const __m256i hf2_avx2 = _mm256_loadu_si256((__m256i*)(in_odd + 1));

__m256i even = _mm256_add_epi32(hf1_avx2, hf2_avx2);
even = _mm256_add_epi32(even, two);
even = _mm256_srai_epi32(even, 2);
even = _mm256_sub_epi32(lf_avx2, even);

next_even = _mm_extract_epi32(_mm256_extracti128_si256(even, 1), 3);
even_m1 = _mm256_permutevar8x32_epi32(even, reg_permutevar_mask_move_right);
even_m1 = _mm256_blend_epi32(even_m1, _mm256_set1_epi32(prev_even), (1 << 0));

//out[0] + out[2]
odd = _mm256_add_epi32(even_m1, even);
odd = _mm256_srai_epi32(odd, 1);
odd = _mm256_add_epi32(odd, hf1_avx2);

unpack1_avx2 = _mm256_unpacklo_epi32(even_m1, odd);
unpack2_avx2 = _mm256_unpackhi_epi32(even_m1, odd);

_mm_storeu_si128((__m128i*)(out_ptr + 0), _mm256_castsi256_si128(unpack1_avx2));
_mm_storeu_si128((__m128i*)(out_ptr + 4), _mm256_castsi256_si128(unpack2_avx2));
_mm_storeu_si128((__m128i*)(out_ptr + 8), _mm256_extracti128_si256(unpack1_avx2,
0x1));
_mm_storeu_si128((__m128i*)(out_ptr + 12),
_mm256_extracti128_si256(unpack2_avx2, 0x1));

prev_even = next_even;

out_ptr += 16;
in_even += 8;
in_odd += 8;
}
out_ptr[0] = prev_even;
for (j = simd_batch * 16 + 1; j < (len - 2); j += 2) {
out_ptr[2] = in_even[1] - ((in_odd[0] + in_odd[1] + 2) >> 2);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
in_even++;
in_odd++;
out_ptr += 2;
}

if (len & 1) {
out_ptr[2] = in_even[1] - ((in_odd[0] + 1) >> 1);
out_ptr[1] = in_odd[0] + ((out_ptr[0] + out_ptr[2]) >> 1);
} else { //!(len & 1)
out_ptr[1] = in_odd[0] + out_ptr[0];
}
#else
OPJ_INT32 d1c, d1n, s1n, s0c, s0n;

Expand Down Expand Up @@ -397,7 +584,8 @@ static void opj_idwt53_h_cas0(OPJ_INT32* tmp,
} else {
tmp[len - 1] = d1n + s0n;
}
#endif
#endif /*(__AVX512F__ || __AVX2__)*/
#endif /*TWO_PASS_VERSION*/
memcpy(tiledp, tmp, (OPJ_UINT32)len * sizeof(OPJ_INT32));
}

Expand Down Expand Up @@ -511,10 +699,20 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#endif
}

#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION)
#if (defined(__SSE2__) || defined(__AVX2__) || defined(__AVX512F__)) && !defined(STANDARD_SLOW_VERSION)

/* Conveniency macros to improve the readability of the formulas */
#if __AVX2__
#if defined(__AVX512F__)
#define VREG __m512i
#define LOAD_CST(x) _mm512_set1_epi32(x)
#define LOAD(x) _mm512_loadu_si512((const VREG*)(x))
#define LOADU(x) _mm512_loadu_si512((const VREG*)(x))
#define STORE(x,y) _mm512_storeu_si512((VREG*)(x),(y))
#define STOREU(x,y) _mm512_storeu_si512((VREG*)(x),(y))
#define ADD(x,y) _mm512_add_epi32((x),(y))
#define SUB(x,y) _mm512_sub_epi32((x),(y))
#define SAR(x,y) _mm512_srai_epi32((x),(y))
#elif defined(__AVX2__)
#define VREG __m256i
#define LOAD_CST(x) _mm256_set1_epi32(x)
#define LOAD(x) _mm256_load_si256((const VREG*)(x))
Expand Down Expand Up @@ -576,18 +774,24 @@ static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
const VREG two = LOAD_CST(2);

assert(len > 1);
#if __AVX2__
#if defined(__AVX512F__)
assert(PARALLEL_COLS_53 == 32);
assert(VREG_INT_COUNT == 16);
#elif defined(__AVX2__)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif

//For AVX512 code aligned load/store is set to it's unaligned equivalents
#if !defined(__AVX512F__)
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
#endif

s1n_0 = LOADU(in_even + 0);
s1n_1 = LOADU(in_even + VREG_INT_COUNT);
Expand Down Expand Up @@ -678,18 +882,24 @@ static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
const OPJ_INT32* in_odd = &tiledp_col[0];

assert(len > 2);
#if __AVX2__
#if defined(__AVX512F__)
assert(PARALLEL_COLS_53 == 32);
assert(VREG_INT_COUNT == 16);
#elif defined(__AVX2__)
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif

//For AVX512 code aligned load/store is set to it's unaligned equivalents
#if !defined(__AVX512F__)
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((OPJ_SIZE_T)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
#endif

s1_0 = LOADU(in_even + stride);
/* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
Expand Down
Loading