diff --git a/src/bindings/python/src/openvino/runtime/properties/__init__.py b/src/bindings/python/src/openvino/runtime/properties/__init__.py index 3269ea42e32ac2..a02a18e556135b 100644 --- a/src/bindings/python/src/openvino/runtime/properties/__init__.py +++ b/src/bindings/python/src/openvino/runtime/properties/__init__.py @@ -30,6 +30,10 @@ from openvino._pyopenvino.properties import loaded_from_cache from openvino._pyopenvino.properties import cache_encryption_callbacks from openvino._pyopenvino.properties import weights_path +from openvino._pyopenvino.properties import key_cache_precision +from openvino._pyopenvino.properties import value_cache_precision +from openvino._pyopenvino.properties import key_cache_group_size +from openvino._pyopenvino.properties import value_cache_group_size # Submodules from openvino.runtime.properties import hint diff --git a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp index 564e5f69f5ee14..937e9b66a0135f 100644 --- a/src/bindings/python/src/pyopenvino/core/properties/properties.cpp +++ b/src/bindings/python/src/pyopenvino/core/properties/properties.cpp @@ -44,6 +44,10 @@ void regmodule_properties(py::module m) { wrap_property_RW(m_properties, ov::force_tbb_terminate, "force_tbb_terminate"); wrap_property_RW(m_properties, ov::enable_mmap, "enable_mmap"); wrap_property_RW(m_properties, ov::weights_path, "weights_path"); + wrap_property_RW(m_properties, ov::key_cache_precision, "key_cache_precision"); + wrap_property_RW(m_properties, ov::value_cache_precision, "value_cache_precision"); + wrap_property_RW(m_properties, ov::key_cache_group_size, "key_cache_group_size"); + wrap_property_RW(m_properties, ov::value_cache_group_size, "value_cache_group_size"); wrap_property_RO(m_properties, ov::supported_properties, "supported_properties"); wrap_property_RO(m_properties, ov::available_devices, "available_devices"); diff --git a/src/bindings/python/tests/test_runtime/test_properties.py b/src/bindings/python/tests/test_runtime/test_properties.py index 6065d72196b44b..cbdd117c9fe97f 100644 --- a/src/bindings/python/tests/test_runtime/test_properties.py +++ b/src/bindings/python/tests/test_runtime/test_properties.py @@ -271,6 +271,18 @@ def test_properties_ro(ov_property_ro, expected_value): "WEIGHTS_PATH", (("./model.bin", "./model.bin"),), ), + ( + props.key_cache_group_size, + "KEY_CACHE_GROUP_SIZE", + ((64, 64),), + ), + ( + props.value_cache_group_size, + "VALUE_CACHE_GROUP_SIZE", + ((64, 64),), + ), + (props.key_cache_precision, "KEY_CACHE_PRECISION", ((Type.f32, Type.f32),)), + (props.value_cache_precision, "VALUE_CACHE_PRECISION", ((Type.f32, Type.f32),)), (hints.inference_precision, "INFERENCE_PRECISION_HINT", ((Type.f32, Type.f32),)), ( hints.model_priority, diff --git a/src/inference/include/openvino/runtime/properties.hpp b/src/inference/include/openvino/runtime/properties.hpp index 8baea3ed408656..c7570b818f9665 100644 --- a/src/inference/include/openvino/runtime/properties.hpp +++ b/src/inference/include/openvino/runtime/properties.hpp @@ -1359,4 +1359,28 @@ static constexpr Property, PropertyMutability::RO> exec * @note This property is used for weightless caching. Only used when ov::CacheMode Property is set to "OPTIMIZE_SIZE". */ static constexpr Property weights_path{"WEIGHTS_PATH"}; + +/** + * @brief The precision of key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_precision{"KEY_CACHE_PRECISION"}; + +/** + * @brief The precision of value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_precision{"VALUE_CACHE_PRECISION"}; + +/** + * @brief The group_size of key cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property key_cache_group_size{"KEY_CACHE_GROUP_SIZE"}; + +/** + * @brief The group_size of value cache compression + * @ingroup ov_runtime_cpp_prop_api + */ +static constexpr Property value_cache_group_size{"VALUE_CACHE_GROUP_SIZE"}; } // namespace ov diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index f81c7dbbced99d..60b63f871e0c95 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -256,6 +256,10 @@ ov::Any CompiledModel::get_property(const std::string& name) const { RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::key_cache_precision.name()), + RO_property(ov::value_cache_precision.name()), + RO_property(ov::key_cache_group_size.name()), + RO_property(ov::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START @@ -332,6 +336,14 @@ ov::Any CompiledModel::get_property(const std::string& name) const { return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); + } else if (name == ov::key_cache_precision) { + return decltype(ov::key_cache_precision)::value_type(config.keyCachePrecision); + } else if (name == ov::value_cache_precision) { + return decltype(ov::value_cache_precision)::value_type(config.valueCachePrecision); + } else if (name == ov::key_cache_group_size) { + return decltype(ov::key_cache_group_size)::value_type(config.keyCacheGroupSize); + } else if (name == ov::value_cache_group_size) { + return decltype(ov::value_cache_group_size)::value_type(config.valueCacheGroupSize); } OPENVINO_THROW("Unsupported property: ", name); } diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 7d1ee05897e81d..73460258356ad7 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -373,6 +373,60 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { ov::hint::kv_cache_precision.name(), ". Supported values: u8, bf16, f16, f32"); } + } else if (key == ov::key_cache_precision.name()) { + try { + keyCachePrecisionSetExplicitly = true; + auto const prec = val.as(); + if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { + keyCachePrecision = prec; + } else { + OPENVINO_THROW("keyCachePrecision doesn't support value ", prec); + } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::key_cache_precision.name(), + ". Supported values: u8, bf16, f16, f32"); + } + } else if (key == ov::value_cache_precision.name()) { + try { + valueCachePrecisionSetExplicitly = true; + auto const prec = val.as(); + if (one_of(prec, + ov::element::f32, + ov::element::f16, + ov::element::bf16, + ov::element::u8, + ov::element::u4)) { + valueCachePrecision = prec; + } else { + OPENVINO_THROW("valueCachePrecision doesn't support value ", prec); + } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + ov::value_cache_precision.name(), + ". Supported values: u4, u8, bf16, f16, f32"); + } + } else if (key == ov::key_cache_group_size.name() || key == ov::value_cache_group_size.name()) { + try { + auto const groupSize = val.as(); + if (key == ov::key_cache_group_size.name()) { + keyCacheGroupSizeSetExplicitly = true; + keyCacheGroupSize = groupSize; + } else { + valueCacheGroupSizeSetExplicitly = true; + valueCacheGroupSize = groupSize; + } + } catch (ov::Exception&) { + OPENVINO_THROW("Wrong value ", + val.as(), + " for property key ", + key, + ". Expected only unsinged integer numbers"); + } } else if (key == ov::cache_encryption_callbacks.name()) { try { auto encryption_callbacks = val.as(); @@ -408,6 +462,13 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { aclFastMath = true; } #endif + // key/value cache precision has higher priority, if not defined use kvCachePrecision + if (!keyCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { + keyCachePrecision = kvCachePrecision; + } + if (!valueCachePrecisionSetExplicitly && kvCachePrecisionSetExplicitly) { + valueCachePrecision = kvCachePrecision; + } // disable dynamic quantization and kv quantization for best accuracy if (executionMode == ov::hint::ExecutionMode::ACCURACY) { if (!fcDynamicQuantizationGroupSizeSetExplicitly) { @@ -416,6 +477,12 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { if (!kvCachePrecisionSetExplicitly) { kvCachePrecision = ov::element::f32; } + if (!keyCachePrecisionSetExplicitly) { + keyCachePrecision = ov::element::f32; + } + if (!valueCachePrecisionSetExplicitly) { + valueCachePrecision = ov::element::f32; + } } if (!prop.empty()) @@ -462,7 +529,7 @@ void Config::applyRtInfo(const std::shared_ptr& model) { // if user sets explicitly, it will be higher priority than rt_info if (!kvCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) { - this->kvCachePrecision = + this->kvCachePrecision = this->keyCachePrecision = this->valueCachePrecision = model->get_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()}); } if (!fcDynamicQuantizationGroupSizeSetExplicitly && @@ -470,6 +537,23 @@ void Config::applyRtInfo(const std::shared_ptr& model) { this->fcDynamicQuantizationGroupSize = model->get_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()}); } + if (!keyCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_precision.name()})) { + this->keyCachePrecision = + model->get_rt_info({"runtime_options", ov::key_cache_precision.name()}); + } + if (!valueCachePrecisionSetExplicitly && + model->has_rt_info({"runtime_options", ov::value_cache_precision.name()})) { + this->valueCachePrecision = + model->get_rt_info({"runtime_options", ov::value_cache_precision.name()}); + } + if (!keyCacheGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::key_cache_group_size.name()})) { + this->keyCacheGroupSize = model->get_rt_info({"runtime_options", ov::key_cache_group_size.name()}); + } + if (!valueCacheGroupSizeSetExplicitly && + model->has_rt_info({"runtime_options", ov::value_cache_group_size.name()})) { + this->valueCacheGroupSize = + model->get_rt_info({"runtime_options", ov::value_cache_group_size.name()}); + } } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index 1aa08f4412f0b3..75bfde2303a34f 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -48,17 +48,27 @@ struct Config { uint64_t fcDynamicQuantizationGroupSize = 32; bool fcDynamicQuantizationGroupSizeSetExplicitly = false; bool kvCachePrecisionSetExplicitly = false; + bool keyCachePrecisionSetExplicitly = false; + bool valueCachePrecisionSetExplicitly = false; + bool keyCacheGroupSizeSetExplicitly = false; + bool valueCacheGroupSizeSetExplicitly = false; #if defined(OV_CPU_WITH_ACL) bool aclFastMath = false; #endif #if defined(OPENVINO_ARCH_X86_64) ov::element::Type kvCachePrecision = ov::element::u8; + ov::element::Type keyCachePrecision = ov::element::u8; + ov::element::Type valueCachePrecision = ov::element::u8; size_t rtCacheCapacity = 5000ul; #else ov::element::Type kvCachePrecision = ov::element::f16; + ov::element::Type keyCachePrecision = ov::element::f16; + ov::element::Type valueCachePrecision = ov::element::f16; // TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives size_t rtCacheCapacity = 0ul; #endif + size_t keyCacheGroupSize = 0ul; + size_t valueCacheGroupSize = 0ul; ov::threading::IStreamsExecutor::Config streamExecutorConfig; int streams = 1; bool streamsChanged = false; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 095180d659142e..26282a70fcb512 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -27,10 +27,10 @@ namespace XARCH { using namespace ov; template -static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { +static void find_minmax(const T* src, size_t n, float& min, float& max) { + max = -FLT_MAX; + min = FLT_MAX; size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; #if defined(HAVE_AVX512F) auto v0_max = _mm512_set1_ps(-FLT_MAX); auto v0_min = _mm512_set1_ps(FLT_MAX); @@ -131,12 +131,18 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& max = std::max(max, tmp); min = std::min(min, tmp); } +} + +template +static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { + size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; + find_minmax(src, n, min, max); scale = (max - min) / 255; if (scale == 0) scale = 0.0001f; zp = -min / scale; - - i = 0; #if defined(HAVE_AVX512F) auto v_scale = _mm512_set1_ps(1 / scale); auto v_zp = _mm512_set1_ps(zp); @@ -170,6 +176,116 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& } } +template +static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) { + size_t i = 0; + float max = -FLT_MAX; + float min = FLT_MAX; + find_minmax(src, n, min, max); + auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return dst | (uint8_t)(val << shift); + }; + auto dst_ptr = reinterpret_cast(dst); + scale = (max - min) / ((1 << 4) - 1); + if (scale == 0) + scale = 0.0001f; + zp = -min / scale; +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(1 / scale); + auto v_zp = _mm512_set1_ps(zp); + auto v_zero = _mm512_setzero_epi32(); + auto v_upper = _mm512_set1_epi32(15); + for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { + auto v0 = mm512_uni_loadu_ps(src + i); + auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); + v0 = _mm512_fmadd_ps(v0, v_scale, v_zp); + v1 = _mm512_fmadd_ps(v1, v_scale, v_zp); + auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + v0_i32 = _mm512_max_epi32(v0_i32, v_zero); + v1_i32 = _mm512_max_epi32(v1_i32, v_zero); + v0_i32 = _mm512_min_epi32(v0_i32, v_upper); + v1_i32 = _mm512_min_epi32(v1_i32, v_upper); + __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); + auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); + first_half = _mm512_slli_epi32(first_half, 4); + auto mask = _mm512_set1_epi32(0x0F); + second_half = _mm512_and_epi32(second_half, mask); + auto combined = _mm512_or_epi32(first_half, second_half); + _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); + } +#endif +#if defined(HAVE_AVX2) + auto v256_zero = _mm256_set1_epi32(0); + auto v256_upper = _mm256_set1_epi32(15); + auto v256_scale = _mm256_set1_ps(1 / scale); + auto v256_zp = _mm256_set1_ps(zp); + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto v0 = mm256_uni_loadu_ps(src + i); + auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); + v0 = _mm256_fmadd_ps(v0, v256_scale, v256_zp); + v1 = _mm256_fmadd_ps(v1, v256_scale, v256_zp); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + + auto v0_i32 = _mm256_cvtps_epi32(v0); + auto v1_i32 = _mm256_cvtps_epi32(v1); + v0_i32 = _mm256_max_epi32(v0_i32, v256_zero); + v1_i32 = _mm256_max_epi32(v1_i32, v256_zero); + v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); + v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); + auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); + v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 + // _mm256_permutevar8x32_epi32 + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permute2x128_si256 + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // shift + mask + or + // [0,1],[2,3], ..., [12,13], [14,15] + auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); + auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); + first_half = _mm256_slli_epi32(first_half, 4); + auto mask = _mm256_set1_epi32(0x0F); + second_half = _mm256_and_si256(second_half, mask); + auto combined = _mm256_or_si256(first_half, second_half); + + auto high4 = _mm256_extractf128_si256(combined, 1); + auto low4 = _mm256_castsi256_si128(combined); + // ignore sign bit for u4 case + auto packed = _mm_packus_epi32(low4, high4); + packed = _mm_packus_epi16(packed, packed); + _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); + } +#endif + for (; i < n; i++) { + float tmp = src[i]; +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + uint8_t src_val = MIN(15, (uint8_t)(std::round(tmp / scale + zp))); + uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; + dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); + dst_ptr[i / 2] = dst_val; + } +} + +template ::type = true> +static void quantize(const T* src, uint8_t* dst, size_t n, float* scale_zp) { + quant_u8(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + +template ::type = true> +static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { + quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); +} + template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -187,36 +303,55 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, }); } -template +template static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping) { + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3]; size_t block_size = k_dst.m_dims[2]; + size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; if (slot < 0) return; auto block_number = slot / block_size; auto block_offset = slot % block_size; - - auto p_k = reinterpret_cast(k_dst.ptr(block_number, h, block_offset)); - auto p_v = reinterpret_cast(v_dst.ptr(block_number, h, block_offset)); // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| - quant_u8(k_src.ptr(b, h, m), - k_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), - S, - p_k[0], - p_k[1]); - quant_u8(v_src.ptr(b, h, m), - v_dst.ptr(block_number, h, block_offset) + sizeof(float) + sizeof(float), - SV, - p_v[0], - p_v[1]); + for (size_t src_offset = 0, dst_offset = 0; src_offset < S; + src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) { + auto p_k = reinterpret_cast( + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset)); + quantize( + k_src.ptr(b, h, m, src_offset), + k_dst.ptr::value_type>(block_number, + h, + block_offset, + dst_offset) + + sizeof(float) + sizeof(float), + key_group_size, + p_k); + } + + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size, + dst_offset += value_group_size / sub_byte_multiplier + sizeof(float) + sizeof(float)) { + uint8_t* v_base = reinterpret_cast( + v_dst.m_ptr.get() + + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / + sub_byte_multiplier + + dst_offset); + auto p_v = reinterpret_cast(v_base); + uint8_t* v_ptr = v_base + sizeof(float) * 2; + quantize(v_src.ptr(b, h, m, src_offset), v_ptr, value_group_size, p_v); + } }); } @@ -245,20 +380,48 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping) { - if (k_src.get_precision() == ov::element::f32 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); - } else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); - } else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) { - paged_attn_quant_mt(k_src, v_src, k_dst, v_dst, slot_mapping); - } else { + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size) { + using function_type = void (*)(const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const size_t, + const size_t); + static constexpr function_type funcs_fp32[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + static constexpr function_type funcs_bf16[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + static constexpr function_type funcs_f16[] = { + paged_attn_quant_mt, + paged_attn_quant_mt, + }; + if (k_dst.get_precision() != ov::element::u8) { OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); } + std::map dispatch_table = { + {ov::element::u8, 0}, + {ov::element::u4, 1}, + {ov::element::i4, 2}, + }; + size_t dispatch = dispatch_table[v_dst.get_precision()]; + if (k_src.get_precision() == ov::element::f32) { + funcs_fp32[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } else if (k_src.get_precision() == ov::element::bf16) { + funcs_bf16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } else if (k_src.get_precision() == ov::element::f16) { + funcs_f16[dispatch](k_src, v_src, k_dst, v_dst, slot_mapping, key_group_size, value_group_size); + } } void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp) { @@ -266,7 +429,7 @@ void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float } void attn_dequant_u8(const uint8_t* src, float* dst, size_t n, float scale, float zp) { - attn_dequant_u8_kernel(src, dst, n, scale, zp); + attn_dequant_kernel(src, dst, n, scale, zp); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp index 2f39f74f5b3460..364e5775861ed2 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp @@ -27,7 +27,9 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, const ov::intel_cpu::PlainTensor& k_dst, const ov::intel_cpu::PlainTensor& v_dst, - const ov::intel_cpu::PlainTensor& slot_mapping); + const ov::intel_cpu::PlainTensor& slot_mapping, + const size_t key_group_size, + const size_t value_group_size); void attn_quant_u8(const float* src, uint8_t* dst, size_t n, float& scale, float& zp); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 759d0005103871..761a136eda2997 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -17,8 +17,10 @@ namespace Extensions { namespace Cpu { namespace XARCH { -template -void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { size_t i = 0; // loadu_si128/epi64 does not support const qualifier uint8_t* src_nc = const_cast(src); @@ -52,6 +54,90 @@ void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } } +template ::type = true> +void attn_dequant_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { + // 2 4bit data form a byte + /* 0,1|2,3|4,5|6,7 + / \ + 0,2,4,6|1,3,5,7 + | + permute + | + 0,1,2,3,4,5,6,7 + */ + size_t i = 0; + uint8_t* src_nc = const_cast(src); +#if defined(HAVE_AVX512F) + auto v_scale = _mm512_set1_ps(scale); + auto v_zp_scale = _mm512_set1_ps(zp * scale); + for (; i + vec_len_f32_avx512 * 2 <= n; i += vec_len_f32_avx512 * 2) { + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(src_nc + i / 2)); + auto v_i32 = _mm512_cvtepu8_epi32(data); + + auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); + + auto mask = _mm512_set1_epi32(0x0F); + auto v_512_high_half = _mm512_and_si512(v_i32, mask); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); + // q * scale- zp * scale + v_f32_low_half = _mm512_fmsub_ps(v_f32_low_half, v_scale, v_zp_scale); + v_f32_high_half = _mm512_fmsub_ps(v_f32_high_half, v_scale, v_zp_scale); + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + mm512_uni_storeu_ps(dst + i, first_half); + mm512_uni_storeu_ps(dst + i + vec_len_f32_avx512, second_half); + } +#elif defined(HAVE_AVX2) + auto v256_zp = _mm256_set1_ps(zp); + auto v256_scale = _mm256_set1_ps(scale); + for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(src_nc + i / 2)); + + auto v_i32 = _mm256_cvtepu8_epi32(data); + auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto mask = _mm256_set1_epi32(0x0F); + auto v_256_high_half = _mm256_and_si256(v_i32, mask); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); + v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); + + v_f32_low_half = _mm256_mul_ps(v_f32_low_half, v256_scale); + v_f32_high_half = _mm256_mul_ps(v_f32_high_half, v256_scale); + + // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 + // _mm256_permute2f128_ps + // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 + // _mm256_permutevar8x32_ps + // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + + mm256_uni_storeu_ps(dst + i, first_half); + mm256_uni_storeu_ps(dst + i + vec_len_f32_avx2, second_half); + } +#endif + auto extract_half_byte = [&](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + return (uint8_t)((val >> shift) & 0x000F); + }; + for (; i < n; ++i) { + float tmp = extract_half_byte(src_nc[i / 2], (uint8_t)(i % 2)); + tmp = (tmp - zp) * scale; + dst[i] = tmp; + } +} + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index a74021d8ac0d05..955e7687ef97b3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -72,8 +72,22 @@ void cvt_copy(TA* dst, TB* src, size_t n) { } } -template -static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size) { +size_t inline get_sub_byte_multiplier(ov::element::Type type) { + return one_of(type, ov::element::i4, ov::element::u4) ? 8 / type.bitwidth() : 1; +} + +template ::value || std::is_same::value || + std::is_same::value) && + (SRC_PREC != ov::element::u8 || SRC_PREC != ov::element::u4), + bool>::type = true> +static void attn_acc_value_block(float* out, + float* weight, + T* v, + const size_t S, + const size_t block_size, + const size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -200,117 +214,262 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size v += S; } } - -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size) { +template ::type = true> +static void attn_acc_value_block(float* out, + float* weight, + uint8_t* v, + const size_t S, + const size_t block_size, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t params_offset = sizeof(float) * 2; + const size_t src_stride = S / group_size * (group_size + params_offset); + # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto v_f0 = reinterpret_cast(v); - auto v_f1 = reinterpret_cast(v + S + 8); - auto v_f2 = reinterpret_cast(v + 2 * (S + 8)); - auto v_f3 = reinterpret_cast(v + 3 * (S + 8)); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); - auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); - auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - auto zp1 = _mm512_set1_ps(v_f1[1]); - auto zp2 = _mm512_set1_ps(v_f2[1]); - auto zp3 = _mm512_set1_ps(v_f3[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), - zp0); - auto v1 = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + S + 8)))), - zp1); - auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 2 * (S + 8))))), - zp2); - auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i + 3 * (S + 8))))), - zp3); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); - v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; - out[i] += weight[1] * (v[i + S + 8] - v_f1[1]) * v_f1[0]; - out[i] += weight[2] * (v[i + 2 * (S + 8)] - v_f2[1]) * v_f2[0]; - out[i] += weight[3] * (v[i + 3 * (S + 8)] - v_f3[1]) * v_f3[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + // process group by group + uint8_t* v_ptr = v + src_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto v_f1 = reinterpret_cast(v_ptr + src_stride); + auto v_f2 = reinterpret_cast(v_ptr + 2 * src_stride); + auto v_f3 = reinterpret_cast(v_ptr + 3 * src_stride); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto attn_w_vec1 = _mm512_set1_ps(weight[1] * v_f1[0]); + auto attn_w_vec2 = _mm512_set1_ps(weight[2] * v_f2[0]); + auto attn_w_vec3 = _mm512_set1_ps(weight[3] * v_f3[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + auto zp1 = _mm512_set1_ps(v_f1[1]); + auto zp2 = _mm512_set1_ps(v_f2[1]); + auto zp3 = _mm512_set1_ps(v_f3[1]); + uint8_t* v_data_ptr = v + src_offset + params_offset; + size_t i = 0; + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps(out + dst_offset + i); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); + auto v1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i + src_stride)))), + zp1); + auto v2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 2 * src_stride)))), + zp2); + auto v3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(v_data_ptr + i + 3 * src_stride)))), + zp3); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec1, v1, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec2, v2, v_out); + v_out = _mm512_fmadd_ps(attn_w_vec3, v3, v_out); + _mm512_storeu_ps(out + dst_offset + i, v_out); + } + for (; i < group_size; i++) { + out[i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + out[i] += weight[1] * (v_data_ptr[i + src_stride] - v_f1[1]) * v_f1[0]; + out[i] += weight[2] * (v_data_ptr[i + 2 * src_stride] - v_f2[1]) * v_f2[0]; + out[i] += weight[3] * (v_data_ptr[i + 3 * src_stride] - v_f3[1]) * v_f3[0]; + } + dst_offset += group_size; + src_offset += group_size + params_offset; } - v += 4 * (S + 8) - 8; weight += 4; + v += 4 * src_stride; } for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm512_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { - auto v_out = mm512_uni_loadu_ps(out + i); - auto v0 = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v + i)))), - zp0); - v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); - - _mm512_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + uint8_t* v_ptr = v + src_offset; + uint8_t* v_data_ptr = v_ptr + params_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto attn_w_vec0 = _mm512_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm512_set1_ps(v_f0[1]); + size_t i = 0; + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { + auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); + v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); + + _mm512_storeu_ps((out + dst_offset + i), v_out); + } + for (; i < group_size; i++) { + out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + } + dst_offset += group_size; + src_offset += group_size + params_offset; } - v += S; + v += src_stride; weight++; } return; # elif defined(HAVE_AVX2) size_t j = 0; for (; j < block_size; j++) { - auto v_f0 = reinterpret_cast(v); - auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); - auto zp0 = _mm256_set1_ps(v_f0[1]); - size_t i = 0; - v += 8; - for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { - auto v_out = mm256_uni_loadu_ps(out + i); - auto v0 = _mm256_sub_ps( - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i)))), - zp0); - v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); - - mm256_uni_storeu_ps(out + i, v_out); - } - for (; i < S; i++) { - out[i] += weight[0] * (v[i] - v_f0[1]) * v_f0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + uint8_t* v_ptr = v + src_offset; + uint8_t* v_data_ptr = v_ptr + params_offset; + auto v_f0 = reinterpret_cast(v_ptr); + auto attn_w_vec0 = _mm256_set1_ps(weight[0] * v_f0[0]); + auto zp0 = _mm256_set1_ps(v_f0[1]); + size_t i = 0; + v += 8; + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { + auto v_out = mm256_uni_loadu_ps(out + dst_offset + i); + auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); + v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); + + mm256_uni_storeu_ps(out + dst_offset + i, v_out); + } + for (; i < group_size; i++) { + out[dst_offset + i] += weight[0] * (v_data_ptr[i] - v_f0[1]) * v_f0[0]; + } + dst_offset += group_size; + src_offset += group_size + params_offset; } - v += S; + v += src_stride; weight++; } return; # endif for (size_t j = 0; j < block_size; j++) { - auto v0 = reinterpret_cast(v); - v += 8; - for (size_t i = 0; i < S; i++) { - out[i] += weight[j] * (v[i] - v0[1]) * v0[0]; + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v + src_offset); + for (size_t i = 0; i < group_size; i++) { + out[dst_offset + i] += weight[j] * (v[i + src_offset + params_offset] - v0[1]) * v0[0]; + } + dst_offset += group_size; + src_offset += group_size + params_offset; } - v += S; + v += src_stride; + } +} + +template ::type = true> +static void attn_acc_value_block(float* out, + float* weight, + void* v, + const size_t S, + const size_t block_size, + const size_t group_size) { + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t params_offset = sizeof(float) * 2; + uint8_t* v_ptr = reinterpret_cast(v); + auto sub_byte_multiplier = 8 / 4; + const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); + auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { + uint8_t shift = high_half ? 0 : 4; + + return (uint8_t)((val >> shift) & 0x000F); + }; + for (size_t j = 0; j < block_size; j++) { + dst_offset = 0; + src_offset = 0; + while (dst_offset < S) { + auto v0 = reinterpret_cast(v_ptr + src_offset); + size_t i = 0; +# if defined(HAVE_AVX512F) + auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); + auto v_zp = _mm512_set1_ps(v0[1]); + for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); + auto v_i32 = _mm512_cvtepu8_epi32(data); + + auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm512_cvtepi32_ps(v_512_low_half); + + auto mask = _mm512_set1_epi32(0x0F); + auto v_512_high_half = _mm512_and_si512(v_i32, mask); + auto v_f32_high_half = _mm512_cvtepi32_ps(v_512_high_half); + + // q - zp + v_f32_low_half = _mm512_sub_ps(v_f32_low_half, v_zp); + v_f32_high_half = _mm512_sub_ps(v_f32_high_half, v_zp); + + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); + __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); + auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); + v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); + v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); + mm512_uni_storeu_ps(out + dst_offset + i, v_out0); + mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); + } +# elif defined(HAVE_AVX2) + auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); + auto v256_zp = _mm256_set1_ps(v0[1]); + for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); + + auto v_i32 = _mm256_cvtepu8_epi32(data); + auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); + auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); + + auto mask = _mm256_set1_epi32(0x0F); + auto v_256_high_half = _mm256_and_si256(v_i32, mask); + auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); + // q - zp + v_f32_low_half = _mm256_sub_ps(v_f32_low_half, v256_zp); + v_f32_high_half = _mm256_sub_ps(v_f32_high_half, v256_zp); + + auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); + auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); + + __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); + auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); + first_half = _mm256_permutevar8x32_ps(first_half, idx1); + __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); + second_half = _mm256_permutevar8x32_ps(second_half, idx1); + + v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); + v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); + mm256_uni_storeu_ps(out + dst_offset + i, v_out0); + mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); + } +# endif + for (; i < group_size; i += 2) { + uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; + float tmp0 = extract_half_byte(data, static_cast(i % 2)); + float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); + out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1 - v0[1]) * v0[0]; + } + dst_offset += group_size; + src_offset += group_size / sub_byte_multiplier + params_offset; + } + v_ptr += src_stride; } } template -static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size) { +static void dot_product_block(TA* a, + TB* b, + float* c, + const size_t n, + const size_t block_size, + const size_t group_size) { # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { @@ -422,175 +581,235 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz } template -static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size) { +static void dot_product_block(TA* a, + uint8_t* b, + float* c, + const size_t n, + const size_t block_size, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + size_t src_offset = 0; + size_t dst_offset = 0; + const size_t params_offset = sizeof(float) * 2; + const size_t src_stride = n / group_size * (group_size + params_offset); # if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm512_setzero_ps(); - auto vsum1 = _mm512_setzero_ps(); - auto vsum2 = _mm512_setzero_ps(); - auto vsum3 = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm512_set1_ps(b0[1]); - auto v_zp1 = _mm512_set1_ps(b1[1]); - auto v_zp2 = _mm512_set1_ps(b2[1]); - auto v_zp3 = _mm512_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb0 = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), - v_zp0); - auto vb1 = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + n + 8)))), - v_zp1); - auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), - v_zp2); - auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( - _mm_loadu_si128(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), - v_zp3); - - vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); + src_offset = 0; + dst_offset = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + while (dst_offset < n) { + auto vsum0 = _mm512_setzero_ps(); + auto vsum1 = _mm512_setzero_ps(); + auto vsum2 = _mm512_setzero_ps(); + auto vsum3 = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto b1 = reinterpret_cast(b + src_offset + src_stride); + auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); + auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); + auto v_zp0 = _mm512_set1_ps(b0[1]); + auto v_zp1 = _mm512_set1_ps(b1[1]); + auto v_zp2 = _mm512_set1_ps(b2[1]); + auto v_zp3 = _mm512_set1_ps(b3[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + dst_offset + i); + auto vb0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); + + vsum0 = _mm512_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm512_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm512_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm512_fmadd_ps(va, vb3, vsum3); + } + float group_sum0 = _mm512_reduce_add_ps(vsum0); + float group_sum1 = _mm512_reduce_add_ps(vsum1); + float group_sum2 = _mm512_reduce_add_ps(vsum2); + float group_sum3 = _mm512_reduce_add_ps(vsum3); + for (; i < group_size; i++) { + group_sum0 += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + group_sum1 += a[i + dst_offset] * (b_data_ptr[i + src_stride] - b1[1]); + group_sum2 += a[i + dst_offset] * (b_data_ptr[i + 2 * src_stride] - b2[1]); + group_sum3 += a[i + dst_offset] * (b_data_ptr[i + 3 * src_stride] - b3[1]); + } + sum0 += group_sum0 * b0[0]; + sum1 += group_sum1 * b1[0]; + sum2 += group_sum2 * b2[0]; + sum3 += group_sum3 * b3[0]; + dst_offset += group_size; + src_offset += group_size + params_offset; } - float sum0 = _mm512_reduce_add_ps(vsum0); - float sum1 = _mm512_reduce_add_ps(vsum1); - float sum2 = _mm512_reduce_add_ps(vsum2); - float sum3 = _mm512_reduce_add_ps(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; c += 4; - b += 4 * (n + 8) - 8; + b += 4 * src_stride; } for (; j < block_size; j++) { - auto vsum = _mm512_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm512_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { - auto va = mm512_uni_loadu_ps(a + i); - auto vb = _mm512_sub_ps( - _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b + i)))), - v_zp); - vsum = _mm512_fmadd_ps(va, vb, vsum); - } - float sum = _mm512_reduce_add_ps(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + src_offset = 0; + dst_offset = 0; + float sum = 0; + while (dst_offset < n) { + auto vsum = _mm512_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto v_zp = _mm512_set1_ps(b0[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx512 <= group_size; i += vec_len_f32_avx512) { + auto va = mm512_uni_loadu_ps(a + dst_offset + i); + auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp); + vsum = _mm512_fmadd_ps(va, vb, vsum); + } + float group_sum = _mm512_reduce_add_ps(vsum); + for (; i < group_size; i++) { + group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += group_size; + src_offset += group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } return; # elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { - auto vsum0 = _mm256_setzero_ps(); - auto vsum1 = _mm256_setzero_ps(); - auto vsum2 = _mm256_setzero_ps(); - auto vsum3 = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto b1 = reinterpret_cast(b + n + 8); - auto b2 = reinterpret_cast(b + (n + 8) * 2); - auto b3 = reinterpret_cast(b + (n + 8) * 3); - auto v_zp0 = _mm256_set1_ps(b0[1]); - auto v_zp1 = _mm256_set1_ps(b1[1]); - auto v_zp2 = _mm256_set1_ps(b2[1]); - auto v_zp3 = _mm256_set1_ps(b3[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb0 = _mm256_sub_ps( - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), - v_zp0); - auto vb1 = _mm256_sub_ps( - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + n + 8)))), - v_zp1); - auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 2 * (n + 8))))), - v_zp2); - auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i + 3 * (n + 8))))), - v_zp3); - - vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); - vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); - vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); - vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); + src_offset = 0; + dst_offset = 0; + float sum0 = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + while (dst_offset < n) { + auto vsum0 = _mm256_setzero_ps(); + auto vsum1 = _mm256_setzero_ps(); + auto vsum2 = _mm256_setzero_ps(); + auto vsum3 = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto b1 = reinterpret_cast(b + src_offset + src_stride); + auto b2 = reinterpret_cast(b + src_offset + src_stride * 2); + auto b3 = reinterpret_cast(b + src_offset + src_stride * 3); + auto v_zp0 = _mm256_set1_ps(b0[1]); + auto v_zp1 = _mm256_set1_ps(b1[1]); + auto v_zp2 = _mm256_set1_ps(b2[1]); + auto v_zp3 = _mm256_set1_ps(b3[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + dst_offset + i); + auto vb0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp0); + auto vb1 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i + src_stride)))), + v_zp1); + auto vb2 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 2 * src_stride)))), + v_zp2); + auto vb3 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<__m128i*>(b_data_ptr + i + 3 * src_stride)))), + v_zp3); + + vsum0 = _mm256_fmadd_ps(va, vb0, vsum0); + vsum1 = _mm256_fmadd_ps(va, vb1, vsum1); + vsum2 = _mm256_fmadd_ps(va, vb2, vsum2); + vsum3 = _mm256_fmadd_ps(va, vb3, vsum3); + } + hsum(vsum0); + hsum(vsum1); + hsum(vsum2); + hsum(vsum3); + float group_sum0 = _mm256_cvtss_f32(vsum0); + float group_sum1 = _mm256_cvtss_f32(vsum1); + float group_sum2 = _mm256_cvtss_f32(vsum2); + float group_sum3 = _mm256_cvtss_f32(vsum3); + for (; i < group_size; i++) { + group_sum0 += a[dst_offset + i] * (b[i] - b0[1]); + group_sum1 += a[dst_offset + i] * (b[i + src_stride] - b1[1]); + group_sum2 += a[dst_offset + i] * (b[i + 2 * src_stride] - b2[1]); + group_sum3 += a[dst_offset + i] * (b[i + 3 * src_stride] - b3[1]); + } + sum0 += group_sum0 * b0[0]; + sum1 += group_sum1 * b1[0]; + sum2 += group_sum2 * b2[0]; + sum3 += group_sum3 * b3[0]; + dst_offset += group_size; + src_offset += group_size + params_offset; } - hsum(vsum0); - hsum(vsum1); - hsum(vsum2); - hsum(vsum3); - float sum0 = _mm256_cvtss_f32(vsum0); - float sum1 = _mm256_cvtss_f32(vsum1); - float sum2 = _mm256_cvtss_f32(vsum2); - float sum3 = _mm256_cvtss_f32(vsum3); - for (; i < n; i++) { - sum0 += a[i] * (b[i] - b0[1]); - sum1 += a[i] * (b[i + n + 8] - b1[1]); - sum2 += a[i] * (b[i + 2 * (n + 8)] - b2[1]); - sum3 += a[i] * (b[i + 3 * (n + 8)] - b3[1]); - } - c[0] = sum0 * b0[0]; - c[1] = sum1 * b1[0]; - c[2] = sum2 * b2[0]; - c[3] = sum3 * b3[0]; + c[0] = sum0; + c[1] = sum1; + c[2] = sum2; + c[3] = sum3; c += 4; - b += 4 * (n + 8) - 8; + b += 4 * src_stride; } for (; j < block_size; j++) { - auto vsum = _mm256_setzero_ps(); - auto b0 = reinterpret_cast(b); - auto v_zp = _mm256_set1_ps(b0[1]); - size_t i = 0; - b += 8; - for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { - auto va = mm256_uni_loadu_ps(a + i); - auto vb = _mm256_sub_ps( - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b + i)))), - v_zp); - vsum = _mm256_fmadd_ps(va, vb, vsum); - } - hsum(vsum); - float sum = _mm256_cvtss_f32(vsum); - for (; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + src_offset = 0; + dst_offset = 0; + float sum = 0; + while (dst_offset < n) { + auto vsum = _mm256_setzero_ps(); + auto b0 = reinterpret_cast(b + src_offset); + auto v_zp = _mm256_set1_ps(b0[1]); + size_t i = 0; + uint8_t* b_data_ptr = b + src_offset + params_offset; + for (; i + vec_len_f32_avx2 <= group_size; i += vec_len_f32_avx2) { + auto va = mm256_uni_loadu_ps(a + dst_offset + i); + auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp); + vsum = _mm256_fmadd_ps(va, vb, vsum); + } + hsum(vsum); + float group_sum = _mm256_cvtss_f32(vsum); + for (; i < group_size; i++) { + group_sum += a[i + dst_offset] * (b_data_ptr[i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += group_size; + src_offset += group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } return; # endif for (size_t j = 0; j < block_size; j++) { float sum = 0; - auto b0 = reinterpret_cast(b); - b += 8; - for (size_t i = 0; i < n; i++) { - sum += a[i] * (b[i] - b0[1]); + dst_offset = 0; + src_offset = 0; + while (dst_offset < n) { + auto b0 = reinterpret_cast(b + src_offset); + float group_sum = 0.0f; + for (size_t i = 0; i < group_size; i++) { + group_sum += a[dst_offset + i] * (b[src_offset + params_offset + i] - b0[1]); + } + sum += group_sum * b0[0]; + dst_offset += group_size; + src_offset += group_size + params_offset; } - b += n; - *c++ = sum * b0[0]; + b += src_stride; + *c++ = sum; } } @@ -634,73 +853,138 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } // N must be multiple of 16 -template -void transpose_16NxK(TDST* dst, TSRC* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template ::type = true> +void transpose_16NxK(TDST* dst, + void* src, + TDST* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { size_t k = 0; + auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { for (size_t n = 0; n < N; n += 16) { - transpose_16x16_kernel(dst + n, src + n * src_stride, dst_stride, src_stride); + transpose_16x16_kernel(dst + n, src_ptr + n * src_stride, dst_stride, src_stride); } dst += 16 * dst_stride; - src += 16; + src_ptr += 16; } if (k < K) { for (size_t n = 0; n < N; n += 16) { - transpose_16xK_kernel(dst + n, src + n * src_stride, K - k, dst_stride, src_stride); + transpose_16xK_kernel(dst + n, src_ptr + n * src_stride, K - k, dst_stride, src_stride); } } } - # if defined(HAVE_AVX512F) template ::value || std::is_same::value), bool>::type> -static void transpose_16NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { + ov::element::Type_t SRC_PREC, + typename std::enable_if<(SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16) && + (SRC_PREC == precision_of::value), + bool>::type = true> +static void transpose_16NxK(T* dst, + T* src, + T* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { // will treat as uint32_t transpose auto s = reinterpret_cast(src); auto d = reinterpret_cast(dst); - transpose_16NxK(d, s, reinterpret_cast(0), N, K >> 1, dst_stride, src_stride >> 1); + transpose_16NxK(d, + s, + reinterpret_cast(0), + N, + K >> 1, + dst_stride, + src_stride >> 1, + group_size); } # endif -template -void transpose_16NxK(TDST* dst, uint8_t* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template ::type = true> +void transpose_16NxK(TDST* dst, + void* src, + TDST* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; + auto s = reinterpret_cast::value_type*>(src); auto t = tmp; + // if group_size not set, the whole row is used as a group for (size_t n = 0; n < N; n++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); - s += src_stride + 2 * sizeof(float); + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_kernel(s + src_offset + sizeof(float) * 2, + t + dst_offset, + group_size, + f[0], + f[1]); + src_offset += group_size + sizeof(float) * 2; + dst_offset += group_size; + } + s += src_offset; t += src_stride; } - transpose_16NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride, 0); } // dequant f16/u8 to float -template -static inline void dequant(T* dst, T* src, size_t N, size_t K) { +template ::type = true> +static inline void dequant(T* dst, void* src, const size_t N, const size_t K, const size_t group_size) { // never called OPENVINO_THROW("dequant: should not be called."); } - -static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K) { +template ::type = true> +static inline void dequant(float* dst, ov::float16* src, const size_t N, const size_t K, const size_t group_size) { cvt_copy(dst, src, K * N); } -template -void dequant(TDST* dst, uint8_t* src, size_t N, size_t K) { +template ::type = true> +void dequant(TDST* dst, uint8_t* src, const size_t N, const size_t K, const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; + const size_t params_offset = sizeof(float) * 2; + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); + for (size_t n = 0; n < N; n++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), dst, K, f[0], f[1]); - s += K + 2 * sizeof(float); + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_kernel(s + src_offset + params_offset, + dst + dst_offset, + group_size, + f[0], + f[1]); + src_offset += group_size / sub_byte_mulitplier + params_offset; + dst_offset += group_size; + } + s += src_offset; dst += K; } } @@ -772,54 +1056,101 @@ static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_strid } } -template ::value || std::is_same::value), bool>::type> -static void pack_32NxK(T* dst, T* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template ::value != ov::element::f32 && + (SRC_PREC == ov::element::bf16 || SRC_PREC == ov::element::f16), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { + auto src_ptr = reinterpret_cast::value_type*>(src); for (size_t n = 0; n < N; n += 32) { size_t k = 0; for (; k + 32 <= K; k += 32) { - pack_32x32_kernel(dst + k * 2, src + k, dst_stride, src_stride); + pack_32x32_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); } if (k + 16 <= K) { - pack_32x16_kernel(dst + k * 2, src + k, dst_stride, src_stride); + pack_32x16_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride); k += 16; } if (k < K) { - pack_32xK_kernel(dst + k * 2, src + k, dst_stride, src_stride, K - k); + pack_32xK_kernel(dst + k * 2, src_ptr + k, dst_stride, src_stride, K - k); } dst += 32 * dst_stride; - src += 32 * src_stride; + src_ptr += 32 * src_stride; } } -template ::value || std::is_same::value), bool>::type> -static void pack_32NxK(T* dst, uint8_t* src, T* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template ::value != ov::element::f32 && + (SRC_PREC == ov::element::u4 || SRC_PREC == ov::element::u8), + bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { // The layout for per token per head: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = src; + auto s = reinterpret_cast(src); auto t = tmp; + // if group_size not set, the whole row is used as a group + const size_t sub_byte_mulitplier = get_sub_byte_multiplier(SRC_PREC); for (size_t n = 0; n < N; n++) { - auto f = reinterpret_cast(s); - attn_dequant_u8_kernel(s + 2 * sizeof(float), t, K, f[0], f[1]); - s += src_stride + 2 * sizeof(float); + size_t src_offset = 0; + size_t dst_offset = 0; + while (dst_offset < K) { + auto f = reinterpret_cast(s + src_offset); + attn_dequant_kernel(s + (src_offset + sizeof(float) * 2), + t + dst_offset, + group_size, + f[0], + f[1]); + src_offset += group_size / sub_byte_mulitplier + sizeof(float) * 2; + dst_offset += group_size; + } + s += src_offset; t += src_stride; } - pack_32NxK(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + pack_32NxK::value>(dst, + tmp, + reinterpret_cast(0), + N, + K, + dst_stride, + src_stride, + group_size); } # endif -template -static void pack_32NxK(float* dst, T* src, float* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride) { +template ::value == ov::element::f32, bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + const size_t N, + const size_t K, + const size_t dst_stride, + const size_t src_stride, + const size_t group_size) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } -template +template struct MHAHelper { // initialize once size_t _H; @@ -831,6 +1162,8 @@ struct MHAHelper { size_t _nthr; size_t _sliding_window; float _d_scale; + size_t _key_group_size = 0; + size_t _value_group_size = 0; PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh @@ -860,6 +1193,12 @@ struct MHAHelper { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } + explicit MHAHelper(size_t key_group_size, size_t value_group_size) + : _key_group_size(key_group_size), + _value_group_size(value_group_size) { + _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); + } + void init(size_t H, size_t S, size_t SV, @@ -947,11 +1286,11 @@ struct MHAHelper { if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && precision_of::value == ov::element::bf16 && - precision_of::value == ov::element::bf16) { + precision_of::value == ov::element::bf16 && VALUE_PREC == ov::element::bf16) { _fastpath_valid_prec = ov::element::bf16; } else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_fp16) && precision_of::value == ov::element::f16 && - precision_of::value == ov::element::f16) { + precision_of::value == ov::element::f16 && VALUE_PREC == ov::element::f16) { _fastpath_valid_prec = ov::element::f16; } } @@ -1020,7 +1359,7 @@ struct MHAHelper { auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); for (size_t h = hq_beg; h < hq_end; h++) { auto* q_ptr = query.ptr(h, q_start, 0); @@ -1164,7 +1503,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk); } } @@ -1176,10 +1515,11 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk, _S, - std::min(_block_size, cur_kv_len - pk)); + std::min(_block_size, cur_kv_len - pk), + _key_group_size); } } } @@ -1217,14 +1557,20 @@ struct MHAHelper { memset(_output.ptr(ithr), 0, q_len * _H * _SV * sizeof(float)); for (size_t pv = 0, i = 0; pv < cur_kv_len; pv += _block_size, i++) { auto block_number = block_table[i]; - auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - attn_acc_value_block(_output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v, - _SV, - std::min(_block_size, cur_kv_len - pv)); + auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); + size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * + present_value.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = reinterpret_cast::value_type*>( + present_value.m_ptr.get() + v_stride); + attn_acc_value_block::value_type, VALUE_PREC>( + _output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } } } @@ -1301,7 +1647,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk); } } @@ -1310,10 +1656,11 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk, _S, - std::min(_block_size, context_len - pk)); + std::min(_block_size, context_len - pk), + _key_group_size); } } } @@ -1380,14 +1727,21 @@ struct MHAHelper { // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; - auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), - _weight_bhl.ptr(b, h, pq) + pv, - v, - _SV, - std::min(_block_size, context_len - pv)); + auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * + present_value.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = reinterpret_cast::value_type*>( + present_value.m_ptr.get() + v_stride); + attn_acc_value_block::value_type, VALUE_PREC>( + _output_bhl.ptr(ithr, b, pq, h), + _weight_bhl.ptr(b, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, context_len - pv), + _value_group_size); } } } @@ -1408,9 +1762,9 @@ struct MHAHelper { } }; -template +template struct MHA { - MHAHelper& _helper; + MHAHelper& _helper; struct AttnWorkItem { int32_t batch_in_reorder; // which batch in reorder buffer will be used int32_t batch_in_seq; // batch idx in sequence @@ -1510,7 +1864,7 @@ struct MHA { WorkItems _workitems; - MHA(MHAHelper& helper) : _helper(helper) {} + MHA(MHAHelper& helper) : _helper(helper) {} // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, @@ -1527,7 +1881,7 @@ struct MHA { auto Hk = v_cache.m_dims[1]; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size(); @@ -1547,30 +1901,45 @@ struct MHA { return; auto ithr = parallel_get_thread_num(); - auto* k_ptr = k_cache.ptr(block_number, hk); - auto* v_ptr = v_cache.ptr(block_number, hk); - transpose_16NxK(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - k_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._S, - _helper._block_size, - _helper._S); + auto* k_ptr = k_cache.ptr(block_number, hk); + + transpose_16NxK::value>( + _helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, + _helper._block_size, + _helper._S, + _helper._key_group_size); + if (q_is_xf16) { - pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV); + auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * + v_cache.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + pack_32NxK( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); } else { // need to decompress if (!q_cache_is_same) { - dequant(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._block_size, - _helper._SV); + auto sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) * + v_cache.get_precision().size() / sub_byte_multiplier; + auto* v_ptr = v_cache.m_ptr.get() + v_stride; + dequant( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._block_size, + _helper._SV, + _helper._value_group_size); } } }); @@ -1721,14 +2090,18 @@ struct MHA { } }; -template +template struct AttentionExecutor : public PagedAttentionExecutor { - MHAHelper _helper; - MHA _kernel; + MHAHelper _helper; + MHA _kernel; PlainTensor _slot_mapping; AttentionExecutor() : _kernel(_helper) {} + explicit AttentionExecutor(size_t key_group_size, size_t value_group_size) + : _helper(MHAHelper(key_group_size, value_group_size)), + _kernel(_helper) {} + void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, @@ -1769,8 +2142,22 @@ struct AttentionExecutor : public PagedAttentionExecutor { // The layout for per token per head for u8 kv cache: // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized // feature(u8,idx_S)| The actual size needs to deduct scale and zeropoint. - auto S = k_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); - auto SV = v_cache.size(3) - (k_cache.m_dt == ov::element::Type_t::u8 ? sizeof(float) * 2 : 0); + const size_t key_sub_byte_multiplier = get_sub_byte_multiplier(k_cache.get_precision()); + const size_t value_sub_byte_multiplier = get_sub_byte_multiplier(v_cache.get_precision()); + const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplier; + // u4 needs scale + zp. s4 needs scale. + const size_t param_size = + one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); + const size_t value_params_size = param_size * value_sub_byte_multiplier; + size_t key_group_num = + _helper._key_group_size ? k_cache.size(3) / (_helper._key_group_size + key_params_size) : 1; + size_t value_group_num = + _helper._value_group_size ? v_cache.size(3) / (_helper._value_group_size + value_params_size) : 1; + auto S = k_cache.size(3) - (k_cache.get_precision().is_real() ? 0 : key_params_size * key_group_num); + auto SV = v_cache.size(3) - (v_cache.get_precision().is_real() ? 0 : value_params_size * value_group_num); + // revise group_size if it's zero. + _helper._key_group_size = _helper._key_group_size ? _helper._key_group_size : S; + _helper._value_group_size = _helper._value_group_size ? _helper._value_group_size : SV; auto block_size = k_cache.size(2); auto H = q.size(1) / S; auto h_each_group_len = 1; @@ -1785,9 +2172,10 @@ struct AttentionExecutor : public PagedAttentionExecutor { q = q.reshape({B_token, H, 1, S}); k = k.reshape({B_token, Hk, 1, S}); v = v.reshape({B_token, Hk, 1, SV}); + if (k_cache.m_dt == ov::element::Type_t::u8) { - k_cache.assert_dims({0, Hk, block_size, S + sizeof(float) * 2}, true); - v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + sizeof(float) * 2}); + k_cache.assert_dims({0, Hk, block_size, S + key_params_size * key_group_num}, true); + v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV + value_params_size * value_group_num}); } else { k_cache.assert_dims({0, Hk, block_size, S}, true); v_cache.assert_dims({k_cache.m_dims[0], Hk, block_size, SV}); @@ -1837,7 +2225,13 @@ struct AttentionExecutor : public PagedAttentionExecutor { } if (k_cache.m_dt == ov::element::Type_t::u8) { - paged_attn_quantkv(k, v, k_cache, v_cache, _slot_mapping); + paged_attn_quantkv(k, + v, + k_cache, + v_cache, + _slot_mapping, + _helper._key_group_size, + _helper._value_group_size); } else { paged_attn_memcpy(k, v, k_cache, v_cache, _slot_mapping); } @@ -1887,40 +2281,79 @@ struct AttentionExecutor : public PagedAttentionExecutor { }; #endif -std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type) { +std::shared_ptr make_pa_executor(ov::element::Type data_type, + ov::element::Type key_cache_type, + ov::element::Type value_cache_type, + size_t key_group_size, + size_t value_group_size) { std::shared_ptr executor; #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { # if defined(HAVE_AVX512F) - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + if (value_cache_type == ov::element::u4) { + executor = + std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = + std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } + } else { - OPENVINO_ASSERT(kvcache_type == ov::element::bf16, "expect kvcache type bf16, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); # endif } else if (data_type == ov::element::f16) { # if defined(HAVE_AVX512F) - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + if (value_cache_type == ov::element::u4) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } } else { - OPENVINO_ASSERT(kvcache_type == ov::element::f16, "expect kvcache type f16, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); + executor = std::make_shared>(); } # else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); # endif } else if (data_type == ov::element::f32) { - if (kvcache_type == ov::element::u8) { - executor = std::make_shared>(); - } else if (kvcache_type == ov::element::f16) { - executor = std::make_shared>(); + if (key_cache_type == ov::element::u8) { + if (value_cache_type == ov::element::u4) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else if (value_cache_type == ov::element::u8) { + executor = std::make_shared>(key_group_size, + value_group_size); + } else { + OPENVINO_THROW("make_pa_executor: key_cache_type u8 with value_cache_type ", + value_cache_type.to_string(), + " is not support"); + } + } else if (key_cache_type == ov::element::f16) { + executor = std::make_shared>(key_group_size, + value_group_size); } else { - OPENVINO_ASSERT(kvcache_type == ov::element::f32, "expect kvcache type f32, current: ", kvcache_type); - executor = std::make_shared>(); + OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); + executor = + std::make_shared>(key_group_size, value_group_size); } } else { OPENVINO_THROW("make_pa_executor: unsupported precision: ", data_type); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index d28125b3898460..64e4eefc3b760d 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -17,7 +17,11 @@ namespace Extensions { namespace Cpu { namespace XARCH { -std::shared_ptr make_pa_executor(ov::element::Type data_type, ov::element::Type kvcache_type); +std::shared_ptr make_pa_executor(ov::element::Type data_type, + ov::element::Type key_cache_type, + ov::element::Type value_cache_type, + size_t key_group_size, + size_t value_group_size); } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index b51b2b3d8029a9..54aa80e9dff7c0 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -84,13 +84,14 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { OPENVINO_ASSERT(orgInputNumber == 13, "The input number of PagedAttention should be 13."); // kvcache, float, [] - auto past_kv_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto past_key_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto past_value_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); config.inConfs[PagedAttentionExecutor::ID_KCACHE].setMemDesc( creatorsMap.at(LayoutType::ncsp) - ->createSharedDesc(past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); + ->createSharedDesc(past_key_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_KCACHE))); config.inConfs[PagedAttentionExecutor::ID_VCACHE].setMemDesc( creatorsMap.at(LayoutType::ncsp) - ->createSharedDesc(past_kv_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); + ->createSharedDesc(past_value_input_mem_precision, getInputShapeAtPort(PagedAttentionExecutor::ID_VCACHE))); // past_lens, int, [b_seq] config.inConfs[PagedAttentionExecutor::ID_PAST_LENS].setMemDesc( creatorsMap.at(LayoutType::ncsp) @@ -140,8 +141,14 @@ void PagedAttention::createPrimitive() { auto builder = [&](const PagedAttentionKey& key) -> std::shared_ptr { #ifdef OPENVINO_ARCH_X86_64 - auto kvCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); - return make_pa_executor(rtPrecision, kvCachePrecision); + // Since we are quantize only last dim it's safe to use the last dim of KV. + auto kCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); + auto vCachePrecision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); + const auto& cpuConfig = context->getConfig(); + + size_t key_group_size = cpuConfig.keyCacheGroupSize; + size_t value_group_size = cpuConfig.valueCacheGroupSize; + return make_pa_executor(rtPrecision, kCachePrecision, vCachePrecision, key_group_size, value_group_size); #else return nullptr; #endif @@ -202,6 +209,20 @@ void PagedAttention::execute(dnnl::stream strm) { bool PagedAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { + auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE); + auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE); + if (one_of(vCachePrecision, + ov::element::u4, + ov::element::u8, + ov::element::f32, + ov::element::f16, + ov::element::bf16)) { + if (!one_of(kCachePrecision, ov::element::u8, ov::element::f16, ov::element::f32, ov::element::bf16)) { + errorMessage = "PageAttn key value cache compression doesn't support key cache prec " + + kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string(); + return false; + } + } int orgInput = static_cast(op->get_input_size()); if (op->get_type_name() == std::string("PagedAttentionExtension") && orgInput == PagedAttentionExecutor::ID_SLIDING_WINDOW + 1) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index 7fe3fc8dc5045d..aec0ff2e7d9026 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -436,12 +436,14 @@ struct MHAKernel { } T* v_ptr = is_xf16 ? &wv_scratch_b.at({b, h / h_each_group_len, 0}) : &present_value.at({b, h / h_each_group_len, 0, 0}); - wv_gemm_ptr->executeGemm(m_cnt < m_block_size, - w_ptr, - v_ptr, - fp32_out_ptr, - wsp.data() + tid * wsp_size_per_thread, - wv_scratch_a ? &wv_scratch_a.at({tid, 0}) : nullptr); + wv_gemm_ptr->executeGemm(m_cntget_scratch_a_size()> 0 + ? &wv_scratch_a.at({tid, 0}) + : nullptr); if (is_xf16) { if (has_out_transpose) { attn_memcpy2d_kernel(&fp32_out.at({b, m_start, h, 0}), @@ -1059,6 +1061,26 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrgetConfig(); + const auto& keyCachePrecision = cpuConfig.keyCachePrecision; + const auto& valueCachePrecision = cpuConfig.valueCachePrecision; + const auto keyDims = getInputShapeAtPort(1).getDims(); + const auto valueDims = getInputShapeAtPort(2).getDims(); + const auto keyS = *(keyDims.end() - 1); + const auto valueS = *(valueDims.end() - 1); + OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision, + "CPU: SDPA node only supports same key/value cache precision"); + OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8), + "CPU: SDPA only supports key/value cache precision f32, f16, bf16, u8 but gets ", + keyCachePrecision); + m_key_quant_param.groupSize = (cpuConfig.keyCacheGroupSize == 0 || keyS % cpuConfig.keyCacheGroupSize != 0) + ? keyS + : cpuConfig.keyCacheGroupSize; + m_key_quant_param.precision = keyCachePrecision; + m_value_quant_param.groupSize = (cpuConfig.valueCacheGroupSize == 0 || valueS % cpuConfig.valueCacheGroupSize != 0) + ? valueS + : cpuConfig.valueCacheGroupSize; + m_key_quant_param.precision = valueCachePrecision; if (const auto node = std::dynamic_pointer_cast(op)) { m_config.config.is_causal = node->get_causal(); @@ -1833,12 +1855,16 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M ov::element::Type ScaledDotProductAttention::getKVCachePrecision() { ov::element::Type kvcache_precision; + // TODO: SDPA only supports same key/value cache precision. auto rtPrecision = getRuntimePrecision(); - auto kvCachePrecisionHint = context->getConfig().kvCachePrecision; + auto keyCachePrecisionHint = context->getConfig().keyCachePrecision; + auto valueCachePrecisionHint = context->getConfig().valueCachePrecision; bool enableKVCacheFP16 = m_config.config.fuse_concat && mayiuse(cpu_isa_t::avx2) && - rtPrecision != ov::element::bf16 && kvCachePrecisionHint == ov::element::f16; + rtPrecision != ov::element::bf16 && + (keyCachePrecisionHint == ov::element::f16 && valueCachePrecisionHint == ov::element::f16); kvcache_precision = enableKVCacheFP16 ? ov::element::f16 : rtPrecision; - bool use_int8_kv_cache_precision = kvCachePrecisionHint == ov::element::u8; + bool use_int8_kv_cache_precision = + (keyCachePrecisionHint == ov::element::u8 && valueCachePrecisionHint == ov::element::u8); if (use_int8_kv_cache_precision) kvcache_precision = ov::element::u8; else diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 21b9056ba9517c..2917342314fafd 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -47,7 +47,10 @@ class ScaledDotProductAttention : public Node { real_order = {permute_axes[2], permute_axes[0], permute_axes[1], permute_axes[3]}; return real_order; } - + struct SDPAQuantParam { + ov::element::Type precision = ov::element::undefined; + size_t groupSize = 0; + }; ov::element::Type getKVCachePrecision(); private: @@ -86,6 +89,8 @@ class ScaledDotProductAttention : public Node { // (0, 1, 2, 3) for BHLS // (2, 0, 1, 3) for LBHS std::vector m_kvstate_layout = {2, 0, 1, 3}; + SDPAQuantParam m_key_quant_param; + SDPAQuantParam m_value_quant_param; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index b3c2aa0b298a5a..ec9b37c2c2d22e 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -392,6 +392,14 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& options) engConfig.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(engConfig.kvCachePrecision); + } else if (name == ov::key_cache_precision) { + return decltype(ov::key_cache_precision)::value_type(engConfig.keyCachePrecision); + } else if (name == ov::value_cache_precision) { + return decltype(ov::value_cache_precision)::value_type(engConfig.valueCachePrecision); + } else if (name == ov::key_cache_group_size) { + return decltype(ov::key_cache_group_size)::value_type(engConfig.keyCacheGroupSize); + } else if (name == ov::value_cache_group_size) { + return decltype(ov::value_cache_group_size)::value_type(engConfig.valueCacheGroupSize); } return get_ro_property(name, options); } @@ -435,6 +443,10 @@ ov::Any Plugin::get_ro_property(const std::string& name, const ov::AnyMap& optio RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), + RW_property(ov::key_cache_precision.name()), + RW_property(ov::value_cache_precision.name()), + RW_property(ov::key_cache_group_size.name()), + RW_property(ov::value_cache_group_size.name()), }; OPENVINO_SUPPRESS_DEPRECATED_START diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp index 73086b78a0de95..9d38d03e5eadde 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_executable_network/properties.cpp @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/runtime/properties.hpp" + #include -#include "utils/properties_test.hpp" -#include "openvino/runtime/system_conf.hpp" -#include "openvino/runtime/core.hpp" #include "openvino/runtime/compiled_model.hpp" -#include "openvino/runtime/properties.hpp" +#include "openvino/runtime/core.hpp" #include "openvino/runtime/intel_cpu/properties.hpp" +#include "openvino/runtime/system_conf.hpp" +#include "utils/properties_test.hpp" namespace { @@ -41,6 +42,10 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSupportedPropertiesAreAvailable RO_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RO_property(ov::hint::dynamic_quantization_group_size.name()), RO_property(ov::hint::kv_cache_precision.name()), + RO_property(ov::key_cache_precision.name()), + RO_property(ov::value_cache_precision.name()), + RO_property(ov::key_cache_group_size.name()), + RO_property(ov::value_cache_group_size.name()), }; ov::Core ie; @@ -84,7 +89,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkSetROPropertiesThrow) { TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -97,7 +102,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::num_streams(streams))); @@ -110,7 +115,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreStreamsHasHigherPriori TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanLatencyHint) { ov::Core ie; - int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams + int32_t streams = ov::get_number_of_cpu_cores(); // latency hint should apply lower number of streams int32_t value = 0; OV_ASSERT_NO_THROW(ie.set_property(deviceName, ov::hint::performance_mode(ov::hint::PerformanceMode::LATENCY))); @@ -125,7 +130,7 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPrior TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelStreamsHasHigherPriorityThanThroughputHint) { ov::Core ie; - int32_t streams = 1; // throughput hint should apply higher number of streams + int32_t streams = 1; // throughput hint should apply higher number of streams int32_t value = 0; ov::AnyMap config; @@ -183,6 +188,36 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) { ASSERT_EQ(kv_cache_precision_value, ov::element::f32); } +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) { + ov::Core core; + + core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::value_cache_precision(ov::element::u4)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_precision_value = ov::element::undefined; + auto value_cache_precision_value = ov::element::undefined; + OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::value_cache_precision)); + ASSERT_EQ(key_cache_precision_value, ov::element::f16); + ASSERT_EQ(value_cache_precision_value, ov::element::u4); +} + +TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) { + ov::Core core; + + core.set_property(deviceName, ov::key_cache_group_size(32)); + core.set_property(deviceName, ov::value_cache_group_size(16)); + ov::CompiledModel compiledModel = core.compile_model(model, deviceName); + + auto key_cache_group_size_value = 0; + auto value_cache_group_size_value = 0; + OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::value_cache_group_size)); + ASSERT_EQ(key_cache_group_size_value, 32); + ASSERT_EQ(value_cache_group_size_value, 16); +} + TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantizationGroupSize) { ov::Core core; @@ -226,7 +261,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckExecutionModeIsAvailableIn ASSERT_FALSE(model_exec_mode_it->is_mutable()); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCoreInferencePrecision) { ov::Core ie; auto inference_precision_value = ov::element::undefined; @@ -240,7 +276,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHas ASSERT_EQ(inference_precision_value, bf16_if_can_be_emulated); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckCoreInferencePrecisionHasHigherPriorityThanModelPerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::ACCURACY; auto inference_precision_value = ov::element::undefined; @@ -258,7 +295,8 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCoreInferencePrecisionHasH ASSERT_EQ(inference_precision_value, ov::element::f32); } -TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { +TEST_F(OVClassConfigTestCPU, + smoke_CpuExecNetworkCheckModelInferencePrecisionHasHigherPriorityThanCorePerformanceExecutionMode) { ov::Core ie; auto execution_mode_value = ov::hint::ExecutionMode::PERFORMANCE; auto inference_precision_value = ov::element::undefined; @@ -289,14 +327,13 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckLogLevel) { OV_ASSERT_NO_THROW(value = compiledModel.get_property(ov::log::level)); ASSERT_EQ(value.as(), ov::log::Level::NO); } - //check set and get - const std::vector logLevels = { - ov::log::Level::ERR, - ov::log::Level::NO, - ov::log::Level::WARNING, - ov::log::Level::INFO, - ov::log::Level::DEBUG, - ov::log::Level::TRACE}; + // check set and get + const std::vector logLevels = {ov::log::Level::ERR, + ov::log::Level::NO, + ov::log::Level::WARNING, + ov::log::Level::INFO, + ov::log::Level::DEBUG, + ov::log::Level::TRACE}; for (unsigned int i = 0; i < logLevels.size(); i++) { ov::Any value; @@ -331,50 +368,109 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptions) { ov::Core ie; ov::Any type; ov::Any size; + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); + model->set_rt_info("u8", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("u8", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::f16); ASSERT_EQ(size.as(), 0); + ASSERT_EQ(keySize.as(), 32); + ASSERT_EQ(valueSize.as(), 16); + ASSERT_EQ(keyCacheType.as(), ov::element::u8); + ASSERT_EQ(valueCacheType.as(), ov::element::u8); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCompileConfig) { ov::Core ie; ov::Any type; ov::Any size; + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("0", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("0", "runtime_options", ov::value_cache_group_size.name()); + model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); ov::AnyMap config; config[ov::hint::kv_cache_precision.name()] = "u8"; config[ov::hint::dynamic_quantization_group_size.name()] = "16"; + // propperty has higher priority than rt_info + config[ov::key_cache_group_size.name()] = "32"; + config[ov::value_cache_group_size.name()] = "16"; + // key/value cache prec has higher priority than kvCachePrec + config[ov::key_cache_precision.name()] = "f16"; + config[ov::value_cache_precision.name()] = "bf16"; OV_ASSERT_NO_THROW(compiledModel = ie.compile_model(model, deviceName, config)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); ASSERT_EQ(type.as(), ov::element::u8); ASSERT_EQ(size.as(), 16); + ASSERT_EQ(keySize.as(), 32); + ASSERT_EQ(valueSize.as(), 16); + ASSERT_EQ(keyCacheType.as(), ov::element::f16); + ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckCPURuntimOptionsWithCoreProperties) { ov::Core core; ov::Any type; ov::Any size; - + ov::Any keySize; + ov::Any valueSize; + ov::Any keyCacheType; + ov::Any valueCacheType; core.set_property(deviceName, ov::hint::kv_cache_precision(ov::element::f32)); core.set_property(deviceName, ov::hint::dynamic_quantization_group_size(16)); + core.set_property(deviceName, ov::key_cache_group_size(8)); + core.set_property(deviceName, ov::value_cache_group_size(8)); + core.set_property(deviceName, ov::key_cache_precision(ov::element::f16)); + core.set_property(deviceName, ov::value_cache_precision(ov::element::bf16)); ov::CompiledModel compiledModel; model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name()); model->set_rt_info("0", "runtime_options", ov::hint::dynamic_quantization_group_size.name()); + model->set_rt_info("32", "runtime_options", ov::key_cache_group_size.name()); + model->set_rt_info("16", "runtime_options", ov::value_cache_group_size.name()); + // User's setting has higher priority than rt_info + model->set_rt_info("f32", "runtime_options", ov::key_cache_precision.name()); + model->set_rt_info("f32", "runtime_options", ov::value_cache_precision.name()); OV_ASSERT_NO_THROW(compiledModel = core.compile_model(model, deviceName)); OV_ASSERT_NO_THROW(type = compiledModel.get_property(ov::hint::kv_cache_precision)); OV_ASSERT_NO_THROW(size = compiledModel.get_property(ov::hint::dynamic_quantization_group_size)); + OV_ASSERT_NO_THROW(keySize = compiledModel.get_property(ov::key_cache_group_size)); + OV_ASSERT_NO_THROW(valueSize = compiledModel.get_property(ov::value_cache_group_size)); + OV_ASSERT_NO_THROW(keyCacheType = compiledModel.get_property(ov::key_cache_precision)); + OV_ASSERT_NO_THROW(valueCacheType = compiledModel.get_property(ov::value_cache_precision)); + ASSERT_EQ(type.as(), ov::element::f32); ASSERT_EQ(size.as(), 16); + ASSERT_EQ(keySize.as(), 8); + ASSERT_EQ(valueSize.as(), 8); + ASSERT_EQ(keyCacheType.as(), ov::element::f16); + ASSERT_EQ(valueCacheType.as(), ov::element::bf16); } -} // namespace +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp index 904d2b81dc05b6..c6289a4dc80716 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/behavior/ov_plugin/properties.cpp @@ -56,6 +56,10 @@ TEST_F(OVClassConfigTestCPU, smoke_PluginAllSupportedPropertiesAreAvailable) { RW_property(ov::intel_cpu::sparse_weights_decompression_rate.name()), RW_property(ov::hint::dynamic_quantization_group_size.name()), RW_property(ov::hint::kv_cache_precision.name()), + RW_property(ov::key_cache_precision.name()), + RW_property(ov::value_cache_precision.name()), + RW_property(ov::key_cache_group_size.name()), + RW_property(ov::value_cache_group_size.name()), }; ov::Core ie;