diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index fec6252373a70b..aec0ff2e7d9026 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -1068,21 +1068,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr(op)) { m_config.config.is_causal = node->get_causal(); } else if (const auto node = std::dynamic_pointer_cast(op)) { diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 21b9056ba9517c..2917342314fafd 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -47,7 +47,10 @@ class ScaledDotProductAttention : public Node { real_order = {permute_axes[2], permute_axes[0], permute_axes[1], permute_axes[3]}; return real_order; } - + struct SDPAQuantParam { + ov::element::Type precision = ov::element::undefined; + size_t groupSize = 0; + }; ov::element::Type getKVCachePrecision(); private: @@ -86,6 +89,8 @@ class ScaledDotProductAttention : public Node { // (0, 1, 2, 3) for BHLS // (2, 0, 1, 3) for LBHS std::vector m_kvstate_layout = {2, 0, 1, 3}; + SDPAQuantParam m_key_quant_param; + SDPAQuantParam m_value_quant_param; }; } // namespace node