From 99f666c1b605494b5c44e6e6bde9e1aa8d7f6f14 Mon Sep 17 00:00:00 2001 From: netrunnereve <139727413+netrunnereve@users.noreply.github.com> Date: Sat, 15 Jun 2024 21:34:02 -0400 Subject: [PATCH] iq3_s --- ggml-quants.c | 55 +++++++++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index e71d224d9fc32..f12c2cace47e3 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9452,8 +9452,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); - // AVX2 full_signs_1 is full_sign_bits_0 here - // AVX2 full_signs_2 is full_sign_bits_1 here + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here __m128i signs_0, signs_1; signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); @@ -9492,16 +9492,16 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); - __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); @@ -10798,12 +10798,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = _mm256_set1_epi32(256); + const __m128i idx_mask = _mm_set1_epi32(256); typedef union { - __m128i vec_0[2]; - __m128i vec_1[2]; + __m128i vec[4]; uint32_t index[16]; } index_t; @@ -10825,27 +10823,36 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; - idx.vec_0[0] = _mm_set1_epi32(qh[ib32+0]); - idx.vec_1[0] = _mm_set1_epi32(qh[ib32+0]); - idx.vec_0[1] = _mm_set1_epi32(qh[ib32+1]); - idx.vec_1[1] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; - // TODO this section - idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + // AVX has no sllv so we have to do this + for (int j = 0; j <= 2; j += 2) { + for (int k = 0; k < 8; ++k) { + int32_t * curn = (int32_t *) &idx.vec[j] + k; + *curn = *curn << (8 - k); + } + } - idx.vec_0[0] = _mm_or_si128(idx.vec_0[0], _mm_cvtepi16_epi32(idx_l_0)); - idx.vec_1[0] = _mm_or_si128(idx.vec_1[0], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); - idx.vec_0[1] = _mm_or_si128(idx.vec_0[1], _mm_cvtepi16_epi32(idxl_1)); - idx.vec_1[1] = _mm_or_si128(idx.vec_1[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + idx.vec[0] = _mm_and_si128(idx.vec[0], idx_mask); + idx.vec[1] = _mm_and_si128(idx.vec[1], idx_mask); + idx.vec[2] = _mm_and_si128(idx.vec[2], idx_mask); + idx.vec[3] = _mm_and_si128(idx.vec[3], idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); - const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); - const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]); __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); __m128i aux128_1 = aux128_0;