diff --git a/cmake/features.cmake b/cmake/features.cmake index 386af20211050f..dc8ebeeb9371ad 100644 --- a/cmake/features.cmake +++ b/cmake/features.cmake @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF) ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF) -ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND X86_64" OFF) +ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND (X86_64 OR AARCH64)" OFF) ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF) diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index f47bc0b5d86fc2..9ba84937b261df 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -160,6 +160,10 @@ if(ENABLE_CPU_DEBUG_CAPS) add_definitions(-DCPU_DEBUG_CAPS) endif() +if(AARCH64 AND (NOT ANDROID)) + set(ENABLE_SNIPPETS_LIBXSMM_TPP ON) +endif() + if (ENABLE_SNIPPETS_LIBXSMM_TPP) # Note: LIBXSMM_DEFAULT_CONFIG needed so libxsmm_config can be included without issues add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG) @@ -198,7 +202,9 @@ if(NOT X86_64) ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/* ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/* ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/* - ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*) + ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/* + ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/* + ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*) endif() if (AARCH64) @@ -208,7 +214,9 @@ endif() if(NOT (AARCH64 OR ARM)) list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/* + ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/* ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/* + ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/aarch64/* ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/* ${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*) endif() diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp index a915fb0fe17e21..c6b66a8cd7f215 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -10,6 +10,7 @@ #include "emitters/snippets/aarch64/jit_kernel_emitter.hpp" #include "emitters/snippets/aarch64/jit_loop_emitters.hpp" #include "emitters/snippets/aarch64/jit_memory_emitters.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" #include "emitters/snippets/cpu_runtime_configurator.hpp" #include "emitters/utils.hpp" #include "jit_snippets_emitters.hpp" @@ -24,12 +25,17 @@ #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" +#ifdef SNIPPETS_LIBXSMM_TPP +# include "emitters/tpp/aarch64/jit_brgemm_emitter.hpp" +# include "transformations/tpp/common/op/brgemm.hpp" +#endif + namespace ov { -#define CREATE_SNIPPETS_EMITTER(e_type) \ +#define CREATE_SNIPPETS_EMITTER(e_type, ...) \ { \ [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(h.get(), isa, expr); \ + return std::make_shared(h.get(), isa, expr, ##__VA_ARGS__); \ }, \ [](const std::shared_ptr& n) -> std::set> { \ return e_type::get_supported_precisions(n); \ @@ -202,6 +208,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa, jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter); jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter); +#ifdef SNIPPETS_LIBXSMM_TPP + // brgemm + jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache); +#endif + // control flow jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter); jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_dynamic_emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp new file mode 100644 index 00000000000000..548d8cc80a47e5 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.cpp @@ -0,0 +1,233 @@ +// Copyright (C) 2020-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm_base.hpp" + +#include "common/utils.hpp" +#include "dnnl_extension_utils.h" +#include "utils/general_utils.h" + +#define PRINT(X) ss << #X << " = " << X << "\n" +#define EQ(X) X == rhs.X +#define HASH(X) seed = dnnl::impl::hash_combine(seed, X) + +namespace ov { +namespace intel_cpu { + +bool BrgemmBaseKernelConfig::is_completed() const { + return !one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty(); +} + +bool BrgemmBaseKernelConfig::is_empty() const { + return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta); +} + +bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const { + return EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC); +} + +void BrgemmBaseKernelConfig::update(int64_t M, int64_t N, int64_t K, float beta) { + // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) + // To process this case, we have to make this Config as empty (nullify runtime parameters) + if (one_of(0, M, N, K)) { + m_M = 0; + m_N = 0; + m_K = 0; + m_beta = 0; + } else { + m_M = M; + m_N = N; + m_K = K; + m_beta = beta; + } +} + +void BrgemmBaseKernelConfig::update(int64_t M, + int64_t N, + int64_t K, + int64_t LDA, + int64_t LDB, + int64_t LDC, + float beta) { + // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) + // To process this case, we have to make this Config as empty (nullify runtime parameters) + if (one_of(0, M, N, K)) { + m_M = 0; + m_N = 0; + m_K = 0; + m_LDA = 0; + m_LDB = 0; + m_LDC = 0; + m_beta = 0; + } else { + m_M = M; + m_N = N; + m_K = K; + m_LDA = LDA; + m_LDB = LDB; + m_LDC = LDC; + m_beta = beta; + } +} + +size_t BrgemmBaseKernelConfig::compute_hash() const { + size_t seed = 0; + HASH(m_M); + HASH(m_N); + HASH(m_K); + HASH(m_LDA); + HASH(m_LDB); + HASH(m_LDC); + HASH(m_beta); + return seed; +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmBaseKernelConfig::to_string() const { + std::stringstream ss; + PRINT(m_M); + PRINT(m_N); + PRINT(m_K); + PRINT(m_LDA); + PRINT(m_LDB); + PRINT(m_LDC); + PRINT(m_beta); + return ss.str(); +} +#endif + +float BrgemmBaseKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, + int loop_id, + const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { + // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. + // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). + // It means that previous executed Loops have Loop ID less the current Loop ID. + // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have + // `beta = 1`. + // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. + if (loop_id > 0) { + const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); + // Check the previous Loops + --loop_id; + while (loop_id >= 0) { + const auto& expanded_loop_info = + loop_manager->get_loop_info(loop_id); + if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) + return 0; + if (expanded_loop_info->get_work_amount() > 0) { + // there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1` + return 1; + } + --loop_id; + } + } + return 0; +} + +void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig& config) { + const auto& input_pds = expr->get_input_port_descriptors(); + const auto& output_pds = expr->get_output_port_descriptors(); + OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, + "Invalid number of in/out port descriptors"); + + const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout()); + const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout()); + auto in0_subtensor = input_pds[0]->get_subtensor(); + auto in1_subtensor = input_pds[1]->get_subtensor(); + + // Need to update M, K, N + // 1. If the original value in subtensor is `FULL_DIM`, it means that + // Brgemm block should process full tensor by this dim -> take dimension from shape + // 2. Otherwise, Brgemm block processes part of the tensor by this dim + // (there is blocking by this dimension) -> take from Loop increment + + auto M = *++in0_subtensor.rbegin(); + auto K = *in0_subtensor.rbegin(); + auto N = *in1_subtensor.rbegin(); + + size_t loop_idx = 0; + const auto& loop_ids = expr->get_loop_ids(); + const auto& loop_manager = linear_ir->get_loop_manager(); + auto get_loop_info = [&]() { + OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); + return loop_manager->get_loop_info(loop_ids[loop_idx++]); + }; + + /* ------- Dimension M ----------*/ + if (ov::snippets::utils::is_full_dim_value(M)) { + M = *++in0_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` + // to avoid extra checks, we validate only first input port + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { + return p.get_dim_idx() == 1 && p.is_processed(); + }; + OPENVINO_ASSERT( + in_ports.size() > 1 && check_port(in_ports[0]) && out_ports.size() == 1 && check_port(out_ports[0]), + "Incorrect Loop by Brgemm dimension M"); + M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[0]->set_subtensor_dim(1, M); + output_pds[0]->set_subtensor_dim(1, M); + } + + /* ------- Dimension N ----------*/ + if (ov::snippets::utils::is_full_dim_value(N)) { + N = *in1_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { + return p.get_dim_idx() == 0 && p.is_processed(); + }; + OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_processed() && + std::all_of(in_ports.cbegin() + 1, in_ports.cend(), check_port) && out_ports.size() == 1 && + check_port(out_ports.back()), + "Incorrect Loop by Brgemm dimension N"); + N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[1]->set_subtensor_dim(0, N); + output_pds[0]->set_subtensor_dim(0, N); + } + + /* ------- Dimension K ----------*/ + // 1. If Brgemm block processes full dimension K -> `beta = 0` + // 2. If Brgemm block processes part of the dimension K (there is blocking), need to find + // the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0). + // First of them will have `beta = 0`, other - `beta = 1` + float beta = 0; + if (ov::snippets::utils::is_full_dim_value(K)) { + K = *in0_shape.rbegin(); + } else { + const auto& current_expanded_loop_info = get_loop_info(); + const auto& in_ports = current_expanded_loop_info->get_input_ports(); + const auto& out_ports = current_expanded_loop_info->get_output_ports(); + // Quick validation check: Should we check that port is really Brgemm port? + OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().get_dim_idx() == 0 && + in_ports.front().is_processed() && in_ports.back().get_dim_idx() == 1 && + in_ports.back().is_processed() && out_ports.size() == 1 && + !out_ports.front().is_processed(), + "Incorrect Loop by Brgemm dimension K"); + K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; + input_pds[0]->set_subtensor_dim(0, K); + input_pds[1]->set_subtensor_dim(1, K); + if (K > 0) + beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); + } + + config.update(static_cast(M), static_cast(N), static_cast(K), beta); +} + +#undef PRINT +#undef EQ +#undef HASH + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp new file mode 100644 index 00000000000000..8efd651112f369 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/brgemm_base.hpp @@ -0,0 +1,80 @@ +// Copyright (C) 2020-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include "emitters/utils.hpp" +#include "openvino/core/type/element_type.hpp" +#include "snippets/lowered/loop_info.hpp" +#include "snippets/lowered/loop_manager.hpp" +#include "utils/general_utils.h" + +namespace ov { +namespace intel_cpu { + +struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConfig { +public: + BrgemmBaseKernelConfig() = default; + + bool is_completed() const override; + bool is_empty() const; + + void update(int64_t M, int64_t N, int64_t K, float beta); + void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta); + + bool operator==(const BrgemmBaseKernelConfig& rhs) const; + bool operator!=(const BrgemmBaseKernelConfig& rhs) const { + return !(*this == rhs); + } + + int64_t get_M() const { + return m_M; + } + int64_t get_N() const { + return m_N; + } + int64_t get_K() const { + return m_K; + } + float get_beta() const { + return m_beta; + } + int64_t get_LDA() const { + return m_LDA; + } + int64_t get_LDB() const { + return m_LDB; + } + int64_t get_LDC() const { + return m_LDC; + } + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override; +#endif + +protected: + size_t compute_hash() const; + + int64_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0}; + float m_beta{0}; +}; + +class BrgemmBaseKernelExecutor { +public: + virtual ~BrgemmBaseKernelExecutor() = default; + +protected: + static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, + int loop_id, + const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info); + + static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig& config); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index 5e4a8992aa7165..1e2c5706b1dcb9 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -44,10 +44,10 @@ # include "emitters/tpp/x64/jit_eltwise_emitters.hpp" # include "emitters/tpp/x64/jit_equation_emitter.hpp" # include "emitters/tpp/x64/jit_scalar_emitter.hpp" -# include "transformations/tpp/x64/op/brgemm.hpp" +# include "transformations/tpp/common/op/brgemm.hpp" +# include "transformations/tpp/common/op/modifiers.hpp" # include "transformations/tpp/x64/op/eltwise.hpp" # include "transformations/tpp/x64/op/equation.hpp" -# include "transformations/tpp/x64/op/modifiers.hpp" # include "transformations/tpp/x64/op/reduce.hpp" # include "transformations/tpp/x64/op/scalar.hpp" // Note: for reference implementations @@ -295,7 +295,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho #endif #ifdef SNIPPETS_LIBXSMM_TPP - jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter); + jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache); jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index 59508f46154f28..91eec18009b42c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -5,7 +5,7 @@ #pragma once #include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" +#include "emitters/snippets/brgemm_base.hpp" #include "jit_binary_call_emitter.hpp" namespace ov { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index 58a31a1804782a..ec6f77bdf76644 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -6,7 +6,6 @@ #include "common/utils.hpp" #include "dnnl_extension_utils.h" -#include "snippets/lowered/pass/insert_specific_iterations.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" @@ -21,7 +20,7 @@ BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), + : BrgemmBaseKernelConfig_x64(), m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) { m_hash = compute_hash(); } @@ -78,7 +77,7 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const { - return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); + return BrgemmBaseKernelExecutor_x64::update_config(expr, linear_ir, config); } void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 9cc17049c4d3ae..69c8ca114c7912 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -4,12 +4,12 @@ #pragma once -#include "brgemm_base.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { +struct BrgemmKernelConfig : public BrgemmBaseKernelConfig_x64 { public: BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, @@ -59,7 +59,7 @@ struct BrgemmCompiledKernel { std::shared_ptr brgemm_kernel = nullptr; }; -class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor, +class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor_x64, public CPUKernelExecutor { public: struct call_args { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp index 12c52d43b2c4b8..6d2e67dc7330cc 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp @@ -24,7 +24,7 @@ namespace intel_cpu { BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), + : BrgemmBaseKernelConfig_x64(), m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { m_hash = compute_hash(); } @@ -117,7 +117,7 @@ std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel const auto& cache = m_kernel_cache.lock(); OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in BrgemmAMXKernelExecutor::compile_kernel()"); - auto brgemm_key = [&config](dnnl_dim_t K, dnnl_dim_t LDA, float beta) { + auto brgemm_key = [&config](int64_t K, int64_t LDA, float beta) { auto key = config; key.update(config.get_M(), config.get_N(), K, LDA, config.get_LDB(), config.get_LDC(), beta); return key; @@ -223,7 +223,7 @@ void BrgemmAMXKernelExecutor::create_brgemm_copy_a_kernel( void BrgemmAMXKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmAMXKernelConfig& config) const { - return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); + return BrgemmBaseKernelExecutor_x64::update_config(expr, linear_ir, config); } void BrgemmAMXKernelExecutor::configure_tiles_if_needed(amx_tile_config_t* config, diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp index 733295ec995583..614651a8897c38 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp @@ -4,18 +4,16 @@ #pragma once -#include #include -#include "brgemm_base.hpp" #include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig { +struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig_x64 { public: BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, @@ -75,7 +73,7 @@ struct BrgemmAMXCompiledKernel { std::shared_ptr brgemm_copy_a_kernel{nullptr}; }; -class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor, +class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor_x64, public CPUKernelExecutor { public: struct call_args { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp index 8b3ed792fce535..c9c96797b5eded 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2020-2024 Intel Corporation +// Copyright (C) 2020-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -9,7 +9,6 @@ #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" -#define DIM_CAST(X) static_cast(X) #define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) #define PRINT(X) ss << #X << " = " << X << "\n" #define EQ(X) X == rhs.X @@ -22,77 +21,45 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -bool BrgemmBaseKernelConfig::is_completed() const { - return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty(); -} - -bool BrgemmBaseKernelConfig::is_empty() const { - return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta); -} - -bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const { - return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && +bool BrgemmBaseKernelConfig_x64::operator==(const BrgemmBaseKernelConfig_x64& rhs) const { + return BrgemmBaseKernelConfig::operator==(rhs) && (EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params())); } -void BrgemmBaseKernelConfig::update(dnnl_dim_t M, - dnnl_dim_t N, - dnnl_dim_t K, - dnnl_dim_t LDA, - dnnl_dim_t LDB, - dnnl_dim_t LDC, - float beta) { - // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) - // To process this case, we have to make this Config as empty (nullify runtime parameters) - if (utils::one_of(0, M, N, K)) { - m_M = 0; - m_N = 0; - m_K = 0; - m_LDA = 0; - m_LDB = 0; - m_LDC = 0; - m_beta = 0; - } else { - m_M = M; - m_N = N; - m_K = K; - m_LDA = LDA; - m_LDB = LDB; - m_LDC = LDC; - m_beta = beta; - } - m_hash = compute_hash(); -} - -size_t BrgemmBaseKernelConfig::compute_hash() const { +size_t BrgemmBaseKernelConfig_x64::compute_hash() const { size_t seed = get_static_params()->hash(); - HASH(m_M); - HASH(m_N); - HASH(m_K); - HASH(m_LDA); - HASH(m_LDB); - HASH(m_LDC); - HASH(m_beta); + HASH(BrgemmBaseKernelConfig::compute_hash()); return seed; } -BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, - const element::Type& in1_dtype, - cpu_isa_t primitive_isa, - size_t hash_seed) +void BrgemmBaseKernelConfig_x64::update(int64_t M, + int64_t N, + int64_t K, + int64_t LDA, + int64_t LDB, + int64_t LDC, + float beta) { + BrgemmBaseKernelConfig::update(M, N, K, LDA, LDB, LDC, beta); + m_hash = compute_hash(); +} + +BrgemmBaseKernelConfig_x64::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + cpu_isa_t primitive_isa, + size_t hash_seed) : dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), isa(primitive_isa), m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} -bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { +bool BrgemmBaseKernelConfig_x64::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa); } -size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, - dnnl_data_type_t dt_in0, - dnnl_data_type_t dt_in1, - cpu_isa_t isa) { +size_t BrgemmBaseKernelConfig_x64::StaticBaseParams::compute_hash(size_t hash_seed, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + cpu_isa_t isa) { size_t seed = hash_seed; HASH(dt_in0); HASH(dt_in1); @@ -101,7 +68,7 @@ size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, } #ifdef SNIPPETS_DEBUG_CAPS -std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { +std::string BrgemmBaseKernelConfig_x64::StaticBaseParams::to_string() const { std::stringstream ss; PRINT(dt_in0); PRINT(dt_in1); @@ -109,171 +76,46 @@ std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { return ss.str(); } -std::string BrgemmBaseKernelConfig::to_string() const { +std::string BrgemmBaseKernelConfig_x64::to_string() const { std::stringstream ss; ss << get_static_params()->to_string() << "\n"; - PRINT(m_M); - PRINT(m_N); - PRINT(m_K); - PRINT(m_LDA); - PRINT(m_LDB); - PRINT(m_LDC); - PRINT(m_beta); + ss << BrgemmBaseKernelConfig::to_string() << "\n"; return ss.str(); } #endif -float BrgemmBaseKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, - int loop_id, - const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { - // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. - // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). - // It means that previous executed Loops have Loop ID less the current Loop ID. - // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have - // `beta = 1`. - // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. - if (loop_id > 0) { - const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); - // Check the previous Loops - --loop_id; - while (loop_id >= 0) { - const auto& expanded_loop_info = - loop_manager->get_loop_info(loop_id); - if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) - return 0; - if (expanded_loop_info->get_work_amount() > 0) { - // there is previous executed Brgemm with `beta = 0` -> the current Brgemm should have `beta = 1` - return 1; - } - --loop_id; - } - } - return 0; -} - -void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, - const ov::snippets::lowered::LinearIRCPtr& linear_ir, - BrgemmBaseKernelConfig& config) { - const auto& input_pds = expr->get_input_port_descriptors(); - const auto& output_pds = expr->get_output_port_descriptors(); - OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, - "Invalid number of in/out port descriptors"); - - const auto in0_shape = snippets::utils::get_planar_vdims(input_pds[0]->get_shape(), input_pds[0]->get_layout()); - const auto in1_shape = snippets::utils::get_planar_vdims(input_pds[1]->get_shape(), input_pds[1]->get_layout()); - auto in0_subtensor = input_pds[0]->get_subtensor(); - auto in1_subtensor = input_pds[1]->get_subtensor(); - - // Need to update M, K, N - // 1. If the original value in subtensor is `FULL_DIM`, it means that - // Brgemm block should process full tensor by this dim -> take dimension from shape - // 2. Otherwise, Brgemm block processes part of the tensor by this dim - // (there is blocking by this dimension) -> take from Loop increment - - auto M = *++in0_subtensor.rbegin(); - auto K = *in0_subtensor.rbegin(); - auto N = *in1_subtensor.rbegin(); - - size_t loop_idx = 0; - const auto& loop_ids = expr->get_loop_ids(); - const auto& loop_manager = linear_ir->get_loop_manager(); - auto get_loop_info = [&]() { - OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); - return loop_manager->get_loop_info(loop_ids[loop_idx++]); - }; - - /* ------- Dimension M ----------*/ - if (ov::snippets::utils::is_full_dim_value(M)) { - M = *++in0_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` - // to avoid extra checks, we validate only first input port - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { - return p.get_dim_idx() == 1 && p.is_processed(); - }; - OPENVINO_ASSERT( - in_ports.size() > 1 && check_port(in_ports[0]) && out_ports.size() == 1 && check_port(out_ports[0]), - "Incorrect Loop by Brgemm dimension M"); - M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; - input_pds[0]->set_subtensor_dim(1, M); - output_pds[0]->set_subtensor_dim(1, M); - } - - /* ------- Dimension N ----------*/ - if (ov::snippets::utils::is_full_dim_value(N)) { - N = *in1_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { - return p.get_dim_idx() == 0 && p.is_processed(); - }; - OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_processed() && - std::all_of(in_ports.cbegin() + 1, in_ports.cend(), check_port) && out_ports.size() == 1 && - check_port(out_ports.back()), - "Incorrect Loop by Brgemm dimension N"); - N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; - input_pds[1]->set_subtensor_dim(0, N); - output_pds[0]->set_subtensor_dim(0, N); - } - - /* ------- Dimension K ----------*/ - // 1. If Brgemm block processes full dimension K -> `beta = 0` - // 2. If Brgemm block processes part of the dimension K (there is blocking), need to find - // the most first executed Brgemm Block in Loops which iterate through dimension K (work_amount > 0). - // First of them will have `beta = 0`, other - `beta = 1` - float beta = 0; - if (ov::snippets::utils::is_full_dim_value(K)) { - K = *in0_shape.rbegin(); - } else { - const auto& current_expanded_loop_info = get_loop_info(); - const auto& in_ports = current_expanded_loop_info->get_input_ports(); - const auto& out_ports = current_expanded_loop_info->get_output_ports(); - // Quick validation check: Should we check that port is really Brgemm port? - OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().get_dim_idx() == 0 && - in_ports.front().is_processed() && in_ports.back().get_dim_idx() == 1 && - in_ports.back().is_processed() && out_ports.size() == 1 && - !out_ports.front().is_processed(), - "Incorrect Loop by Brgemm dimension K"); - K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; - input_pds[0]->set_subtensor_dim(0, K); - input_pds[1]->set_subtensor_dim(1, K); - if (K > 0) - beta = get_beta(loop_manager, static_cast(loop_ids.back()), current_expanded_loop_info); - } +void BrgemmBaseKernelExecutor_x64::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmBaseKernelConfig_x64& config) { + // update M/N/K/beta + BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); - const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); - const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); - auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); + const auto LDA = snippets::utils::get_dim_stride(expr->get_input_port(0)); + const auto LDC = snippets::utils::get_dim_stride(expr->get_output_port(0)); + auto LDB = snippets::utils::get_dim_stride(expr->get_input_port(1)); const auto& brgemm_node = as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); // In case of data repacking LDB is chosen in accordance with repacking buffer size if (with_repacking(brgemm_node->get_type())) - LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1))); + LDB = brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)); - config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); + config.update(config.get_M(), config.get_N(), config.get_K(), LDA, LDB, LDC, config.get_beta()); } -void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr& kernel, - dnnl_data_type_t dt0, - dnnl_data_type_t dt1, - cpu_isa_t isa, - dnnl_dim_t M, - dnnl_dim_t N, - dnnl_dim_t K, - dnnl_dim_t LDA, - dnnl_dim_t LDB, - dnnl_dim_t LDC, - float beta, - bool with_amx, - char* palette) { +void BrgemmBaseKernelExecutor_x64::create_brgemm_kernel(std::shared_ptr& kernel, + dnnl_data_type_t dt0, + dnnl_data_type_t dt1, + cpu_isa_t isa, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta, + bool with_amx, + char* palette) { cpu::x64::brgemm_desc_t desc; OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, isa, @@ -305,7 +147,7 @@ void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr(kernel_); } -void BrgemmBaseKernelExecutor::execute_brgemm_kernel( +void BrgemmBaseKernelExecutor_x64::execute_brgemm_kernel( const std::shared_ptr& kernel, const void* src, const void* wei, @@ -328,7 +170,6 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel( (*kernel)(&brgemm_p); } -#undef DIM_CAST #undef DTYPE_CAST #undef PRINT #undef EQ diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp index 674ea42522230b..f96187c08794cb 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2020-2024 Intel Corporation +// Copyright (C) 2020-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -7,33 +7,26 @@ #include #include "cpu/x64/cpu_isa_traits.hpp" -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" -#include "openvino/core/type/element_type.hpp" -#include "snippets/lowered/loop_info.hpp" -#include "snippets/lowered/loop_manager.hpp" +#include "emitters/snippets/brgemm_base.hpp" namespace ov { namespace intel_cpu { -struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConfig { +struct BrgemmBaseKernelConfig_x64 : public BrgemmBaseKernelConfig { public: - BrgemmBaseKernelConfig() = default; + BrgemmBaseKernelConfig_x64() = default; - bool is_completed() const override; size_t hash() const override { return m_hash; } - bool is_empty() const; - void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); - - bool operator==(const BrgemmBaseKernelConfig& rhs) const; - bool operator!=(const BrgemmBaseKernelConfig& rhs) const { + bool operator==(const BrgemmBaseKernelConfig_x64& rhs) const; + bool operator!=(const BrgemmBaseKernelConfig_x64& rhs) const { return !(*this == rhs); } + void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta); + dnnl_data_type_t get_dt_in0() const { return get_static_params()->dt_in0; } @@ -44,29 +37,6 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return get_static_params()->isa; } - float get_beta() const { - return m_beta; - } - - dnnl_dim_t get_M() const { - return m_M; - } - dnnl_dim_t get_N() const { - return m_N; - } - dnnl_dim_t get_K() const { - return m_K; - } - - dnnl_dim_t get_LDA() const { - return m_LDA; - } - dnnl_dim_t get_LDB() const { - return m_LDB; - } - dnnl_dim_t get_LDC() const { - return m_LDC; - } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const override; @@ -105,24 +75,17 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf virtual std::shared_ptr get_static_params() const = 0; size_t compute_hash() const; - - dnnl_dim_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0}; - float m_beta{0}; size_t m_hash{SIZE_MAX}; }; -class BrgemmBaseKernelExecutor { +class BrgemmBaseKernelExecutor_x64 : public BrgemmBaseKernelExecutor { public: - virtual ~BrgemmBaseKernelExecutor() = default; + virtual ~BrgemmBaseKernelExecutor_x64() = default; protected: - static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, - int loop_id, - const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info); - static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, - BrgemmBaseKernelConfig& config); + BrgemmBaseKernelConfig_x64& config); static void create_brgemm_kernel(std::shared_ptr& kernel, dnnl_data_type_t dt0, diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp new file mode 100644 index 00000000000000..22764248442000 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "jit_brgemm_emitter.hpp" + +#include "snippets/utils/utils.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" + +using namespace ov::intel_cpu::tpp; +using namespace Xbyak_aarch64; + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; +using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; +using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; + +jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, + cpu_isa_t isa, + const ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) + : jit_emitter(h, isa) { + in_out_type_ = emitter_in_out_map::gpr_to_gpr; + const auto& brgemm_node = as_type_ptr(expr->get_node()); + const auto& brg0Prc = brgemm_node->get_input_element_type(0); + const auto& brg1Prc = brgemm_node->get_input_element_type(1); + BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc); + m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); +} + +std::set> jit_brgemm_emitter::get_supported_precisions( + const std::shared_ptr& node) { + // Note: Brgemm currently supports only fp32 on arm + return {{element::f32, element::f32}}; +} + +void jit_brgemm_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { + OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); + OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); +} + +void jit_brgemm_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { + validate_arguments(in, out); + emit_impl(in, out); +} + +void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vector& out) const { + validate_arguments(in, out); + std::unordered_set exclude = {}; + store_context(exclude); + + Xbyak_aarch64::XReg func_reg(9); + h->mov(func_reg, get_execute_function_ptr()); + Xbyak_aarch64::XReg x0(0); + Xbyak_aarch64::XReg x1(1); + Xbyak_aarch64::XReg x2(2); + Xbyak_aarch64::XReg x3(3); + + const auto& compiled_kernel = get_compiled_kernel_ptr(); + h->mov(x0, compiled_kernel); + h->mov(x1, Xbyak_aarch64::XReg(in[0])); + h->mov(x2, Xbyak_aarch64::XReg(in[1])); + h->mov(x3, Xbyak_aarch64::XReg(out[0])); + h->blr(func_reg); + + restore_context(exclude); +} + +const uintptr_t jit_brgemm_emitter::get_compiled_kernel_ptr() const { + return reinterpret_cast(m_kernel_executor.get()); +} + +const uintptr_t jit_brgemm_emitter::get_execute_function_ptr() const { + return reinterpret_cast(BrgemmKernelExecutor::execute); +} + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp new file mode 100644 index 00000000000000..855771a702b6f7 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/aarch64/jit_brgemm_emitter.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "emitters/plugin/aarch64/jit_emitter.hpp" +#include "emitters/tpp/common/kernel_executors/brgemm.hpp" + +namespace ov { +namespace intel_cpu { +namespace aarch64 { + +class jit_brgemm_emitter : public jit_emitter { +public: + jit_brgemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, + dnnl::impl::cpu::aarch64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); + + size_t get_inputs_count() const override { + return 2; + } + + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); + + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; + +private: + void validate_arguments(const std::vector& in, const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; + + const uintptr_t get_execute_function_ptr() const; + const uintptr_t get_compiled_kernel_ptr() const; + + std::shared_ptr m_kernel_executor = nullptr; +}; + +} // namespace aarch64 +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp new file mode 100644 index 00000000000000..be961fd6e31d1a --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.cpp @@ -0,0 +1,166 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm.hpp" + +#include "emitters/tpp/common/utils.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" + +#define PRINT(X) ss << #X << " = " << X << "\n" +#define HASH(X) seed = dnnl::impl::hash_combine(seed, X) + +namespace ov { +namespace intel_cpu { +namespace tpp { +BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype) + : BrgemmBaseKernelConfig(), + m_static_params(std::make_shared(in0_dtype, in1_dtype)) {} + +bool BrgemmKernelConfig::operator==(const BrgemmKernelConfig& rhs) const { + return BrgemmBaseKernelConfig::operator==(rhs) && + (get_static_params() == rhs.get_static_params() || + *get_static_params() == *(rhs.get_static_params())); +} + +size_t BrgemmKernelConfig::compute_hash() const { + size_t static_seed = get_static_params()->hash(); + size_t dynamic_seed = BrgemmBaseKernelConfig::compute_hash(); + return dnnl::impl::hash_combine(static_seed, dynamic_seed); +} + +void BrgemmKernelConfig::update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta) { + BrgemmBaseKernelConfig::update(M, N, K, LDA, LDB, LDC, beta); + m_hash = compute_hash(); +} + +BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype) { + m_type_in0 = tpp::ov_to_xsmm_dtype(in0_dtype); + m_type_in1 = tpp::ov_to_xsmm_dtype(in1_dtype); + m_type_exec = LIBXSMM_DATATYPE_F32; + m_type_out0 = LIBXSMM_DATATYPE_F32; + m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + m_prefetching_flags = false; + m_hash = compute_hash(); +} + +size_t BrgemmKernelConfig::StaticParams::compute_hash() { + size_t seed = 0; + HASH(m_type_in0); + HASH(m_type_in1); + HASH(m_type_exec); + HASH(m_type_out0); + HASH(m_compile_flags); + HASH(m_prefetching_flags); + return seed; +} + +bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { + return m_type_in0 == rhs.m_type_in0 && m_type_in1 == rhs.m_type_in1 && m_type_exec == rhs.m_type_exec && + m_type_out0 == rhs.m_type_out0 && m_compile_flags == rhs.m_compile_flags && + m_prefetching_flags == rhs.m_prefetching_flags; +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmKernelConfig::StaticParams::to_string() const { + std::stringstream ss; + PRINT(m_type_in0); + PRINT(m_type_in1); + PRINT(m_type_out0); + PRINT(m_type_exec); + PRINT(m_compile_flags); + PRINT(m_prefetching_flags); + return ss.str(); +} + +std::string BrgemmKernelConfig::to_string() const { + std::stringstream ss; + ss << get_static_params()->to_string() << "\n"; + ss << BrgemmBaseKernelConfig::to_string() << "\n"; + return ss.str(); +} +#endif + +BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} + +std::shared_ptr BrgemmKernelExecutor::compile_kernel(const BrgemmKernelConfig& config) const { + std::shared_ptr compiled_kernel = std::make_shared(); + + // Brgemm is not executable - nothing to compile + if (config.is_empty()) + return compiled_kernel; + + libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(), + config.get_M(), + config.get_K(), + config.get_LDB(), + config.get_LDA(), + config.get_LDC(), + config.get_type_in0(), + config.get_type_in1(), + config.get_type_out0(), + config.get_type_exec()); + compiled_kernel->brgemm_kernel = + std::make_shared(reinterpret_cast(COMPILE_TPP_KERNEL( + libxsmm_dispatch_gemm(m_shape, config.get_compile_flags(), config.get_prefetching_flags())))); + + return compiled_kernel; +} + +void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmKernelConfig& config) const { + BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); + const auto& tpp_mod = std::dynamic_pointer_cast(expr->get_node()); + auto replace_full_dim = [](size_t dim, size_t replace_dim) { + if (ov::snippets::utils::is_full_dim_value(dim)) + return replace_dim; + return dim; + }; + + const auto num_ins = expr->get_node()->get_input_size(); + const auto num_outs = expr->get_node()->get_output_size(); + + std::vector io_strides(num_ins + num_outs); + + for (size_t i = 0; i < num_ins; i++) { + io_strides[i] = + replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back()); + } + + for (size_t i = 0; i < num_outs; i++) { + const auto i_off = i + num_ins; + io_strides[i_off] = + replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back()); + } + + config.update(config.get_M(), + config.get_N(), + config.get_K(), + io_strides[0], + io_strides[1], + io_strides[2], + config.get_beta()); + // update compile flag, which is depend on beta. should be part of hash. + config.set_compile_flags(config.get_beta() == 0); +} + +void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) { + OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); + libxsmm_gemm_param gemm_p; + gemm_p.a.primary = in1; + gemm_p.b.primary = in0; + gemm_p.c.primary = out0; + auto brg_kernel = executor->get_kernel(); + OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel"); + OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel"); + (*(brg_kernel->brgemm_kernel))(&gemm_p); +} + +#undef PRINT +#undef HASH + +} // namespace tpp +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp new file mode 100644 index 00000000000000..f84c2f766e44f6 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/kernel_executors/brgemm.hpp @@ -0,0 +1,134 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "common/utils.hpp" +#include "emitters/snippets/brgemm_base.hpp" +#include "emitters/utils.hpp" +#include "libxsmm.h" + +namespace ov { +namespace intel_cpu { +namespace tpp { + +struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { +public: + BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype); + BrgemmKernelConfig() = delete; + + std::unique_ptr get_clone_ptr() const override { + return std::unique_ptr(new BrgemmKernelConfig(*this)); + } + + bool operator==(const BrgemmKernelConfig& rhs) const; + bool operator!=(const BrgemmKernelConfig& rhs) const { + return !(*this == rhs); + } + + void update(int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC, float beta); + + size_t hash() const override { + return m_hash; + } + size_t compute_hash() const; + + libxsmm_bitfield get_static_compile_flags() const { + return m_static_params->m_compile_flags; + } + libxsmm_bitfield get_compile_flags() const { + return m_compile_flags; + } + void set_compile_flags(bool zero_beta) { + if (zero_beta) { + m_compile_flags = get_static_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0; + } else { + m_compile_flags = get_static_compile_flags(); + } + } + bool get_prefetching_flags() const { + return m_static_params->m_prefetching_flags; + } + libxsmm_datatype get_type_in0() const { + return m_static_params->m_type_in0; + } + libxsmm_datatype get_type_in1() const { + return m_static_params->m_type_in1; + } + libxsmm_datatype get_type_out0() const { + return m_static_params->m_type_out0; + } + libxsmm_datatype get_type_exec() const { + return m_static_params->m_type_exec; + } +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override; +#endif + +private: + struct StaticParams { + StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype); + virtual ~StaticParams() = default; + + bool operator==(const StaticParams& rhs) const; + bool operator!=(const StaticParams& rhs) const { + return !(*this == rhs); + } + size_t hash() const { + return m_hash; + } + size_t compute_hash(); + +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + + libxsmm_datatype m_type_in0; + libxsmm_datatype m_type_in1; + libxsmm_datatype m_type_out0; + libxsmm_datatype m_type_exec; + libxsmm_bitfield m_compile_flags; + bool m_prefetching_flags; + + size_t m_hash{SIZE_MAX}; + }; + std::shared_ptr get_static_params() const { + return m_static_params; + } + + libxsmm_bitfield m_compile_flags{0}; + std::shared_ptr m_static_params{nullptr}; + + size_t m_hash{SIZE_MAX}; +}; + +// The `update_kernel` method verifies that a compiled kernel is not nullptr. +// However, the compiled kernel might be empty in cases if nothing is to be compiled (`Config.is_empty() == true`). +// To cover this case, we wrap the `libxsmm_gemmfunction` in the separate structure which may contain empty +// `libxsmm_gemmfunction` +struct BrgemmTppCompiledKernel { + std::shared_ptr brgemm_kernel = nullptr; +}; + +class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor, + public CPUKernelExecutor { +public: + BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config); + virtual ~BrgemmKernelExecutor() = default; + + // Function that will be called in runtime to execute the kernel + static void execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0); + +private: + std::shared_ptr compile_kernel(const BrgemmKernelConfig& c) const override; + + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmKernelConfig& config) const override; +}; +#define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) + +} // namespace tpp +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp b/src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp new file mode 100644 index 00000000000000..42440f14556131 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp @@ -0,0 +1,49 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "emitters/utils.hpp" +#include "libxsmm.h" + +namespace ov { +namespace intel_cpu { +namespace tpp { +// Note: The macro allows to automatically set appropriate environment variables for TPP/Libxsmm kernel compilation +// All TPP kernels must be compiled using this macro. +// * LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX enables more accurate exp approximation and exact division in TPP +// * LIBXSMM_GEMM_K_A_PF_DIST allows to tweak prefetching for GEMM kernels +#define COMPILE_TPP_KERNEL(...) \ + [&]() { \ + setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \ + setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \ + auto res = reinterpret_cast(__VA_ARGS__); \ + unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \ + unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \ + return res; \ + }() + +inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t element_type) { + switch (element_type) { + case ov::element::Type_t::f32: + return LIBXSMM_DATATYPE_F32; + case ov::element::Type_t::bf16: + return LIBXSMM_DATATYPE_BF16; + case ov::element::Type_t::f16: + return LIBXSMM_DATATYPE_F16; + case ov::element::Type_t::i8: + return LIBXSMM_DATATYPE_I8; + case ov::element::Type_t::u8: + return LIBXSMM_DATATYPE_U8; + default: + OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type:", element_type); + return LIBXSMM_DATATYPE_IMPLICIT; + } +} + +} // namespace tpp +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp index e873d7f7aa98eb..349a3470ec4e98 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp @@ -5,7 +5,8 @@ #include "jit_brgemm_emitter.hpp" #include "emitters/snippets/x64/jit_snippets_emitters.hpp" -#include "transformations/tpp/x64/op/brgemm.hpp" +#include "emitters/tpp/common/utils.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" using jit_generator = dnnl::impl::cpu::x64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t; @@ -20,70 +21,19 @@ void BrgemmTppEmitter::validate_subtensors(const VectorDims& in_0, const VectorD OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors"); } -BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) +BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, + cpu_isa_t isa, + const ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) : TppEmitter(h, isa, expr) { const auto& brgemm_node = as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(brgemm_node && !brgemm_node->is_dynamic(), "Invoked with invalid node type"); - - const auto& input_0_desc = expr->get_input_port_descriptor(0); - const auto& input_1_desc = expr->get_input_port_descriptor(1); - const auto& output_desc = expr->get_output_port_descriptor(0); - - std::vector leading_dimensions{brgemm_node->get_input_stride(0), - brgemm_node->get_input_stride(1), - brgemm_node->get_output_stride(0)}; - - auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0)); - auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1)); - exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? LIBXSMM_DATATYPE_I32 - : LIBXSMM_DATATYPE_F32; - auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? LIBXSMM_DATATYPE_I32 : LIBXSMM_DATATYPE_F32; - - const auto beta = brgemm_node->get_beta(); - OV_CPU_JIT_EMITTER_ASSERT(beta == 0 || beta == 1, "Detected unsupported beta value: " + std::to_string(beta)); - - const auto& subtensor_in0 = input_0_desc->get_subtensor(); - const auto& subtensor_in1 = input_1_desc->get_subtensor(); - const auto& subtensor_out0 = output_desc->get_subtensor(); - validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0); - - const auto K = static_cast(*subtensor_in0.rbegin()); - const auto M = static_cast(*++subtensor_in0.rbegin()); - const auto N = static_cast(*subtensor_in1.rbegin()); - - const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32; - const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16; - const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8; - OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm || (is_bf16_gemm && K % 2 == 0) || (is_i8_gemm && K % 4 == 0), - "Unsupported parameter combination for kernel configuration"); - - m_compile_flags = is_f32_gemm ? LIBXSMM_GEMM_FLAGS('N', 'N') - : LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') | - LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG; - - if (beta == 0) - m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0; - - if (in_0_prec == LIBXSMM_DATATYPE_U8) { - in_0_prec = LIBXSMM_DATATYPE_I8; - m_compile_flags |= LIBXSMM_GEMM_FLAG_A_UNSIGNED; - } - if (in_1_prec == LIBXSMM_DATATYPE_U8) { - in_1_prec = LIBXSMM_DATATYPE_I8; - m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED; - } - - m_shape = libxsmm_create_gemm_shape(N, - M, - K, - io_strides[1], - io_strides[0], - io_strides[2], - in_1_prec, - in_0_prec, - out_0_prec, - exec_dtype); - m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE; + const auto& brg0Prc = brgemm_node->get_input_element_type(0); + const auto& brg1Prc = brgemm_node->get_input_element_type(1); + tpp::BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc); + m_kernel_executor = + kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); } std::set> BrgemmTppEmitter::get_supported_precisions(const std::shared_ptr& node) { @@ -97,16 +47,11 @@ void BrgemmTppEmitter::validate_arguments(const std::vector& in, const s } const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const { - return COMPILE_TPP_KERNEL(libxsmm_dispatch_gemm(m_shape, m_compile_flags, m_prefetching_flags)); + return reinterpret_cast(m_kernel_executor.get()); } -void BrgemmTppEmitter::execute_brgemm_kernel(libxsmm_gemmfunction brg_kernel, void* in0, void* in1, void* out0) { - libxsmm_gemm_param gemm_p; - gemm_p.a.primary = in1; - gemm_p.b.primary = in0; - gemm_p.c.primary = out0; - OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "Invalid brgemm kernel pointer"); - brg_kernel(&gemm_p); +const uintptr_t BrgemmTppEmitter::get_execute_function_ptr() const { + return reinterpret_cast(tpp::BrgemmKernelExecutor::execute); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp index 129f8fa579ce9a..2b5a1c528d39c1 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp @@ -1,8 +1,9 @@ -// Copyright (C) 2020-2022 Intel Corporation +// Copyright (C) 2020-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once +#include "emitters/tpp/common/kernel_executors/brgemm.hpp" #include "jit_tpp_emitter.hpp" namespace ov { @@ -12,7 +13,9 @@ class BrgemmTppEmitter : public TppEmitter { public: BrgemmTppEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + const ov::snippets::lowered::ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); size_t get_inputs_num() const override { return 2; @@ -20,18 +23,15 @@ class BrgemmTppEmitter : public TppEmitter { static std::set> get_supported_precisions( const std::shared_ptr& node = nullptr); - static void execute_brgemm_kernel(libxsmm_gemmfunction brgemm_kernel, void* in0, void* in1, void* out0); - - const uintptr_t get_execute_function_ptr() const override { - return reinterpret_cast(execute_brgemm_kernel); - } + const uintptr_t get_execute_function_ptr() const override; const uintptr_t get_compiled_kernel_ptr() const override; protected: void validate_arguments(const std::vector& in, const std::vector& out) const override; static void validate_subtensors(const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0); - libxsmm_gemm_shape m_shape; - libxsmm_bitfield m_prefetching_flags{0}; + +private: + std::shared_ptr m_kernel_executor = nullptr; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp index a18b1616bb517c..61c96820dca052 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp @@ -5,6 +5,7 @@ #include "jit_tpp_emitter.hpp" #include "emitters/plugin/x64/utils.hpp" +#include "emitters/tpp/common/utils.hpp" #include "snippets/lowered/port_descriptor.hpp" #include "transformations/tpp/x64/op/eltwise.hpp" @@ -56,7 +57,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h, }; for (size_t i = 0; i < num_ins; i++) { - io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i)); + io_dtypes[i] = tpp::ov_to_xsmm_dtype(node->get_input_element_type(i)); io_offsets[i] = tpp_mod->get_input_offset(i); io_strides[i] = replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back()); @@ -65,7 +66,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h, for (size_t i = 0; i < num_outs; i++) { const auto i_off = i + num_ins; - io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i)); + io_dtypes[i_off] = tpp::ov_to_xsmm_dtype(node->get_output_element_type(i)); io_offsets[i_off] = tpp_mod->get_output_offset(i); io_strides[i_off] = replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back()); @@ -121,21 +122,5 @@ void TppEmitter::emit_impl(const std::vector& in, const std::vector(__VA_ARGS__); \ - unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \ - unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \ - return res; \ - }() + class DebugTppEmitter; class TppEmitter : public jit_binary_call_emitter { friend DebugTppEmitter; @@ -34,7 +23,6 @@ class TppEmitter : public jit_binary_call_emitter { dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); void emit_code(const std::vector& in, const std::vector& out) const; - static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type); protected: void emit_impl(const std::vector& in, const std::vector& out) const override; diff --git a/src/plugins/intel_cpu/src/emitters/utils.hpp b/src/plugins/intel_cpu/src/emitters/utils.hpp index a987c5b5795116..92df781a1fc318 100644 --- a/src/plugins/intel_cpu/src/emitters/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/utils.hpp @@ -7,6 +7,7 @@ #include #include "openvino/core/except.hpp" +#include "openvino/core/type/element_type.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 43a005b27cb450..de30da867638d2 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -52,12 +52,14 @@ #ifdef SNIPPETS_LIBXSMM_TPP # include "snippets/lowered/pass/optimize_domain.hpp" -# include "transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp" -# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp" -# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp" -# include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp" -# include "transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.hpp" -# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp" +# include "transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp" +# include "transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.hpp" +# include "transformations/tpp/common/pass/lowered/set_tpp_leading_dim.hpp" +# if defined(OPENVINO_ARCH_X86_64) +# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp" +# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp" +# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp" +# endif #endif namespace ov { @@ -460,11 +462,20 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { # define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) #endif // OPENVINO_ARCH_X86_64 +#if defined(OPENVINO_ARCH_ARM64) +# define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) \ + backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \ + std::make_shared(__VA_ARGS__)) +#else +# define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) +#endif // OPENVINO_ARCH_ARM64 + SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU); SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations @@ -498,54 +509,75 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::intel_cpu::tpp::pass::EltwiseToEltwiseTPP, ov::intel_cpu::tpp::pass::FuseTPPToEquations); + SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::Before, + ov::snippets::pass::PropagatePrecision, + ov::intel_cpu::tpp::pass::BrgemmToBrgemmTPP); #endif #undef SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON #undef SNIPPETS_REGISTER_PASS_RELATIVE_COMMON #undef SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64 #undef SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 +#undef SNIPPETS_REGISTER_PASS_RELATIVE_ARM64 return backend_passes; } Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const { ControlFlowPasses backend_passes; - -#if defined(OPENVINO_ARCH_X86_64) +#if defined(OPENVINO_ARCH_X86_64) || (defined(OPENVINO_ARCH_ARM64) && defined(SNIPPETS_LIBXSMM_TPP)) using PassPosition = ov::snippets::pass::PassPosition; using Place = PassPosition::Place; -# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \ +#endif + +#if defined(OPENVINO_ARCH_X86_64) +# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) \ backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \ std::make_shared(__VA_ARGS__)) #else -# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) +# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) #endif // OPENVINO_ARCH_X86_64 - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, - ov::snippets::lowered::pass::MarkLoops, - ov::intel_cpu::pass::BrgemmCPUBlocking); +#if defined(OPENVINO_ARCH_ARM64) +# define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) \ + backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \ + std::make_shared(__VA_ARGS__)) +#else +# define SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(PASS_PLACE, TARGET_PASS, PASS, ...) +#endif // OPENVINO_ARCH_ARM64 + + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, + ov::snippets::lowered::pass::MarkLoops, + ov::intel_cpu::pass::BrgemmCPUBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, - ov::snippets::lowered::pass::InitLoops, - ov::intel_cpu::pass::AdjustBrgemmCopyBLoopPorts); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, + ov::snippets::lowered::pass::InitLoops, + ov::intel_cpu::pass::AdjustBrgemmCopyBLoopPorts); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, - ov::snippets::lowered::pass::InsertLoops, - ov::intel_cpu::pass::FuseLoadStoreConvert); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, - ov::snippets::lowered::pass::InsertBuffers, - ov::intel_cpu::pass::InsertBrgemmCopyBuffers); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, + ov::snippets::lowered::pass::InsertLoops, + ov::intel_cpu::pass::FuseLoadStoreConvert); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, + ov::snippets::lowered::pass::InsertBuffers, + ov::intel_cpu::pass::InsertBrgemmCopyBuffers); #ifdef SNIPPETS_LIBXSMM_TPP - SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, - ov::intel_cpu::pass::BrgemmCPUBlocking, - ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); - SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, - ov::intel_cpu::pass::FuseLoadStoreConvert, - ov::intel_cpu::tpp::pass::SetTPPLeadingDim); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, + ov::intel_cpu::pass::BrgemmCPUBlocking, + ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, + ov::intel_cpu::pass::FuseLoadStoreConvert, + ov::intel_cpu::tpp::pass::SetTPPLeadingDim); + SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After, + ov::snippets::lowered::pass::MarkLoops, + ov::intel_cpu::tpp::pass::BrgemmTPPBlocking); + SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After, + ov::snippets::lowered::pass::InsertLoops, + ov::intel_cpu::tpp::pass::SetTPPLeadingDim); #endif -#undef SNIPPETS_REGISTER_PASS_RELATIVE +#undef SNIPPETS_REGISTER_PASS_RELATIVE_X86_64 +#undef SNIPPETS_REGISTER_PASS_RELATIVE_ARM64 return backend_passes; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp index a3c9a1c184d550..041dedb06f8896 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/shape_inference.cpp @@ -7,6 +7,7 @@ #include "snippets/shape_inference/shape_infer_instances.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" namespace ov { namespace snippets { @@ -42,6 +43,7 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry{ SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::tpp::op::BrgemmTPP, BrgemmShapeInfer), }; #undef SHAPE_INFER_OP_SPECIFIC #undef SHAPE_INFER_PREDEFINED diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index e0a87ca288bac1..f6bf4d29818a88 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -16,7 +16,7 @@ #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" -#include "transformations/tpp/x64/op/modifiers.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" #include "utils/general_utils.h" namespace ov { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp index 50a2399e93ecc4..14d72c3d3d0969 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -14,7 +14,7 @@ #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" #ifdef SNIPPETS_LIBXSMM_TPP -# include "transformations/tpp/x64/op/brgemm.hpp" +# include "transformations/tpp/common/op/brgemm.hpp" # include "transformations/tpp/x64/op/equation.hpp" # include "transformations/tpp/x64/op/reduce.hpp" # include "transformations/tpp/x64/op/scalar.hpp" diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp b/src/plugins/intel_cpu/src/transformations/tpp/common/op/brgemm.cpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.cpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/op/brgemm.cpp diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp b/src/plugins/intel_cpu/src/transformations/tpp/common/op/brgemm.hpp similarity index 96% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/op/brgemm.hpp index cda7f58afebea8..9c450ec93b96ba 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/brgemm.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/common/op/brgemm.hpp @@ -5,7 +5,7 @@ #pragma once #include "modifiers.hpp" -#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "snippets/op/brgemm.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/modifiers.hpp b/src/plugins/intel_cpu/src/transformations/tpp/common/op/modifiers.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/op/modifiers.hpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/op/modifiers.hpp diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/brgemm_to_brgemm_tpp.cpp similarity index 99% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/brgemm_to_brgemm_tpp.cpp index c042373f054fa2..03cc43dbe82ab1 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/brgemm_to_brgemm_tpp.cpp @@ -11,7 +11,7 @@ #include "snippets/itt.hpp" #include "snippets/op/brgemm.hpp" #include "snippets/utils/utils.hpp" -#include "transformations/tpp/x64/op/brgemm.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" #include "utils/general_utils.h" namespace ov { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.cpp similarity index 98% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.cpp index d9485b1c6b7b9d..a7b1bd938dafe3 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.cpp @@ -9,7 +9,6 @@ #include "snippets/lowered/loop_manager.hpp" #include "snippets/snippets_isa.hpp" #include "snippets/utils/utils.hpp" -#include "transformations/tpp/x64/op/brgemm.hpp" namespace ov { namespace intel_cpu { @@ -55,6 +54,7 @@ ov::snippets::lowered::SpecificIterationHandlers BrgemmTPPBlocking::get_k_loop_h handlers.register_pass(); return handlers; } + } // namespace pass } // namespace tpp } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.hpp similarity index 97% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.hpp index 31f4bfeadc8979..64dd36d8dfbf71 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.hpp @@ -5,7 +5,7 @@ #pragma once #include "snippets/lowered/pass/brgemm_blocking.hpp" -#include "transformations/tpp/x64/op/brgemm.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.cpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/set_tpp_leading_dim.cpp similarity index 99% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.cpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/set_tpp_leading_dim.cpp index c1b981275face0..7720e0b142d45f 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/set_tpp_leading_dim.cpp @@ -9,7 +9,7 @@ #include "snippets/op/brgemm.hpp" #include "snippets/op/buffer.hpp" #include "snippets/utils/utils.hpp" -#include "transformations/tpp/x64/op/modifiers.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.hpp b/src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/set_tpp_leading_dim.hpp similarity index 100% rename from src/plugins/intel_cpu/src/transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.hpp rename to src/plugins/intel_cpu/src/transformations/tpp/common/pass/lowered/set_tpp_leading_dim.hpp diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/eltwise.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/eltwise.hpp index 7338450ff8257d..0e0a38c0c6161c 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/eltwise.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/eltwise.hpp @@ -5,7 +5,6 @@ #pragma once #include "descriptor.hpp" -#include "modifiers.hpp" #include "openvino/op/add.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/exp.hpp" @@ -14,6 +13,7 @@ #include "openvino/op/subtract.hpp" #include "snippets/op/powerstatic.hpp" #include "snippets/utils/utils.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/equation.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/equation.hpp index bf16f149b415de..0df8ad2eb776f0 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/equation.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/equation.hpp @@ -5,8 +5,8 @@ #pragma once #include "descriptor.hpp" -#include "modifiers.hpp" #include "openvino/op/op.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/reduce.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/reduce.hpp index 07ed321abc7ff5..a8cf39cef7bda5 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/reduce.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/reduce.hpp @@ -6,8 +6,8 @@ #include "eltwise.hpp" #include "libxsmm_typedefs.h" -#include "modifiers.hpp" #include "snippets/op/reduce.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.cpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.cpp index 5855481efd1d60..f43d65d15180d3 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.cpp @@ -4,8 +4,6 @@ #include "scalar.hpp" -#include "modifiers.hpp" - namespace ov { namespace intel_cpu { namespace tpp { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.hpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.hpp index 9807dbfafa31d0..f836e01d32554a 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.hpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/op/scalar.hpp @@ -5,8 +5,8 @@ #pragma once #include "eltwise.hpp" -#include "modifiers.hpp" #include "snippets/op/reduce.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/scalar_to_scalar_tpp.cpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/scalar_to_scalar_tpp.cpp index f5188df53aeb28..33caae3e9e42d3 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/scalar_to_scalar_tpp.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/scalar_to_scalar_tpp.cpp @@ -8,7 +8,7 @@ #include "snippets/itt.hpp" #include "snippets/lowered/port_connector.hpp" #include "snippets/op/scalar.hpp" -#include "transformations/tpp/x64/op/modifiers.hpp" +#include "transformations/tpp/common/op/modifiers.hpp" #include "transformations/tpp/x64/op/scalar.hpp" namespace ov { diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 880cdd54c42812..14de24a1c89170 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -161,7 +161,7 @@ #include "snippets/pass/split_dimension_m.hpp" #include "snippets/pass/tokenization.hpp" #if defined(SNIPPETS_LIBXSMM_TPP) -# include "transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp" +# include "transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp" #endif // Misc diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 4eb4fa819e3224..221c22937c6c03 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -459,8 +459,9 @@ std::vector disabledTestPatterns() { // Issue: 126738 retVector.emplace_back(R"(smoke_Snippets.*\[.*\?.*\].*)"); retVector.emplace_back(R"(smoke_Snippets_Eltwise.*\[1.1..10.1..8.1..4\].*)"); - // smoke_Snippets test cases are not supported on arm64 platforms, except for smoke_Snippets_Eltwise - retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise|_Convert).*)"); + // smoke_Snippets test cases are not supported on arm64 platforms, + // except for smoke_Snippets_Eltwise and smoke_Snippets_MatMul(t) + retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise|_Convert|_MatMul/|_MatMult/).*)"); // arm snippets doesn't support sve_128 that required by dnnl injector jit_uni_eltwise_injector_f32 yet retVector.emplace_back(R"(smoke_Snippets_Eltwise_TwoResults.*)"); retVector.emplace_back(R"(smoke_Snippets_Eltwise/TwoInputsAndOutputs.*)"); @@ -492,13 +493,13 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_LPT/RecurrentCellTransformation.CompareWithRefImpl/f32_\[1,1,3\]_CPU_f32FQ_X_level=256_.*_FQ_W_level=255.*)"); retVector.emplace_back(R"(.*smoke_static/ConvertFqRnnToQuantizedRnn.CompareWithRefs/Type=GRUSequence.*2.5.10.*2.1.4.*2.1.4.*)"); } +#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) if (!ov::with_cpu_x86_avx2()) { // MatMul in Snippets uses BRGEMM that is supported only on AVX2 (and newer) platforms // Disabled Snippets MHA tests as well because MHA pattern contains MatMul retVector.emplace_back(R"(.*Snippets.*MHA.*)"); retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)"); } -#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) if (!ov::with_cpu_x86_avx512_core_fp16()) { // Skip fp16 tests for paltforms that don't support fp16 precision retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)"); diff --git a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp index fc6783f3b3ca45..d1fc05ded13554 100644 --- a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp @@ -4,7 +4,7 @@ #include "transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.hpp" #ifdef SNIPPETS_LIBXSMM_TPP - #include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp" + #include "transformations/tpp/common/pass/lowered/brgemm_tpp_blocking.hpp" #endif #include "lir_test_utils.hpp" @@ -13,7 +13,7 @@ #include "snippets/snippets_isa.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "transformations/tpp/x64/op/brgemm.hpp" +#include "transformations/tpp/common/op/brgemm.hpp" #include "cpu/x64/cpu_isa_traits.hpp" namespace ov { diff --git a/src/plugins/intel_cpu/thirdparty/CMakeLists.txt b/src/plugins/intel_cpu/thirdparty/CMakeLists.txt index f25b52057848c0..038e012902d647 100644 --- a/src/plugins/intel_cpu/thirdparty/CMakeLists.txt +++ b/src/plugins/intel_cpu/thirdparty/CMakeLists.txt @@ -154,7 +154,12 @@ function(ov_add_onednn) endif() endfunction() +if(AARCH64 AND (NOT ANDROID)) + set(ENABLE_SNIPPETS_LIBXSMM_TPP ON) +endif() + if (ENABLE_SNIPPETS_LIBXSMM_TPP) + ov_add_compiler_flags(-Wno-missing-declarations) add_subdirectory(libxsmm) ov_install_static_lib(libxsmm ${OV_CPACK_COMP_CORE}) endif()