Skip to content

Commit

Permalink
[CPU] [ARM64] jit clamp (openvinotoolkit#23086)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [AARCH64] jit eltwise Clamp

### Tickets:
 - *CVS-133829*
  • Loading branch information
eshoguli authored Mar 15, 2024
1 parent 3a4d1ff commit 307d0fe
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <memory>
#include "common/utils.hpp"
#include "emitters/utils.hpp"

namespace ov {
namespace intel_cpu {
Expand Down Expand Up @@ -50,15 +51,13 @@ void jit_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const st
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
Expand All @@ -72,6 +71,70 @@ std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(c
return {{element::f32, element::f32}};
}

/// CLAMP ///
jit_clamp_emitter::jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
const auto clamp = std::dynamic_pointer_cast<ov::op::v0::Clamp>(node);
if (clamp == nullptr) {
OV_CPU_JIT_EMITTER_THROW("Can't cast to ov::op::v0::Clamp");
}
min = static_cast<float>(clamp->get_min());
max = static_cast<float>(clamp->get_max());

prepare_table();
}

jit_clamp_emitter::jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const float min,
const float max,
const ov::element::Type exec_prc)
: jit_emitter(host, host_isa, exec_prc),
min(min),
max(max) {
prepare_table();
}

size_t jit_clamp_emitter::get_inputs_count() const { return 1; }

size_t jit_clamp_emitter::get_aux_vecs_count() const { return 1; }

size_t jit_clamp_emitter::get_aux_gprs_count() const { return 1; }

void jit_clamp_emitter::register_table_entries() {
push_arg_entry_of("min", dnnl::impl::float2int(min), true);
push_arg_entry_of("max", dnnl::impl::float2int(max), true);
}

void jit_clamp_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_clamp_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_vec_idxs[0]);
TReg aux = TReg(aux_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);

h->ld1r(aux.s, table_val2("min"));
h->fmax(dst.s, src.s, aux.s);
h->ld1r(aux.s, table_val2("max"));
h->fmin(dst.s, dst.s, aux.s);
}

std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32, element::f32}};
}

/// DIVIDE ///
jit_divide_emitter::jit_divide_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand All @@ -89,15 +152,13 @@ void jit_divide_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_divide_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
Expand Down Expand Up @@ -133,15 +194,13 @@ void jit_mul_add_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, cons
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
const TReg dst = TReg(out_vec_idxs[0]);
Expand Down Expand Up @@ -191,15 +250,13 @@ void jit_multiply_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, con
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_multiply_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
Expand All @@ -221,7 +278,7 @@ jit_power_static_emitter::jit_power_static_emitter(dnnl::impl::cpu::aarch64::jit
: jit_emitter(host, host_isa, node, exec_prc) {
auto powerStaticNode = ov::as_type_ptr<ov::snippets::op::PowerStatic>(node);
if (powerStaticNode == nullptr) {
OPENVINO_THROW("Can't cast to snippets::op::PowerStatic");
OV_CPU_JIT_EMITTER_THROW("Can't cast to snippets::op::PowerStatic");
}

power = powerStaticNode->get_power();
Expand Down Expand Up @@ -264,15 +321,13 @@ void jit_power_static_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_power_static_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg dst = TReg(out_vec_idxs[0]);
Expand Down Expand Up @@ -377,15 +432,13 @@ void jit_prelu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_prelu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

Expand Down Expand Up @@ -424,15 +477,13 @@ void jit_relu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const s
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_relu_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;

Expand Down Expand Up @@ -462,15 +513,13 @@ void jit_subtract_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, con
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OPENVINO_THROW("Can't create jit eltwise kernel");
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_subtract_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (exec_prc_ != ov::element::f32) {
OPENVINO_THROW("unsupported precision: " + exec_prc_.to_string());
}
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src0 = TReg(in_vec_idxs[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,37 @@ class jit_add_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_clamp_emitter : public jit_emitter {
public:
jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const float min,
const float max,
const ov::element::Type exec_prc = ov::element::f32);

jit_clamp_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

size_t get_aux_vecs_count() const override;

size_t get_aux_gprs_count() const override;

void register_table_entries() override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
float min;
float max;

void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_divide_emitter : public jit_emitter {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bool JitEltwiseExecutor::isSupported(
const float gamma) {
const auto is_supported = one_of(algorithm,
Algorithm::EltwiseAdd,
Algorithm::EltwiseClamp,
Algorithm::EltwiseDivide,
Algorithm::EltwiseMultiply,
Algorithm::EltwiseMulAdd,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,17 @@ struct EltwiseEmitter {
}
};

template<>
struct EltwiseEmitter<jit_clamp_emitter> {
void operator()(EltwiseEmitterContext& ctx) {
ctx.emitter = std::make_shared<jit_clamp_emitter>(ctx.host,
ctx.host_isa,
ctx.opData.alpha,
ctx.opData.beta,
ctx.exec_prc);
}
};

template<>
struct EltwiseEmitter<jit_power_static_emitter> {
void operator()(EltwiseEmitterContext& ctx) {
Expand All @@ -503,6 +514,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte

OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo,
OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, ov::intel_cpu::aarch64::jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, ov::intel_cpu::aarch64::jit_multiply_emitter),
Expand Down Expand Up @@ -656,6 +668,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo,
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter),
OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
const std::vector<std::pair<ov::PartialShape, std::vector<ov::Shape>>>& input_shapes) const {
#if defined(OV_CPU_WITH_ACL)
#if defined(OPENVINO_ARCH_ARM64)
if ((element_type == ov::element::f32) && (activation_type == utils::ActivationTypes::Relu)) {
if ((element_type == ov::element::f32) &&
((activation_type == utils::ActivationTypes::Clamp) ||
(activation_type == utils::ActivationTypes::Relu))) {
return "jit";
}

Expand Down

0 comments on commit 307d0fe

Please sign in to comment.