Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Change kvcache default type of PagedAttention to u8 for CPU plugin #1206

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
22 changes: 3 additions & 19 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,14 @@ class DeviceConfig {
m_block_size = get_block_size_by_device(device);

if (m_device == "CPU") {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
m_kv_cache_type = inference_precision == ov::element::bf16 ? ov::element::bf16 : ov::element::f16;

// if user sets precision hint, kv cache type should be changed
const auto inference_precision_it = plugin_config.find(ov::hint::inference_precision.name());
if (inference_precision_it != plugin_config.end()) {
const auto inference_precision = inference_precision_it->second.as<ov::element::Type>();
if (inference_precision == ov::element::f32) {
m_kv_cache_type = ov::element::f32;
} else if (inference_precision == ov::element::f16) {
m_kv_cache_type = ov::element::f16;
} else if (inference_precision == ov::element::bf16) {
m_kv_cache_type = ov::element::bf16;
} else {
// use default f32
m_kv_cache_type = ov::element::f32;
}
}

// if user sets ov::kv_cache_precision hint
const auto kv_cache_precision_it = plugin_config.find(ov::hint::kv_cache_precision.name());
if (kv_cache_precision_it != plugin_config.end()) {
const auto kv_cache_precision = kv_cache_precision_it->second.as<ov::element::Type>();
m_kv_cache_type = kv_cache_precision;
} else {
// x86 and arm have different default kv cache type
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
m_kv_cache_type = core.get_property(device, ov::hint::kv_cache_precision);
}
} else if (m_device.find("GPU") != std::string::npos) {
auto inference_precision = core.get_property(device, ov::hint::inference_precision);
Expand Down
Loading