Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU][ARM]Snippets MatMul via brgemm emitter and executor #28304

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND X86_64" OFF)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also RISCV64, AArch32 etc. Can we add only supported archs to condition?

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU" OFF)

ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)

Expand Down
10 changes: 9 additions & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

if (ENABLE_SNIPPETS_LIBXSMM_TPP)
# Note: LIBXSMM_DEFAULT_CONFIG needed so libxsmm_config can be included without issues
add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG)
Expand Down Expand Up @@ -198,7 +202,9 @@ if(NOT X86_64)
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*)
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*)
endif()

if (AARCH64)
Expand All @@ -208,7 +214,9 @@ endif()

if(NOT (AARCH64 OR ARM))
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*)
endif()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,6 +10,7 @@
#include "emitters/snippets/aarch64/jit_kernel_emitter.hpp"
#include "emitters/snippets/aarch64/jit_loop_emitters.hpp"
#include "emitters/snippets/aarch64/jit_memory_emitters.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"
#include "emitters/utils.hpp"
#include "jit_snippets_emitters.hpp"
Expand All @@ -24,12 +25,17 @@
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"

#ifdef SNIPPETS_LIBXSMM_TPP
# include "emitters/tpp/aarch64/jit_brgemm_emitter.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
#endif

namespace ov {

#define CREATE_SNIPPETS_EMITTER(e_type) \
#define CREATE_SNIPPETS_EMITTER(e_type, ...) \
{ \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
return std::make_shared<e_type>(h.get(), isa, expr); \
return std::make_shared<e_type>(h.get(), isa, expr, ##__VA_ARGS__); \
}, \
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
return e_type::get_supported_precisions(n); \
Expand Down Expand Up @@ -202,6 +208,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

#ifdef SNIPPETS_LIBXSMM_TPP
// brgemm
jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
#endif

// control flow
jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter);
jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_dynamic_emitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

#ifndef OPENVINO_ARCH_X86_64
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta);
return;
#endif

Comment on lines +251 to +255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say that there should be common cross-arch base class with method init_runtime_params(M,N,K,LDA,LDB,LDC).
Then x64 dnnl executors update LDB and LDA if needed. aarch64 (tpp) should call update_config(...) with these parameters.

If we use ifdef in common code, I think this is a sign of problematic code and we should resolve it.

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
#include <cpu/x64/brgemm/brgemm.hpp>

#include "cpu/x64/cpu_isa_traits.hpp"
#include "emitters/plugin/x64/jit_emitter.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "utils/general_utils.h"

namespace ov {
namespace intel_cpu {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
# include "emitters/tpp/x64/jit_eltwise_emitters.hpp"
# include "emitters/tpp/x64/jit_equation_emitter.hpp"
# include "emitters/tpp/x64/jit_scalar_emitter.hpp"
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/common/op/modifiers.hpp"
# include "transformations/tpp/x64/op/eltwise.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/modifiers.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
// Note: for reference implementations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "jit_binary_call_emitter.hpp"

namespace ov {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "brgemm_base.hpp"
#include "emitters/snippets/brgemm_base.hpp"

namespace ov {
namespace intel_cpu {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include <cpu/x64/brgemm/brgemm.hpp>
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>

#include "brgemm_base.hpp"
#include "emitters/plugin/x64/jit_emitter.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_emitter.hpp"

#include "snippets/utils/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak_aarch64;

namespace ov {
namespace intel_cpu {
namespace aarch64 {

using jit_generator = dnnl::impl::cpu::aarch64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;

jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa);
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
// Note: Brgemm currently supports only fp32 on arm
return {{element::f32, element::f32}};
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in,
const std::vector<size_t>& out,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const {
validate_arguments(in, out);
emit_impl(in, out);
}

void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);
std::unordered_set<size_t> exclude = {};
store_context(exclude);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that we will merge #27391 soon. This PR efficently provides efficient work with reg spills - we will able to spill only needed (live) registers

Just for information and to align with other our activities 😊


Xbyak_aarch64::XReg func_reg(9);
h->mov(func_reg, get_execute_function_ptr());
Xbyak_aarch64::XReg x0(0);
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
Xbyak_aarch64::XReg x3(3);

const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);
h->mov(x1, Xbyak_aarch64::XReg(in[0]));
h->mov(x2, Xbyak_aarch64::XReg(in[1]));
h->mov(x3, Xbyak_aarch64::XReg(out[0]));
h->blr(func_reg);

restore_context(exclude);
}

const uintptr_t jit_brgemm_emitter::get_compiled_kernel_ptr() const {
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get());
}

const uintptr_t jit_brgemm_emitter::get_execute_function_ptr() const {
return reinterpret_cast<const uintptr_t>(BrgemmKernelExecutor::execute);
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/tpp/aarch64/kernel_executors/brgemm.hpp"

namespace ov {
namespace intel_cpu {
namespace aarch64 {

class jit_brgemm_emitter : public jit_emitter {
public:
jit_brgemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_count() const override {
return 2;
}

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

void emit_code(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs = {},
const std::vector<size_t>& pool_gpr_idxs = {}) const override;

private:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

const uintptr_t get_execute_function_ptr() const;
const uintptr_t get_compiled_kernel_ptr() const;

std::shared_ptr<BrgemmKernelExecutor> m_kernel_executor = nullptr;
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Loading
Loading