Skip to content

Commit

Permalink
iq3_s
Browse files Browse the repository at this point in the history
  • Loading branch information
netrunnereve committed Jun 16, 2024
1 parent 39e816e commit 99f666c
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -9452,8 +9452,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);

// AVX2 full_signs_1 is full_sign_bits_0 here
// AVX2 full_signs_2 is full_sign_bits_1 here
// AVX2 full_signs_1 is full_sign_bits_0 here
// AVX2 full_signs_2 is full_sign_bits_1 here
__m128i signs_0, signs_1;
signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
Expand Down Expand Up @@ -9492,16 +9492,16 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);

__m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
__m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));

Expand Down Expand Up @@ -10798,12 +10798,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);

const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
const __m256i idx_mask = _mm256_set1_epi32(256);
const __m128i idx_mask = _mm_set1_epi32(256);

typedef union {
__m128i vec_0[2];
__m128i vec_1[2];
__m128i vec[4];
uint32_t index[16];
} index_t;

Expand All @@ -10825,27 +10823,36 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
idx.vec_0[0] = _mm_set1_epi32(qh[ib32+0]);
idx.vec_1[0] = _mm_set1_epi32(qh[ib32+0]);
idx.vec_0[1] = _mm_set1_epi32(qh[ib32+1]);
idx.vec_1[1] = _mm_set1_epi32(qh[ib32+1]);
idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
idx.vec[1] = idx.vec[0];
idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
idx.vec[3] = idx.vec[2];

// TODO this section
idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
// AVX has no sllv so we have to do this
for (int j = 0; j <= 2; j += 2) {
for (int k = 0; k < 8; ++k) {
int32_t * curn = (int32_t *) &idx.vec[j] + k;
*curn = *curn << (8 - k);
}
}

idx.vec_0[0] = _mm_or_si128(idx.vec_0[0], _mm_cvtepi16_epi32(idx_l_0));
idx.vec_1[0] = _mm_or_si128(idx.vec_1[0], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
idx.vec_0[1] = _mm_or_si128(idx.vec_0[1], _mm_cvtepi16_epi32(idxl_1));
idx.vec_1[1] = _mm_or_si128(idx.vec_1[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
idx.vec[0] = _mm_and_si128(idx.vec[0], idx_mask);
idx.vec[1] = _mm_and_si128(idx.vec[1], idx_mask);
idx.vec[2] = _mm_and_si128(idx.vec[2], idx_mask);
idx.vec[3] = _mm_and_si128(idx.vec[3], idx_mask);

idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));

const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]);

__m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
__m128i aux128_1 = aux128_0;
Expand Down

0 comments on commit 99f666c

Please sign in to comment.