Skip to content

Commit

Permalink
[onert-micro] Add cmsis-nn Mul kernel
Browse files Browse the repository at this point in the history
This commit adds cmsis-nn Mul kernel

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Sep 20, 2023
1 parent e35823f commit 5c6ee58
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ REGISTER_KERNEL(LESS_EQUAL, LessEqual)
REGISTER_KERNEL(LOGICAL_AND, LogicalAnd)
REGISTER_KERNEL(LOGICAL_OR, LogicalOr)
REGISTER_KERNEL(LEAKY_RELU, LeakyRelu)
REGISTER_KERNEL(MUL, Mul)
REGISTER_KERNEL(CONCATENATION, Concatenation)
REGISTER_KERNEL(SHAPE, Shape)
REGISTER_KERNEL(NOT_EQUAL, NotEqual)
Expand Down
36 changes: 21 additions & 15 deletions onert-micro/luci-interpreter/pal/cmsisnn/PALMul.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,35 @@
#ifndef LUCI_INTERPRETER_PAL_MUL_H
#define LUCI_INTERPRETER_PAL_MUL_H

#include <tensorflow/lite/kernels/internal/reference/mul.h>
#include "PALMulCommon.h"
#include "arm_nnfunctions.h"

namespace luci_interpreter_pal
{
template <typename T>
static inline void Mul(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
const T *input1_data, const tflite::RuntimeShape &input2_shape,
const T *input2_data, const tflite::RuntimeShape &output_shape,
T *output_data)

template <>
inline void Mul<int8_t>(const ArithmeticParams &params, const int flat_size,
const int8_t *input1_data, const int8_t *input2_data, int8_t *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
auto status = arm_elementwise_mul_s8(
input1_data, input2_data, params.input1_offset, params.input2_offset, output_data,
params.output_offset, params.output_multiplier, params.output_shift,
params.quantized_activation_min, params.quantized_activation_max, flat_size);
assert(status == ARM_CMSIS_NN_SUCCESS);
}

template <typename T>
static inline void
BroadcastMul4DSlow(tflite::ArithmeticParams &params, const tflite::RuntimeShape &input1_shape,
const T *input1_data, const tflite::RuntimeShape &input2_shape,
const T *input2_data, const tflite::RuntimeShape &output_shape, T *output_data)
template <>
inline void Mul<int16_t>(const ArithmeticParams &params, const int flat_size,
const int16_t *input1_data, const int16_t *input2_data,
int16_t *output_data)
{
tflite::reference_ops::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
auto status = arm_elementwise_mul_s16(
input1_data, input2_data, params.input1_offset, params.input2_offset, output_data,
params.output_offset, params.output_multiplier, params.output_shift,
params.quantized_activation_min, params.quantized_activation_max, flat_size);
assert(status == ARM_CMSIS_NN_SUCCESS);
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_MUL_H
95 changes: 63 additions & 32 deletions onert-micro/luci-interpreter/src/kernels/Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,63 @@
namespace luci_interpreter
{

namespace
{

#ifndef DIS_QUANT
void evalQuantized(const circle::Tensor *input1, const circle::Tensor *input2,
const circle::Tensor *output, const circle::MulOptions *options,
BaseRuntimeGraph *runtime_graph, DataType type)
{
assert(type == DataType::S16 or type == DataType::S8 && "Wrong Type");

luci_interpreter_pal::ArithmeticParams params{};
luci_interpreter::RuntimeShape input_shape1 =
kernels::getTensorRuntimeShape(input1, runtime_graph);
luci_interpreter::RuntimeShape input_shape2 =
kernels::getTensorRuntimeShape(input2, runtime_graph);

const bool need_broadcast =
luci_interpreter_pal::ProcessBroadcastShapes(input_shape1, input_shape2, &params);

assert(need_broadcast == false && "Broadcast for INT8 and INT16 not supported now");

params.input1_offset = -Tensor::zero_point(input1);
params.input2_offset = -Tensor::zero_point(input2);
params.output_offset = Tensor::zero_point(output);

const auto input1_scale = static_cast<double>(Tensor::scale(input1));
const auto input2_scale = static_cast<double>(Tensor::scale(input2));
const auto output_scale = static_cast<double>(Tensor::scale(output));

double real_multiplier = input1_scale * input2_scale / output_scale;

kernels::quantizeMultiplier(real_multiplier, &params.output_multiplier, &params.output_shift);

kernels::calculateActivationRangeQuantized(luci_actfunc(options->fused_activation_function()),
output, &params.quantized_activation_min,
&params.quantized_activation_max);
if (type == DataType::S8)
{
luci_interpreter_pal::Mul(
params, input_shape1.flatSize(),
kernels::getTensorData<int8_t>(runtime_graph->getDataByTensor(input1)),
kernels::getTensorData<int8_t>(runtime_graph->getDataByTensor(input2)),
kernels::getTensorData<int8_t>(runtime_graph->getDataByTensor(output)));
}
else
{
luci_interpreter_pal::Mul(
params, input_shape1.flatSize(),
kernels::getTensorData<int16_t>(runtime_graph->getDataByTensor(input1)),
kernels::getTensorData<int16_t>(runtime_graph->getDataByTensor(input2)),
kernels::getTensorData<int16_t>(runtime_graph->getDataByTensor(output)));
}
}
#endif // DIS_QUANT

} // namespace

void configure_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
kernels::TISOKernel kernel(cur_op, runtime_graph);
Expand Down Expand Up @@ -57,8 +114,8 @@ void execute_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph *
kernels::getTensorRuntimeShape(kernel.input2(), runtime_graph);

bool is_inplace = runtime_graph->is_inplace_op(cur_op);

switch (Tensor::element_type(kernel.input1()))
const auto type = Tensor::element_type(kernel.input1());
switch (type)
{
#ifndef DIS_FLOAT
case DataType::FLOAT32:
Expand Down Expand Up @@ -113,41 +170,15 @@ void execute_kernel_CircleMul(const circle::Operator *cur_op, BaseRuntimeGraph *
}
}
break;
#if 0
#ifndef DIS_QUANT
// TODO: check quantize Mul
case DataType::U8:
case DataType::S8:
case DataType::S16:
{
auto tiso_func = [](const luci_interpreter_pal::ArithmeticParams &params,
const luci_interpreter::RuntimeShape &input1_shape, const uint8_t *input1_data,
const luci_interpreter::RuntimeShape &input2_shape, const uint8_t *input2_data,
const luci_interpreter::RuntimeShape &output_shape, uint8_t *output_data) {
luci_interpreter_pal::Mul(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
};
auto broadcast_tiso_func =
[](const luci_interpreter_pal::ArithmeticParams &params, const luci_interpreter::RuntimeShape &input1_shape,
const uint8_t *input1_data, const luci_interpreter::RuntimeShape &input2_shape,
const uint8_t *input2_data, const luci_interpreter::RuntimeShape &output_shape,
uint8_t *output_data) {
luci_interpreter_pal::BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
};
if (is_inplace)
{
kernels::evalTISOInplaceQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
options);
}
else
{
kernels::TISOData kernel_data = kernel.readData();
kernels::evalTISOQuantizedKernel<uint8_t>(tiso_func, broadcast_tiso_func, &kernel,
&kernel_data, options);
}
evalQuantized(kernel.input1(), kernel.input2(), kernel.output(), options, runtime_graph,
type);
}
break;
#endif // DIS_QUANT
#endif // 0
default:
assert(false && "Unsupported type.");
}
Expand Down

0 comments on commit 5c6ee58

Please sign in to comment.