Skip to content

Commit

Permalink
[CPU]fix group_size in sdpa
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang Yi <[email protected]>
  • Loading branch information
zhangYiIntel committed Jan 7, 2025
1 parent 0515410 commit 7a412f7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
19 changes: 9 additions & 10 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,21 +1068,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ov::N
const auto valueDims = getInputShapeAtPort(2).getDims();
const auto keyS = *(keyDims.end() - 1);
const auto valueS = *(valueDims.end() - 1);
if (keyS % cpuConfig.keyCacheGroupSize != 0) {
OPENVINO_THROW("ScaledDotProductAttention AttentionExecutor creation fails key state " + std::to_string(keyS) +
" cannot be divided by group size " + std::to_string(cpuConfig.keyCacheGroupSize));
}

if (valueS % cpuConfig.valueCacheGroupSize != 0) {
OPENVINO_THROW("ScaledDotProductAttention AttentionExecutor creation fails value state " +
std::to_string(keyS) + " cannot be divided by group size " +
std::to_string(cpuConfig.valueCacheGroupSize));
}
OPENVINO_ASSERT(valueCachePrecision == keyCachePrecision,
"CPU: SDPA node only supports same key/value cache precision");
OPENVINO_ASSERT(one_of(keyCachePrecision, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8),
"CPU: SDPA only supports key/value cache precision f32, f16, bf16, u8 but gets ",
keyCachePrecision);
m_key_quant_param.groupSize = (cpuConfig.keyCacheGroupSize == 0 || keyS % cpuConfig.keyCacheGroupSize != 0)
? keyS
: cpuConfig.keyCacheGroupSize;
m_key_quant_param.precision = keyCachePrecision;
m_value_quant_param.groupSize = (cpuConfig.valueCacheGroupSize == 0 || valueS % cpuConfig.valueCacheGroupSize != 0)
? valueS
: cpuConfig.valueCacheGroupSize;
m_key_quant_param.precision = valueCachePrecision;

if (const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
m_config.config.is_causal = node->get_causal();
} else if (const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionWithKVCache>(op)) {
Expand Down
7 changes: 6 additions & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class ScaledDotProductAttention : public Node {
real_order = {permute_axes[2], permute_axes[0], permute_axes[1], permute_axes[3]};
return real_order;
}

struct SDPAQuantParam {
ov::element::Type precision = ov::element::undefined;
size_t groupSize = 0;
};
ov::element::Type getKVCachePrecision();

private:
Expand Down Expand Up @@ -86,6 +89,8 @@ class ScaledDotProductAttention : public Node {
// (0, 1, 2, 3) for BHLS
// (2, 0, 1, 3) for LBHS
std::vector<size_t> m_kvstate_layout = {2, 0, 1, 3};
SDPAQuantParam m_key_quant_param;
SDPAQuantParam m_value_quant_param;
};

} // namespace node
Expand Down

0 comments on commit 7a412f7

Please sign in to comment.