-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU][ARM]Snippets MatMul via brgemm emitter and executor #28304
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -248,6 +248,11 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres | |
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info); | ||
} | ||
|
||
#ifndef OPENVINO_ARCH_X86_64 | ||
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta); | ||
return; | ||
#endif | ||
|
||
Comment on lines
+251
to
+255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd say that there should be common cross-arch base class with method If we use |
||
const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0))); | ||
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0))); | ||
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "jit_brgemm_emitter.hpp" | ||
|
||
#include "snippets/utils/utils.hpp" | ||
#include "transformations/tpp/common/op/brgemm.hpp" | ||
|
||
using namespace Xbyak_aarch64; | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace aarch64 { | ||
|
||
using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; | ||
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; | ||
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; | ||
|
||
jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, | ||
cpu_isa_t isa, | ||
const ExpressionPtr& expr, | ||
const snippets::KernelExecutorTablePtr& kernel_table, | ||
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) | ||
: jit_emitter(h, isa) { | ||
in_out_type_ = emitter_in_out_map::gpr_to_gpr; | ||
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node()); | ||
const auto& brg0Prc = brgemm_node->get_input_element_type(0); | ||
const auto& brg1Prc = brgemm_node->get_input_element_type(1); | ||
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa); | ||
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config); | ||
} | ||
|
||
std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions( | ||
const std::shared_ptr<ov::Node>& node) { | ||
// Note: Brgemm currently supports only fp32 on arm | ||
return {{element::f32, element::f32}}; | ||
} | ||
|
||
void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const { | ||
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); | ||
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); | ||
} | ||
|
||
void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in, | ||
const std::vector<size_t>& out, | ||
const std::vector<size_t>& pool_vec_idxs, | ||
const std::vector<size_t>& pool_gpr_idxs) const { | ||
validate_arguments(in, out); | ||
emit_impl(in, out); | ||
} | ||
|
||
void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const { | ||
validate_arguments(in, out); | ||
std::unordered_set<size_t> exclude = {}; | ||
store_context(exclude); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please note that we will merge #27391 soon. This PR efficently provides efficient work with reg spills - we will able to spill only needed (live) registers Just for information and to align with other our activities 😊 |
||
|
||
Xbyak_aarch64::XReg func_reg(9); | ||
h->mov(func_reg, get_execute_function_ptr()); | ||
Xbyak_aarch64::XReg x0(0); | ||
Xbyak_aarch64::XReg x1(1); | ||
Xbyak_aarch64::XReg x2(2); | ||
Xbyak_aarch64::XReg x3(3); | ||
|
||
const auto& compiled_kernel = get_compiled_kernel_ptr(); | ||
h->mov(x0, compiled_kernel); | ||
h->mov(x1, Xbyak_aarch64::XReg(in[0])); | ||
h->mov(x2, Xbyak_aarch64::XReg(in[1])); | ||
h->mov(x3, Xbyak_aarch64::XReg(out[0])); | ||
h->blr(func_reg); | ||
|
||
restore_context(exclude); | ||
} | ||
|
||
const uintptr_t jit_brgemm_emitter::get_compiled_kernel_ptr() const { | ||
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get()); | ||
} | ||
|
||
const uintptr_t jit_brgemm_emitter::get_execute_function_ptr() const { | ||
return reinterpret_cast<const uintptr_t>(BrgemmKernelExecutor::execute); | ||
} | ||
|
||
} // namespace aarch64 | ||
} // namespace intel_cpu | ||
} // namespace ov |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "emitters/plugin/aarch64/jit_emitter.hpp" | ||
#include "emitters/tpp/aarch64/kernel_executors/brgemm.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace aarch64 { | ||
|
||
class jit_brgemm_emitter : public jit_emitter { | ||
public: | ||
jit_brgemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, | ||
dnnl::impl::cpu::aarch64::cpu_isa_t isa, | ||
const ov::snippets::lowered::ExpressionPtr& expr, | ||
const snippets::KernelExecutorTablePtr& kernel_table, | ||
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); | ||
|
||
size_t get_inputs_count() const override { | ||
return 2; | ||
} | ||
|
||
static std::set<std::vector<element::Type>> get_supported_precisions( | ||
const std::shared_ptr<ov::Node>& node = nullptr); | ||
|
||
void emit_code(const std::vector<size_t>& in_idxs, | ||
const std::vector<size_t>& out_idxs, | ||
const std::vector<size_t>& pool_vec_idxs = {}, | ||
const std::vector<size_t>& pool_gpr_idxs = {}) const override; | ||
|
||
private: | ||
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override; | ||
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override; | ||
|
||
const uintptr_t get_execute_function_ptr() const; | ||
const uintptr_t get_compiled_kernel_ptr() const; | ||
|
||
std::shared_ptr<BrgemmKernelExecutor> m_kernel_executor = nullptr; | ||
}; | ||
|
||
} // namespace aarch64 | ||
} // namespace intel_cpu | ||
} // namespace ov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also RISCV64, AArch32 etc. Can we add only supported archs to condition?