From 15fcdb8d1980933a30e5fd235add2c558b8c31bc Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 31 Oct 2024 09:09:19 +0800 Subject: [PATCH 01/28] [CPU]separate precisions of kv cache Signed-off-by: yi3.zhang@intel.com --- .../nodes/kernels/scaled_attn/executor_pa.cpp | 70 ++++++++++--------- .../nodes/kernels/scaled_attn/executor_pa.hpp | 4 +- .../intel_cpu/src/nodes/paged_attn.cpp | 13 ++-- .../intel_cpu/src/nodes/scaled_attn.cpp | 2 +- 4 files changed, 48 insertions(+), 41 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 90167ac86a8e1a..ed3346386e347e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -769,7 +769,7 @@ static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_ OPENVINO_THROW("pack_32NxK: should not be called."); } -template +template struct MHAHelper { // initialize once size_t _H; @@ -885,11 +885,13 @@ struct MHAHelper { if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && precision_of::value == ov::element::bf16 && - precision_of::value == ov::element::bf16) { + precision_of::value == ov::element::bf16 && + precision_of::value == ov::element::bf16) { _fastpath_valid_prec = ov::element::bf16; } else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_fp16) && precision_of::value == ov::element::f16 && - precision_of::value == ov::element::f16) { + precision_of::value == ov::element::f16 && + precision_of::value == ov::element::bf16) { _fastpath_valid_prec = ov::element::f16; } } @@ -944,7 +946,7 @@ struct MHAHelper { auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); for (size_t h = hq_beg; h < hq_end; h++) { auto* q_ptr = query.ptr(h, q_start, 0); @@ -1073,7 +1075,7 @@ struct MHAHelper { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - (*_gemv)(query.ptr(h, pq), present_key.ptr(block_number, hk), + (*_gemv)(query.ptr(h, pq), present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk); } } @@ -1084,7 +1086,7 @@ struct MHAHelper { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), + dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk)); } } @@ -1121,7 +1123,7 @@ struct MHAHelper { memset(_output.ptr(ithr), 0, q_len * _H * _SV * sizeof(float)); for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { auto block_number = block_table[i]; - auto* v = present_value.ptr(block_number, hk); + auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { attn_acc_value_block(_output.ptr(ithr, pq, h), @@ -1203,7 +1205,7 @@ struct MHAHelper { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), + (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk); } } @@ -1211,7 +1213,7 @@ struct MHAHelper { } else { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), + dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); } } @@ -1279,7 +1281,7 @@ struct MHAHelper { // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; - auto* v = present_value.ptr(block_number, hk); + auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), @@ -1307,9 +1309,9 @@ struct MHAHelper { } }; -template +template struct MHA { - MHAHelper& _helper; + MHAHelper& _helper; struct AttnWorkItem { int32_t batch_in_reorder; // which batch in reorder buffer will be used int32_t batch_in_seq; // batch idx in sequence @@ -1407,7 +1409,7 @@ struct MHA { WorkItems _workitems; - MHA(MHAHelper& helper) : _helper(helper) {} + MHA(MHAHelper& helper) : _helper(helper) {} // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, @@ -1424,7 +1426,7 @@ struct MHA { auto Hk = v_cache.m_dims[1]; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == precision_of::value; auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size(); @@ -1442,8 +1444,8 @@ struct MHA { return; auto ithr = parallel_get_thread_num(); - auto* k_ptr = k_cache.ptr(block_number, hk); - auto* v_ptr = v_cache.ptr(block_number, hk); + auto* k_ptr = k_cache.ptr(block_number, hk); + auto* v_ptr = v_cache.ptr(block_number, hk); transpose_16NxK(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), k_ptr, _helper._output.template ptr(ithr), @@ -1577,10 +1579,10 @@ struct MHA { } }; -template +template struct AttentionExecutor : public PagedAttentionExecutor { - MHAHelper _helper; - MHA _kernel; + MHAHelper _helper; + MHA _kernel; PlainTensor _slot_mapping; AttentionExecutor() : _kernel(_helper) {} @@ -1697,40 +1699,40 @@ struct AttentionExecutor : public PagedAttentionExecutor { }; #endif -std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type) { +std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type key_cache_type, ov::element::Type value_cache_type) { std::shared_ptr executor; #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { #if defined(HAVE_AVX512F) - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(kvcache_type == ov::element::bf16, "expect kvcache type bf16, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); + executor = std::make_shared>(); } #else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); #endif } else if (data_type == ov::element::f16) { #if defined(HAVE_AVX512F) - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(kvcache_type == ov::element::f16, "expect kvcache type f16, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); + executor = std::make_shared>(); } #else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); #endif } else if (data_type == ov::element::f32) { - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); - } else if (kvcache_type == ov::element::f16) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + executor = std::make_shared>(); + } else if (key_cache_type == ov::element::f16) { + executor = std::make_shared>(); } else { - OPENVINO_ASSERT(kvcache_type == ov::element::f32, "expect kvcache type f32, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); + executor = std::make_shared>(); } } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index ed779dee13c96d..247fafbfac51e8 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -16,7 +16,9 @@ namespace Extensions { namespace Cpu { namespace XARCH { -std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type); +std::shared_ptr make_pa_executor(ov::element::Type data_type, + ov::element::Type key_cache_type, + ov::element::Type value_cache_type); } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index b9666388490f74..314cb4dcc42731 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -82,11 +82,12 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { OPENVINO_ASSERT(orgInputNumber == 13, "The input number of PagedAttention should be 13."); // kvcache, float, [] - auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto past_key_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto past_value_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); + past_key_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); config.inConfs[PagedAttentionExecutor::ID_VCACHE].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( - past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); + past_value_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); // past_lens, int, [b_seq] config.inConfs[PagedAttentionExecutor::ID_PAST_LENS].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_PAST_LENS))); @@ -128,8 +129,10 @@ void PagedAttention::createPrimitive() { auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { #ifdef OPENVINO_ARCH_X86_64 - auto kvCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); - return make_pa_executor(rtPrecision, kvCachePrecision); + auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); + std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << std::endl; + return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision); #else return nullptr; #endif diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index f9f853230c4dd6..837991ba068d08 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -435,7 +435,7 @@ struct MHAKernel { v_ptr, fp32_out_ptr, wsp.data() + tid * wsp_size_per_thread, - wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); + wv_gemm_ptr->get_scratch_a_size() > 0 ? &wv_scratch_a.at({tid, 0}) : nullptr); if (is_xf16) { if (has_out_transpose) { attn_memcpy2d_kernel(&fp32_out.at({b, m_start, h, 0}), From 82f843aa60e18152f62b2dbab225ca2dc9354a14 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 6 Nov 2024 14:27:13 +0800 Subject: [PATCH 02/28] [CPU]use element as template args --- .../nodes/kernels/scaled_attn/executor_pa.cpp | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index ed3346386e347e..963a437c75a773 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -601,40 +601,40 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } // N must be multiple of 16 -template -void transpose_16NxK(TDST* dst, TSRC* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template::type = true> +void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { size_t k = 0; + auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { for (size_t n = 0; n < N; n += 16) { - transpose_16x16_kernel(dst + n, src + n * src_stride, dst_stride, src_stride); + transpose_16x16_kernel(dst + n, src_ptr + n * src_stride, dst_stride, src_stride); } dst += 16 * dst_stride; - src += 16; + src_ptr += 16; } if (k < K) { for (size_t n = 0; n < N; n += 16) { - transpose_16xK_kernel(dst + n, src + n * src_stride, K - k, dst_stride, src_stride); + transpose_16xK_kernel(dst + n, src_ptr + n * src_stride, K - k, dst_stride, src_stride); } } } - #if defined(HAVE_AVX512F) -template::value || std::is_same::value), bool>::type> +template::value), bool>::type = true> static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); - transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); + transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); } #endif -template -void transpose_16NxK(TDST* dst, uint8_t* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template::type = true> +void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; + auto s = reinterpret_cast::value_type*>(src); auto t = tmp; for (size_t n = 0; n < N; n ++) { auto f = reinterpret_cast(s); @@ -642,7 +642,7 @@ void transpose_16NxK(TDST* dst, uint8_t* src, TDST* tmp, size_t N, size_t K, siz s += src_stride + 2 * sizeof(float); t += src_stride; } - transpose_16NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } // dequant f16/u8 to float @@ -726,32 +726,33 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } } -template::value || std::is_same::value), bool>::type> -static void pack_32NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template::value != ov::element::f32 && (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), bool>::type = true> +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { + auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; for (; k + 32 <= K; k += 32) { - pack_32x32_kernel(dst + k * 2, src + k, dst_stride, src_stride); + pack_32x32_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); } if (k + 16 <= K) { - pack_32x16_kernel(dst + k * 2, src + k, dst_stride, src_stride); + pack_32x16_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); k += 16; } if (k < K) { - pack_32xK_kernel(dst + k * 2, src + k, dst_stride, src_stride, K - k); + pack_32xK_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride, K - k); } dst += 32 * dst_stride; - src += 32 * src_stride; + src_ptr += 32 * src_stride; } } -template::value || std::is_same::value), bool>::type> -static void pack_32NxK(T* dst, uint8_t* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template::value != ov::element::f32 && SRC_PREC == ov::element::u8, bool>::type = true> +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; + auto s = reinterpret_cast::value_type*>(src); auto t = tmp; for (size_t n = 0; n < N; n ++) { auto f = reinterpret_cast(s); @@ -759,12 +760,12 @@ static void pack_32NxK(T* dst, uint8_t* src, T* tmp, size_t N, size_t K, size_t s += src_stride + 2 * sizeof(float); t += src_stride; } - pack_32NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } #endif -template -static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template::value == ov::element::f32, bool>::type = true> +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } @@ -1446,19 +1447,21 @@ struct MHA { auto ithr = parallel_get_thread_num(); auto* k_ptr = k_cache.ptr(block_number, hk); auto* v_ptr = v_cache.ptr(block_number, hk); - transpose_16NxK(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - k_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._S, _helper._block_size, _helper._S); + + transpose_16NxK::value>(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, _helper._block_size, _helper._S); + if (q_is_xf16) { - pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV); + pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV); } else { // need to decompress if (!q_cache_is_same) { From a754404003a5bf0b66b60da974e85eb4ebc7ffc0 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 8 Nov 2024 15:15:55 +0800 Subject: [PATCH 03/28] [CPU]make quantize grouped --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 117 +++++++++++++++--- .../nodes/kernels/scaled_attn/attn_quant.hpp | 4 +- .../kernels/scaled_attn/attn_quant_kernel.hpp | 19 +++ .../nodes/kernels/scaled_attn/executor_pa.cpp | 38 ++++-- .../nodes/kernels/scaled_attn/executor_pa.hpp | 6 +- .../intel_cpu/src/nodes/paged_attn.cpp | 7 +- 6 files changed, 156 insertions(+), 35 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 66772bda03db51..0a91ffb089ee1a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -170,6 +170,38 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& } } +template +static void quant_u4(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { + auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return dst | (uint8_t) (val << shift); + }; + size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; + for (; i < n; i++) { + float tmp = src[i]; + max = std::max(max, tmp); + min = std::min(min, tmp); + } + scale = (max - min) / ((1 << 4) - 1); + if (scale == 0) + scale = 0.0001f; + zp = -min / scale; + i = 0; + for (; i < n; i++) { + float tmp = src[i]; + #define MIN(a, b) ((a) < (b) ? (a) : (b)) + uint8_t src_val = MIN(15, (uint8_t)(std::round(tmp / scale + zp))); + uint8_t dst_val = i % 2 == 0 ? 0 : dst[i / 2]; + dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); + if (i < 4) + printf("index %ld float %f src %d hex %x", i, tmp, src_val, dst_val); + dst[i / 2] = dst_val; + } + printf("quant scale %f zp %f\n", scale, zp); +} + template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -195,34 +227,81 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping) { + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; + size_t _key_group_size = key_group_size == 0 ? S : key_group_size; + size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) return; auto block_number = slot / block_size; auto block_offset = slot % block_size; + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + quant_u8(k_src.ptr(b, h, m, src_offset), + k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); + } + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { + auto p_v = reinterpret_cast(v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + quant_u8(v_src.ptr(b, h, m, src_offset), + v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), + _value_group_size, + p_v[0], + p_v[1]); + } + }); +} - auto p_k = reinterpret_cast(k_dst.ptr(block_number, h, block_offset)); - auto p_v = reinterpret_cast(v_dst.ptr(block_number, h, block_offset)); +template ::type = true> +static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, + const ov::intel_cpu::PlainTensor& v_src, + const ov::intel_cpu::PlainTensor& k_dst, + const ov::intel_cpu::PlainTensor& v_dst, + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { + size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; + size_t block_size = k_dst.m_dims[2]; + size_t _key_group_size = key_group_size == 0 ? S : key_group_size; + size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; + parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { + auto slot = slot_mapping.ptr(b)[m]; + if (slot < 0) return; + auto block_number = slot / block_size; + auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - quant_u8(k_src.ptr(b, h, m), - k_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), - S, - p_k[0], - p_k[1]); - quant_u8(v_src.ptr(b, h, m), - v_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), - SV, - p_v[0], - p_v[1]); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + quant_u8(k_src.ptr(b, h, m, src_offset), + k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); + } + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { + auto p_v = reinterpret_cast(v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + quant_u4(v_src.ptr(b, h, m, src_offset), + v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), + _value_group_size, + p_v[0], + p_v[1]); + } }); } @@ -247,13 +326,15 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping) { + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { if (k_src.get_precision() == ov::element::f32 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); + paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else { OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp index ca930a1055db2b..70711ee49cbc3c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp @@ -26,7 +26,9 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping); + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size); void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 4e013a004d29f9..4008c7d5a9dfb7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -50,6 +50,25 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } +template +void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { + auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return (uint8_t) ((val >> shift) & 0x000F); + }; + size_t i = 0; + // loadu_si128/epi64 does not support const qualifier; + uint8_t* src_nc = const_cast(src); + float temp[4] = {0}; + for (; i < n; ++i) { + float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); + if (i < 4) + printf("index %ld integral %f float %f hex %x ", i, tmp, (tmp - zp) * scale, src_nc[i / 2]); + tmp = (tmp - zp) * scale; + dst[i] = tmp; + } +} + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 963a437c75a773..376c1d35b0f438 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -782,6 +782,8 @@ struct MHAHelper { size_t _nthr; size_t _sliding_window; float _d_scale; + size_t _key_group_size = 0; + size_t _value_group_size = 0; PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh @@ -811,6 +813,10 @@ struct MHAHelper { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } + explicit MHAHelper(size_t key_group_size, size_t value_group_size) : _key_group_size(key_group_size), _value_group_size(value_group_size) { + _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); + } + void init(size_t H, size_t S, size_t SV, size_t Hk, size_t h_each_group_len, size_t block_size, size_t sliding_window, float d_scale, size_t kv_len, bool init_alibi_lookup) { // query shape: [B, H, L, S] @@ -1590,6 +1596,10 @@ struct AttentionExecutor : public PagedAttentionExecutor { AttentionExecutor() : _kernel(_helper) {} + explicit AttentionExecutor(size_t key_group_size, size_t value_group_size) + : _helper(MHAHelper(key_group_size, value_group_size)), + _kernel(_helper) {} + void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache, PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins, float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, PlainTensor& output_emb, PlainTensor& output_score) { @@ -1676,7 +1686,7 @@ struct AttentionExecutor : public PagedAttentionExecutor { } if (k_cache.m_dt == ov::element::Type_t::u8) { - paged_attn_quantkv(k, v, k_cache, v_cache, _slot_mapping); + paged_attn_quantkv(k, v, k_cache, v_cache, _slot_mapping, _helper._key_group_size, _helper._value_group_size); } else { paged_attn_memcpy(k, v, k_cache, v_cache, _slot_mapping); } @@ -1702,35 +1712,39 @@ struct AttentionExecutor : public PagedAttentionExecutor { }; #endif -std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type key_cache_type, ov::element::Type value_cache_type) { +std::shared_ptr make_pa_executor(ov::element::Type data_type, + ov::element::Type key_cache_type, + ov::element::Type value_cache_type, + size_t key_group_size, + size_t value_group_size) { std::shared_ptr executor; #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(); + executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); executor = std::make_shared>(); } -#else +# else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); -#endif +# endif } else if (data_type == ov::element::f16) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(); + executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); executor = std::make_shared>(); } -#else - OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); -#endif +# else + OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); +# endif } else if (data_type == ov::element::f32) { if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(); + executor = std::make_shared>(key_group_size, value_group_size); } else if (key_cache_type == ov::element::f16) { executor = std::make_shared>(); } else { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index 247fafbfac51e8..ae8b8aa348f1e3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -18,9 +18,11 @@ namespace XARCH { std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type key_cache_type, - ov::element::Type value_cache_type); + ov::element::Type value_cache_type, + size_t key_group_size, + size_t value_group_size); -} // namespace XARCH +} // namespace XARCHl } // namespace Cpu } // namespace Extensions } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 314cb4dcc42731..7d8cf4f95fedc4 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -129,10 +129,13 @@ void PagedAttention::createPrimitive() { auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { #ifdef OPENVINO_ARCH_X86_64 + // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << std::endl; - return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision); + size_t key_group_size = 0; + size_t value_group_size = 0; + std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << "|key_group_size|" << key_group_size << "|value_group_size|" << value_group_size << std::endl; + return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else return nullptr; #endif From 2aba224dd878dc7d302546d84abe44ed12ff8fd5 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 13 Nov 2024 16:52:04 +0800 Subject: [PATCH 04/28] [CPU]make u8 kernel grouped --- .../nodes/kernels/scaled_attn/executor_pa.cpp | 548 +++++++++++------- .../intel_cpu/src/nodes/paged_attn.cpp | 7 +- 2 files changed, 336 insertions(+), 219 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 376c1d35b0f438..a40dba035d4b4a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -72,7 +72,7 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } template -static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size) { +static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size = 0) { #if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -200,104 +200,138 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size } } -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size) { +static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t _group_size = group_size ? group_size : S; + const size_t params_offset = sizeof(float) * 2; + const size_t src_stride = S / _group_size * (_group_size + params_offset); + #if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto v_f0 = reinterpret_cast(v); - auto v_f1 = reinterpret_cast(v + S + 8); - auto v_f2 = reinterpret_cast(v + 2 * (S + 8)); - auto v_f3 = reinterpret_cast(v + 3 * (S + 8)); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); - auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); - auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - auto zp1 = _mm512_set1_ps(v_f1[1]); - auto zp2 = _mm512_set1_ps(v_f2[1]); - auto zp3 = _mm512_set1_ps(v_f3[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); - auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + S + 8)))), zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 2 * (S + 8))))), zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 3 * (S + 8))))), zp3); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; - out[i] += weight[1] * (v[i + S + 8] - v_f1[1]) * v_f1[0]; - out[i] += weight[2] * (v[i + 2 * (S + 8)] - v_f2[1]) * v_f2[0]; - out[i] += weight[3] * (v[i + 3 * (S + 8)] - v_f3[1]) * v_f3[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + // process group by group + uint8_t* v_ptr = v + src_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto v_f1 = reinterpret_cast(v_ptr + src_stride); + auto v_f2 = reinterpret_cast(v_ptr + 2 * src_stride); + auto v_f3 = reinterpret_cast(v_ptr + 3 * src_stride); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); + auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); + auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + auto zp1 = _mm512_set1_ps(v_f1[1]); + auto zp2 = _mm512_set1_ps(v_f2[1]); + auto zp3 = _mm512_set1_ps(v_f3[1]); + uint8_t* v_data_ptr = v + src_offset + params_offset; + size_t i = 0; + for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); + auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), zp3); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); + _mm512_storeu_ps(out + dst_offset + i, v_out); + } + for (; i < _group_size; i++) { + out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + out[i] += weight[1] * (v_data_ptr[i + src_stride] - v_f1[1]) * v_f1[0]; + out[i] += weight[2] * (v_data_ptr[i + 2 * src_stride] - v_f2[1]) * v_f2[0]; + out[i] += weight[3] * (v_data_ptr[i + 3 * src_stride] - v_f3[1]) * v_f3[0]; + } + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - v += 4 * (S + 8) - 8; weight += 4; + v += 4 * src_stride; } for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), zp0); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + uint8_t* v_ptr = v + src_offset; + uint8_t* v_data_ptr = v_ptr + params_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + size_t i = 0; + // printf("j %d dst_offset %d src_offset %ld src_stride %ld scale %f zp %f vec_len_f32_avx512 %ld _group_size %ld\n", j, dst_offset, src_offset, src_stride, v_f0[0], v_f0[1], vec_len_f32_avx512, _group_size); + for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + + _mm512_storeu_ps((out + dst_offset + i), v_out); + } + for (; i < _group_size; i++) { + out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + } + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - v += S; + v += src_stride; weight++; } return; #elif defined(HAVE_AVX2) size_t j = 0; for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm256_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i)))), zp0); - v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + uint8_t* v_ptr = v + src_offset; + uint8_t* v_data_ptr = v_ptr + params_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm256_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + dst_offset + i); + auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); + v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); + + mm256_uni_storeu_ps(out + dst_offset + i, v_out); + } + for (; i < _group_size; i++) { + out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + } + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - v += S; + v += src_stride; weight++; } return; #endif for (size_t j = 0; j < block_size; j++) { - auto v0 = reinterpret_cast(v); - v += 8; - for (size_t i = 0; i < S; i++) { - out[i] += weight[j] * (v[i] - v0[1]) * v0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + for (size_t i = 0; i < _group_size; i++) { + out[dst_offset + i] += weight[j] * (v[i + src_offset + params_offset] - v0[1]) * v0[0]; + } + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - v += S; + v += src_stride; } } template -static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size) { +static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { #if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -409,155 +443,212 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz } template -static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size) { +static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t _group_size = group_size ? group_size : n; + const size_t params_offset = sizeof(float) * 2; + const size_t src_stride = n / _group_size * (_group_size + params_offset); #if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm512_setzero_ps(); - auto vsum1 = _mm512_setzero_ps(); - auto vsum2 = _mm512_setzero_ps(); - auto vsum3 = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm512_set1_ps(b0[1]); - auto v_zp1 = _mm512_set1_ps(b1[1]); - auto v_zp2 = _mm512_set1_ps(b2[1]); - auto v_zp3 = _mm512_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp0); - auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); - - vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); + src_offset = 0; + dst_offset = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + while (dst_offset < n) { + auto vsum0 = _mm512_setzero_ps(); + auto vsum1 = _mm512_setzero_ps(); + auto vsum2 = _mm512_setzero_ps(); + auto vsum3 = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto b1 = reinterpret_cast(b + src_offset + src_stride); + auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); + auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); + auto v_zp0 = _mm512_set1_ps(b0[1]); + auto v_zp1 = _mm512_set1_ps(b1[1]); + auto v_zp2 = _mm512_set1_ps(b2[1]); + auto v_zp3 = _mm512_set1_ps(b3[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + dst_offset + i); + auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); + auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + + vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); + } + float group_sum0 = _mm512_reduce_add_ps(vsum0); + float group_sum1 = _mm512_reduce_add_ps(vsum1); + float group_sum2 = _mm512_reduce_add_ps(vsum2); + float group_sum3 = _mm512_reduce_add_ps(vsum3); + for (; i < _group_size; i++) { + group_sum0 += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + group_sum1 += a[i + dst_offset] * (b_data_ptr[i + src_stride] - b1[1]); + group_sum2 += a[i + dst_offset] * (b_data_ptr[i + 2 * src_stride] - b2[1]); + group_sum3 += a[i + dst_offset] * (b_data_ptr[i + 3 * src_stride] - b3[1]); + } + sum0 += group_sum0 * b0[0]; + sum1 += group_sum1 * b1[0]; + sum2 += group_sum2 * b2[0]; + sum3 += group_sum3 * b3[0]; + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - float sum0 = _mm512_reduce_add_ps(vsum0); - float sum1 = _mm512_reduce_add_ps(vsum1); - float sum2 = _mm512_reduce_add_ps(vsum2); - float sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; c += 4; - b += 4 * (n + 8) - 8; + b += 4 * src_stride; } for (; j < block_size; j++) { - auto vsum = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm512_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), v_zp); - vsum = _mm512_fmadd_ps(va, vb, vsum); - } - float sum = _mm512_reduce_add_ps(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + src_offset = 0; + dst_offset = 0; + float sum = 0; + while (dst_offset < n) { + auto vsum = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto v_zp = _mm512_set1_ps(b0[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + dst_offset + i); + auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp); + vsum = _mm512_fmadd_ps(va, vb, vsum); + } + float group_sum = _mm512_reduce_add_ps(vsum); + for (; i < _group_size; i++) { + group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } return; #elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm256_setzero_ps(); - auto vsum1 = _mm256_setzero_ps(); - auto vsum2 = _mm256_setzero_ps(); - auto vsum3 = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm256_set1_ps(b0[1]); - auto v_zp1 = _mm256_set1_ps(b1[1]); - auto v_zp2 = _mm256_set1_ps(b2[1]); - auto v_zp3 = _mm256_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp0); - auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + n + 8)))), v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), v_zp3); - - vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); - } - hsum(vsum0); - hsum(vsum1); - hsum(vsum2); - hsum(vsum3); - float sum0 = _mm256_cvtss_f32(vsum0); - float sum1 = _mm256_cvtss_f32(vsum1); - float sum2 = _mm256_cvtss_f32(vsum2); - float sum3 = _mm256_cvtss_f32(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; + src_offset = 0; + dst_offset = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + while (dst_offset < n) { + auto vsum0 = _mm256_setzero_ps(); + auto vsum1 = _mm256_setzero_ps(); + auto vsum2 = _mm256_setzero_ps(); + auto vsum3 = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto b1 = reinterpret_cast(b + src_offset + src_stride); + auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); + auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); + auto v_zp0 = _mm256_set1_ps(b0[1]); + auto v_zp1 = _mm256_set1_ps(b1[1]); + auto v_zp2 = _mm256_set1_ps(b2[1]); + auto v_zp3 = _mm256_set1_ps(b3[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + dst_offset + i); + auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); + auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + + vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); + } + hsum(vsum0); + hsum(vsum1); + hsum(vsum2); + hsum(vsum3); + float group_sum0 = _mm256_cvtss_f32(vsum0); + float group_sum1 = _mm256_cvtss_f32(vsum1); + float group_sum2 = _mm256_cvtss_f32(vsum2); + float group_sum3 = _mm256_cvtss_f32(vsum3); + for (; i < _group_size; i++) { + group_sum0 += a[dst_offset + i] * (b[i] - b0[1]); + group_sum1 += a[dst_offset + i] * (b[i +src_stride] - b1[1]); + group_sum2 += a[dst_offset + i] * (b[i + 2 * src_stride] - b2[1]); + group_sum3 += a[dst_offset + i] * (b[i + 3 * src_stride] - b3[1]); + } + sum0 += group_sum0 * b0[0]; + sum1 += group_sum1 * b1[0]; + sum2 += group_sum2 * b2[0]; + sum3 += group_sum3 * b3[0]; + dst_offset += _group_size; + src_offset += _group_size + params_offset; + } + + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; c += 4; - b += 4 * (n + 8) - 8; + b += 4 * src_stride; } for (; j < block_size; j++) { - auto vsum = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm256_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), v_zp); - vsum = _mm256_fmadd_ps(va, vb, vsum); - } - hsum(vsum); - float sum = _mm256_cvtss_f32(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + src_offset = 0; + dst_offset = 0; + float sum = 0; + while (dst_offset < n) { + auto vsum = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto v_zp = _mm256_set1_ps(b0[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + dst_offset + i); + auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp); + vsum = _mm256_fmadd_ps(va, vb, vsum); + } + hsum(vsum); + float group_sum = _mm256_cvtss_f32(vsum); + for (; i < _group_size; i++) { + group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } return; #endif for (size_t j = 0; j < block_size; j++) { float sum = 0; - auto b0 = reinterpret_cast(b); - b += 8; - for (size_t i = 0; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + dst_offset = 0; + src_offset = 0; + while (dst_offset < n) { + auto b0 = reinterpret_cast(b + src_offset); + float group_sum = 0.0f; + for (size_t i = 0; i < _group_size; i++) { + group_sum += a[dst_offset + i] * (b[src_offset + params_offset + i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += _group_size; + src_offset += _group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } } @@ -602,7 +693,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str // N must be multiple of 16 template::type = true> -void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { size_t k = 0; auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { @@ -621,7 +712,7 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t } #if defined(HAVE_AVX512F) template::value), bool>::type = true> -static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); @@ -630,16 +721,24 @@ static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t d #endif template::type = true> -void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; + // if group_size not set, the whole row is used as a group + size_t _group_size = group_size ? group_size : K; for (size_t n = 0; n < N; n ++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); - s += src_stride + 2 * sizeof(float); + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, _group_size, f[0], f[1]); + src_offset += _group_size + sizeof(float) * 2; + dst_offset += _group_size; + } + s += src_offset; t += src_stride; } transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); @@ -727,7 +826,7 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } template::value != ov::element::f32 && (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; @@ -748,16 +847,24 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } template::value != ov::element::f32 && SRC_PREC == ov::element::u8, bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; + // if group_size not set, the whole row is used as a group + size_t _group_size = group_size ? group_size : K; for (size_t n = 0; n < N; n ++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); - s += src_stride + 2 * sizeof(float); + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, _group_size, f[0], f[1]); + src_offset += _group_size + sizeof(float) * 2; + dst_offset += _group_size; + } + s += src_offset; t += src_stride; } pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); @@ -765,7 +872,7 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size #endif template::value == ov::element::f32, bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } @@ -1094,7 +1201,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), - _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk)); + _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk), _key_group_size); } } } @@ -1130,14 +1237,15 @@ struct MHAHelper { memset(_output.ptr(ithr), 0, q_len * _H * _SV * sizeof(float)); for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { auto block_number = block_table[i]; - auto* v = present_value.ptr(block_number, hk); + auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { attn_acc_value_block(_output.ptr(ithr, pq, h), _weight.ptr(ithr, h, pq) + pv, v, _SV, - std::min(_block_size, cur_kv_len - pv)); + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } } } @@ -1221,7 +1329,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); + _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk), _key_group_size); } } } @@ -1295,7 +1403,8 @@ struct MHAHelper { _weight_bhl.ptr(b, h, pq) + pv, v, _SV, - std::min(_block_size, context_len - pv)); + std::min(_block_size, context_len - pv), + _value_group_size); } } } @@ -1458,7 +1567,7 @@ struct MHA { k_ptr, _helper._output.template ptr(ithr), _helper._block_size, - _helper._S, _helper._block_size, _helper._S); + _helper._S, _helper._block_size, _helper._S, _helper._key_group_size); if (q_is_xf16) { pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), @@ -1467,7 +1576,8 @@ struct MHA { _helper._block_size, _helper._SV, rnd_up(_helper._SV, _helper._block_size), - _helper._SV); + _helper._SV, + _helper._value_group_size); } else { // need to decompress if (!q_cache_is_same) { @@ -1623,11 +1733,15 @@ struct AttentionExecutor : public PagedAttentionExecutor { auto B_token = q.size(0); auto Hk = k_cache.size(1); + auto _key_group_size = _helper._key_group_size; + auto _value_group_size = _helper._key_group_size; // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The actual size needs to deduct scale and zeropoint. - auto S = k_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); - auto SV = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); + size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + sizeof(float) * 2) : _key_group_size; + size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + 8) : _value_group_size; + auto S = k_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 * key_group_num : 0); + auto SV = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 * value_group_num : 0); auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; @@ -1643,8 +1757,8 @@ struct AttentionExecutor : public PagedAttentionExecutor { k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); if (k_cache.m_dt == ov::element::Type_t::u8) { - k_cache.assert_dims({0, Hk, block_size, S + sizeof(float) * 2}, true); - v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + sizeof(float) * 2}); + k_cache.assert_dims({0, Hk, block_size, S + sizeof(float) * 2 * key_group_num}, true); + v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + sizeof(float) * 2 * value_group_num}); } else { k_cache.assert_dims({0, Hk, block_size, S}, true); v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV}); diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 7d8cf4f95fedc4..095b89bbeb4a12 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -132,8 +132,11 @@ void PagedAttention::createPrimitive() { // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - size_t key_group_size = 0; - size_t value_group_size = 0; + size_t group_size = 64; + if (getenv("GROUP_SIZE")) + group_size = std::stoi(std::string(getenv("GROUP_SIZE"))); + size_t key_group_size = group_size; + size_t value_group_size = group_size; std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << "|key_group_size|" << key_group_size << "|value_group_size|" << value_group_size << std::endl; return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else From fc435f6a8d1b42487ed5ea9e1c6d87f8b035b00e Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Mon, 18 Nov 2024 14:29:39 +0800 Subject: [PATCH 05/28] [CPU]U4 Group size support with reference Signed-off-by: yi3.zhang@intel.com --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 93 +++++-- .../kernels/scaled_attn/attn_quant_kernel.hpp | 20 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 231 +++++++++++++++--- 3 files changed, 287 insertions(+), 57 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 0a91ffb089ee1a..1726ecc3a8ffd5 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -171,11 +171,12 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& } template -static void quant_u4(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { +static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) { auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; return dst | (uint8_t) (val << shift); }; + auto dst_ptr = reinterpret_cast(dst); size_t i = 0; float max = -FLT_MAX; float min = FLT_MAX; @@ -193,11 +194,42 @@ static void quant_u4(const T* src, uint8_t* dst, size_t n, float& scale, float& float tmp = src[i]; #define MIN(a, b) ((a) < (b) ? (a) : (b)) uint8_t src_val = MIN(15, (uint8_t)(std::round(tmp / scale + zp))); - uint8_t dst_val = i % 2 == 0 ? 0 : dst[i / 2]; + uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; + dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); + dst_ptr[i / 2] = dst_val; + } +} + +template +static void quant_s4(const T* src, void* dst, size_t n, float& scale, float& zp) { + auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return dst | (uint8_t) (val << shift); + }; + auto dst_ptr = reinterpret_cast(dst); + size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; + for (; i < n; i++) { + float tmp = src[i]; + max = std::max(max, tmp); + min = std::min(min, tmp); + } + float max_abs = std::max(std::abs(min), std::abs(max)); + scale = max_abs / ((1 << 3) - 1); + if (scale == 0) + scale = 0.0001f; + i = 0; + for (; i < n; i++) { + float tmp = src[i]; + #define MIN(a, b) ((a) < (b) ? (a) : (b)) + // add 8.5 here is to save a clamp to (-2^3) + uint8_t src_val = MIN(15, (int8_t)(tmp / scale + 8.5f)); + uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); if (i < 4) printf("index %ld float %f src %d hex %x", i, tmp, src_val, dst_val); - dst[i / 2] = dst_val; + dst_ptr[i / 2] = dst_val; } printf("quant scale %f zp %f\n", scale, zp); } @@ -278,6 +310,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, size_t block_size = k_dst.m_dims[2]; size_t _key_group_size = key_group_size == 0 ? S : key_group_size; size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; + size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) return; @@ -293,14 +326,17 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, p_k[0], p_k[1]); } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { - auto p_v = reinterpret_cast(v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); - quant_u4(v_src.ptr(b, h, m, src_offset), - v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _value_group_size, - p_v[0], - p_v[1]); + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, + dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float) + sizeof(float)) { + uint8_t* v_base = reinterpret_cast( + v_dst.m_ptr.get() + + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / + sub_byte_multiplier + + dst_offset); + auto p_v = reinterpret_cast(v_base); + uint8_t* v_ptr = v_base + sizeof(float) * 2; + quant_u4(v_src.ptr(b, h, m, src_offset), v_ptr, _value_group_size, p_v[0], p_v[1]); } }); } @@ -329,15 +365,36 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& slot_mapping, const size_t key_group_size, const size_t value_group_size) { - if (k_src.get_precision() == ov::element::f32 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); - } else { + using function_type = void (*)(const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const size_t, + const size_t); + static constexpr function_type funcs_fp32[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + static constexpr function_type funcs_bf16[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + static constexpr function_type funcs_f16[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + if (k_dst.get_precision() != ov::element::u8) { OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); } + int dispatch = v_dst.get_precision() == ov::element::u8 ? 0 : 1; + if (k_src.get_precision() == ov::element::f32) { + funcs_fp32[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } else if (k_src.get_precision() == ov::element::bf16) { + funcs_bf16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } else if (k_src.get_precision() == ov::element::f16) { + funcs_f16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } } void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 4008c7d5a9dfb7..e3eec3d72efd00 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -59,16 +59,30 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale size_t i = 0; // loadu_si128/epi64 does not support const qualifier; uint8_t* src_nc = const_cast(src); - float temp[4] = {0}; for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); - if (i < 4) - printf("index %ld integral %f float %f hex %x ", i, tmp, (tmp - zp) * scale, src_nc[i / 2]); tmp = (tmp - zp) * scale; dst[i] = tmp; } } +template +void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale) { + auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return (uint8_t) ((val >> shift) & 0x000F); + }; + size_t i = 0; + // loadu_si128/epi64 does not support const qualifier; + uint8_t* src_nc = const_cast(src); + float temp[4] = {0}; + for (; i < n; ++i) { + float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); + tmp = (tmp - 8) * scale; + dst[i] = tmp; + } +} + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index a40dba035d4b4a..60f10caa57c602 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -71,7 +71,7 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } } -template +template::type = true> static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size = 0) { #if defined(HAVE_AVX512F) size_t j = 0; @@ -199,7 +199,7 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size v += S; } } - +template::type = true> static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| @@ -330,6 +330,38 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } +template::type = true> +static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t _group_size = group_size ? group_size : S; + const size_t params_offset = sizeof(float) * 2; + auto sub_byte_multiplyer = 8 / 4; + const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); + auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + + return (uint8_t) ((val >> shift) & 0x000F); + }; + for (size_t j = 0; j < block_size; j++) { + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + for (size_t i = 0; i < _group_size; i += 2) { + uint8_t data = v[i/2 + src_offset + params_offset]; + float tmp0 = extract_half_byte(data, (bool)(i % 2)); + float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; + } + dst_offset += _group_size; + src_offset += _group_size / sub_byte_multiplyer + params_offset; + } + v += src_stride; + } +} + template static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { #if defined(HAVE_AVX512F) @@ -745,26 +777,60 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t } // dequant f16/u8 to float -template -static inline void dequant(T* dst, T* src, size_t N, size_t K) { +template::type = true> +static inline void dequant(T* dst, void* src, size_t N, size_t K, size_t group_size = 0) { // never called OPENVINO_THROW("dequant: should not be called."); } - -static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K) { +template::type = true> +static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K, size_t group_size = 0) { cvt_copy(dst, src, K * N); } -template -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K) { +template::type = true> +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = src; + const size_t params_offset = sizeof(float) * 2; + const size_t _group_size = group_size ? group_size : K; + const size_t src_stride = K / _group_size * (_group_size + params_offset); + + for (size_t n = 0; n < N; n ++) { + size_t group_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + group_offset); + attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); + group_offset += _group_size + params_offset; + dst_offset += _group_size; + } + s += src_stride; + dst += K; + } +} + +template::type = true> +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; + const size_t params_offset = sizeof(float) * 2; + const size_t _group_size = group_size ? group_size : K; + const size_t src_stride = K / _group_size * (_group_size + params_offset); + for (size_t n = 0; n < N; n ++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), dst, K, f[0], f[1]); - s += K + 2 * sizeof(float); + size_t group_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + group_offset); + attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); + group_offset += _group_size + params_offset; + dst_offset += _group_size; + } + s += src_stride; dst += K; } } @@ -869,6 +935,31 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } + +template::value != ov::element::f32 && SRC_PREC == ov::element::u4, bool>::type = true> +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = reinterpret_cast(src); + auto t = tmp; + // if group_size not set, the whole row is used as a group + const size_t sub_byte_mulitplier = 2; + size_t _group_size = group_size ? group_size : K; + for (size_t n = 0; n < N; n ++) { + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, _group_size, f[0], f[1]); + src_offset += _group_size / sub_byte_mulitplier + sizeof(float) * 2; + dst_offset += _group_size; + } + s += src_offset; + t += src_stride; + } + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); +} #endif template::value == ov::element::f32, bool>::type = true> @@ -1240,12 +1331,27 @@ struct MHAHelper { auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - attn_acc_value_block(_output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); + if (present_value.get_precision() == ov::element::u4) { + auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = present_value.m_ptr.get() + v_stride; + attn_acc_value_block( + _output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); + } else { + attn_acc_value_block::value>( + _output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); + } + } } } @@ -1399,12 +1505,26 @@ struct MHAHelper { auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v, - _SV, - std::min(_block_size, context_len - pv), - _value_group_size); + if (present_value.get_precision() == ov::element::u4) { + auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = present_value.m_ptr.get() + v_stride; + attn_acc_value_block( + _output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, context_len - pv), + _value_group_size); + } else { + attn_acc_value_block::value>( + _output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v, + _SV, + std::min(_block_size, context_len - pv), + _value_group_size); + } } } } @@ -1570,18 +1690,50 @@ struct MHA { _helper._S, _helper._block_size, _helper._S, _helper._key_group_size); if (q_is_xf16) { - pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + if (v_cache.get_precision() == ov::element::u4) { + auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); + } else { + pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); + } + } else { + // need to decompress + if (!q_cache_is_same) { + if (v_cache.get_precision() == ov::element::u4) { + printf("PaAttn|dequant value u4\n"); + auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + dequant( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), v_ptr, - _helper._output.template ptr(ithr), _helper._block_size, _helper._SV, - rnd_up(_helper._SV, _helper._block_size), + _helper._value_group_size); + } else { + dequant::value>( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._block_size, _helper._SV, _helper._value_group_size); - } else { - // need to decompress - if (!q_cache_is_same) { - dequant(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), v_ptr, _helper._block_size, _helper._SV); + } } } }); @@ -1738,10 +1890,14 @@ struct AttentionExecutor : public PagedAttentionExecutor { // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| // The actual size needs to deduct scale and zeropoint. - size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + sizeof(float) * 2) : _key_group_size; - size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + 8) : _value_group_size; - auto S = k_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 * key_group_num : 0); - auto SV = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 * value_group_num : 0); + const size_t key_sub_byte_multiplyer = 8 / k_cache.get_precision().bitwidth(); + const size_t value_sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); + const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplyer; + const size_t value_params_size = sizeof(float) * 2 * value_sub_byte_multiplyer; + size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : _key_group_size; + size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : _value_group_size; + auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); + auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; @@ -1756,9 +1912,12 @@ struct AttentionExecutor : public PagedAttentionExecutor { q = q.reshape({B_token, H, 1, S}); k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); + printf("k_cache prec %s shape [%ld %ld %ld %ld]\n", k_cache.get_precision().to_string().c_str(), k_cache.size(0), k_cache.size(1), k_cache.size(2), k_cache.size(3)); + printf("v_cache prec %s shape [%ld %ld %ld %ld]\n", v_cache.get_precision().to_string().c_str(), v_cache.size(0), v_cache.size(1), v_cache.size(2), v_cache.size(3)); + if (k_cache.m_dt == ov::element::Type_t::u8) { - k_cache.assert_dims({0, Hk, block_size, S + sizeof(float) * 2 * key_group_num}, true); - v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + sizeof(float) * 2 * value_group_num}); + k_cache.assert_dims({0, Hk, block_size, S + key_params_size * key_group_num}, true); + v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + value_params_size * value_group_num}); } else { k_cache.assert_dims({0, Hk, block_size, S}, true); v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV}); From d080e2a76efad980ef2a8167ddcdd1f12c55ec20 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 28 Nov 2024 11:18:08 +0800 Subject: [PATCH 06/28] [CPU]AVX512 support for u4 kernel --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 90 +++++++++++- .../kernels/scaled_attn/attn_quant_kernel.hpp | 91 ++++++++++-- .../nodes/kernels/scaled_attn/executor_pa.cpp | 134 ++++++++++++++++++ 3 files changed, 299 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 1726ecc3a8ffd5..8af011fab1b50b 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -172,24 +172,100 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& template static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) { - auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - return dst | (uint8_t) (val << shift); - }; - auto dst_ptr = reinterpret_cast(dst); size_t i = 0; float max = -FLT_MAX; float min = FLT_MAX; +#if defined(HAVE_AVX512F) + auto v0_max = _mm512_set1_ps(-FLT_MAX); + auto v0_min = _mm512_set1_ps(FLT_MAX); + auto v1_max = _mm512_set1_ps(-FLT_MAX); + auto v1_min = _mm512_set1_ps(FLT_MAX); + auto v2_max = _mm512_set1_ps(-FLT_MAX); + auto v2_min = _mm512_set1_ps(FLT_MAX); + auto v3_max = _mm512_set1_ps(-FLT_MAX); + auto v3_min = _mm512_set1_ps(FLT_MAX); + for (; i + 4 * vec_len_f32_avx512 <= n; i += vec_len_f32_avx512 * 4) { + auto v0 = mm512_uni_loadu_ps(src + i); + auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); + auto v2 = mm512_uni_loadu_ps(src + i + 2 * vec_len_f32_avx512); + auto v3 = mm512_uni_loadu_ps(src + i + 3 * vec_len_f32_avx512); + v0_max = _mm512_max_ps(v0_max, v0); + v0_min = _mm512_min_ps(v0_min, v0); + v1_max = _mm512_max_ps(v1_max, v1); + v1_min = _mm512_min_ps(v1_min, v1); + v2_max = _mm512_max_ps(v2_max, v2); + v2_min = _mm512_min_ps(v2_min, v2); + v3_max = _mm512_max_ps(v3_max, v3); + v3_min = _mm512_min_ps(v3_min, v3); + } + if (i + 2 * vec_len_f32_avx512 <= n) { + auto v0 = mm512_uni_loadu_ps(src + i); + auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); + v0_max = _mm512_max_ps(v0_max, v0); + v0_min = _mm512_min_ps(v0_min, v0); + v1_max = _mm512_max_ps(v1_max, v1); + v1_min = _mm512_min_ps(v1_min, v1); + i += 2 * vec_len_f32_avx512; + } + if (i + vec_len_f32_avx512 <= n) { + auto v0 = mm512_uni_loadu_ps(src + i); + v0_max = _mm512_max_ps(v0_max, v0); + v0_min = _mm512_min_ps(v0_min, v0); + i += vec_len_f32_avx512; + } + v0_max = _mm512_max_ps(v0_max, v1_max); + v0_min = _mm512_min_ps(v0_min, v1_min); + v2_max = _mm512_max_ps(v2_max, v3_max); + v2_min = _mm512_min_ps(v2_min, v3_min); + v0_max = _mm512_max_ps(v0_max, v2_max); + v0_min = _mm512_min_ps(v0_min, v2_min); + max = _mm512_reduce_max_ps(v0_max); + min = _mm512_reduce_min_ps(v0_min); +#else for (; i < n; i++) { float tmp = src[i]; max = std::max(max, tmp); min = std::min(min, tmp); } +#endif + auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return dst | (uint8_t) (val << shift); + }; + auto dst_ptr = reinterpret_cast(dst); scale = (max - min) / ((1 << 4) - 1); if (scale == 0) scale = 0.0001f; zp = -min / scale; i = 0; +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(1 / scale); + auto v_zp = _mm512_set1_ps(zp); + auto v_zero = _mm512_setzero_epi32(); + auto v_upper = _mm512_set1_epi32(15); + printf("total size %ld avx512 size %ld\n", n, vec_len_f32_avx512); + for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { + auto v0 = mm512_uni_loadu_ps(src + i); + auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); + v0 = _mm512_fmadd_ps(v0, v_scale, v_zp); + v1 = _mm512_fmadd_ps(v1, v_scale, v_zp); + auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + v0_i32 = _mm512_max_epi32(v0_i32, v_zero); + v1_i32 = _mm512_max_epi32(v1_i32, v_zero); + v0_i32 = _mm512_min_epi32(v0_i32, v_upper); + v1_i32 = _mm512_min_epi32(v1_i32, v_upper); + __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); + auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); + first_half = _mm512_slli_epi32(first_half, 4); + auto mask = _mm512_set1_epi32(0x0F); + second_half = _mm512_and_epi32(second_half, mask); + auto combined = _mm512_or_epi32(first_half, second_half); + _mm512_mask_cvtusepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); + } +#endif for (; i < n; i++) { float tmp = src[i]; #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -222,9 +298,9 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale, float& zp) i = 0; for (; i < n; i++) { float tmp = src[i]; - #define MIN(a, b) ((a) < (b) ? (a) : (b)) // add 8.5 here is to save a clamp to (-2^3) - uint8_t src_val = MIN(15, (int8_t)(tmp / scale + 8.5f)); + int8_t src_val = std::min((int8_t)(7), (int8_t)(tmp / scale)); + src_val = std::max((int8_t)(-8), src_val); uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); if (i < 4) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index e3eec3d72efd00..9b621976b5ff6c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -52,13 +52,55 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale template void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { + // 2 4bit data form a byte + /* 0,1|2,3|4,5|6,7 + / \ + 0,2,4,6|1,3,5,7 + | + permute + | + 0,1,2,3,4,5,6,7 + */ + size_t i = 0; + uint8_t* src_nc = const_cast(src); +#if defined(HAVE_AVX512F) + auto extract_half_byte2 = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return (uint8_t) ((val >> shift) & 0x000F); + }; + auto v_zp = _mm512_set1_ps(zp); + auto v_scale = _mm512_set1_ps(scale); + for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { + auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); + __m128i low_half = _mm_srli_epi16(high_half, 4); + const __m128i mask = _mm_set1_epi8(0x0F); + low_half = _mm_and_si128(mask, low_half); + high_half = _mm_and_si128(mask, high_half); + + //cvt to f32 + auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); + auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); + v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); + // (q - zp) * scale + v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); + v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); + + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + mm512_uni_storeu_ps(dst + i, first_half); + mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); + } +#endif auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; return (uint8_t) ((val >> shift) & 0x000F); }; - size_t i = 0; - // loadu_si128/epi64 does not support const qualifier; - uint8_t* src_nc = const_cast(src); for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); tmp = (tmp - zp) * scale; @@ -68,17 +110,48 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale template void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale) { + // 2 4bit data form a byte + /* 0,1|2,3|4,5|6,7 + / \ + 0,2,4,6|1,3,5,7 + | + permute + | + 0,1,2,3,4,5,6,7 + */ + size_t i = 0; + uint8_t* src_nc = const_cast(src); + for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { + auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); + __m128i low_half = _mm_srli_epi16(high_half, 4); + const __m128i mask = _mm_set1_epi8(0x0F); + low_half = _mm_and_si128(mask, low_half); + auto v_scale = _mm512_set1_ps(1/scale); + //cvt to f32 + auto v_256_low_half = _mm512_cvtepi8_epi32(low_half); + auto v_256_high_half = _mm512_cvtepi8_epi32(high_half); + v_256_high_half = _mm512_slli_epi32(v_256_high_half, 28); + v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); + // q * scale + v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); + v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); + + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + mm512_uni_storeu_ps(dst + i, first_half); + mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); + } auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; - return (uint8_t) ((val >> shift) & 0x000F); + return (int8_t) ((val >> shift) & 0x000F); }; - size_t i = 0; - // loadu_si128/epi64 does not support const qualifier; - uint8_t* src_nc = const_cast(src); - float temp[4] = {0}; for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); - tmp = (tmp - 8) * scale; + tmp = tmp * scale; dst[i] = tmp; } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 60f10caa57c602..fe401a86653dcd 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -343,6 +343,57 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S return (uint8_t) ((val >> shift) & 0x000F); }; +#if defined(HAVE_AVX512F) + for (size_t j = 0; j < block_size; j++) { + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + size_t i = 0; + auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); + auto v_zp = _mm512_set1_ps(v0[1]); + for (; i + vec_len_f32_avx512 * 2 < _group_size; i += vec_len_f32_avx512 * 2) { + auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + __m128i low_half = _mm_srli_epi16(high_half, 4); + const __m128i mask = _mm_set1_epi8(0x0F); + low_half = _mm_and_si128(mask, low_half); + high_half = _mm_and_si128(mask, high_half); + + //cvt to f32 + auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); + auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); + v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); + + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); + v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); + v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); + mm512_uni_storeu_ps(out + dst_offset + i, v_out0); + mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); + } + + for (; i < _group_size; i += 2) { + uint8_t data = v[i/2 + src_offset + params_offset]; + float tmp0 = extract_half_byte(data, (bool)(i % 2)); + float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; + } + dst_offset += _group_size; + src_offset += _group_size / sub_byte_multiplyer + params_offset; + } + v += src_stride; + } + return; +#endif for (size_t j = 0; j < block_size; j++) { dst_offset = 0; src_offset = 0; @@ -362,6 +413,89 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } +// template::type = true> +// static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { +// size_t src_offset = 0; +// size_t dst_offset = 0; +// const size_t _group_size = group_size ? group_size : S; +// const size_t params_offset = sizeof(float) * 1; +// auto sub_byte_multiplyer = 8 / 4; +// const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); +// auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { +// uint8_t shift = high_half ? 0 : 4; + +// return (uint8_t) ((val >> shift) & 0x000F); +// }; +// // #if defined(HAVE_AVX512F) +// // for (size_t j = 0; j < block_size; j++) { +// // dst_offset = 0; +// // src_offset = 0; +// // while (dst_offset < S) { +// // auto v0 = reinterpret_cast(v + src_offset); +// // size_t i = 0; +// // auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); +// // auto v_zp = _mm512_set1_ps(v0[1]); +// // for (; i + vec_len_f32_avx512 * 2 < _group_size; i += vec_len_f32_avx512 * 2) { +// // auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); +// // __m128i low_half = _mm_srli_epi16(high_half, 4); +// // const __m128i mask = _mm_set1_epi8(0x0F); +// // low_half = _mm_and_si128(mask, low_half); +// // high_half = _mm_and_si128(mask, high_half); + +// // //cvt to f32 +// // auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); +// // auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); +// // auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); +// // auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); +// // // q - zp +// // v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); +// // v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); + +// // __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); +// // __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); +// // __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); +// // __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); +// // auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); +// // auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); +// // v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); +// // v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); +// // mm512_uni_storeu_ps(out + dst_offset + i, v_out0); +// // mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); +// // } + +// // for (; i < _group_size; i += 2) { +// // uint8_t data = v[i/2 + src_offset + params_offset]; +// // float tmp0 = extract_half_byte(data, (bool)(i % 2)); +// // float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); +// // out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; +// // out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; +// // } +// // dst_offset += _group_size; +// // src_offset += _group_size / sub_byte_multiplyer + params_offset; +// // } +// // v += src_stride; +// // } +// // return; +// // #endif +// for (size_t j = 0; j < block_size; j++) { +// dst_offset = 0; +// src_offset = 0; +// while (dst_offset < S) { +// auto v0 = reinterpret_cast(v + src_offset); +// for (size_t i = 0; i < _group_size; i += 2) { +// uint8_t data = v[i/2 + src_offset + params_offset]; +// float tmp0 = extract_half_byte(data, (bool)(i % 2)); +// float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); +// out[dst_offset + i] += weight[j] * tmp0 * v0[0]; +// out[dst_offset + i + 1] += weight[j] * tmp1 * v0[0]; +// } +// dst_offset += _group_size; +// src_offset += _group_size / sub_byte_multiplyer + params_offset; +// } +// v += src_stride; +// } +// } + template static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { #if defined(HAVE_AVX512F) From 78ef4dd4ecefbd09a6ad94410cf527af4a50c9e2 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 29 Nov 2024 15:53:43 +0800 Subject: [PATCH 07/28] [CPU]Support S4 quantization Signed-off-by: yi3.zhang@intel.com --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 72 +++- .../kernels/scaled_attn/attn_quant_kernel.hpp | 106 ++++-- .../nodes/kernels/scaled_attn/executor_pa.cpp | 359 +++++++++++++----- 3 files changed, 396 insertions(+), 141 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 8af011fab1b50b..67a7ba55160728 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -221,13 +221,12 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) v0_min = _mm512_min_ps(v0_min, v2_min); max = _mm512_reduce_max_ps(v0_max); min = _mm512_reduce_min_ps(v0_min); -#else +#endif for (; i < n; i++) { - float tmp = src[i]; + float tmp = static_cast(src[i]); max = std::max(max, tmp); min = std::min(min, tmp); } -#endif auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; return dst | (uint8_t) (val << shift); @@ -277,9 +276,11 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } template -static void quant_s4(const T* src, void* dst, size_t n, float& scale, float& zp) { +static void quant_s4(const T* src, void* dst, size_t n, float& scale) { auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; + if (high_half) + val &= 0x0F; return dst | (uint8_t) (val << shift); }; auto dst_ptr = reinterpret_cast(dst); @@ -293,21 +294,23 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale, float& zp) } float max_abs = std::max(std::abs(min), std::abs(max)); scale = max_abs / ((1 << 3) - 1); + printf("max %f min %f scale %f\n", min, max, scale); if (scale == 0) scale = 0.0001f; i = 0; for (; i < n; i++) { float tmp = src[i]; // add 8.5 here is to save a clamp to (-2^3) - int8_t src_val = std::min((int8_t)(7), (int8_t)(tmp / scale)); + float temp1 = std::round(tmp / scale); + float temp2 = (int8_t)(tmp / scale); + int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale)); src_val = std::max((int8_t)(-8), src_val); uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); - if (i < 4) - printf("index %ld float %f src %d hex %x", i, tmp, src_val, dst_val); + if (i < 16) + printf("index %ld float %f temp1 %f temp2 %f src %d hex %x\n", i, tmp, temp1, temp2, src_val, dst_val); dst_ptr[i / 2] = dst_val; } - printf("quant scale %f zp %f\n", scale, zp); } template @@ -417,6 +420,49 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } +template ::type = true> +static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, + const ov::intel_cpu::PlainTensor& v_src, + const ov::intel_cpu::PlainTensor& k_dst, + const ov::intel_cpu::PlainTensor& v_dst, + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { + size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; + size_t block_size = k_dst.m_dims[2]; + size_t _key_group_size = key_group_size == 0 ? S : key_group_size; + size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; + size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); + parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { + auto slot = slot_mapping.ptr(b)[m]; + if (slot < 0) return; + auto block_number = slot / block_size; + auto block_offset = slot % block_size; + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + quant_u8(k_src.ptr(b, h, m, src_offset), + k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); + } + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, + dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float)) { + uint8_t* v_base = reinterpret_cast( + v_dst.m_ptr.get() + + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / + sub_byte_multiplier + + dst_offset); + auto p_v = reinterpret_cast(v_base); + uint8_t* v_ptr = v_base + sizeof(float); + quant_s4(v_src.ptr(b, h, m, src_offset), v_ptr, _value_group_size, p_v[0]); + } + }); +} + void attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -451,19 +497,27 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, static constexpr function_type funcs_fp32[] = { paged_attn_quant_mt, paged_attn_quant_mt, + paged_attn_quant_mt, }; static constexpr function_type funcs_bf16[] = { paged_attn_quant_mt, paged_attn_quant_mt, + paged_attn_quant_mt, }; static constexpr function_type funcs_f16[] = { paged_attn_quant_mt, paged_attn_quant_mt, + paged_attn_quant_mt, }; if (k_dst.get_precision() != ov::element::u8) { OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); } - int dispatch = v_dst.get_precision() == ov::element::u8 ? 0 : 1; + std::map dispatch_table = { + {ov::element::u8, 0}, + {ov::element::u4, 1}, + {ov::element::i4, 2}, + }; + size_t dispatch = dispatch_table[v_dst.get_precision()]; if (k_src.get_precision() == ov::element::f32) { funcs_fp32[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else if (k_src.get_precision() == ov::element::bf16) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 9b621976b5ff6c..174e0d8b5804dc 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -64,24 +64,18 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale size_t i = 0; uint8_t* src_nc = const_cast(src); #if defined(HAVE_AVX512F) - auto extract_half_byte2 = [&](uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - return (uint8_t) ((val >> shift) & 0x000F); - }; auto v_zp = _mm512_set1_ps(zp); auto v_scale = _mm512_set1_ps(scale); for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { - auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); - __m128i low_half = _mm_srli_epi16(high_half, 4); - const __m128i mask = _mm_set1_epi8(0x0F); - low_half = _mm_and_si128(mask, low_half); - high_half = _mm_and_si128(mask, high_half); - - //cvt to f32 - auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); - auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i/2)); + auto v_i32 = _mm512_cvtepu8_epi32(data); + + auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); + + auto mask = _mm512_set1_epi32(0x0F); + auto v_512_high_half = _mm512_and_si512(v_i32, mask); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); // q - zp v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); @@ -96,6 +90,34 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale mm512_uni_storeu_ps(dst + i, first_half); mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } + auto v256_zp = _mm256_set1_ps(zp); + auto v256_scale = _mm256_set1_ps(scale); + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(src_nc + i/2)); + + auto v_i32 = _mm256_cvtepu8_epi32(data); + auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto mask = _mm256_set1_epi32(0x0F); + auto v_256_high_half = _mm256_and_si256(v_i32, mask); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); + v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); + + v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); + v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); + + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + + mm256_uni_storeu_ps(dst + i, first_half); + mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); + } #endif auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; @@ -121,18 +143,17 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale */ size_t i = 0; uint8_t* src_nc = const_cast(src); +#if defined(HAVE_AVX512F) for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { - auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); - __m128i low_half = _mm_srli_epi16(high_half, 4); - const __m128i mask = _mm_set1_epi8(0x0F); - low_half = _mm_and_si128(mask, low_half); - auto v_scale = _mm512_set1_ps(1/scale); - //cvt to f32 - auto v_256_low_half = _mm512_cvtepi8_epi32(low_half); - auto v_256_high_half = _mm512_cvtepi8_epi32(high_half); - v_256_high_half = _mm512_slli_epi32(v_256_high_half, 28); - v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); + auto v_scale = _mm512_set1_ps(scale); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); + // cvt to f32 + auto v_i32 = _mm512_cvtepi8_epi32(data); + + auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); + auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); + v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); // q * scale v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); @@ -141,16 +162,43 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); mm512_uni_storeu_ps(dst + i, first_half); - mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); + mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } - auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto v256_scale = _mm256_set1_ps(scale); + auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(src_nc + i / 2)); + + auto v_i32 = _mm256_cvtepi8_epi32(data); + auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); + v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + + // q * scale + v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); + v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); + + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + mm256_uni_storeu_ps(dst + i, first_half); + mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); + } +#endif + auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t { uint8_t shift = high_half ? 0 : 4; - return (int8_t) ((val >> shift) & 0x000F); + return float((val >> shift) & 0x000F); }; for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); + tmp = tmp > 8 ? (tmp - 16) : tmp; tmp = tmp * scale; dst[i] = tmp; } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index fe401a86653dcd..37887d426a1157 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -352,18 +352,17 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S size_t i = 0; auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); auto v_zp = _mm512_set1_ps(v0[1]); - for (; i + vec_len_f32_avx512 * 2 < _group_size; i += vec_len_f32_avx512 * 2) { - auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); - __m128i low_half = _mm_srli_epi16(high_half, 4); - const __m128i mask = _mm_set1_epi8(0x0F); - low_half = _mm_and_si128(mask, low_half); - high_half = _mm_and_si128(mask, high_half); + for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto v_i32 = _mm512_cvtepu8_epi32(data); + + auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); + + auto mask = _mm512_set1_epi32(0x0F); + auto v_512_high_half = _mm512_and_si512(v_i32, mask); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); - //cvt to f32 - auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); - auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); // q - zp v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); @@ -380,6 +379,37 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } + auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); + auto v256_zp = _mm256_set1_ps(v0[1]); + for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { + auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + + auto v_i32 = _mm256_cvtepu8_epi32(data); + auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto mask = _mm256_set1_epi32(0x0F); + auto v_256_high_half = _mm256_and_si256(v_i32, mask); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); + v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); + + auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); + + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + + v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); + v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); + mm256_uni_storeu_ps(out + dst_offset + i, v_out0); + mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); + } + for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, (bool)(i % 2)); @@ -413,88 +443,110 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } -// template::type = true> -// static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { -// size_t src_offset = 0; -// size_t dst_offset = 0; -// const size_t _group_size = group_size ? group_size : S; -// const size_t params_offset = sizeof(float) * 1; -// auto sub_byte_multiplyer = 8 / 4; -// const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); -// auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { -// uint8_t shift = high_half ? 0 : 4; - -// return (uint8_t) ((val >> shift) & 0x000F); -// }; -// // #if defined(HAVE_AVX512F) -// // for (size_t j = 0; j < block_size; j++) { -// // dst_offset = 0; -// // src_offset = 0; -// // while (dst_offset < S) { -// // auto v0 = reinterpret_cast(v + src_offset); -// // size_t i = 0; -// // auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); -// // auto v_zp = _mm512_set1_ps(v0[1]); -// // for (; i + vec_len_f32_avx512 * 2 < _group_size; i += vec_len_f32_avx512 * 2) { -// // auto high_half = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); -// // __m128i low_half = _mm_srli_epi16(high_half, 4); -// // const __m128i mask = _mm_set1_epi8(0x0F); -// // low_half = _mm_and_si128(mask, low_half); -// // high_half = _mm_and_si128(mask, high_half); - -// // //cvt to f32 -// // auto v_256_low_half = _mm512_cvtepu8_epi32(low_half); -// // auto v_256_high_half = _mm512_cvtepu8_epi32(high_half); -// // auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); -// // auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); -// // // q - zp -// // v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); -// // v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); - -// // __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); -// // __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); -// // __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); -// // __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); -// // auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); -// // auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); -// // v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); -// // v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); -// // mm512_uni_storeu_ps(out + dst_offset + i, v_out0); -// // mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); -// // } - -// // for (; i < _group_size; i += 2) { -// // uint8_t data = v[i/2 + src_offset + params_offset]; -// // float tmp0 = extract_half_byte(data, (bool)(i % 2)); -// // float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); -// // out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; -// // out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; -// // } -// // dst_offset += _group_size; -// // src_offset += _group_size / sub_byte_multiplyer + params_offset; -// // } -// // v += src_stride; -// // } -// // return; -// // #endif -// for (size_t j = 0; j < block_size; j++) { -// dst_offset = 0; -// src_offset = 0; -// while (dst_offset < S) { -// auto v0 = reinterpret_cast(v + src_offset); -// for (size_t i = 0; i < _group_size; i += 2) { -// uint8_t data = v[i/2 + src_offset + params_offset]; -// float tmp0 = extract_half_byte(data, (bool)(i % 2)); -// float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); -// out[dst_offset + i] += weight[j] * tmp0 * v0[0]; -// out[dst_offset + i + 1] += weight[j] * tmp1 * v0[0]; -// } -// dst_offset += _group_size; -// src_offset += _group_size / sub_byte_multiplyer + params_offset; -// } -// v += src_stride; -// } -// } +template::type = true> +static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t _group_size = group_size ? group_size : S; + const size_t params_offset = sizeof(float); + auto sub_byte_multiplyer = 8 / 4; + const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); + auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + + return (uint8_t) ((val >> shift) & 0x000F); + }; +#if defined(HAVE_AVX512F) + for (size_t j = 0; j < block_size; j++) { + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + size_t i = 0; + auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); + for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto v_i32 = _mm512_cvtepi8_epi32(data); + //cvt to f32 + auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); + auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); + v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); + + auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); + + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); + v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); + v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); + mm512_uni_storeu_ps(out + dst_offset + i, v_out0); + mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); + } + auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); + for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { + auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + + auto v_i32 = _mm256_cvtepi8_epi32(data); + auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); + v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + + auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); + v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); + mm256_uni_storeu_ps(out + dst_offset + i, v_out0); + mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); + } + + for (; i < _group_size; i += 2) { + uint8_t data = v[i/2 + src_offset + params_offset]; + float tmp0 = extract_half_byte(data, (bool)(i % 2)); + tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; + float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; + out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; + } + dst_offset += _group_size; + src_offset += _group_size / sub_byte_multiplyer + params_offset; + } + v += src_stride; + } + return; +#endif + for (size_t j = 0; j < block_size; j++) { + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + for (size_t i = 0; i < _group_size; i += 2) { + uint8_t data = v[i/2 + src_offset + params_offset]; + float tmp0 = extract_half_byte(data, (bool)(i % 2)); + tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; + float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; + out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; + } + dst_offset += _group_size; + src_offset += _group_size / sub_byte_multiplyer + params_offset; + } + v += src_stride; + } +} template static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { @@ -953,18 +1005,44 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) auto s = src; const size_t params_offset = sizeof(float) * 2; const size_t _group_size = group_size ? group_size : K; - const size_t src_stride = K / _group_size * (_group_size + params_offset); + const size_t sub_byte_mulitplier = 2; for (size_t n = 0; n < N; n ++) { - size_t group_offset = 0; + size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { - auto f = reinterpret_cast(s + group_offset); - attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); - group_offset += _group_size + params_offset; + // printf("dequant n %ld dst_offset %ld N %ld K %ldd group_size %ld\n", n, dst_offset, N, K, group_size); + auto f = reinterpret_cast(s + src_offset); + attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); + src_offset += _group_size / sub_byte_mulitplier + params_offset; dst_offset += _group_size; } - s += src_stride; + s += src_offset; + dst += K; + } +} + +template::type = true> +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = src; + const size_t params_offset = sizeof(float); + const size_t _group_size = group_size ? group_size : K; + const size_t sub_byte_mulitplier = 2; + + for (size_t n = 0; n < N; n ++) { + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + // printf("dequant n %ld dst_offset %ld N %ld K %ldd group_size %ld\n", n, dst_offset, N, K, group_size); + auto f = reinterpret_cast(s + src_offset); + attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0]); + src_offset += _group_size / sub_byte_mulitplier + params_offset; + dst_offset += _group_size; + } + s += src_offset; dst += K; } } @@ -1070,7 +1148,7 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -template::value != ov::element::f32 && SRC_PREC == ov::element::u4, bool>::type = true> +template::value != ov::element::f32 && (SRC_PREC == ov::element::u4), bool>::type = true> static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| @@ -1094,6 +1172,32 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } + +template::value != ov::element::f32 && (SRC_PREC == ov::element::i4), bool>::type = true> +static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { + // The layout for per token per head: + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + auto s = reinterpret_cast(src); + auto t = tmp; + // if group_size not set, the whole row is used as a group + const size_t sub_byte_mulitplier = 2; + size_t _group_size = group_size ? group_size : K; + printf("pack32| i4 N %ld K %ld\n", N, K); + for (size_t n = 0; n < N; n ++) { + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, _group_size, f[0]); + src_offset += _group_size / sub_byte_mulitplier + sizeof(float); + dst_offset += _group_size; + } + s += src_offset; + t += src_stride; + } + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); +} #endif template::value == ov::element::f32, bool>::type = true> @@ -1476,6 +1580,18 @@ struct MHAHelper { _SV, std::min(_block_size, cur_kv_len - pv), _value_group_size); + } else if (present_value.get_precision() == ov::element::i4) { + printf("exec_kernel_one_bh|attn_acc i4| shape %ld %ld %ld %ld\n", present_value.m_dims[0], present_value.m_dims[1], present_value.m_dims[2], present_value.m_dims[3]); + auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = present_value.m_ptr.get() + v_stride; + attn_acc_value_block( + _output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } else { attn_acc_value_block::value>( _output.ptr(ithr, pq, h), @@ -1650,6 +1766,18 @@ struct MHAHelper { _SV, std::min(_block_size, context_len - pv), _value_group_size); + } else if (present_value.get_precision() == ov::element::i4) { + printf("exec_loop_bhl|attn_acc i4| shape %ld %ld %ld %ld\n", present_value.m_dims[0], present_value.m_dims[1], present_value.m_dims[2], present_value.m_dims[3]); + auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = present_value.m_ptr.get() + v_stride; + attn_acc_value_block( + _output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, context_len - pv), + _value_group_size); } else { attn_acc_value_block::value>( _output_bhl.ptr(ithr, b, pq, h), @@ -1836,6 +1964,18 @@ struct MHA { rnd_up(_helper._SV, _helper._block_size), _helper._SV, _helper._value_group_size); + } else if (v_cache.get_precision() == ov::element::i4) { + auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); } else { pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), v_ptr, @@ -1850,9 +1990,9 @@ struct MHA { // need to decompress if (!q_cache_is_same) { if (v_cache.get_precision() == ov::element::u4) { - printf("PaAttn|dequant value u4\n"); auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; auto* v_ptr = v_cache.m_ptr.get() + v_stride; dequant( _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), @@ -1860,6 +2000,17 @@ struct MHA { _helper._block_size, _helper._SV, _helper._value_group_size); + } else if (v_cache.get_precision() == ov::element::i4) { + auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); + size_t v_stride = + (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + dequant( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._block_size, + _helper._SV, + _helper._value_group_size); } else { dequant::value>( _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), @@ -2027,7 +2178,9 @@ struct AttentionExecutor : public PagedAttentionExecutor { const size_t key_sub_byte_multiplyer = 8 / k_cache.get_precision().bitwidth(); const size_t value_sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplyer; - const size_t value_params_size = sizeof(float) * 2 * value_sub_byte_multiplyer; + // u4 needs scale + zp. s4 needs scale. + const size_t param_size = one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); + const size_t value_params_size = param_size * value_sub_byte_multiplyer; size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : _key_group_size; size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : _value_group_size; auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); From 3e821ea25e47edc0163d14c9ebf8be2d8fa39cce Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 29 Nov 2024 16:16:41 +0800 Subject: [PATCH 08/28] [CPU]use AVX512 to quant s4 --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 67a7ba55160728..a8fd91eb67e64a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -262,7 +262,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) auto mask = _mm512_set1_epi32(0x0F); second_half = _mm512_and_epi32(second_half, mask); auto combined = _mm512_or_epi32(first_half, second_half); - _mm512_mask_cvtusepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); + _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif for (; i < n; i++) { @@ -298,9 +298,37 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { if (scale == 0) scale = 0.0001f; i = 0; +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(1 / scale); + auto v_upper = _mm512_set1_epi32(7.0); + auto v_lower = _mm512_set1_epi32(-8.0); + for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { + auto v0 = mm512_uni_loadu_ps(src + i); + auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); + v0 = _mm512_mul_ps(v0, v_scale); + v1 = _mm512_mul_ps(v1, v_scale); + auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + + v0_i32 = _mm512_max_epi32(v0_i32, v_lower); + v1_i32 = _mm512_max_epi32(v1_i32, v_lower); + v0_i32 = _mm512_min_epi32(v0_i32, v_upper); + v1_i32 = _mm512_min_epi32(v1_i32, v_upper); + + __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); + auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); + + auto mask = _mm512_set1_epi32(0x0F); + second_half = _mm512_and_epi32(second_half, mask); + first_half = _mm512_slli_epi32(first_half, 4); + auto combined = _mm512_or_epi32(first_half, second_half); + _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); + } +#endif for (; i < n; i++) { float tmp = src[i]; - // add 8.5 here is to save a clamp to (-2^3) float temp1 = std::round(tmp / scale); float temp2 = (int8_t)(tmp / scale); int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale)); From 80b093f59279f6983fbd674cf9bc1dce65176b3d Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 5 Dec 2024 16:08:09 +0800 Subject: [PATCH 09/28] [CPU]4-bit quantization with avx2 Signed-off-by: yi3.zhang@intel.com --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 176 +++++++++++------- .../kernels/scaled_attn/attn_quant_kernel.hpp | 12 ++ .../nodes/kernels/scaled_attn/executor_pa.cpp | 58 +----- 3 files changed, 123 insertions(+), 123 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index a8fd91eb67e64a..80092fdfc7e01f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -27,10 +27,10 @@ namespace XARCH { using namespace ov; template -static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { +static void find_minmax(const T* src, size_t n, float& min, float& max) { + max = -FLT_MAX; + min = FLT_MAX; size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; #if defined(HAVE_AVX512F) auto v0_max = _mm512_set1_ps(-FLT_MAX); auto v0_min = _mm512_set1_ps(FLT_MAX); @@ -131,12 +131,19 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& max = std::max(max, tmp); min = std::min(min, tmp); } + +} + +template +static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { + size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; + find_minmax(src, n, min, max); scale = (max - min) / 255; if (scale == 0) scale = 0.0001f; zp = -min / scale; - - i = 0; #if defined(HAVE_AVX512F) auto v_scale = _mm512_set1_ps(1 / scale); auto v_zp = _mm512_set1_ps(zp); @@ -175,58 +182,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) size_t i = 0; float max = -FLT_MAX; float min = FLT_MAX; -#if defined(HAVE_AVX512F) - auto v0_max = _mm512_set1_ps(-FLT_MAX); - auto v0_min = _mm512_set1_ps(FLT_MAX); - auto v1_max = _mm512_set1_ps(-FLT_MAX); - auto v1_min = _mm512_set1_ps(FLT_MAX); - auto v2_max = _mm512_set1_ps(-FLT_MAX); - auto v2_min = _mm512_set1_ps(FLT_MAX); - auto v3_max = _mm512_set1_ps(-FLT_MAX); - auto v3_min = _mm512_set1_ps(FLT_MAX); - for (; i + 4 * vec_len_f32_avx512 <= n; i += vec_len_f32_avx512 * 4) { - auto v0 = mm512_uni_loadu_ps(src + i); - auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); - auto v2 = mm512_uni_loadu_ps(src + i + 2 * vec_len_f32_avx512); - auto v3 = mm512_uni_loadu_ps(src + i + 3 * vec_len_f32_avx512); - v0_max = _mm512_max_ps(v0_max, v0); - v0_min = _mm512_min_ps(v0_min, v0); - v1_max = _mm512_max_ps(v1_max, v1); - v1_min = _mm512_min_ps(v1_min, v1); - v2_max = _mm512_max_ps(v2_max, v2); - v2_min = _mm512_min_ps(v2_min, v2); - v3_max = _mm512_max_ps(v3_max, v3); - v3_min = _mm512_min_ps(v3_min, v3); - } - if (i + 2 * vec_len_f32_avx512 <= n) { - auto v0 = mm512_uni_loadu_ps(src + i); - auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); - v0_max = _mm512_max_ps(v0_max, v0); - v0_min = _mm512_min_ps(v0_min, v0); - v1_max = _mm512_max_ps(v1_max, v1); - v1_min = _mm512_min_ps(v1_min, v1); - i += 2 * vec_len_f32_avx512; - } - if (i + vec_len_f32_avx512 <= n) { - auto v0 = mm512_uni_loadu_ps(src + i); - v0_max = _mm512_max_ps(v0_max, v0); - v0_min = _mm512_min_ps(v0_min, v0); - i += vec_len_f32_avx512; - } - v0_max = _mm512_max_ps(v0_max, v1_max); - v0_min = _mm512_min_ps(v0_min, v1_min); - v2_max = _mm512_max_ps(v2_max, v3_max); - v2_min = _mm512_min_ps(v2_min, v3_min); - v0_max = _mm512_max_ps(v0_max, v2_max); - v0_min = _mm512_min_ps(v0_min, v2_min); - max = _mm512_reduce_max_ps(v0_max); - min = _mm512_reduce_min_ps(v0_min); -#endif - for (; i < n; i++) { - float tmp = static_cast(src[i]); - max = std::max(max, tmp); - min = std::min(min, tmp); - } + find_minmax(src, n, min, max); auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; return dst | (uint8_t) (val << shift); @@ -236,13 +192,11 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) if (scale == 0) scale = 0.0001f; zp = -min / scale; - i = 0; #if defined(HAVE_AVX512F) auto v_scale = _mm512_set1_ps(1 / scale); auto v_zp = _mm512_set1_ps(zp); auto v_zero = _mm512_setzero_epi32(); auto v_upper = _mm512_set1_epi32(15); - printf("total size %ld avx512 size %ld\n", n, vec_len_f32_avx512); for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { auto v0 = mm512_uni_loadu_ps(src + i); auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); @@ -264,6 +218,50 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) auto combined = _mm512_or_epi32(first_half, second_half); _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } +#endif +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) + auto v256_zero = _mm256_set1_epi32(0); + auto v256_upper = _mm256_set1_epi32(15); + auto v256_scale = _mm256_set1_ps(1 / scale); + auto v256_zp = _mm256_set1_ps(zp); + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto v0 = mm256_uni_loadu_ps(src + i); + auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); + v0 = _mm256_fmadd_ps(v0, v256_scale, v256_zp); + v1 = _mm256_fmadd_ps(v1, v256_scale, v256_zp); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + + auto v0_i32 = _mm256_cvtps_epi32(v0); + auto v1_i32 = _mm256_cvtps_epi32(v1); + v0_i32 = _mm256_max_epi32(v0_i32, v256_zero); + v1_i32 = _mm256_max_epi32(v1_i32, v256_zero); + v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); + v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); + auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); + v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 + // _mm256_permutevar8x32_epi32 + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permute2x128_si256 + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // shift + mask + or + // [0,1],[2,3], ..., [12,13], [14,15] + auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); + auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); + first_half = _mm256_slli_epi32(first_half, 4); + auto mask = _mm256_set1_epi32(0x0F); + second_half = _mm256_and_si256(second_half, mask); + auto combined = _mm256_or_si256(first_half, second_half); + + auto high4 = _mm256_extractf128_si256(combined, 1); + auto low4 = _mm256_castsi256_si128(combined); + // ignore sign bit for u4 case + auto packed = _mm_packus_epi32(low4, high4); + packed = _mm_packus_epi16(packed, packed); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); + } #endif for (; i < n; i++) { float tmp = src[i]; @@ -287,21 +285,15 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { size_t i = 0; float max = -FLT_MAX; float min = FLT_MAX; - for (; i < n; i++) { - float tmp = src[i]; - max = std::max(max, tmp); - min = std::min(min, tmp); - } + find_minmax(src, n, min, max); float max_abs = std::max(std::abs(min), std::abs(max)); scale = max_abs / ((1 << 3) - 1); - printf("max %f min %f scale %f\n", min, max, scale); if (scale == 0) scale = 0.0001f; - i = 0; #if defined(HAVE_AVX512F) auto v_scale = _mm512_set1_ps(1 / scale); - auto v_upper = _mm512_set1_epi32(7.0); - auto v_lower = _mm512_set1_epi32(-8.0); + auto v_upper = _mm512_set1_epi32(7); + auto v_lower = _mm512_set1_epi32(-8); for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { auto v0 = mm512_uni_loadu_ps(src + i); auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); @@ -326,17 +318,57 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { auto combined = _mm512_or_epi32(first_half, second_half); _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } +#endif +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) + auto v256_lower = _mm256_set1_epi32(-8); + auto v256_upper = _mm256_set1_epi32(7); + auto v256_scale = _mm256_set1_ps(1 / scale); + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto v0 = mm256_uni_loadu_ps(src + i); + auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); + v0 = _mm256_mul_ps(v0, v256_scale); + v1 = _mm256_mul_ps(v1, v256_scale); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + + auto v0_i32 = _mm256_cvtps_epi32(v0); + auto v1_i32 = _mm256_cvtps_epi32(v1); + v0_i32 = _mm256_max_epi32(v0_i32, v256_lower); + v1_i32 = _mm256_max_epi32(v1_i32, v256_lower); + v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); + v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); + auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); + v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); + + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 + // _mm256_permutevar8x32_epi32 + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permute2x128_si256 + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // shift + mask + or + // [0,1],[2,3], ..., [12,13], [14,15] + auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); + auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); + first_half = _mm256_slli_epi32(first_half, 4); + auto mask = _mm256_set1_epi32(0x0F); + second_half = _mm256_and_si256(second_half, mask); + auto combined = _mm256_or_si256(first_half, second_half); + + auto high4 = _mm256_extractf128_si256(combined, 1); + auto low4 = _mm256_castsi256_si128(combined); + // keep sign bit for s4 case + auto packed = _mm_packs_epi32(low4, high4); + packed = _mm_packs_epi16(packed, packed); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); + } #endif for (; i < n; i++) { float tmp = src[i]; - float temp1 = std::round(tmp / scale); - float temp2 = (int8_t)(tmp / scale); int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale)); src_val = std::max((int8_t)(-8), src_val); uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); - if (i < 16) - printf("index %ld float %f temp1 %f temp2 %f src %d hex %x\n", i, tmp, temp1, temp2, src_val, dst_val); dst_ptr[i / 2] = dst_val; } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 174e0d8b5804dc..ba98f1893607ee 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -90,6 +90,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale mm512_uni_storeu_ps(dst + i, first_half); mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } +#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_zp = _mm256_set1_ps(zp); auto v256_scale = _mm256_set1_ps(scale); for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { @@ -109,6 +110,11 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // _mm256_permute2f128_ps + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permutevar8x32_ps + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); first_half = _mm256_permutevar8x32_ps(first_half, idx1); @@ -167,6 +173,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } +#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { auto v256_scale = _mm256_set1_ps(scale); auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(src_nc + i / 2)); @@ -183,6 +190,11 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // _mm256_permute2f128_ps + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permutevar8x32_ps + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); first_half = _mm256_permutevar8x32_ps(first_half, idx1); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 37887d426a1157..ecf69cdf6a3d0c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -266,7 +266,6 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); auto zp0 = _mm512_set1_ps(v_f0[1]); size_t i = 0; - // printf("j %d dst_offset %d src_offset %ld src_stride %ld scale %f zp %f vec_len_f32_avx512 %ld _group_size %ld\n", j, dst_offset, src_offset, src_stride, v_f0[0], v_f0[1], vec_len_f32_avx512, _group_size); for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); @@ -343,13 +342,13 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S return (uint8_t) ((val >> shift) & 0x000F); }; -#if defined(HAVE_AVX512F) for (size_t j = 0; j < block_size; j++) { dst_offset = 0; src_offset = 0; while (dst_offset < S) { auto v0 = reinterpret_cast(v + src_offset); size_t i = 0; +#if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); auto v_zp = _mm512_set1_ps(v0[1]); for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { @@ -378,7 +377,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } - +#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { @@ -409,27 +408,8 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm256_uni_storeu_ps(out + dst_offset + i, v_out0); mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } - - for (; i < _group_size; i += 2) { - uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); - out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; - } - dst_offset += _group_size; - src_offset += _group_size / sub_byte_multiplyer + params_offset; - } - v += src_stride; - } - return; #endif - for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v + src_offset); - for (size_t i = 0; i < _group_size; i += 2) { + for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, (bool)(i % 2)); float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); @@ -456,13 +436,14 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S return (uint8_t) ((val >> shift) & 0x000F); }; -#if defined(HAVE_AVX512F) + for (size_t j = 0; j < block_size; j++) { dst_offset = 0; src_offset = 0; while (dst_offset < S) { auto v0 = reinterpret_cast(v + src_offset); size_t i = 0; +#if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); @@ -486,6 +467,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } +#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); @@ -510,29 +492,8 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm256_uni_storeu_ps(out + dst_offset + i, v_out0); mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } - - for (; i < _group_size; i += 2) { - uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); - tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); - tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; - out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; - } - dst_offset += _group_size; - src_offset += _group_size / sub_byte_multiplyer + params_offset; - } - v += src_stride; - } - return; #endif - for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v + src_offset); - for (size_t i = 0; i < _group_size; i += 2) { + for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, (bool)(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; @@ -1011,7 +972,6 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { - // printf("dequant n %ld dst_offset %ld N %ld K %ldd group_size %ld\n", n, dst_offset, N, K, group_size); auto f = reinterpret_cast(s + src_offset); attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); src_offset += _group_size / sub_byte_mulitplier + params_offset; @@ -1036,7 +996,6 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { - // printf("dequant n %ld dst_offset %ld N %ld K %ldd group_size %ld\n", n, dst_offset, N, K, group_size); auto f = reinterpret_cast(s + src_offset); attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0]); src_offset += _group_size / sub_byte_mulitplier + params_offset; @@ -1183,7 +1142,6 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size // if group_size not set, the whole row is used as a group const size_t sub_byte_mulitplier = 2; size_t _group_size = group_size ? group_size : K; - printf("pack32| i4 N %ld K %ld\n", N, K); for (size_t n = 0; n < N; n ++) { size_t src_offset = 0; size_t dst_offset = 0; @@ -1581,7 +1539,6 @@ struct MHAHelper { std::min(_block_size, cur_kv_len - pv), _value_group_size); } else if (present_value.get_precision() == ov::element::i4) { - printf("exec_kernel_one_bh|attn_acc i4| shape %ld %ld %ld %ld\n", present_value.m_dims[0], present_value.m_dims[1], present_value.m_dims[2], present_value.m_dims[3]); auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; @@ -1767,7 +1724,6 @@ struct MHAHelper { std::min(_block_size, context_len - pv), _value_group_size); } else if (present_value.get_precision() == ov::element::i4) { - printf("exec_loop_bhl|attn_acc i4| shape %ld %ld %ld %ld\n", present_value.m_dims[0], present_value.m_dims[1], present_value.m_dims[2], present_value.m_dims[3]); auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; From 13a496e17a896ffafef55c16b78b10f5bbf373e6 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 6 Dec 2024 09:56:38 +0800 Subject: [PATCH 10/28] fix build on elder compiler --- .../src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp | 4 ++-- .../intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index ba98f1893607ee..120e14ab7da5df 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -94,7 +94,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale auto v256_zp = _mm256_set1_ps(zp); auto v256_scale = _mm256_set1_ps(scale); for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(src_nc + i/2)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i/2)); auto v_i32 = _mm256_cvtepu8_epi32(data); auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); @@ -176,7 +176,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale #elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { auto v256_scale = _mm256_set1_ps(scale); - auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(src_nc + i / 2)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); auto v_i32 = _mm256_cvtepi8_epi32(data); auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index ecf69cdf6a3d0c..59ebbda4a0e2a9 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -381,7 +381,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepu8_epi32(data); auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); @@ -470,7 +470,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S #elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadu_si64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepi8_epi32(data); auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); From 92e6cb308d5d0d439cd24df1bee7055cd8d243ca Mon Sep 17 00:00:00 2001 From: Zhang Yi3 Date: Sun, 8 Dec 2024 19:01:41 -0800 Subject: [PATCH 11/28] [CPU]fix fp32 inference --- .../src/nodes/kernels/scaled_attn/executor_pa.cpp | 4 ++-- src/plugins/intel_cpu/src/nodes/paged_attn.cpp | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 59ebbda4a0e2a9..51b750a0ec11a0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -2262,10 +2262,10 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else if (key_cache_type == ov::element::f16) { - executor = std::make_shared>(); + executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); - executor = std::make_shared>(); + executor = std::make_shared>(key_group_size, value_group_size); } } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 095b89bbeb4a12..46cfb7ceee9ed2 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -132,7 +132,12 @@ void PagedAttention::createPrimitive() { // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - size_t group_size = 64; + const auto keyDims = getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE).getDims(); + const auto valueDims = getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE).getDims(); + const auto keyS = *(keyDims.end() - 1); + const auto valueS = *(valueDims.end() - 1); + + size_t group_size = keyS; if (getenv("GROUP_SIZE")) group_size = std::stoi(std::string(getenv("GROUP_SIZE"))); size_t key_group_size = group_size; From 91ebc0999e86da1ca3e0a82d242fa31165f38946 Mon Sep 17 00:00:00 2001 From: Zhang Yi3 Date: Mon, 9 Dec 2024 19:23:55 -0800 Subject: [PATCH 12/28] [CPU]set group size via hint Signed-off-by: Zhang Yi3 --- .../openvino/runtime/properties/hint/__init__.py | 2 ++ .../src/pyopenvino/core/properties/properties.cpp | 2 ++ .../python/tests/test_runtime/test_properties.py | 10 ++++++++++ .../include/openvino/runtime/properties.hpp | 12 ++++++++++++ src/plugins/intel_cpu/src/compiled_model.cpp | 6 ++++++ src/plugins/intel_cpu/src/config.cpp | 15 +++++++++++++++ src/plugins/intel_cpu/src/config.h | 2 ++ .../src/nodes/kernels/scaled_attn/executor_pa.cpp | 4 ++-- src/plugins/intel_cpu/src/nodes/paged_attn.cpp | 14 ++++---------- .../behavior/ov_executable_network/properties.cpp | 2 ++ .../custom/behavior/ov_plugin/properties.cpp | 2 ++ 11 files changed, 59 insertions(+), 12 deletions(-) diff --git a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py index d1dce289d09941..53eb5a76effdb4 100644 --- a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py @@ -23,4 +23,6 @@ from openvino._pyopenvino.properties.hint import allow_auto_batching from openvino._pyopenvino.properties.hint import dynamic_quantization_group_size from openvino._pyopenvino.properties.hint import kv_cache_precision +from openvino._pyopenvino.properties.hint import key_cache_group_size +from openvino._pyopenvino.properties.hint import value_cache_group_size from openvino._pyopenvino.properties.hint import activations_scale_factor diff --git a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp index 564e5f69f5ee14..cec0aae9b07a21 100644 --- a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp +++ b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp @@ -101,6 +101,8 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_hint, ov::hint::allow_auto_batching, "allow_auto_batching"); wrap_property_RW(m_hint, ov::hint::dynamic_quantization_group_size, "dynamic_quantization_group_size"); wrap_property_RW(m_hint, ov::hint::kv_cache_precision, "kv_cache_precision"); + wrap_property_RW(m_hint, ov::hint::key_cache_group_size, "key_cache_group_size"); + wrap_property_RW(m_hint, ov::hint::value_cache_group_size, "value_cache_group_size"); wrap_property_RW(m_hint, ov::hint::activations_scale_factor, "activations_scale_factor"); // Submodule intel_cpu diff --git a/src/bindings/python/tests/test_runtime/test_properties.py b/src/bindings/python/tests/test_runtime/test_properties.py index 6065d72196b44b..d2d95c32079bea 100644 --- a/src/bindings/python/tests/test_runtime/test_properties.py +++ b/src/bindings/python/tests/test_runtime/test_properties.py @@ -334,6 +334,16 @@ def test_properties_ro(ov_property_ro, expected_value): "DYNAMIC_QUANTIZATION_GROUP_SIZE", ((64, 64),), ), + ( + hints.key_cache_group_size, + "KEY_CACHE_GROUP_SIZE", + ((64, 64),), + ), + ( + hints.value_cache_group_size, + "VALUE_CACHE_GROUP_SIZE", + ((64, 64),), + ), (hints.kv_cache_precision, "KV_CACHE_PRECISION", ((Type.f32, Type.f32),)), ( hints.activations_scale_factor, diff --git a/src/inference/include/openvino/runtime/properties.hpp b/src/inference/include/openvino/runtime/properties.hpp index 5674c75dd546d7..e539b7e209fcb3 100644 --- a/src/inference/include/openvino/runtime/properties.hpp +++ b/src/inference/include/openvino/runtime/properties.hpp @@ -580,6 +580,18 @@ static constexpr Property dynamic_quantization */ static constexpr Property kv_cache_precision{"KV_CACHE_PRECISION"}; +/** + * @brief Hint for device to use group_size for key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_group_size{"KEY_CACHE_GROUP_SIZE"}; + +/** + * @brief Hint for device to use group_size for value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"}; + /** * @brief This property scales down activations to prevent overflows when inference precision is f16. * @ingroup ov_runtime_cpp_prop_api diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index bbee5d937be5d5..2fd048cc3a05e0 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -256,6 +256,8 @@ ov::Any CompiledModel::get_property(const std::string& name) const { RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::hint::key_cache_group_size.name()), + RO_property(ov::hint::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START @@ -333,6 +335,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const { config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); + } else if (name == ov::hint::key_cache_group_size) { + return decltype(ov::hint::key_cache_group_size)::value_type(config.keyCacheGroupSize); + } else if (name == ov::hint::value_cache_group_size) { + return decltype(ov::hint::value_cache_group_size)::value_type(config.valueCacheGroupSize); } OPENVINO_THROW("Unsupported property: ", name); } diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 83e4ed1c99ea3d..a25401f12566fc 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -375,6 +375,21 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } + } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { + try { + auto const groupSize = val.as(); + if (key == ov::hint::key_cache_group_size.name()) { + keyCacheGroupSize = groupSize; + } else { + valueCacheGroupSize = groupSize; + } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + key, + ". Expected only unsinged integer numbers"); + } } else if (key == ov::cache_encryption_callbacks.name()) { try { auto encryption_callbacks = val.as(); diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index a8439d87803fd4..b6aeeaca38e0ee 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -64,6 +64,8 @@ struct Config { // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif + size_t keyCacheGroupSize = 0ul; + size_t valueCacheGroupSize = 0ul; ov::threading::IStreamsExecutor::Config streamExecutorConfig; int streams = 1; bool streamsChanged = false; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 51b750a0ec11a0..f545cba7dd5097 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -2137,8 +2137,8 @@ struct AttentionExecutor : public PagedAttentionExecutor { // u4 needs scale + zp. s4 needs scale. const size_t param_size = one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); const size_t value_params_size = param_size * value_sub_byte_multiplyer; - size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : _key_group_size; - size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : _value_group_size; + size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : 1; + size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : 1; auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); auto block_size = k_cache.size(2); diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 46cfb7ceee9ed2..41e7274953f9e6 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -132,16 +132,10 @@ void PagedAttention::createPrimitive() { // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - const auto keyDims = getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE).getDims(); - const auto valueDims = getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE).getDims(); - const auto keyS = *(keyDims.end() - 1); - const auto valueS = *(valueDims.end() - 1); - - size_t group_size = keyS; - if (getenv("GROUP_SIZE")) - group_size = std::stoi(std::string(getenv("GROUP_SIZE"))); - size_t key_group_size = group_size; - size_t value_group_size = group_size; + const auto cpuConfig = context->getConfig(); + + size_t key_group_size = cpuConfig.keyCacheGroupSize; + size_t value_group_size = cpuConfig.valueCacheGroupSize; std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << "|key_group_size|" << key_group_size << "|value_group_size|" << value_group_size << std::endl; return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 73086b78a0de95..29e5fbbe982542 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -41,6 +41,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSupportedPropertiesAreAvailable RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::hint::key_cache_group_size.name()), + RO_property(ov::hint::value_cache_group_size.name()), }; ov::Core ie; diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp index 904d2b81dc05b6..696f73f27e1142 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp @@ -56,6 +56,8 @@ TEST_F(OVClassConfigTestCPU, smoke_PluginAllSupportedPropertiesAreAvailable) { RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), + RW_property(ov::hint::key_cache_group_size.name()), + RW_property(ov::hint::value_cache_group_size.name()), }; ov::Core ie; From 685f263a75775314d0fbe6a4023bf5cda6aa8b87 Mon Sep 17 00:00:00 2001 From: Zhang Yi3 Date: Tue, 10 Dec 2024 00:18:42 -0800 Subject: [PATCH 13/28] [CPU]fix code style Signed-off-by: Zhang Yi3 --- src/plugins/intel_cpu/src/config.cpp | 2 +- .../nodes/kernels/scaled_attn/attn_quant.cpp | 102 +++++++--- .../kernels/scaled_attn/attn_quant_kernel.hpp | 5 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 182 +++++++++++++----- .../nodes/kernels/scaled_attn/executor_pa.hpp | 2 +- .../intel_cpu/src/nodes/paged_attn.cpp | 3 +- 6 files changed, 208 insertions(+), 88 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index a25401f12566fc..257dee95546e34 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -389,7 +389,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { " for property key ", key, ". Expected only unsinged integer numbers"); - } + } } else if (key == ov::cache_encryption_callbacks.name()) { try { auto encryption_callbacks = val.as(); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 80092fdfc7e01f..ee13118ec80d11 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -131,7 +131,6 @@ static void find_minmax(const T* src, size_t n, float& min, float& max) { max = std::max(max, tmp); min = std::min(min, tmp); } - } template @@ -398,7 +397,10 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -417,27 +419,48 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { - auto p_v = reinterpret_cast(v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; + src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { + auto p_v = reinterpret_cast( + v_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(v_src.ptr(b, h, m, src_offset), - v_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _value_group_size, - p_v[0], - p_v[1]); + v_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _value_group_size, + p_v[0], + p_v[1]); } }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -457,13 +480,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, @@ -480,7 +512,10 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template ::type = true> static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -500,13 +535,22 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, auto block_offset = slot % block_size; // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast(k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset)); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, h, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, - p_k[0], - p_k[1]); + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + _key_group_size, + p_k[0], + p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 120e14ab7da5df..434286766d5188 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -82,7 +82,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale // (q - zp) * scale v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); @@ -106,7 +105,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale // q - zp v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); - + v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); @@ -206,7 +205,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale #endif auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t { uint8_t shift = high_half ? 0 : 4; - return float((val >> shift) & 0x000F); + return static_cast((val >> shift) & 0x000F); }; for (; i < n; ++i) { float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index f545cba7dd5097..c7d12a0818749e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -234,15 +234,23 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S size_t i = 0; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); - auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), zp3); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); + auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), + zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), + zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), + zp3); v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - _mm512_storeu_ps(out + dst_offset + i, v_out); + _mm512_storeu_ps(out + dst_offset + i, v_out); } for (; i < _group_size; i++) { out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; @@ -275,7 +283,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } for (; i < _group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } + } dst_offset += _group_size; src_offset += _group_size + params_offset; } @@ -305,7 +313,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } for (; i < _group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; - } + } dst_offset += _group_size; src_offset += _group_size + params_offset; } @@ -411,8 +419,8 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S #endif for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + float tmp0 = extract_half_byte(data, static_cast(i % 2)); + float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; } @@ -495,9 +503,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S #endif for (; i < _group_size; i += 2) { uint8_t data = v[i/2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, (bool)(i % 2)); + float tmp0 = extract_half_byte(data, static_cast(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; - float tmp1 = extract_half_byte(data, (bool)((i + 1) % 2)); + float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; @@ -640,7 +648,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc float sum1 = 0.0f; float sum2 = 0.0f; float sum3 = 0.0f; - while (dst_offset < n) { + while (dst_offset < n) { auto vsum0 = _mm512_setzero_ps(); auto vsum1 = _mm512_setzero_ps(); auto vsum2 = _mm512_setzero_ps(); @@ -657,10 +665,18 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto va = mm512_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); - auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); @@ -744,10 +760,18 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { auto va = mm256_uni_loadu_ps(a + dst_offset + i); - auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp0); - auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), v_zp3); + auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); @@ -775,7 +799,6 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc dst_offset += _group_size; src_offset += _group_size + params_offset; } - c[0] = sum0; c[1] = sum1; c[2] = sum2; @@ -890,14 +913,31 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t } } #if defined(HAVE_AVX512F) -template::value), bool>::type = true> -static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value), + bool>::type = true> +static void transpose_16NxK(T* dst, + T* src, + T* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); - transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); + transpose_16NxK(d, + s, + reinterpret_cast(0), + N, + K >> 1, + dst_stride, + src_stride >> 1); } -#endif +# endif template::type = true> void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { @@ -1062,8 +1102,19 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } } -template::value != ov::element::f32 && (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && + (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; @@ -1083,16 +1134,26 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } } -template::value != ov::element::f32 && SRC_PREC == ov::element::u8, bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && SRC_PREC == ov::element::u8, + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; // if group_size not set, the whole row is used as a group size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1107,17 +1168,27 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -template::value != ov::element::f32 && (SRC_PREC == ov::element::u4), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && (SRC_PREC == ov::element::u4), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group const size_t sub_byte_mulitplier = 2; size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1132,17 +1203,27 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -template::value != ov::element::f32 && (SRC_PREC == ov::element::i4), bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value != ov::element::f32 && (SRC_PREC == ov::element::i4), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group const size_t sub_byte_mulitplier = 2; size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1156,7 +1237,7 @@ static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size } pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); } -#endif +# endif template::value == ov::element::f32, bool>::type = true> static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { @@ -1558,7 +1639,6 @@ struct MHAHelper { std::min(_block_size, cur_kv_len - pv), _value_group_size); } - } } } @@ -2155,8 +2235,6 @@ struct AttentionExecutor : public PagedAttentionExecutor { q = q.reshape({B_token, H, 1, S}); k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); - printf("k_cache prec %s shape [%ld %ld %ld %ld]\n", k_cache.get_precision().to_string().c_str(), k_cache.size(0), k_cache.size(1), k_cache.size(2), k_cache.size(3)); - printf("v_cache prec %s shape [%ld %ld %ld %ld]\n", v_cache.get_precision().to_string().c_str(), v_cache.size(0), v_cache.size(1), v_cache.size(2), v_cache.size(3)); if (k_cache.m_dt == ov::element::Type_t::u8) { k_cache.assert_dims({0, Hk, block_size, S + key_params_size * key_group_num}, true); @@ -2237,27 +2315,27 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { -# if defined(HAVE_AVX512F) +#if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); executor = std::make_shared>(); } -# else +#else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); -# endif +#endif } else if (data_type == ov::element::f16) { -# if defined(HAVE_AVX512F) +#if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); executor = std::make_shared>(); } -# else +#else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); -# endif +#endif } else if (data_type == ov::element::f32) { if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index ae8b8aa348f1e3..d386dc9f44e321 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -22,7 +22,7 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ size_t key_group_size, size_t value_group_size); -} // namespace XARCHl +} // namespace XARCH } // namespace Cpu } // namespace Extensions } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 41e7274953f9e6..468fc72b28296a 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -135,8 +135,7 @@ void PagedAttention::createPrimitive() { const auto cpuConfig = context->getConfig(); size_t key_group_size = cpuConfig.keyCacheGroupSize; - size_t value_group_size = cpuConfig.valueCacheGroupSize; - std::cout << "PagedAttn|Kcache|" << kCachePrecision << "|Vcache|" << vCachePrecision << "|key_group_size|" << key_group_size << "|value_group_size|" << value_group_size << std::endl; + size_t value_group_size = cpuConfig.valueCacheGroupSize; return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else return nullptr; From e56639ab3e6aaf60e6e7f88b9feb73855a5e0f71 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 11 Dec 2024 09:11:07 +0800 Subject: [PATCH 14/28] [CPU]fix property test Signed-off-by: Zhang Yi3 --- .../runtime/properties/hint/__init__.py | 2 ++ .../pyopenvino/core/properties/properties.cpp | 2 ++ .../tests/test_runtime/test_properties.py | 2 ++ .../include/openvino/runtime/properties.hpp | 12 ++++++++++ src/plugins/intel_cpu/src/compiled_model.cpp | 6 +++++ src/plugins/intel_cpu/src/config.cpp | 22 +++++++++++++++++++ src/plugins/intel_cpu/src/config.h | 4 ++++ src/plugins/intel_cpu/src/plugin.cpp | 12 ++++++++++ .../ov_executable_network/properties.cpp | 2 ++ .../custom/behavior/ov_plugin/properties.cpp | 2 ++ 10 files changed, 66 insertions(+) diff --git a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py index 53eb5a76effdb4..d5c5d5595e5e0b 100644 --- a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py @@ -23,6 +23,8 @@ from openvino._pyopenvino.properties.hint import allow_auto_batching from openvino._pyopenvino.properties.hint import dynamic_quantization_group_size from openvino._pyopenvino.properties.hint import kv_cache_precision +from openvino._pyopenvino.properties.hint import key_cache_precision +from openvino._pyopenvino.properties.hint import value_cache_precision from openvino._pyopenvino.properties.hint import key_cache_group_size from openvino._pyopenvino.properties.hint import value_cache_group_size from openvino._pyopenvino.properties.hint import activations_scale_factor diff --git a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp index cec0aae9b07a21..2b997c6664cee0 100644 --- a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp +++ b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp @@ -101,6 +101,8 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_hint, ov::hint::allow_auto_batching, "allow_auto_batching"); wrap_property_RW(m_hint, ov::hint::dynamic_quantization_group_size, "dynamic_quantization_group_size"); wrap_property_RW(m_hint, ov::hint::kv_cache_precision, "kv_cache_precision"); + wrap_property_RW(m_hint, ov::hint::key_cache_precision, "key_cache_precision"); + wrap_property_RW(m_hint, ov::hint::value_cache_precision, "value_cache_precision"); wrap_property_RW(m_hint, ov::hint::key_cache_group_size, "key_cache_group_size"); wrap_property_RW(m_hint, ov::hint::value_cache_group_size, "value_cache_group_size"); wrap_property_RW(m_hint, ov::hint::activations_scale_factor, "activations_scale_factor"); diff --git a/src/bindings/python/tests/test_runtime/test_properties.py b/src/bindings/python/tests/test_runtime/test_properties.py index d2d95c32079bea..d0745f84361310 100644 --- a/src/bindings/python/tests/test_runtime/test_properties.py +++ b/src/bindings/python/tests/test_runtime/test_properties.py @@ -345,6 +345,8 @@ def test_properties_ro(ov_property_ro, expected_value): ((64, 64),), ), (hints.kv_cache_precision, "KV_CACHE_PRECISION", ((Type.f32, Type.f32),)), + (hints.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)), + (hints.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)), ( hints.activations_scale_factor, "ACTIVATIONS_SCALE_FACTOR", diff --git a/src/inference/include/openvino/runtime/properties.hpp b/src/inference/include/openvino/runtime/properties.hpp index e539b7e209fcb3..caff66750029fc 100644 --- a/src/inference/include/openvino/runtime/properties.hpp +++ b/src/inference/include/openvino/runtime/properties.hpp @@ -580,6 +580,18 @@ static constexpr Property dynamic_quantization */ static constexpr Property kv_cache_precision{"KV_CACHE_PRECISION"}; +/** + * @brief Hint for device to use specified precision for key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_precision{"KEY_CACHE_PRECISION"}; + +/** + * @brief Hint for device to use specified precision for value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_precision{"VALUE_CACHE_PRECISION"}; + /** * @brief Hint for device to use group_size for key cache compression * @ingroup ov_runtime_cpp_prop_api diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index 2fd048cc3a05e0..275fd0dbfff755 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -256,6 +256,8 @@ ov::Any CompiledModel::get_property(const std::string& name) const { RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::hint::key_cache_precision.name()), + RO_property(ov::hint::value_cache_precision.name()), RO_property(ov::hint::key_cache_group_size.name()), RO_property(ov::hint::value_cache_group_size.name()), }; @@ -335,6 +337,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const { config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); + } else if (name == ov::hint::key_cache_precision) { + return decltype(ov::hint::key_cache_precision)::value_type(config.keyCachePrecision); + } else if (name == ov::hint::value_cache_precision) { + return decltype(ov::hint::value_cache_precision)::value_type(config.valueCachePrecision); } else if (name == ov::hint::key_cache_group_size) { return decltype(ov::hint::key_cache_group_size)::value_type(config.keyCacheGroupSize); } else if (name == ov::hint::value_cache_group_size) { diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 257dee95546e34..32653626daa981 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -375,6 +375,26 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } + } else if (key == ov::hint::key_cache_precision.name() || key == ov::hint::value_cache_precision.name()) { + try { + kvCachePrecisionSetExplicitly = true; + auto const prec = val.as(); + if (key == ov::hint::key_cache_precision.name()) { + if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { + keyCachePrecision = prec; + } else { + OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); + } + } else { + if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8, ov::element::u4, ov::element::i4)) { + valueCachePrecision = prec; + } else { + OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); + } + } + } catch (ov::Exception&) { + + } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { auto const groupSize = val.as(); @@ -432,6 +452,8 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { } if (!kvCachePrecisionSetExplicitly) { kvCachePrecision = ov::element::f32; + valueCachePrecision = ov::element::f32; + keyCachePrecision = ov::element::f32; } } diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index b6aeeaca38e0ee..bcde841814d09c 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -58,9 +58,13 @@ struct Config { #endif #if defined(OPENVINO_ARCH_X86_64) ov::element::Type kvCachePrecision = ov::element::u8; + ov::element::Type keyCachePrecision = ov::element::u8; + ov::element::Type valueCachePrecision = ov::element::u8; size_t rtCacheCapacity = 5000ul; #else ov::element::Type kvCachePrecision = ov::element::f16; + ov::element::Type keyCachePrecision = ov::element::f16; + ov::element::Type valueCachePrecision = ov::element::f16; // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 6fdbf7a4ea4dee..f16f504ee2f5a0 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -390,6 +390,14 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& options) engConfig.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(engConfig.kvCachePrecision); + } else if (name == ov::hint::key_cache_precision) { + return decltype(ov::hint::key_cache_precision)::value_type(engConfig.keyCachePrecision); + } else if (name == ov::hint::value_cache_precision) { + return decltype(ov::hint::value_cache_precision)::value_type(engConfig.valueCachePrecision); + } else if (name == ov::hint::key_cache_group_size) { + return decltype(ov::hint::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize); + } else if (name == ov::hint::value_cache_group_size) { + return decltype(ov::hint::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize); } return get_ro_property(name, options); } @@ -433,6 +441,10 @@ ov::Any Plugin::get_ro_property(const std::string& name, const ov::AnyMap& optio RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), + RW_property(ov::hint::key_cache_precision.name()), + RW_property(ov::hint::value_cache_precision.name()), + RW_property(ov::hint::key_cache_group_size.name()), + RW_property(ov::hint::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 29e5fbbe982542..59fd31cdb34303 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -41,6 +41,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSupportedPropertiesAreAvailable RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::hint::key_cache_precision.name()), + RO_property(ov::hint::value_cache_precision.name()), RO_property(ov::hint::key_cache_group_size.name()), RO_property(ov::hint::value_cache_group_size.name()), }; diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp index 696f73f27e1142..589f0641eae0e8 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp @@ -56,6 +56,8 @@ TEST_F(OVClassConfigTestCPU, smoke_PluginAllSupportedPropertiesAreAvailable) { RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), + RW_property(ov::hint::key_cache_precision.name()), + RW_property(ov::hint::value_cache_precision.name()), RW_property(ov::hint::key_cache_group_size.name()), RW_property(ov::hint::value_cache_group_size.name()), }; From a34ce8bc61a90b852583879f98bfa50cb5ef2369 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 11 Dec 2024 15:17:48 +0800 Subject: [PATCH 15/28] [CPU]add cache precision check Signed-off-by: Zhang Yi3 --- src/plugins/intel_cpu/src/nodes/paged_attn.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 468fc72b28296a..b78510d6b8934f 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -196,6 +196,14 @@ void PagedAttention::execute(dnnl::stream strm) { bool PagedAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { + auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); + auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); + if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) { + if (kCachePrecision != ov::element::u8) { + errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); + return false; + } + } int orgInput = static_cast(op->get_input_size()); if (op->get_type_name() == std::string("PagedAttentionExtension") && orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) { return true; From fe6c311847684c7f85a8b9c293f9dc4179cc4f42 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 12 Dec 2024 10:13:32 +0800 Subject: [PATCH 16/28] [CPU]fix code style of config.cpp Signed-off-by: Zhang Yi3 --- src/plugins/intel_cpu/src/config.cpp | 14 +++++++++++++- src/plugins/intel_cpu/src/nodes/paged_attn.cpp | 3 ++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 32653626daa981..6262027e344032 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -393,7 +393,19 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { } } } catch (ov::Exception&) { - + if (key == ov::hint::key_cache_precision.name()) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::key_cache_precision.name(), + ". Supported values: u8, bf16, f16, f32"); + } else { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::value_cache_precision.name(), + ". Supported values: u4, s4, u8, bf16, f16, f32"); + } } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index b78510d6b8934f..8d927b2ed68ba2 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -200,7 +200,8 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr& auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) { if (kCachePrecision != ov::element::u8) { - errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); + errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); return false; } } From 8faadd84eeb364d0fe4534fac12375a81c0ddf99 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Tue, 17 Dec 2024 16:51:27 +0800 Subject: [PATCH 17/28] [CPU]pre calculate count --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 32 ++- .../kernels/scaled_attn/attn_quant_kernel.hpp | 10 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 217 +++++++++--------- 3 files changed, 122 insertions(+), 137 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 19721b4961fbb0..fb6dc8439ac9bf 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -402,8 +402,6 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; - size_t _key_group_size = key_group_size == 0 ? S : key_group_size; - size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) @@ -414,7 +412,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( k_dst.ptr::value_type>(block_number, h, @@ -426,13 +424,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, + key_group_size, p_k[0], p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += _value_group_size, dst_offset += _value_group_size + sizeof(float) + sizeof(float)) { + src_offset += value_group_size, dst_offset += value_group_size + sizeof(float) + sizeof(float)) { auto p_v = reinterpret_cast( v_dst.ptr::value_type>(block_number, h, @@ -444,7 +442,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _value_group_size, + value_group_size, p_v[0], p_v[1]); } @@ -464,8 +462,6 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; - size_t _key_group_size = key_group_size == 0 ? S : key_group_size; - size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; @@ -477,7 +473,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( k_dst.ptr::value_type>(block_number, h, @@ -489,13 +485,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, + key_group_size, p_k[0], p_k[1]); } - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, - dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float) + sizeof(float)) { + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size, + dst_offset += value_group_size / sub_byte_multiplier + sizeof(float) + sizeof(float)) { uint8_t* v_base = reinterpret_cast( v_dst.m_ptr.get() + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / @@ -503,7 +499,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, dst_offset); auto p_v = reinterpret_cast(v_base); uint8_t* v_ptr = v_base + sizeof(float) * 2; - quant_u4(v_src.ptr(b, h, m, src_offset), v_ptr, _value_group_size, p_v[0], p_v[1]); + quant_u4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0], p_v[1]); } }); } @@ -521,8 +517,6 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; - size_t _key_group_size = key_group_size == 0 ? S : key_group_size; - size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; @@ -534,7 +528,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { + src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( k_dst.ptr::value_type>(block_number, h, @@ -546,13 +540,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, block_offset, dst_offset) + sizeof(float) + sizeof(float), - _key_group_size, + key_group_size, p_k[0], p_k[1]); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += _value_group_size, dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float)) { + src_offset += value_group_size, dst_offset += value_group_size / sub_byte_multiplier + sizeof(float)) { uint8_t* v_base = reinterpret_cast( v_dst.m_ptr.get() + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / @@ -560,7 +554,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, dst_offset); auto p_v = reinterpret_cast(v_base); uint8_t* v_ptr = v_base + sizeof(float); - quant_s4(v_src.ptr(b, h, m, src_offset), v_ptr, _value_group_size, p_v[0]); + quant_s4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0]); } }); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 9b9944350f14d0..9bb1da2ccc4ec8 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -68,6 +68,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale #if defined(HAVE_AVX512F) auto v_zp = _mm512_set1_ps(zp); auto v_scale = _mm512_set1_ps(scale); + auto v_zp_scale = _mm512_set1_ps(zp * scale); for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); auto v_i32 = _mm512_cvtepu8_epi32(data); @@ -78,12 +79,9 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale auto mask = _mm512_set1_epi32(0x0F); auto v_512_high_half = _mm512_and_si512(v_i32, mask); auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); - // q - zp - v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); - v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); - // (q - zp) * scale - v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); - v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); + // q * scale- zp * scale + v_f32_low_half = _mm512_fmsub_ps(v_f32_low_half, v_scale, v_zp_scale); + v_f32_high_half = _mm512_fmsub_ps(v_f32_high_half, v_scale, v_zp_scale); __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 3acb0c7447db5c..3e78f645b9c95a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -75,7 +75,7 @@ void cvt_copy(TA* dst, TB* src, size_t n) { template ::type = true> -static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size = 0) { +static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -209,16 +209,15 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, - size_t block_size, - size_t group_size = 0) { + const size_t block_size, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) size_t src_offset = 0; size_t dst_offset = 0; - const size_t _group_size = group_size ? group_size : S; const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = S / _group_size * (_group_size + params_offset); + const size_t src_stride = S / group_size * (group_size + params_offset); # if defined(HAVE_AVX512F) size_t j = 0; @@ -242,7 +241,7 @@ static void attn_acc_value_block(float* out, auto zp3 = _mm512_set1_ps(v_f3[1]); uint8_t* v_data_ptr = v + src_offset + params_offset; size_t i = 0; - for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), @@ -262,14 +261,14 @@ static void attn_acc_value_block(float* out, v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); _mm512_storeu_ps(out + dst_offset + i, v_out); } - for (; i < _group_size; i++) { + for (; i < group_size; i++) { out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; out[i] += weight[1] * (v_data_ptr[i + src_stride] - v_f1[1]) * v_f1[0]; out[i] += weight[2] * (v_data_ptr[i + 2 * src_stride] - v_f2[1]) * v_f2[0]; out[i] += weight[3] * (v_data_ptr[i + 3 * src_stride] - v_f3[1]) * v_f3[0]; } - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } weight += 4; v += 4 * src_stride; @@ -284,7 +283,7 @@ static void attn_acc_value_block(float* out, auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); auto zp0 = _mm512_set1_ps(v_f0[1]); size_t i = 0; - for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), @@ -293,11 +292,11 @@ static void attn_acc_value_block(float* out, _mm512_storeu_ps((out + dst_offset + i), v_out); } - for (; i < _group_size; i++) { + for (; i < group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; } - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } v += src_stride; weight++; @@ -316,7 +315,7 @@ static void attn_acc_value_block(float* out, auto zp0 = _mm256_set1_ps(v_f0[1]); size_t i = 0; v += 8; - for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { auto v_out = mm256_uni_loadu_ps(out + dst_offset + i); auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), @@ -325,11 +324,11 @@ static void attn_acc_value_block(float* out, mm256_uni_storeu_ps(out + dst_offset + i, v_out); } - for (; i < _group_size; i++) { + for (; i < group_size; i++) { out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; } - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } v += src_stride; weight++; @@ -341,11 +340,11 @@ static void attn_acc_value_block(float* out, src_offset = 0; while (dst_offset < S) { auto v0 = reinterpret_cast(v + src_offset); - for (size_t i = 0; i < _group_size; i++) { + for (size_t i = 0; i < group_size; i++) { out[dst_offset + i] += weight[j] * (v[i + src_offset + params_offset] - v0[1]) * v0[0]; } - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } v += src_stride; } @@ -358,14 +357,13 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, - size_t block_size, - size_t group_size = 0) { + const size_t block_size, + const size_t group_size) { size_t src_offset = 0; size_t dst_offset = 0; - const size_t _group_size = group_size ? group_size : S; const size_t params_offset = sizeof(float) * 2; auto sub_byte_multiplyer = 8 / 4; - const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); + const size_t src_stride = S / group_size * (group_size / sub_byte_multiplyer + params_offset); auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; @@ -380,7 +378,7 @@ static void attn_acc_value_block(float* out, # if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); auto v_zp = _mm512_set1_ps(v0[1]); - for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { + for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepu8_epi32(data); @@ -409,7 +407,7 @@ static void attn_acc_value_block(float* out, # elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); - for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { + for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepu8_epi32(data); @@ -438,15 +436,15 @@ static void attn_acc_value_block(float* out, mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } # endif - for (; i < _group_size; i += 2) { + for (; i < group_size; i += 2) { uint8_t data = v[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; } - dst_offset += _group_size; - src_offset += _group_size / sub_byte_multiplyer + params_offset; + dst_offset += group_size; + src_offset += group_size / sub_byte_multiplyer + params_offset; } v += src_stride; } @@ -459,14 +457,13 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, - size_t block_size, - size_t group_size = 0) { + const size_t block_size, + const size_t group_size) { size_t src_offset = 0; size_t dst_offset = 0; - const size_t _group_size = group_size ? group_size : S; const size_t params_offset = sizeof(float); auto sub_byte_multiplyer = 8 / 4; - const size_t src_stride = S / _group_size * (_group_size / sub_byte_multiplyer + params_offset); + const size_t src_stride = S / group_size * (group_size / sub_byte_multiplyer + params_offset); auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; @@ -481,7 +478,7 @@ static void attn_acc_value_block(float* out, size_t i = 0; # if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { + for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepi8_epi32(data); // cvt to f32 @@ -505,7 +502,7 @@ static void attn_acc_value_block(float* out, } # elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { + for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepi8_epi32(data); @@ -529,7 +526,7 @@ static void attn_acc_value_block(float* out, mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } # endif - for (; i < _group_size; i += 2) { + for (; i < group_size; i += 2) { uint8_t data = v[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; @@ -538,15 +535,15 @@ static void attn_acc_value_block(float* out, out[dst_offset + i] += weight[j] * (tmp0)*v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0]; } - dst_offset += _group_size; - src_offset += _group_size / sub_byte_multiplyer + params_offset; + dst_offset += group_size; + src_offset += group_size / sub_byte_multiplyer + params_offset; } v += src_stride; } } template -static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { +static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -658,15 +655,14 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz } template -static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { +static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) size_t src_offset = 0; size_t dst_offset = 0; - const size_t _group_size = group_size ? group_size : n; const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = n / _group_size * (_group_size + params_offset); + const size_t src_stride = n / group_size * (group_size + params_offset); # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -691,7 +687,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc auto v_zp3 = _mm512_set1_ps(b3[1]); size_t i = 0; uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { auto va = mm512_uni_loadu_ps(a + dst_offset + i); auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), @@ -715,7 +711,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc float group_sum1 = _mm512_reduce_add_ps(vsum1); float group_sum2 = _mm512_reduce_add_ps(vsum2); float group_sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < _group_size; i++) { + for (; i < group_size; i++) { group_sum0 += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); group_sum1 += a[i + dst_offset] * (b_data_ptr[i + src_stride] - b1[1]); group_sum2 += a[i + dst_offset] * (b_data_ptr[i + 2 * src_stride] - b2[1]); @@ -725,8 +721,8 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc sum1 += group_sum1 * b1[0]; sum2 += group_sum2 * b2[0]; sum3 += group_sum3 * b3[0]; - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } c[0] = sum0; c[1] = sum1; @@ -745,7 +741,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc auto v_zp = _mm512_set1_ps(b0[1]); size_t i = 0; uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { auto va = mm512_uni_loadu_ps(a + dst_offset + i); auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), @@ -753,12 +749,12 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc vsum = _mm512_fmadd_ps(va, vb, vsum); } float group_sum = _mm512_reduce_add_ps(vsum); - for (; i < _group_size; i++) { + for (; i < group_size; i++) { group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); } sum += group_sum * b0[0]; - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } b += src_stride; *c++ = sum; @@ -788,7 +784,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc auto v_zp3 = _mm256_set1_ps(b3[1]); size_t i = 0; uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { auto va = mm256_uni_loadu_ps(a + dst_offset + i); auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), @@ -816,7 +812,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc float group_sum1 = _mm256_cvtss_f32(vsum1); float group_sum2 = _mm256_cvtss_f32(vsum2); float group_sum3 = _mm256_cvtss_f32(vsum3); - for (; i < _group_size; i++) { + for (; i < group_size; i++) { group_sum0 += a[dst_offset + i] * (b[i] - b0[1]); group_sum1 += a[dst_offset + i] * (b[i + src_stride] - b1[1]); group_sum2 += a[dst_offset + i] * (b[i + 2 * src_stride] - b2[1]); @@ -826,8 +822,8 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc sum1 += group_sum1 * b1[0]; sum2 += group_sum2 * b2[0]; sum3 += group_sum3 * b3[0]; - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } c[0] = sum0; c[1] = sum1; @@ -846,7 +842,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc auto v_zp = _mm256_set1_ps(b0[1]); size_t i = 0; uint8_t* b_data_ptr = b + src_offset + params_offset; - for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { auto va = mm256_uni_loadu_ps(a + dst_offset + i); auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), @@ -855,12 +851,12 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc } hsum(vsum); float group_sum = _mm256_cvtss_f32(vsum); - for (; i < _group_size; i++) { + for (; i < group_size; i++) { group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); } sum += group_sum * b0[0]; - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } b += src_stride; *c++ = sum; @@ -874,12 +870,12 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc while (dst_offset < n) { auto b0 = reinterpret_cast(b + src_offset); float group_sum = 0.0f; - for (size_t i = 0; i < _group_size; i++) { + for (size_t i = 0; i < group_size; i++) { group_sum += a[dst_offset + i] * (b[src_offset + params_offset + i] - b0[1]); } sum += group_sum * b0[0]; - dst_offset += _group_size; - src_offset += _group_size + params_offset; + dst_offset += group_size; + src_offset += group_size + params_offset; } b += src_stride; *c++ = sum; @@ -936,7 +932,7 @@ void transpose_16NxK(TDST* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + size_t group_size) { size_t k = 0; auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { @@ -966,7 +962,7 @@ static void transpose_16NxK(T* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + const size_t group_size) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); @@ -976,7 +972,8 @@ static void transpose_16NxK(T* dst, N, K >> 1, dst_stride, - src_stride >> 1); + src_stride >> 1, + group_size); } # endif @@ -990,22 +987,21 @@ void transpose_16NxK(TDST* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; // if group_size not set, the whole row is used as a group - size_t _group_size = group_size ? group_size : K; for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, _group_size, f[0], f[1]); - src_offset += _group_size + sizeof(float) * 2; - dst_offset += _group_size; + attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); + src_offset += group_size + sizeof(float) * 2; + dst_offset += group_size; } s += src_offset; t += src_stride; @@ -1016,44 +1012,44 @@ void transpose_16NxK(TDST* dst, N, K, dst_stride, - src_stride); + src_stride, + 0); } // dequant f16/u8 to float template ::type = true> -static inline void dequant(T* dst, void* src, size_t N, size_t K, size_t group_size = 0) { +static inline void dequant(T* dst, void* src, size_t N, size_t K, const size_t group_size) { // never called OPENVINO_THROW("dequant: should not be called."); } template ::type = true> -static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K, size_t group_size = 0) { +static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K, const size_t group_size) { cvt_copy(dst, src, K * N); } template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; - const size_t _group_size = group_size ? group_size : K; - const size_t src_stride = K / _group_size * (_group_size + params_offset); + const size_t src_stride = K / group_size * (group_size + params_offset); for (size_t n = 0; n < N; n++) { size_t group_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + group_offset); - attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); - group_offset += _group_size + params_offset; - dst_offset += _group_size; + attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); + group_offset += group_size + params_offset; + dst_offset += group_size; } s += src_stride; dst += K; @@ -1063,13 +1059,12 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; - const size_t _group_size = group_size ? group_size : K; const size_t sub_byte_mulitplier = 2; for (size_t n = 0; n < N; n++) { @@ -1077,9 +1072,9 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0], f[1]); - src_offset += _group_size / sub_byte_mulitplier + params_offset; - dst_offset += _group_size; + attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); + src_offset += group_size / sub_byte_mulitplier + params_offset; + dst_offset += group_size; } s += src_offset; dst += K; @@ -1089,13 +1084,12 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { +void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float); - const size_t _group_size = group_size ? group_size : K; const size_t sub_byte_mulitplier = 2; for (size_t n = 0; n < N; n++) { @@ -1103,9 +1097,9 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, _group_size, f[0]); - src_offset += _group_size / sub_byte_mulitplier + params_offset; - dst_offset += _group_size; + attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0]); + src_offset += group_size / sub_byte_mulitplier + params_offset; + dst_offset += group_size; } s += src_offset; dst += K; @@ -1191,7 +1185,7 @@ static void pack_32NxK(TDST* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + const size_t group_size) { auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; @@ -1222,27 +1216,26 @@ static void pack_32NxK(TDST* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; // if group_size not set, the whole row is used as a group - size_t _group_size = group_size ? group_size : K; for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, _group_size, f[0], f[1]); - src_offset += _group_size + sizeof(float) * 2; - dst_offset += _group_size; + attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); + src_offset += group_size + sizeof(float) * 2; + dst_offset += group_size; } s += src_offset; t += src_stride; } - pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); } template (s + src_offset); - attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, _group_size, f[0], f[1]); - src_offset += _group_size / sub_byte_mulitplier + sizeof(float) * 2; - dst_offset += _group_size; + attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, group_size, f[0], f[1]); + src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2; + dst_offset += group_size; } s += src_offset; t += src_stride; } - pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, group_size); } template (s + src_offset); - attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, _group_size, f[0]); - src_offset += _group_size / sub_byte_mulitplier + sizeof(float); - dst_offset += _group_size; + attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, group_size, f[0]); + src_offset += group_size / sub_byte_mulitplier + sizeof(float); + dst_offset += group_size; } s += src_offset; t += src_stride; } - pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, group_size); } # endif @@ -1326,7 +1317,7 @@ static void pack_32NxK(TDST* dst, size_t K, size_t dst_stride, size_t src_stride, - size_t group_size = 0) { + size_t group_size) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } @@ -2414,8 +2405,6 @@ struct AttentionExecutor : public PagedAttentionExecutor { auto B_token = q.size(0); auto Hk = k_cache.size(1); - auto _key_group_size = _helper._key_group_size; - auto _value_group_size = _helper._key_group_size; // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The actual size needs to deduct scale and zeropoint. @@ -2426,10 +2415,14 @@ struct AttentionExecutor : public PagedAttentionExecutor { const size_t param_size = one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); const size_t value_params_size = param_size * value_sub_byte_multiplyer; - size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : 1; - size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : 1; + size_t key_group_num = _helper._key_group_size ? k_cache.size(3) / (_helper._key_group_size + key_params_size) : 1; + size_t value_group_num = _helper._value_group_size ? v_cache.size(3) / (_helper._value_group_size + value_params_size) : 1; auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); + // revise group_size if it's zero. + _helper._key_group_size = _helper._key_group_size ? _helper._key_group_size : S; + _helper._value_group_size = _helper._value_group_size ? _helper._value_group_size : SV; + printf("key_group_size %ld, value_group_size %ld S %ld V %ld\n", _helper._key_group_size, _helper._value_group_size, S, SV); auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; From b4b0f0d5899e13130730e9be7c502f6c51140eb5 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 18 Dec 2024 12:24:00 +0800 Subject: [PATCH 18/28] [CPU]Use ov::element as template args Signed-off-by: Zhang Yi3 --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 1 + .../kernels/scaled_attn/attn_quant_kernel.hpp | 1 - .../nodes/kernels/scaled_attn/executor_pa.cpp | 453 ++++++++---------- .../intel_cpu/src/nodes/paged_attn.cpp | 2 +- 4 files changed, 214 insertions(+), 243 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index fb6dc8439ac9bf..4e751d49705486 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -627,6 +627,7 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, } else if (k_src.get_precision() == ov::element::bf16) { funcs_bf16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else if (k_src.get_precision() == ov::element::f16) { + printf("quantize with f16\n"); funcs_f16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 9bb1da2ccc4ec8..acde50ade5cf7b 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -66,7 +66,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale size_t i = 0; uint8_t* src_nc = const_cast(src); #if defined(HAVE_AVX512F) - auto v_zp = _mm512_set1_ps(zp); auto v_scale = _mm512_set1_ps(scale); auto v_zp_scale = _mm512_set1_ps(zp * scale); for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 3e78f645b9c95a..e73875d4804500 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -72,10 +72,22 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } } +size_t inline get_sub_byte_multiplier(ov::element::Type type) { + return one_of(type, ov::element::i4, ov::element::u4) ? 8 / type.bitwidth() : 1; +} + template ::type = true> -static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size) { + typename std::enable_if<(std::is_same::value || std::is_same::value || + std::is_same::value) && + (SRC_PREC != ov::element::u8 || SRC_PREC != ov::element::u4), + bool>::type = true> +static void attn_acc_value_block(float* out, + float* weight, + T* v, + const size_t S, + const size_t block_size, + const size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -208,7 +220,7 @@ template ::type = true> static void attn_acc_value_block(float* out, float* weight, - uint8_t* v, - size_t S, + void* v, + const size_t S, const size_t block_size, const size_t group_size) { size_t src_offset = 0; size_t dst_offset = 0; const size_t params_offset = sizeof(float) * 2; - auto sub_byte_multiplyer = 8 / 4; - const size_t src_stride = S / group_size * (group_size / sub_byte_multiplyer + params_offset); + uint8_t* v_ptr = reinterpret_cast(v); + auto sub_byte_multiplier = 8 / 4; + const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; @@ -373,13 +386,13 @@ static void attn_acc_value_block(float* out, dst_offset = 0; src_offset = 0; while (dst_offset < S) { - auto v0 = reinterpret_cast(v + src_offset); + auto v0 = reinterpret_cast(v_ptr + src_offset); size_t i = 0; # if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); auto v_zp = _mm512_set1_ps(v0[1]); for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepu8_epi32(data); auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); @@ -408,7 +421,7 @@ static void attn_acc_value_block(float* out, auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepu8_epi32(data); auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); @@ -437,16 +450,16 @@ static void attn_acc_value_block(float* out, } # endif for (; i < group_size; i += 2) { - uint8_t data = v[i / 2 + src_offset + params_offset]; + uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; } dst_offset += group_size; - src_offset += group_size / sub_byte_multiplyer + params_offset; + src_offset += group_size / sub_byte_multiplier + params_offset; } - v += src_stride; + v_ptr += src_stride; } } @@ -455,15 +468,16 @@ template ::type = true> static void attn_acc_value_block(float* out, float* weight, - uint8_t* v, - size_t S, + void* v, + const size_t S, const size_t block_size, const size_t group_size) { size_t src_offset = 0; size_t dst_offset = 0; const size_t params_offset = sizeof(float); - auto sub_byte_multiplyer = 8 / 4; - const size_t src_stride = S / group_size * (group_size / sub_byte_multiplyer + params_offset); + auto sub_byte_multiplier = 8 / 4; + uint8_t* v_ptr = reinterpret_cast(v); + const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; @@ -474,12 +488,12 @@ static void attn_acc_value_block(float* out, dst_offset = 0; src_offset = 0; while (dst_offset < S) { - auto v0 = reinterpret_cast(v + src_offset); + auto v0 = reinterpret_cast(v_ptr + src_offset); size_t i = 0; # if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepi8_epi32(data); // cvt to f32 auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); @@ -503,7 +517,7 @@ static void attn_acc_value_block(float* out, # elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepi8_epi32(data); auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); @@ -527,7 +541,7 @@ static void attn_acc_value_block(float* out, } # endif for (; i < group_size; i += 2) { - uint8_t data = v[i / 2 + src_offset + params_offset]; + uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); @@ -536,14 +550,19 @@ static void attn_acc_value_block(float* out, out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0]; } dst_offset += group_size; - src_offset += group_size / sub_byte_multiplyer + params_offset; + src_offset += group_size / sub_byte_multiplier + params_offset; } - v += src_stride; + v_ptr += src_stride; } } template -static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size) { +static void dot_product_block(TA* a, + TB* b, + float* c, + const size_t n, + const size_t block_size, + const size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -655,7 +674,12 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz } template -static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size, const size_t group_size) { +static void dot_product_block(TA* a, + uint8_t* b, + float* c, + const size_t n, + const size_t block_size, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) @@ -928,11 +952,11 @@ template ::value_type*>(src); for (; k + 16 <= K; k += 16) { @@ -958,10 +982,10 @@ template (src); @@ -983,10 +1007,10 @@ template ::value>(dst, - tmp, - reinterpret_cast(0), - N, - K, - dst_stride, - src_stride, - 0); + transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); } // dequant f16/u8 to float template ::type = true> -static inline void dequant(T* dst, void* src, size_t N, size_t K, const size_t group_size) { +static inline void dequant(T* dst, void* src, const size_t N, const size_t K, const size_t group_size) { // never called OPENVINO_THROW("dequant: should not be called."); } template ::type = true> -static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K, const size_t group_size) { +static inline void dequant(float* dst, ov::float16* src, const size_t N, const size_t K, const size_t group_size) { cvt_copy(dst, src, K * N); } template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { +void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) @@ -1059,7 +1077,7 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_siz template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { +void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) @@ -1084,7 +1102,7 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_siz template ::type = true> -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, const size_t group_size) { +void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) @@ -1181,10 +1199,10 @@ template ::value_type*>(src); for (size_t n = 0; n < N; n += 32) { @@ -1212,10 +1230,10 @@ template ::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, group_size); + pack_32NxK::value>(dst, + tmp, + reinterpret_cast(0), + N, + K, + dst_stride, + src_stride, + group_size); } template ::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, group_size); + pack_32NxK::value>(dst, + tmp, + reinterpret_cast(0), + N, + K, + dst_stride, + src_stride, + group_size); } # endif @@ -1313,16 +1345,16 @@ template +template struct MHAHelper { // initialize once size_t _H; @@ -1458,13 +1490,11 @@ struct MHAHelper { if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && precision_of::value == ov::element::bf16 && - precision_of::value == ov::element::bf16 && - precision_of::value == ov::element::bf16) { + precision_of::value == ov::element::bf16 && VALUE_PREC == ov::element::bf16) { _fastpath_valid_prec = ov::element::bf16; } else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_fp16) && precision_of::value == ov::element::f16 && - precision_of::value == ov::element::f16 && - precision_of::value == ov::element::bf16) { + precision_of::value == ov::element::f16 && VALUE_PREC == ov::element::f16) { _fastpath_valid_prec = ov::element::f16; } } @@ -1533,7 +1563,7 @@ struct MHAHelper { auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); for (size_t h = hq_beg; h < hq_end; h++) { auto* q_ptr = query.ptr(h, q_start, 0); @@ -1731,42 +1761,20 @@ struct MHAHelper { memset(_output.ptr(ithr), 0, q_len * _H * _SV * sizeof(float)); for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { auto block_number = block_table[i]; - auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - if (present_value.get_precision() == ov::element::u4) { - auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / - sub_byte_multiplyer; - auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block(_output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); - } else if (present_value.get_precision() == ov::element::i4) { - auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / - sub_byte_multiplyer; - auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block(_output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); - } else { - attn_acc_value_block::value>( - _output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); - } + auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * + present_value.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = reinterpret_cast::value_type*>( + present_value.m_ptr.get() + v_stride); + attn_acc_value_block::value_type, VALUE_PREC>( + _output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } } } @@ -1923,44 +1931,21 @@ struct MHAHelper { // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; - auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - if (present_value.get_precision() == ov::element::u4) { - auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / - sub_byte_multiplyer; - auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block( - _output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, context_len - pv), - _value_group_size); - } else if (present_value.get_precision() == ov::element::i4) { - auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / - sub_byte_multiplyer; - auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block( - _output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, context_len - pv), - _value_group_size); - } else { - attn_acc_value_block::value>( - _output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v, - _SV, - std::min(_block_size, context_len - pv), - _value_group_size); - } + auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * + present_value.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = reinterpret_cast::value_type*>( + present_value.m_ptr.get() + v_stride); + attn_acc_value_block::value_type, VALUE_PREC>( + _output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, context_len - pv), + _value_group_size); } } } @@ -1981,9 +1966,9 @@ struct MHAHelper { } }; -template +template struct MHA { - MHAHelper& _helper; + MHAHelper& _helper; struct AttnWorkItem { int32_t batch_in_reorder; // which batch in reorder buffer will be used int32_t batch_in_seq; // batch idx in sequence @@ -2083,7 +2068,7 @@ struct MHA { WorkItems _workitems; - MHA(MHAHelper& helper) : _helper(helper) {} + MHA(MHAHelper& helper) : _helper(helper) {} // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, @@ -2121,7 +2106,6 @@ struct MHA { auto ithr = parallel_get_thread_num(); auto* k_ptr = k_cache.ptr(block_number, hk); - auto* v_ptr = v_cache.ptr(block_number, hk); transpose_16NxK::value>( _helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), @@ -2134,79 +2118,33 @@ struct MHA { _helper._key_group_size); if (q_is_xf16) { - if (v_cache.get_precision() == ov::element::u4) { - auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = - (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; - auto* v_ptr = v_cache.m_ptr.get() + v_stride; - pack_32NxK( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); - } else if (v_cache.get_precision() == ov::element::i4) { - auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = - (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * + v_cache.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + pack_32NxK( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); + } else { + // need to decompress + if (!q_cache_is_same) { + auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * + v_cache.get_precision().size() / sub_byte_multiplier; auto* v_ptr = v_cache.m_ptr.get() + v_stride; - pack_32NxK( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); - } else { - pack_32NxK::value>( + dequant( _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), v_ptr, - _helper._output.template ptr(ithr), _helper._block_size, _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, _helper._value_group_size); } - } else { - // need to decompress - if (!q_cache_is_same) { - if (v_cache.get_precision() == ov::element::u4) { - auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = - (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; - auto* v_ptr = v_cache.m_ptr.get() + v_stride; - dequant( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._block_size, - _helper._SV, - _helper._value_group_size); - } else if (v_cache.get_precision() == ov::element::i4) { - auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = - (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; - auto* v_ptr = v_cache.m_ptr.get() + v_stride; - dequant( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._block_size, - _helper._SV, - _helper._value_group_size); - } else { - dequant::value>( - _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._block_size, - _helper._SV, - _helper._value_group_size); - } - } } }); @@ -2356,16 +2294,16 @@ struct MHA { } }; -template +template struct AttentionExecutor : public PagedAttentionExecutor { - MHAHelper _helper; - MHA _kernel; + MHAHelper _helper; + MHA _kernel; PlainTensor _slot_mapping; AttentionExecutor() : _kernel(_helper) {} explicit AttentionExecutor(size_t key_group_size, size_t value_group_size) - : _helper(MHAHelper(key_group_size, value_group_size)), + : _helper(MHAHelper(key_group_size, value_group_size)), _kernel(_helper) {} void init(const std::vector& inputs, @@ -2408,21 +2346,22 @@ struct AttentionExecutor : public PagedAttentionExecutor { // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The actual size needs to deduct scale and zeropoint. - const size_t key_sub_byte_multiplyer = 8 / k_cache.get_precision().bitwidth(); - const size_t value_sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplyer; + const size_t key_sub_byte_multiplier = get_sub_byte_multiplier(k_cache.get_precision()); + const size_t value_sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplier; // u4 needs scale + zp. s4 needs scale. const size_t param_size = one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); - const size_t value_params_size = param_size * value_sub_byte_multiplyer; - size_t key_group_num = _helper._key_group_size ? k_cache.size(3) / (_helper._key_group_size + key_params_size) : 1; - size_t value_group_num = _helper._value_group_size ? v_cache.size(3) / (_helper._value_group_size + value_params_size) : 1; + const size_t value_params_size = param_size * value_sub_byte_multiplier; + size_t key_group_num = + _helper._key_group_size ? k_cache.size(3) / (_helper._key_group_size + key_params_size) : 1; + size_t value_group_num = + _helper._value_group_size ? v_cache.size(3) / (_helper._value_group_size + value_params_size) : 1; auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); // revise group_size if it's zero. _helper._key_group_size = _helper._key_group_size ? _helper._key_group_size : S; _helper._value_group_size = _helper._value_group_size ? _helper._value_group_size : SV; - printf("key_group_size %ld, value_group_size %ld S %ld V %ld\n", _helper._key_group_size, _helper._value_group_size, S, SV); auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; @@ -2557,11 +2496,23 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ if (data_type == ov::element::bf16) { # if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = - std::make_shared>(key_group_size, value_group_size); + if (value_cache_type == ov::element::u4) { + executor = + std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = + std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } + } else { OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); - executor = std::make_shared>(); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); @@ -2569,24 +2520,44 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ } else if (data_type == ov::element::f16) { # if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = - std::make_shared>(key_group_size, value_group_size); + if (value_cache_type == ov::element::u4) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } } else { OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); - executor = std::make_shared>(); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); # endif } else if (data_type == ov::element::f32) { if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(key_group_size, value_group_size); + if (value_cache_type == ov::element::u4) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } } else if (key_cache_type == ov::element::f16) { - executor = - std::make_shared>(key_group_size, value_group_size); + executor = std::make_shared>(key_group_size, + value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); - executor = std::make_shared>(key_group_size, value_group_size); + executor = + std::make_shared>(key_group_size, value_group_size); } } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 04fd44c20508f8..f10608c02250e1 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -144,7 +144,7 @@ void PagedAttention::createPrimitive() { // Since we are quantize only last dim it's safe to use the last dim of KV. auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); - const auto cpuConfig = context->getConfig(); + const auto& cpuConfig = context->getConfig(); size_t key_group_size = cpuConfig.keyCacheGroupSize; size_t value_group_size = cpuConfig.valueCacheGroupSize; From 5c838f7ae4006b267ad52df4346797a9487dfdba Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Wed, 18 Dec 2024 15:47:49 +0800 Subject: [PATCH 19/28] [CPU]remove redundant marco --- .../src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp | 4 ++-- .../intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index acde50ade5cf7b..97a7d53a2efa05 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -88,7 +88,7 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale mm512_uni_storeu_ps(dst + i, first_half); mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } -#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#elif defined(HAVE_AVX2) auto v256_zp = _mm256_set1_ps(zp); auto v256_scale = _mm256_set1_ps(scale); for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { @@ -171,7 +171,7 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); } -#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#elif defined(HAVE_AVX2) for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { auto v256_scale = _mm256_set1_ps(scale); auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index e73875d4804500..587c535cc1fb05 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -417,7 +417,7 @@ static void attn_acc_value_block(float* out, mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } -# elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# elif defined(HAVE_AVX2) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { @@ -514,7 +514,7 @@ static void attn_acc_value_block(float* out, mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } -# elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# elif defined(HAVE_AVX2) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); From f03e23c76d8b8f1b373a32575eac46ab57d9c917 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 19 Dec 2024 09:31:17 +0800 Subject: [PATCH 20/28] apply review comments --- .../intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 4e751d49705486..fb6dc8439ac9bf 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -627,7 +627,6 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, } else if (k_src.get_precision() == ov::element::bf16) { funcs_bf16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } else if (k_src.get_precision() == ov::element::f16) { - printf("quantize with f16\n"); funcs_f16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); } } From c362399b8bbf7c8346dd8b1474b3e5980c4269af Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 3 Jan 2025 13:40:59 +0800 Subject: [PATCH 21/28] [CPU]apply review comments Signed-off-by: Zhang Yi --- src/plugins/intel_cpu/src/config.cpp | 57 +++--- .../nodes/kernels/scaled_attn/attn_quant.cpp | 163 ++++-------------- .../intel_cpu/src/nodes/paged_attn.cpp | 9 +- .../ov_executable_network/properties.cpp | 30 ++++ 4 files changed, 98 insertions(+), 161 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 5b0d677d11b54f..b16e270b984fca 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -373,43 +373,42 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } - } else if (key == ov::hint::key_cache_precision.name() || key == ov::hint::value_cache_precision.name()) { + } else if (key == ov::hint::key_cache_precision.name()) { try { kvCachePrecisionSetExplicitly = true; auto const prec = val.as(); - if (key == ov::hint::key_cache_precision.name()) { - if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { - keyCachePrecision = prec; - } else { - OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); - } + if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { + keyCachePrecision = prec; } else { - if (one_of(prec, - ov::element::f32, - ov::element::f16, - ov::element::bf16, - ov::element::u8, - ov::element::u4, - ov::element::i4)) { - valueCachePrecision = prec; - } else { - OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); - } + OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); } } catch (ov::Exception&) { - if (key == ov::hint::key_cache_precision.name()) { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::hint::key_cache_precision.name(), - ". Supported values: u8, bf16, f16, f32"); + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::key_cache_precision.name(), + ". Supported values: u8, bf16, f16, f32"); + } + } else if (key == ov::hint::value_cache_precision.name()) { + try { + kvCachePrecisionSetExplicitly = true; + auto const prec = val.as(); + if (one_of(prec, + ov::element::f32, + ov::element::f16, + ov::element::bf16, + ov::element::u8, + ov::element::u4)) { + valueCachePrecision = prec; } else { - OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - ov::hint::value_cache_precision.name(), - ". Supported values: u4, s4, u8, bf16, f16, f32"); + OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::hint::value_cache_precision.name(), + ". Supported values: u4, s4, u8, bf16, f16, f32"); } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index fb6dc8439ac9bf..40ed27bf73ea97 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -218,7 +218,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#if defined(HAVE_AVX2) auto v256_zero = _mm256_set1_epi32(0); auto v256_upper = _mm256_set1_epi32(15); auto v256_scale = _mm256_set1_ps(1 / scale); @@ -273,7 +273,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } template -static void quant_s4(const T* src, void* dst, size_t n, float& scale) { +static void quant_i4(const T* src, void* dst, size_t n, float& scale) { auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; if (high_half) @@ -318,7 +318,7 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +#if defined(HAVE_AVX2) auto v256_lower = _mm256_set1_epi32(-8); auto v256_upper = _mm256_set1_epi32(7); auto v256_scale = _mm256_set1_ps(1 / scale); @@ -372,6 +372,27 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { } } +template ::type = true> +static void quantize(const T* src, uint8_t* dst, size_t n, float* scale_zp) { + quant_u8(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + +template ::type = true> +static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { + quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + +template ::type = true> +static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { + quant_i4(src, dst, n, *scale_zp); +} + template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -389,10 +410,7 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template ::type = true> +template static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, @@ -402,6 +420,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; + size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) @@ -418,76 +437,15 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, h, block_offset, dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); - } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += value_group_size, dst_offset += value_group_size + sizeof(float) + sizeof(float)) { - auto p_v = reinterpret_cast( - v_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset)); - quant_u8(v_src.ptr(b, h, m, src_offset), - v_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - value_group_size, - p_v[0], - p_v[1]); - } - }); -} - -template ::type = true> -static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, - const ov::intel_cpu::PlainTensor& v_src, - const ov::intel_cpu::PlainTensor& k_dst, - const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { - size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; - size_t block_size = k_dst.m_dims[2]; - size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); - parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { - auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) - return; - auto block_number = slot / block_size; - auto block_offset = slot % block_size; - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast( + quantize( + k_src.ptr(b, h, m, src_offset), k_dst.ptr::value_type>(block_number, h, block_offset, - dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); + dst_offset) + + sizeof(float) + sizeof(float), + key_group_size, + p_k); } for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size, @@ -499,62 +457,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, dst_offset); auto p_v = reinterpret_cast(v_base); uint8_t* v_ptr = v_base + sizeof(float) * 2; - quant_u4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0], p_v[1]); - } - }); -} - -template ::type = true> -static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, - const ov::intel_cpu::PlainTensor& v_src, - const ov::intel_cpu::PlainTensor& k_dst, - const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping, - const size_t key_group_size, - const size_t value_group_size) { - size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; - size_t block_size = k_dst.m_dims[2]; - size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); - parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { - auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) - return; - auto block_number = slot / block_size; - auto block_offset = slot % block_size; - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| - for (size_t src_offset = 0, dst_offset = 0; src_offset < S; - src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { - auto p_k = reinterpret_cast( - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset)); - quant_u8(k_src.ptr(b, h, m, src_offset), - k_dst.ptr::value_type>(block_number, - h, - block_offset, - dst_offset) + - sizeof(float) + sizeof(float), - key_group_size, - p_k[0], - p_k[1]); - } - - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; - src_offset += value_group_size, dst_offset += value_group_size / sub_byte_multiplier + sizeof(float)) { - uint8_t* v_base = reinterpret_cast( - v_dst.m_ptr.get() + - (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / - sub_byte_multiplier + - dst_offset); - auto p_v = reinterpret_cast(v_base); - uint8_t* v_ptr = v_base + sizeof(float); - quant_s4(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v[0]); + quantize(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v); } }); } diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 209183f0060ad3..54aa80e9dff7c0 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -211,8 +211,13 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr& try { auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); - if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) { - if (kCachePrecision != ov::element::u8) { + if (one_of(vCachePrecision, + ov::element::u4, + ov::element::u8, + ov::element::f32, + ov::element::f16, + ov::element::bf16)) { + if (!one_of(kCachePrecision, ov::element::u8, ov::element::f16, ov::element::f32, ov::element::bf16)) { errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); return false; diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 59fd31cdb34303..016648a7e1026f 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -187,6 +187,36 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) { ASSERT_EQ(kv_cache_precision_value, ov::element::f32); } +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { + ov::Core core; + + core.set_property(deviceName, ov::hint::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::hint::value_cache_precision(ov::element::u4)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_precision_value = ov::element::undefined; + auto value_cache_precision_value = ov::element::undefined; + OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::hint::key_cache_precision)); + OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::hint::value_cache_precision)); + ASSERT_EQ(key_cache_precision_value, ov::element::f16); + ASSERT_EQ(value_cache_precision_value, ov::element::u4); +} + +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) { + ov::Core core; + + core.set_property(deviceName, ov::hint::key_cache_group_size(32)); + core.set_property(deviceName, ov::hint::value_cache_group_size(16)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_group_size_value = 0; + auto value_cache_group_size_value = 0; + OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::hint::key_cache_group_size)); + OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::hint::value_cache_group_size)); + ASSERT_EQ(key_cache_group_size_value, 32); + ASSERT_EQ(value_cache_group_size_value, 16); +} + TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantizationGroupSize) { ov::Core core; From 28bcf7b482aa9cbbb501686cd7117513895cdf94 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Fri, 3 Jan 2025 16:09:22 +0800 Subject: [PATCH 22/28] [CPU]remove useless code of s4 Signed-off-by: Zhang Yi --- src/plugins/intel_cpu/src/config.cpp | 2 +- .../nodes/kernels/scaled_attn/attn_quant.cpp | 110 -------------- .../nodes/kernels/scaled_attn/executor_pa.cpp | 136 +----------------- 3 files changed, 2 insertions(+), 246 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index b16e270b984fca..3b052e7094d34c 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -408,7 +408,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { val.as(), " for property key ", ov::hint::value_cache_precision.name(), - ". Supported values: u4, s4, u8, bf16, f16, f32"); + ". Supported values: u4, u8, bf16, f16, f32"); } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 40ed27bf73ea97..e25b204e670218 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -272,106 +272,6 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } } -template -static void quant_i4(const T* src, void* dst, size_t n, float& scale) { - auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - if (high_half) - val &= 0x0F; - return dst | (uint8_t)(val << shift); - }; - auto dst_ptr = reinterpret_cast(dst); - size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; - find_minmax(src, n, min, max); - float max_abs = std::max(std::abs(min), std::abs(max)); - scale = max_abs / ((1 << 3) - 1); - if (scale == 0) - scale = 0.0001f; -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(1 / scale); - auto v_upper = _mm512_set1_epi32(7); - auto v_lower = _mm512_set1_epi32(-8); - for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { - auto v0 = mm512_uni_loadu_ps(src + i); - auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); - v0 = _mm512_mul_ps(v0, v_scale); - v1 = _mm512_mul_ps(v1, v_scale); - auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - - v0_i32 = _mm512_max_epi32(v0_i32, v_lower); - v1_i32 = _mm512_max_epi32(v1_i32, v_lower); - v0_i32 = _mm512_min_epi32(v0_i32, v_upper); - v1_i32 = _mm512_min_epi32(v1_i32, v_upper); - - __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); - __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); - auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); - - auto mask = _mm512_set1_epi32(0x0F); - second_half = _mm512_and_epi32(second_half, mask); - first_half = _mm512_slli_epi32(first_half, 4); - auto combined = _mm512_or_epi32(first_half, second_half); - _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); - } -#endif -#if defined(HAVE_AVX2) - auto v256_lower = _mm256_set1_epi32(-8); - auto v256_upper = _mm256_set1_epi32(7); - auto v256_scale = _mm256_set1_ps(1 / scale); - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto v0 = mm256_uni_loadu_ps(src + i); - auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); - v0 = _mm256_mul_ps(v0, v256_scale); - v1 = _mm256_mul_ps(v1, v256_scale); - v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); - v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); - - auto v0_i32 = _mm256_cvtps_epi32(v0); - auto v1_i32 = _mm256_cvtps_epi32(v1); - v0_i32 = _mm256_max_epi32(v0_i32, v256_lower); - v1_i32 = _mm256_max_epi32(v1_i32, v256_lower); - v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); - v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); - auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); - v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); - - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - // _mm256_permutevar8x32_epi32 - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permute2x128_si256 - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // shift + mask + or - // [0,1],[2,3], ..., [12,13], [14,15] - auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); - auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); - first_half = _mm256_slli_epi32(first_half, 4); - auto mask = _mm256_set1_epi32(0x0F); - second_half = _mm256_and_si256(second_half, mask); - auto combined = _mm256_or_si256(first_half, second_half); - - auto high4 = _mm256_extractf128_si256(combined, 1); - auto low4 = _mm256_castsi256_si128(combined); - // keep sign bit for s4 case - auto packed = _mm_packs_epi32(low4, high4); - packed = _mm_packs_epi16(packed, packed); - _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); - } -#endif - for (; i < n; i++) { - float tmp = src[i]; - int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale)); - src_val = std::max((int8_t)(-8), src_val); - uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; - dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); - dst_ptr[i / 2] = dst_val; - } -} - template ::type = true> @@ -386,13 +286,6 @@ static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); } -template ::type = true> -static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { - quant_i4(src, dst, n, *scale_zp); -} - template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -500,17 +393,14 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, static constexpr function_type funcs_fp32[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; static constexpr function_type funcs_bf16[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; static constexpr function_type funcs_f16[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; if (k_dst.get_precision() != ov::element::u8) { OPENVINO_THROW("unsupport src type: ", diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 587c535cc1fb05..bd659cb1b164f7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -463,99 +463,6 @@ static void attn_acc_value_block(float* out, } } -template ::type = true> -static void attn_acc_value_block(float* out, - float* weight, - void* v, - const size_t S, - const size_t block_size, - const size_t group_size) { - size_t src_offset = 0; - size_t dst_offset = 0; - const size_t params_offset = sizeof(float); - auto sub_byte_multiplier = 8 / 4; - uint8_t* v_ptr = reinterpret_cast(v); - const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); - auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - - return (uint8_t)((val >> shift) & 0x000F); - }; - - for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v_ptr + src_offset); - size_t i = 0; -# if defined(HAVE_AVX512F) - auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - auto v_i32 = _mm512_cvtepi8_epi32(data); - // cvt to f32 - auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); - auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); - v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); - - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); - - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); - v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); - v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); - mm512_uni_storeu_ps(out + dst_offset + i, v_out0); - mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); - } -# elif defined(HAVE_AVX2) - auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - - auto v_i32 = _mm256_cvtepi8_epi32(data); - auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); - v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - - auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); - v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); - mm256_uni_storeu_ps(out + dst_offset + i, v_out0); - mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); - } -# endif - for (; i < group_size; i += 2) { - uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, static_cast(i % 2)); - tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; - float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); - tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; - out[dst_offset + i] += weight[j] * (tmp0)*v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0]; - } - dst_offset += group_size; - src_offset += group_size / sub_byte_multiplier + params_offset; - } - v_ptr += src_stride; - } -} - template static void dot_product_block(TA* a, TB* b, @@ -1296,47 +1203,6 @@ static void pack_32NxK(TDST* dst, src_stride, group_size); } - -template ::value != ov::element::f32 && (SRC_PREC == ov::element::i4), - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast(src); - auto t = tmp; - // if group_size not set, the whole row is used as a group - const size_t sub_byte_mulitplier = 2; - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, group_size, f[0]); - src_offset += group_size / sub_byte_mulitplier + sizeof(float); - dst_offset += group_size; - } - s += src_offset; - t += src_stride; - } - pack_32NxK::value>(dst, - tmp, - reinterpret_cast(0), - N, - K, - dst_stride, - src_stride, - group_size); -} # endif template ::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size(); From 56245d0828876f213b348016fe9a6070ebc0bb2e Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Sun, 5 Jan 2025 13:23:43 +0800 Subject: [PATCH 23/28] [CPU]Unify u8/u4 dequant kernel with template arg Signed-off-by: Zhang Yi --- .../nodes/kernels/scaled_attn/attn_quant.cpp | 2 +- .../kernels/scaled_attn/attn_quant_kernel.hpp | 92 ++------------- .../nodes/kernels/scaled_attn/executor_pa.cpp | 110 ++++-------------- 3 files changed, 29 insertions(+), 175 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index e25b204e670218..26282a70fcb512 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -429,7 +429,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float } void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) { - attn_dequant_u8_kernel(src, dst, n, scale, zp); + attn_dequant_kernel(src, dst, n, scale, zp); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 97a7d53a2efa05..761a136eda2997 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -17,8 +17,10 @@ namespace Extensions { namespace Cpu { namespace XARCH { -template -void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { size_t i = 0; // loadu_si128/epi64 does not support const qualifier uint8_t* src_nc = const_cast(src); @@ -52,8 +54,10 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } -template -void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { // 2 4bit data form a byte /* 0,1|2,3|4,5|6,7 / \ @@ -134,86 +138,6 @@ void attn_dequant_u4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } -template -void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale) { - // 2 4bit data form a byte - /* 0,1|2,3|4,5|6,7 - / \ - 0,2,4,6|1,3,5,7 - | - permute - | - 0,1,2,3,4,5,6,7 - */ - size_t i = 0; - uint8_t* src_nc = const_cast(src); -#if defined(HAVE_AVX512F) - for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { - auto v_scale = _mm512_set1_ps(scale); - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); - // cvt to f32 - auto v_i32 = _mm512_cvtepi8_epi32(data); - - auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); - v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); - // q * scale - v_f32_low_half = _mm512_mul_ps(v_f32_low_half, v_scale); - v_f32_high_half = _mm512_mul_ps(v_f32_high_half, v_scale); - - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - mm512_uni_storeu_ps(dst + i, first_half); - mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); - } - -#elif defined(HAVE_AVX2) - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto v256_scale = _mm256_set1_ps(scale); - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); - - auto v_i32 = _mm256_cvtepi8_epi32(data); - auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); - v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - - // q * scale - v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); - v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); - - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // _mm256_permute2f128_ps - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permutevar8x32_ps - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - mm256_uni_storeu_ps(dst + i, first_half); - mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); - } -#endif - auto extract_half_byte = [&](uint8_t val, bool high_half) -> int8_t { - uint8_t shift = high_half ? 0 : 4; - return static_cast((val >> shift) & 0x000F); - }; - for (; i < n; ++i) { - float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); - tmp = tmp > 8 ? (tmp - 16) : tmp; - tmp = tmp * scale; - dst[i] = tmp; - } -} - } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index bd659cb1b164f7..955e7687ef97b3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -930,7 +930,11 @@ void transpose_16NxK(TDST* dst, size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); + attn_dequant_kernel(s + src_offset + sizeof(float) * 2, + t + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size + sizeof(float) * 2; dst_offset += group_size; } @@ -958,71 +962,25 @@ static inline void dequant(float* dst, ov::float16* src, const size_t N, const s template ::type = true> -void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; - const size_t params_offset = sizeof(float) * 2; - const size_t src_stride = K / group_size * (group_size + params_offset); - - for (size_t n = 0; n < N; n++) { - size_t group_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + group_offset); - attn_dequant_u8_kernel(s + group_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); - group_offset += group_size + params_offset; - dst_offset += group_size; - } - s += src_stride; - dst += K; - } -} - -template ::type = true> + typename std::enable_if::type = true> void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; - const size_t sub_byte_mulitplier = 2; - - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_u4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0], f[1]); - src_offset += group_size / sub_byte_mulitplier + params_offset; - dst_offset += group_size; - } - s += src_offset; - dst += K; - } -} - -template ::type = true> -void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; - const size_t params_offset = sizeof(float); - const size_t sub_byte_mulitplier = 2; + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_s4_kernel(s + src_offset + params_offset, dst + dst_offset, group_size, f[0]); + attn_dequant_kernel(s + src_offset + params_offset, + dst + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size / sub_byte_mulitplier + params_offset; dst_offset += group_size; } @@ -1132,40 +1090,8 @@ static void pack_32NxK(TDST* dst, template ::value != ov::element::f32 && SRC_PREC == ov::element::u8, - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast::value_type*>(src); - auto t = tmp; - // if group_size not set, the whole row is used as a group - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_u8_kernel(s + src_offset + sizeof(float) * 2, t + dst_offset, group_size, f[0], f[1]); - src_offset += group_size + sizeof(float) * 2; - dst_offset += group_size; - } - s += src_offset; - t += src_stride; - } - pack_32NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); -} - -template ::value != ov::element::f32 && (SRC_PREC == ov::element::u4), + typename std::enable_if::value != ov::element::f32 && + (SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8), bool>::type = true> static void pack_32NxK(TDST* dst, void* src, @@ -1181,13 +1107,17 @@ static void pack_32NxK(TDST* dst, auto s = reinterpret_cast(src); auto t = tmp; // if group_size not set, the whole row is used as a group - const size_t sub_byte_mulitplier = 2; + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { auto f = reinterpret_cast(s + src_offset); - attn_dequant_u4_kernel(s + (src_offset + sizeof(float) * 2), t + dst_offset, group_size, f[0], f[1]); + attn_dequant_kernel(s + (src_offset + sizeof(float) * 2), + t + dst_offset, + group_size, + f[0], + f[1]); src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2; dst_offset += group_size; } From 84f03a3fac7967bb3f71b32cb5f9deee46ec7d59 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Mon, 6 Jan 2025 11:03:47 +0800 Subject: [PATCH 24/28] [CPU]Define key/value cache prec/group_size priority Signed-off-by: Zhang Yi --- .../openvino/runtime/properties/__init__.py | 4 + .../runtime/properties/hint/__init__.py | 4 - .../pyopenvino/core/properties/properties.cpp | 8 +- .../tests/test_runtime/test_properties.py | 24 ++-- .../include/openvino/runtime/properties.hpp | 48 +++---- src/plugins/intel_cpu/src/compiled_model.cpp | 24 ++-- src/plugins/intel_cpu/src/config.cpp | 48 +++++-- src/plugins/intel_cpu/src/config.h | 4 + .../intel_cpu/src/nodes/scaled_attn.cpp | 19 ++- src/plugins/intel_cpu/src/plugin.cpp | 24 ++-- .../ov_executable_network/properties.cpp | 128 +++++++++++++----- .../custom/behavior/ov_plugin/properties.cpp | 8 +- 12 files changed, 225 insertions(+), 118 deletions(-) diff --git a/src/bindings/python/src/openvino/runtime/properties/__init__.py b/src/bindings/python/src/openvino/runtime/properties/__init__.py index 3269ea42e32ac2..a02a18e556135b 100644 --- a/src/bindings/python/src/openvino/runtime/properties/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/__init__.py @@ -30,6 +30,10 @@ from openvino._pyopenvino.properties import loaded_from_cache from openvino._pyopenvino.properties import cache_encryption_callbacks from openvino._pyopenvino.properties import weights_path +from openvino._pyopenvino.properties import key_cache_precision +from openvino._pyopenvino.properties import value_cache_precision +from openvino._pyopenvino.properties import key_cache_group_size +from openvino._pyopenvino.properties import value_cache_group_size # Submodules from openvino.runtime.properties import hint diff --git a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py index d5c5d5595e5e0b..d1dce289d09941 100644 --- a/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/hint/__init__.py @@ -23,8 +23,4 @@ from openvino._pyopenvino.properties.hint import allow_auto_batching from openvino._pyopenvino.properties.hint import dynamic_quantization_group_size from openvino._pyopenvino.properties.hint import kv_cache_precision -from openvino._pyopenvino.properties.hint import key_cache_precision -from openvino._pyopenvino.properties.hint import value_cache_precision -from openvino._pyopenvino.properties.hint import key_cache_group_size -from openvino._pyopenvino.properties.hint import value_cache_group_size from openvino._pyopenvino.properties.hint import activations_scale_factor diff --git a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp index 2b997c6664cee0..937e9b66a0135f 100644 --- a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp +++ b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp @@ -44,6 +44,10 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate"); wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap"); wrap_property_RW(m_properties, ov::weights_path, "weights_path"); + wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision"); + wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision"); + wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size"); + wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size"); wrap_property_RO(m_properties, ov::supported_properties, "supported_properties"); wrap_property_RO(m_properties, ov::available_devices, "available_devices"); @@ -101,10 +105,6 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_hint, ov::hint::allow_auto_batching, "allow_auto_batching"); wrap_property_RW(m_hint, ov::hint::dynamic_quantization_group_size, "dynamic_quantization_group_size"); wrap_property_RW(m_hint, ov::hint::kv_cache_precision, "kv_cache_precision"); - wrap_property_RW(m_hint, ov::hint::key_cache_precision, "key_cache_precision"); - wrap_property_RW(m_hint, ov::hint::value_cache_precision, "value_cache_precision"); - wrap_property_RW(m_hint, ov::hint::key_cache_group_size, "key_cache_group_size"); - wrap_property_RW(m_hint, ov::hint::value_cache_group_size, "value_cache_group_size"); wrap_property_RW(m_hint, ov::hint::activations_scale_factor, "activations_scale_factor"); // Submodule intel_cpu diff --git a/src/bindings/python/tests/test_runtime/test_properties.py b/src/bindings/python/tests/test_runtime/test_properties.py index d0745f84361310..cbdd117c9fe97f 100644 --- a/src/bindings/python/tests/test_runtime/test_properties.py +++ b/src/bindings/python/tests/test_runtime/test_properties.py @@ -271,6 +271,18 @@ def test_properties_ro(ov_property_ro, expected_value): "WEIGHTS_PATH", (("./model.bin", "./model.bin"),), ), + ( + props.key_cache_group_size, + "KEY_CACHE_GROUP_SIZE", + ((64, 64),), + ), + ( + props.value_cache_group_size, + "VALUE_CACHE_GROUP_SIZE", + ((64, 64),), + ), + (props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)), + (props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)), (hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)), ( hints.model_priority, @@ -334,19 +346,7 @@ def test_properties_ro(ov_property_ro, expected_value): "DYNAMIC_QUANTIZATION_GROUP_SIZE", ((64, 64),), ), - ( - hints.key_cache_group_size, - "KEY_CACHE_GROUP_SIZE", - ((64, 64),), - ), - ( - hints.value_cache_group_size, - "VALUE_CACHE_GROUP_SIZE", - ((64, 64),), - ), (hints.kv_cache_precision, "KV_CACHE_PRECISION", ((Type.f32, Type.f32),)), - (hints.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)), - (hints.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)), ( hints.activations_scale_factor, "ACTIVATIONS_SCALE_FACTOR", diff --git a/src/inference/include/openvino/runtime/properties.hpp b/src/inference/include/openvino/runtime/properties.hpp index 729ccc93feac1f..c7570b818f9665 100644 --- a/src/inference/include/openvino/runtime/properties.hpp +++ b/src/inference/include/openvino/runtime/properties.hpp @@ -580,30 +580,6 @@ static constexpr Property dynamic_quantization */ static constexpr Property kv_cache_precision{"KV_CACHE_PRECISION"}; -/** - * @brief Hint for device to use specified precision for key cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property key_cache_precision{"KEY_CACHE_PRECISION"}; - -/** - * @brief Hint for device to use specified precision for value cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property value_cache_precision{"VALUE_CACHE_PRECISION"}; - -/** - * @brief Hint for device to use group_size for key cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property key_cache_group_size{"KEY_CACHE_GROUP_SIZE"}; - -/** - * @brief Hint for device to use group_size for value cache compression - * @ingroup ov_runtime_cpp_prop_api - */ -static constexpr Property value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"}; - /** * @brief This property scales down activations to prevent overflows when inference precision is f16. * @ingroup ov_runtime_cpp_prop_api @@ -1383,4 +1359,28 @@ static constexpr Property, PropertyMutability::RO> exec * @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE". */ static constexpr Property weights_path{"WEIGHTS_PATH"}; + +/** + * @brief The precision of key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_precision{"KEY_CACHE_PRECISION"}; + +/** + * @brief The precision of value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_precision{"VALUE_CACHE_PRECISION"}; + +/** + * @brief The group_size of key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_group_size{"KEY_CACHE_GROUP_SIZE"}; + +/** + * @brief The group_size of value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"}; } // namespace ov diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index 59ba95ffbeb4c1..60b63f871e0c95 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -256,10 +256,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const { RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), - RO_property(ov::hint::key_cache_precision.name()), - RO_property(ov::hint::value_cache_precision.name()), - RO_property(ov::hint::key_cache_group_size.name()), - RO_property(ov::hint::value_cache_group_size.name()), + RO_property(ov::key_cache_precision.name()), + RO_property(ov::value_cache_precision.name()), + RO_property(ov::key_cache_group_size.name()), + RO_property(ov::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START @@ -336,14 +336,14 @@ ov::Any CompiledModel::get_property(const std::string& name) const { return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); - } else if (name == ov::hint::key_cache_precision) { - return decltype(ov::hint::key_cache_precision)::value_type(config.keyCachePrecision); - } else if (name == ov::hint::value_cache_precision) { - return decltype(ov::hint::value_cache_precision)::value_type(config.valueCachePrecision); - } else if (name == ov::hint::key_cache_group_size) { - return decltype(ov::hint::key_cache_group_size)::value_type(config.keyCacheGroupSize); - } else if (name == ov::hint::value_cache_group_size) { - return decltype(ov::hint::value_cache_group_size)::value_type(config.valueCacheGroupSize); + } else if (name == ov::key_cache_precision) { + return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision); + } else if (name == ov::value_cache_precision) { + return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision); + } else if (name == ov::key_cache_group_size) { + return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize); + } else if (name == ov::value_cache_group_size) { + return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize); } OPENVINO_THROW("Unsupported property: ", name); } diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 3b052e7094d34c..1004ae076eac10 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -373,9 +373,9 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } - } else if (key == ov::hint::key_cache_precision.name()) { + } else if (key == ov::key_cache_precision.name()) { try { - kvCachePrecisionSetExplicitly = true; + keyCachePrecisionSetExplicitly = true; auto const prec = val.as(); if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { keyCachePrecision = prec; @@ -386,12 +386,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { OPENVINO_THROW("Wrong value ", val.as(), " for property key ", - ov::hint::key_cache_precision.name(), + ov::key_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } - } else if (key == ov::hint::value_cache_precision.name()) { + } else if (key == ov::value_cache_precision.name()) { try { - kvCachePrecisionSetExplicitly = true; + valueCachePrecisionSetExplicitly = true; auto const prec = val.as(); if (one_of(prec, ov::element::f32, @@ -407,15 +407,17 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { OPENVINO_THROW("Wrong value ", val.as(), " for property key ", - ov::hint::value_cache_precision.name(), + ov::value_cache_precision.name(), ". Supported values: u4, u8, bf16, f16, f32"); } - } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { + } else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) { try { auto const groupSize = val.as(); - if (key == ov::hint::key_cache_group_size.name()) { + if (key == ov::key_cache_group_size.name()) { + keyCacheGroupSizeSetExplicitly = true; keyCacheGroupSize = groupSize; } else { + valueCacheGroupSizeSetExplicitly = true; valueCacheGroupSize = groupSize; } } catch (ov::Exception&) { @@ -460,6 +462,13 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { aclFastMath = true; } #endif + // key/value cache precision has higher priority, if not defined use kvCachePrecision + if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { + keyCachePrecision = kvCachePrecision; + } + if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { + valueCachePrecision = kvCachePrecision; + } // disable dynamic quantization and kv quantization for best accuracy if (executionMode == ov::hint::ExecutionMode::ACCURACY) { if (!fcDynamicQuantizationGroupSizeSetExplicitly) { @@ -467,9 +476,13 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { } if (!kvCachePrecisionSetExplicitly) { kvCachePrecision = ov::element::f32; - valueCachePrecision = ov::element::f32; + } + if (!keyCachePrecisionSetExplicitly) { keyCachePrecision = ov::element::f32; } + if (!valueCachePrecisionSetExplicitly) { + valueCachePrecision = ov::element::f32; + } } if (!prop.empty()) @@ -524,6 +537,23 @@ void Config::applyRtInfo(const std::shared_ptr& model) { this->fcDynamicQuantizationGroupSize = model->get_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()}); } + if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) { + this->keyCachePrecision = + model->get_rt_info({"runtime_options", ov::key_cache_precision.name()}); + } + if (!valueCachePrecisionSetExplicitly && + model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) { + this->valueCachePrecision = + model->get_rt_info({"runtime_options", ov::value_cache_precision.name()}); + } + if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) { + this->keyCacheGroupSize = model->get_rt_info({"runtime_options", ov::key_cache_group_size.name()}); + } + if (!valueCacheGroupSizeSetExplicitly && + model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) { + this->valueCacheGroupSize = + model->get_rt_info({"runtime_options", ov::value_cache_group_size.name()}); + } } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index 94d4b6e90c531d..75bfde2303a34f 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -48,6 +48,10 @@ struct Config { uint64_t fcDynamicQuantizationGroupSize = 32; bool fcDynamicQuantizationGroupSizeSetExplicitly = false; bool kvCachePrecisionSetExplicitly = false; + bool keyCachePrecisionSetExplicitly = false; + bool valueCachePrecisionSetExplicitly = false; + bool keyCacheGroupSizeSetExplicitly = false; + bool valueCacheGroupSizeSetExplicitly = false; #if defined(OV_CPU_WITH_ACL) bool aclFastMath = false; #endif diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index c0d19a9acd6e15..41d87e3388a035 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1061,7 +1061,14 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrgetConfig(); + const auto& keyCachePrecision = cpuConfig.keyCachePrecision; + const auto& valueCachePrecision = cpuConfig.valueCachePrecision; + OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision, + "CPU: SDPA node only supports same key/value cache precision"); + OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8), + "CPU: SDPA only supports key/value cache precision f32, f16, bf16, u8 but gets ", + keyCachePrecision); if (const auto node = std::dynamic_pointer_cast(op)) { m_config.config.is_causal = node->get_causal(); } else if (const auto node = std::dynamic_pointer_cast(op)) { @@ -1835,12 +1842,16 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M ov::element::Type ScaledDotProductAttention::getKVCachePrecision() { ov::element::Type kvcache_precision; + // TODO: SDPA only supports same key/value cache precision. auto rtPrecision = getRuntimePrecision(); - auto kvCachePrecisionHint = context->getConfig().kvCachePrecision; + auto keyCachePrecisionHint = context->getConfig().keyCachePrecision; + auto valueCachePrecisionHint = context->getConfig().valueCachePrecision; bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && - rtPrecision != ov::element::bf16 && kvCachePrecisionHint == ov::element::f16; + rtPrecision != ov::element::bf16 && + (keyCachePrecisionHint == ov::element::f16 && valueCachePrecisionHint == ov::element::f16); kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; - bool use_int8_kv_cache_precision = kvCachePrecisionHint == ov::element::u8; + bool use_int8_kv_cache_precision = + (keyCachePrecisionHint == ov::element::u8 && valueCachePrecisionHint == ov::element::u8); if (use_int8_kv_cache_precision) kvcache_precision = ov::element::u8; else diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 1c7c79a9c9c6e0..ec9b37c2c2d22e 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -392,14 +392,14 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& options) engConfig.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(engConfig.kvCachePrecision); - } else if (name == ov::hint::key_cache_precision) { - return decltype(ov::hint::key_cache_precision)::value_type(engConfig.keyCachePrecision); - } else if (name == ov::hint::value_cache_precision) { - return decltype(ov::hint::value_cache_precision)::value_type(engConfig.valueCachePrecision); - } else if (name == ov::hint::key_cache_group_size) { - return decltype(ov::hint::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize); - } else if (name == ov::hint::value_cache_group_size) { - return decltype(ov::hint::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize); + } else if (name == ov::key_cache_precision) { + return decltype(ov::key_cache_precision)::value_type(engConfig.keyCachePrecision); + } else if (name == ov::value_cache_precision) { + return decltype(ov::value_cache_precision)::value_type(engConfig.valueCachePrecision); + } else if (name == ov::key_cache_group_size) { + return decltype(ov::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize); + } else if (name == ov::value_cache_group_size) { + return decltype(ov::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize); } return get_ro_property(name, options); } @@ -443,10 +443,10 @@ ov::Any Plugin::get_ro_property(const std::string& name, const ov::AnyMap& optio RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), - RW_property(ov::hint::key_cache_precision.name()), - RW_property(ov::hint::value_cache_precision.name()), - RW_property(ov::hint::key_cache_group_size.name()), - RW_property(ov::hint::value_cache_group_size.name()), + RW_property(ov::key_cache_precision.name()), + RW_property(ov::value_cache_precision.name()), + RW_property(ov::key_cache_group_size.name()), + RW_property(ov::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 016648a7e1026f..9d38d03e5eadde 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/runtime/properties.hpp" + #include -#include "utils/properties_test.hpp" -#include "openvino/runtime/system_conf.hpp" -#include "openvino/runtime/core.hpp" #include "openvino/runtime/compiled_model.hpp" -#include "openvino/runtime/properties.hpp" +#include "openvino/runtime/core.hpp" #include "openvino/runtime/intel_cpu/properties.hpp" +#include "openvino/runtime/system_conf.hpp" +#include "utils/properties_test.hpp" namespace { @@ -41,10 +42,10 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSupportedPropertiesAreAvailable RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), - RO_property(ov::hint::key_cache_precision.name()), - RO_property(ov::hint::value_cache_precision.name()), - RO_property(ov::hint::key_cache_group_size.name()), - RO_property(ov::hint::value_cache_group_size.name()), + RO_property(ov::key_cache_precision.name()), + RO_property(ov::value_cache_precision.name()), + RO_property(ov::key_cache_group_size.name()), + RO_property(ov::value_cache_group_size.name()), }; ov::Core ie; @@ -88,7 +89,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSetROPropertiesThrow) { TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -101,7 +102,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -114,7 +115,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::hint::performance_mode(ov::hint::PerformanceMode::LATENCY))); @@ -129,7 +130,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPrior TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; ov::AnyMap config; @@ -190,14 +191,14 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) { TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { ov::Core core; - core.set_property(deviceName, ov::hint::key_cache_precision(ov::element::f16)); - core.set_property(deviceName, ov::hint::value_cache_precision(ov::element::u4)); + core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::value_cache_precision(ov::element::u4)); ov::CompiledModel compiledModel = core.compile_model(model, deviceName); auto key_cache_precision_value = ov::element::undefined; auto value_cache_precision_value = ov::element::undefined; - OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::hint::key_cache_precision)); - OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::hint::value_cache_precision)); + OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(key_cache_precision_value, ov::element::f16); ASSERT_EQ(value_cache_precision_value, ov::element::u4); } @@ -205,14 +206,14 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) { ov::Core core; - core.set_property(deviceName, ov::hint::key_cache_group_size(32)); - core.set_property(deviceName, ov::hint::value_cache_group_size(16)); + core.set_property(deviceName, ov::key_cache_group_size(32)); + core.set_property(deviceName, ov::value_cache_group_size(16)); ov::CompiledModel compiledModel = core.compile_model(model, deviceName); auto key_cache_group_size_value = 0; auto value_cache_group_size_value = 0; - OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::hint::key_cache_group_size)); - OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::hint::value_cache_group_size)); + OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::value_cache_group_size)); ASSERT_EQ(key_cache_group_size_value, 32); ASSERT_EQ(value_cache_group_size_value, 16); } @@ -260,7 +261,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckExecutionModeIsAvailableIn ASSERT_FALSE(model_exec_mode_it->is_mutable()); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { ov::Core ie; auto inference_precision_value = ov::element::undefined; @@ -274,7 +276,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHas ASSERT_EQ(inference_precision_value, bf16_if_can_be_emulated); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::ACCURACY; auto inference_precision_value = ov::element::undefined; @@ -292,7 +295,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreInferencePrecisionHasH ASSERT_EQ(inference_precision_value, ov::element::f32); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::PERFORMANCE; auto inference_precision_value = ov::element::undefined; @@ -323,14 +327,13 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckLogLevel) { OV_ASSERT_NO_THROW(value = compiledModel.get_property(ov::log::level)); ASSERT_EQ(value.as(), ov::log::Level::NO); } - //check set and get - const std::vector logLevels = { - ov::log::Level::ERR, - ov::log::Level::NO, - ov::log::Level::WARNING, - ov::log::Level::INFO, - ov::log::Level::DEBUG, - ov::log::Level::TRACE}; + // check set and get + const std::vector logLevels = {ov::log::Level::ERR, + ov::log::Level::NO, + ov::log::Level::WARNING, + ov::log::Level::INFO, + ov::log::Level::DEBUG, + ov::log::Level::TRACE}; for (unsigned int i = 0; i < logLevels.size(); i++) { ov::Any value; @@ -365,50 +368,109 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptions) { ov::Core ie; ov::Any type; ov::Any size; + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); + model->set_rt_info("u8", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("u8", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::f16); ASSERT_EQ(size.as(), 0); + ASSERT_EQ(keySize.as(), 32); + ASSERT_EQ(valueSize.as(), 16); + ASSERT_EQ(keyCacheType.as(), ov::element::u8); + ASSERT_EQ(valueCacheType.as(), ov::element::u8); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCompileConfig) { ov::Core ie; ov::Any type; ov::Any size; + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("0", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("0", "runtime_options", ov::value_cache_group_size.name()); + model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); ov::AnyMap config; config[ov::hint::kv_cache_precision.name()] = "u8"; config[ov::hint::dynamic_quantization_group_size.name()] = "16"; + // propperty has higher priority than rt_info + config[ov::key_cache_group_size.name()] = "32"; + config[ov::value_cache_group_size.name()] = "16"; + // key/value cache prec has higher priority than kvCachePrec + config[ov::key_cache_precision.name()] = "f16"; + config[ov::value_cache_precision.name()] = "bf16"; OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName, config)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::u8); ASSERT_EQ(size.as(), 16); + ASSERT_EQ(keySize.as(), 32); + ASSERT_EQ(valueSize.as(), 16); + ASSERT_EQ(keyCacheType.as(), ov::element::f16); + ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCoreProperties) { ov::Core core; ov::Any type; ov::Any size; - + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; core.set_property(deviceName, ov::hint::kv_cache_precision(ov::element::f32)); core.set_property(deviceName, ov::hint::dynamic_quantization_group_size(16)); + core.set_property(deviceName, ov::key_cache_group_size(8)); + core.set_property(deviceName, ov::value_cache_group_size(8)); + core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::value_cache_precision(ov::element::bf16)); ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); + // User's setting has higher priority than rt_info + model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = core.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); + ASSERT_EQ(type.as(), ov::element::f32); ASSERT_EQ(size.as(), 16); + ASSERT_EQ(keySize.as(), 8); + ASSERT_EQ(valueSize.as(), 8); + ASSERT_EQ(keyCacheType.as(), ov::element::f16); + ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } -} // namespace +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp index 589f0641eae0e8..c6289a4dc80716 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp @@ -56,10 +56,10 @@ TEST_F(OVClassConfigTestCPU, smoke_PluginAllSupportedPropertiesAreAvailable) { RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), - RW_property(ov::hint::key_cache_precision.name()), - RW_property(ov::hint::value_cache_precision.name()), - RW_property(ov::hint::key_cache_group_size.name()), - RW_property(ov::hint::value_cache_group_size.name()), + RW_property(ov::key_cache_precision.name()), + RW_property(ov::value_cache_precision.name()), + RW_property(ov::key_cache_group_size.name()), + RW_property(ov::value_cache_group_size.name()), }; ov::Core ie; From e0b437e141dcd35d5d0414b3be56aae47ee57e70 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Mon, 6 Jan 2025 15:24:51 +0800 Subject: [PATCH 25/28] [CPU]fix prec order & check group_size --- src/plugins/intel_cpu/src/config.cpp | 2 +- src/plugins/intel_cpu/src/nodes/scaled_attn.cpp | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 1004ae076eac10..73460258356ad7 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -529,7 +529,7 @@ void Config::applyRtInfo(const std::shared_ptr& model) { // if user sets explicitly, it will be higher priority than rt_info if (!kvCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) { - this->kvCachePrecision = + this->kvCachePrecision = this->keyCachePrecision = this->valueCachePrecision = model->get_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()}); } if (!fcDynamicQuantizationGroupSizeSetExplicitly && diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 41d87e3388a035..fec6252373a70b 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1064,6 +1064,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrgetConfig(); const auto& keyCachePrecision = cpuConfig.keyCachePrecision; const auto& valueCachePrecision = cpuConfig.valueCachePrecision; + const auto keyDims = getInputShapeAtPort(1).getDims(); + const auto valueDims = getInputShapeAtPort(2).getDims(); + const auto keyS = *(keyDims.end() - 1); + const auto valueS = *(valueDims.end() - 1); + if (keyS % cpuConfig.keyCacheGroupSize != 0) { + OPENVINO_THROW("ScaledDotProductAttention AttentionExecutor creation fails key state " + std::to_string(keyS) + + " cannot be divided by group size " + std::to_string(cpuConfig.keyCacheGroupSize)); + } + + if (valueS % cpuConfig.valueCacheGroupSize != 0) { + OPENVINO_THROW("ScaledDotProductAttention AttentionExecutor creation fails value state " + + std::to_string(keyS) + " cannot be divided by group size " + + std::to_string(cpuConfig.valueCacheGroupSize)); + } OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision, "CPU: SDPA node only supports same key/value cache precision"); OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8), From 0515410f950b424dd1359a2b9dcff3d62ee8a89a Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Tue, 7 Jan 2025 11:09:47 +0800 Subject: [PATCH 26/28] [CPU]fix sdpa test --- src/plugins/intel_cpu/src/config.h | 4 ++-- .../custom/subgraph_tests/src/classes/concat_sdp.cpp | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index 75bfde2303a34f..44b78df043cee6 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -67,8 +67,8 @@ struct Config { // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif - size_t keyCacheGroupSize = 0ul; - size_t valueCacheGroupSize = 0ul; + size_t keyCacheGroupSize = 32ul; + size_t valueCacheGroupSize = 32ul; ov::threading::IStreamsExecutor::Config streamExecutorConfig; int streams = 1; bool streamsChanged = false; diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp index 83fc0a635546fc..ab893ac060f55b 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp @@ -76,6 +76,8 @@ void ConcatSDPTest::SetUp() { auto v_ps = inputDynamicShapes[0]; if (m_isDiffKVHeadSize) { v_ps[3] += m_diffKVHeadSize; + // v_ps[3] must be divisible by value_cache_group_size + configuration[ov::value_cache_group_size.name()] = "16"; } inputParams.push_back(std::make_shared(inType, v_ps)); inputParams[0]->set_friendly_name("q"); From 7a412f799f2bb0da46fde1fec1e291fe0285fad1 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Tue, 7 Jan 2025 13:16:09 +0800 Subject: [PATCH 27/28] [CPU]fix group_size in sdpa Signed-off-by: Zhang Yi --- .../intel_cpu/src/nodes/scaled_attn.cpp | 19 +++++++++---------- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 7 ++++++- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index fec6252373a70b..aec0ff2e7d9026 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1068,21 +1068,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr(op)) { m_config.config.is_causal = node->get_causal(); } else if (const auto node = std::dynamic_pointer_cast(op)) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 21b9056ba9517c..2917342314fafd 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -47,7 +47,10 @@ class ScaledDotProductAttention : public Node { real_order = {permute_axes[2], permute_axes[0], permute_axes[1], permute_axes[3]}; return real_order; } - + struct SDPAQuantParam { + ov::element::Type precision = ov::element::undefined; + size_t groupSize = 0; + }; ov::element::Type getKVCachePrecision(); private: @@ -86,6 +89,8 @@ class ScaledDotProductAttention : public Node { // (0, 1, 2, 3) for BHLS // (2, 0, 1, 3) for LBHS std::vector m_kvstate_layout = {2, 0, 1, 3}; + SDPAQuantParam m_key_quant_param; + SDPAQuantParam m_value_quant_param; }; } // namespace node From 594b39299c4d40c409fe1a486bee198b89167822 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Tue, 7 Jan 2025 20:17:20 +0800 Subject: [PATCH 28/28] [CPU]Change default group_size --- src/plugins/intel_cpu/src/config.h | 4 ++-- .../custom/subgraph_tests/src/classes/concat_sdp.cpp | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index 44b78df043cee6..75bfde2303a34f 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -67,8 +67,8 @@ struct Config { // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif - size_t keyCacheGroupSize = 32ul; - size_t valueCacheGroupSize = 32ul; + size_t keyCacheGroupSize = 0ul; + size_t valueCacheGroupSize = 0ul; ov::threading::IStreamsExecutor::Config streamExecutorConfig; int streams = 1; bool streamsChanged = false; diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp index ab893ac060f55b..83fc0a635546fc 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/classes/concat_sdp.cpp @@ -76,8 +76,6 @@ void ConcatSDPTest::SetUp() { auto v_ps = inputDynamicShapes[0]; if (m_isDiffKVHeadSize) { v_ps[3] += m_diffKVHeadSize; - // v_ps[3] must be divisible by value_cache_group_size - configuration[ov::value_cache_group_size.name()] = "16"; } inputParams.push_back(std::make_shared(inType, v_ps)); inputParams[0]->set_friendly_name("q");