Skip to content

Commit

Permalink
[Snippets][CPU] Disable MHA tokenization in LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 22, 2025
1 parent d757efd commit 121697d
Showing 1 changed file with 14 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1031,18 +1031,27 @@ void Transformations::MainSnippets(void) {
}
CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);

// - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const bool isMHASupported =
#if defined(OPENVINO_ARCH_ARM64)
false;
#else
#if defined(OPENVINO_ARCH_X86_64)
// Currently, Snippets don't provide efficient execution for single token inference in LLM case.
// To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'.
// We consider the presence of `ScaledDotProductAttentionWithKVCache` op in the model as a sign that this model is LLM.
const auto ops = model->get_ordered_ops();
const auto is_LLM = std::any_of(ops.cbegin(), ops.cend(), [](const std::shared_ptr<ov::Node>& op) {
return ov::is_type<intel_cpu::ScaledDotProductAttentionWithKVCache>(op);
});
// CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM
const auto is_infer_prc_supported_by_MHA =
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA;
#else
const bool isMHASupported = false;
#endif

if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);
Expand Down

0 comments on commit 121697d

Please sign in to comment.