From 685f263a75775314d0fbe6a4023bf5cda6aa8b87 Mon Sep 17 00:00:00 2001 From: Zhang Yi3 Date: Tue, 10 Dec 2024 00:18:42 -0800 Subject: [PATCH] [CPU]fix code style Signed-off-by: Zhang Yi3 --- src/plugins/intel_cpu/src/config.cpp | 2 +- .../nodes/kernels/scaled_attn/attn_quant.cpp | 102 +++++++--- .../kernels/scaled_attn/attn_quant_kernel.hpp | 5 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 182 +++++++++++++----- .../nodes/kernels/scaled_attn/executor_pa.hpp | 2 +- .../intel_cpu/src/nodes/paged_attn.cpp | 3 +- 6 files changed, 208 insertions(+), 88 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index a25401f12566fc..257dee95546e34 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -389,7 +389,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { " for property key ", key, ". Expected only unsinged integer numbers"); - } + } } else if (key == ov::cache_encryption_callbacks.name()) { try { auto encryption_callbacks = val.as(); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 80092fdfc7e01f..ee13118ec80d11 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -131,7 +131,6 @@ static void find_minmax(const T* src, size_t n, float& min, float& max) { max = std::max(max, tmp); min = std::min(min, tmp); } - } template @@ -398,7 +397,10 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -417,27 +419,48 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { - auto p_v = reinterpret_cast(v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; + src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { + auto p_v = reinterpret_cast( + v_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(v_src.ptr(b, h, m, src_offset), - v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _value_group_size, - p_v[0], - p_v[1]); + v_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _value_group_size, + p_v[0], + p_v[1]); } }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -457,13 +480,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, @@ -480,7 +512,10 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -500,13 +535,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 120e14ab7da5df..434286766d5188 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -82,7 +82,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale // (q - zp) * scale v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); @@ -106,7 +105,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale // q - zp v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); - + v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); @@ -206,7 +205,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale #endif auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t { uint8_t shift = high_half ? 0 : 4; - return float((val >> shift) & 0x000F); + return static_cast((val >> shift) & 0x000F); }; for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index f545cba7dd5097..c7d12a0818749e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -234,15 +234,23 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S size_t i = 0; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); - auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), zp3); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); + auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), + zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), + zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), + zp3); v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - _mm512_storeu_ps(out + dst_offset + i, v_out); + _mm512_storeu_ps(out + dst_offset + i, v_out); } for (; i < _group_size; i++) { out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; @@ -275,7 +283,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } for (; i < _group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } + } dst_offset += _group_size; src_offset += _group_size + params_offset; } @@ -305,7 +313,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } for (; i < _group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } + } dst_offset += _group_size; src_offset += _group_size + params_offset; } @@ -411,8 +419,8 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S #endif for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + float tmp0 = extract_half_byte(data, static_cast(i % 2)); + float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; } @@ -495,9 +503,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S #endif for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); + float tmp0 = extract_half_byte(data, static_cast(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; @@ -640,7 +648,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc float sum1 = 0.0f; float sum2 = 0.0f; float sum3 = 0.0f; - while (dst_offset < n) { + while (dst_offset < n) { auto vsum0 = _mm512_setzero_ps(); auto vsum1 = _mm512_setzero_ps(); auto vsum2 = _mm512_setzero_ps(); @@ -657,10 +665,18 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto va = mm512_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); - auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); @@ -744,10 +760,18 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { auto va = mm256_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); - auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); @@ -775,7 +799,6 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc dst_offset += _group_size; src_offset += _group_size + params_offset; } - c[0] = sum0; c[1] = sum1; c[2] = sum2; @@ -890,14 +913,31 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t } } #if defined(HAVE_AVX512F) -template::value), bool>::type = true> -static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value), + bool>::type = true> +static void transpose_16NxK(T* dst, + T* src, + T* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); - transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); + transpose_16NxK(d, + s, + reinterpret_cast(0), + N, + K >> 1, + dst_stride, + src_stride >> 1); } -#endif +# endif template::type = true> void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { @@ -1062,8 +1102,19 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } } -template::value != ov::element::f32 && (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && + (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; @@ -1083,16 +1134,26 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } } -template::value != ov::element::f32 && SRC_PREC == ov::element::u8, bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && SRC_PREC == ov::element::u8, + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; // if group_size not set, the whole row is used as a group size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1107,17 +1168,27 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -template::value != ov::element::f32 && (SRC_PREC == ov::element::u4), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && (SRC_PREC == ov::element::u4), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group const size_t sub_byte_mulitplier = 2; size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1132,17 +1203,27 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -template::value != ov::element::f32 && (SRC_PREC == ov::element::i4), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && (SRC_PREC == ov::element::i4), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group const size_t sub_byte_mulitplier = 2; size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1156,7 +1237,7 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -#endif +# endif template::value == ov::element::f32, bool>::type = true> static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { @@ -1558,7 +1639,6 @@ struct MHAHelper { std::min(_block_size, cur_kv_len - pv), _value_group_size); } - } } } @@ -2155,8 +2235,6 @@ struct AttentionExecutor : public PagedAttentionExecutor { q = q.reshape({B_token, H, 1, S}); k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); - printf("k_cache prec %s shape [%ld %ld %ld %ld]\n", k_cache.get_precision().to_string().c_str(), k_cache.size(0), k_cache.size(1), k_cache.size(2), k_cache.size(3)); - printf("v_cache prec %s shape [%ld %ld %ld %ld]\n", v_cache.get_precision().to_string().c_str(), v_cache.size(0), v_cache.size(1), v_cache.size(2), v_cache.size(3)); if (k_cache.m_dt == ov::element::Type_t::u8) { k_cache.assert_dims({0, Hk, block_size, S + key_params_size * key_group_num}, true); @@ -2237,27 +2315,27 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { -# if defined(HAVE_AVX512F) +#if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); executor = std::make_shared>(); } -# else +#else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); -# endif +#endif } else if (data_type == ov::element::f16) { -# if defined(HAVE_AVX512F) +#if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); executor = std::make_shared>(); } -# else +#else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); -# endif +#endif } else if (data_type == ov::element::f32) { if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index ae8b8aa348f1e3..d386dc9f44e321 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -22,7 +22,7 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ size_t key_group_size, size_t value_group_size); -} // namespace XARCHl +} // namespace XARCH } // namespace Cpu } // namespace Extensions } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 41e7274953f9e6..468fc72b28296a 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -135,8 +135,7 @@ void PagedAttention::createPrimitive() { const auto cpuConfig = context->getConfig(); size_t key_group_size = cpuConfig.keyCacheGroupSize; - size_t value_group_size = cpuConfig.valueCacheGroupSize; - std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << "|key_group_size|" << key_group_size << "|value_group_size|" << value_group_size << std::endl; + size_t value_group_size = cpuConfig.valueCacheGroupSize; return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else return nullptr;