Skip to content

Commit

Permalink
executor cache
Browse files Browse the repository at this point in the history
x
  • Loading branch information
chenhu-wang committed Jan 8, 2025
1 parent 7f0a72b commit 6ca4f1b
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 85 deletions.
2 changes: 1 addition & 1 deletion cmake/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND X86_64" OFF)
ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU" OFF)

ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -10,6 +10,7 @@
#include "emitters/snippets/aarch64/jit_kernel_emitter.hpp"
#include "emitters/snippets/aarch64/jit_loop_emitters.hpp"
#include "emitters/snippets/aarch64/jit_memory_emitters.hpp"
#include "emitters/snippets/aarch64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"
#include "emitters/utils.hpp"
#include "jit_snippets_emitters.hpp"
Expand All @@ -23,13 +24,15 @@
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"

namespace ov {

#define CREATE_SNIPPETS_EMITTER(e_type) \
#define CREATE_SNIPPETS_EMITTER(e_type, ...) \
{ \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
return std::make_shared<e_type>(h.get(), isa, expr); \
return std::make_shared<e_type>(h.get(), isa, expr, ##__VA_ARGS__); \
}, \
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
return e_type::get_supported_precisions(n); \
Expand Down Expand Up @@ -202,6 +205,10 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

// brgemm
jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);

// control flow
jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter);
jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_dynamic_emitter);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -8,6 +8,7 @@
#include "cpu/aarch64/jit_generator.hpp"
#include "snippets/generator.hpp"
#include "snippets/target_machine.hpp"
#include "cache/multi_cache.h"

namespace ov {
namespace intel_cpu {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@

#include "jit_brgemm_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

using namespace Xbyak_aarch64;
Expand All @@ -33,13 +28,12 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa);
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
m_kernel_executor = kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
// Note: Brgemm currently supports only fp32
// Note: Brgemm currently supports only fp32 on arm
return {{element::f32, element::f32}};
}

Expand All @@ -48,7 +42,7 @@ void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

void jit_brgemm_emitter::emit_code(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);
emit_impl(in, out);
}
Expand All @@ -60,7 +54,6 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec

Xbyak_aarch64::XReg func_reg(9);
h->mov(func_reg, get_execute_function_ptr());

Xbyak_aarch64::XReg x0(0);
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ class jit_brgemm_emitter : public jit_emitter {
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_count() const override {
return m_memory_offsets.size() - 1;
return 2;
}

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;
void emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const;

private:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
Expand All @@ -36,11 +35,6 @@ class jit_brgemm_emitter : public jit_emitter {
const uintptr_t get_execute_function_ptr() const;
const uintptr_t get_compiled_kernel_ptr() const;

// Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in
// runtime
std::vector<size_t> m_memory_offsets{};
// Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmKernelExecutor> m_kernel_executor = nullptr;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
//

#include "brgemm.hpp"

#include "transformations/tpp/x64/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;

#define HASH(X) seed = hash_combine(seed, X)

namespace ov {
namespace intel_cpu {
Expand All @@ -28,45 +26,37 @@ BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
m_static_params(
std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
m_hash = compute_hash();
}

BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)),
m_prefetching_flags(false),
isa(primitive_isa) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
}
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
m_prefetching_flags = false;
isa = primitive_isa;
}

size_t BrgemmKernelConfig::StaticParams::compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa) {
return hash_combine(0, aarch_isa);
}

bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const {
return StaticBaseParams::operator==(rhs) &&
isa == rhs.isa &&
m_type_in0 == rhs.m_type_in0 &&
m_type_in1 == rhs.m_type_in1 &&
m_type_exec == rhs.m_type_exec &&
m_type_out0 == rhs.m_type_out0 &&
m_compile_flags == rhs.m_compile_flags &&
m_prefetching_flags == rhs.m_prefetching_flags;
return StaticBaseParams::operator==(rhs) && isa == rhs.isa && m_type_in0 == rhs.m_type_in0 &&
m_type_in1 == rhs.m_type_in1 && m_type_exec == rhs.m_type_exec && m_type_out0 == rhs.m_type_out0 &&
m_compile_flags == rhs.m_compile_flags && m_prefetching_flags == rhs.m_prefetching_flags;
}

BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache,
BrgemmKernelConfig config)
BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config)
: CPUKernelExecutor<BrgemmKernelConfig, BrgemmTppCompiledKernel>(std::move(kernel_cache), std::move(config)) {}

std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel(
const BrgemmKernelConfig& config) const {
std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel(const BrgemmKernelConfig& config) const {
std::shared_ptr<BrgemmTppCompiledKernel> compiled_kernel = std::make_shared<BrgemmTppCompiledKernel>();

// Brgemm is not executable - nothing to compile
Expand All @@ -84,8 +74,8 @@ std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel(
config.get_type_out0(),
config.get_type_exec());
const auto& compile_flag = config.get_compile_flags();
auto refreshed_compile_flag = config.get_beta() == 0 ? config.get_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0 :
compile_flag;
auto refreshed_compile_flag =
config.get_beta() == 0 ? config.get_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0 : compile_flag;
compiled_kernel->brgemm_kernel = std::make_shared<libxsmm_gemmfunction>(COMPILE_BRGEMM_TPP_KERNEL(
libxsmm_dispatch_gemm(m_shape, refreshed_compile_flag, config.get_prefetching_flags())));

Expand Down Expand Up @@ -119,7 +109,13 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back());
}

config.update(config.get_M(), config.get_N(), config.get_K(), io_strides[0], io_strides[1], io_strides[2], config.get_beta());
config.update(config.get_M(),
config.get_N(),
config.get_K(),
io_strides[0],
io_strides[1],
io_strides[2],
config.get_beta());
// update compile flag, which is depend on beta. should be part of hash.
config.set_compile_flags(config.get_beta() == 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,21 @@

#pragma once

#include <cpu/x64/brgemm/brgemm.hpp>
#include "libxsmm.h"

#include "cpu/aarch64/cpu_isa_traits.hpp"
#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
// #include "emitters/snippets/jit_snippets_call_args.hpp"
// #include "openvino/core/type/element_type.hpp"
// #include "snippets/lowered/loop_info.hpp"
// #include "snippets/lowered/loop_manager.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp"
#include "libxsmm.h"

namespace ov {
namespace intel_cpu {
namespace aarch64 {

struct BrgemmKernelConfig : public BrgemmBaseKernelConfig {
public:
BrgemmKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa = dnnl::impl::cpu::aarch64::cpu_isa_t::isa_undef);
BrgemmKernelConfig(
const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa = dnnl::impl::cpu::aarch64::cpu_isa_t::isa_undef);
BrgemmKernelConfig() = delete;

std::unique_ptr<snippets::KernelExecutorBase::GenericConfig> get_clone_ptr() const override {
Expand Down Expand Up @@ -71,7 +64,7 @@ struct BrgemmKernelConfig : public BrgemmBaseKernelConfig {
}

private:
struct StaticParams : public StaticBaseParams{
struct StaticParams : public StaticBaseParams {
StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa);
Expand All @@ -83,19 +76,19 @@ struct BrgemmKernelConfig : public BrgemmBaseKernelConfig {
}
size_t compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa);

dnnl::impl::cpu::aarch64::cpu_isa_t isa{dnnl::impl::cpu::aarch64::isa_undef};
dnnl::impl::cpu::aarch64::cpu_isa_t isa;
libxsmm_datatype m_type_in0;
libxsmm_datatype m_type_in1;
libxsmm_datatype m_type_out0;
libxsmm_datatype m_type_exec;
libxsmm_bitfield m_compile_flags {0};
const bool m_prefetching_flags{false};
libxsmm_bitfield m_compile_flags;
bool m_prefetching_flags;
};
std::shared_ptr<StaticBaseParams> get_static_params() const override {
return m_static_params;
}

libxsmm_bitfield m_compile_flags {0};
libxsmm_bitfield m_compile_flags{0};
std::shared_ptr<StaticParams> m_static_params{nullptr};
};

Expand All @@ -113,11 +106,11 @@ class BrgemmKernelExecutor : public BrgemmBaseKernelExecutor,
BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config);
virtual ~BrgemmKernelExecutor() = default;

/** Function that will be called in runtime to execute the kernel */
// Function that will be called in runtime to execute the kernel
static void execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0);

private:
std::shared_ptr<BrgemmTppCompiledKernel> compile_kernel(const BrgemmKernelConfig& c) const;
std::shared_ptr<BrgemmTppCompiledKernel> compile_kernel(const BrgemmKernelConfig& c) const override;

void update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
beta = get_beta(loop_manager, static_cast<int>(loop_ids.back()), current_expanded_loop_info);
}

#ifndef OPENVINO_ARCH_X86_64
config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), 0, 0, 0, beta);
return;
#endif

const auto LDA = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(0)));
const auto LDC = DIM_CAST(snippets::utils::get_dim_stride(expr->get_output_port(0)));
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));
Expand All @@ -261,7 +266,6 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

Expand Down Expand Up @@ -327,6 +331,7 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;

brgemm_p.BS = 1; // default value
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp"
#include "emitters/snippets/brgemm_base.hpp"

namespace ov {
namespace intel_cpu {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "brgemm_base.hpp"
#include "emitters/snippets/brgemm_base.hpp"

namespace ov {
namespace intel_cpu {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <cpu/x64/brgemm/brgemm.hpp>
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>

#include "brgemm_base.hpp"
#include "emitters/snippets/brgemm_base.hpp"
#include "emitters/plugin/x64/jit_emitter.hpp"
#include "emitters/snippets/cpu_kernel_executor_table.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"
Expand Down
Loading

0 comments on commit 6ca4f1b

Please sign in to comment.