From 08aeea716bf456634376021fe702e69befb13db0 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 22 Jan 2025 07:30:14 +0100 Subject: [PATCH] [Snippets][CPU] Disable MHA tokenization in LLM --- .../transformation_pipeline.cpp | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index da61917a146db0..ee954c018e6332 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -1031,18 +1031,34 @@ void Transformations::MainSnippets(void) { } CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config); - // - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM - const bool isMHASupported = -#if defined(OPENVINO_ARCH_ARM64) - false; -#else +#if defined(OPENVINO_ARCH_X86_64) + // Currently, Snippets don't provide efficient execution for single token inference in LLM case. + // To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'. + // We consider the presence of `ScaledDotProductAttentionWithKVCache` op in the model as a sign that this model is + // LLM. + const auto is_LLM = [this]() { + // Note: the variable `ops` should not exist during `SnippetsTokenization` execution. + // Otherwise, it will extend the life time of ops (since they're stored as shared ptrs) and + // they will be visible in the model during the tokenization passes even after removing or replacing. + const auto ops = model->get_ops(); + return std::any_of(ops.cbegin(), ops.cend(), [](const std::shared_ptr& op) { + return ov::is_type(op); + }); + }(); + + // CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM + const auto is_infer_prc_supported_by_MHA = (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && one_of(config.inferencePrecision, ov::element::f16)); + const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA; +#else + const bool isMHASupported = false; #endif + if (!isMHASupported) { CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets); CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);