Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] Add cmsis-nn FullyConnected kernel #11564

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
REGISTER_KERNEL(LOGISTIC, Logistic)
REGISTER_KERNEL(GATHER, Gather)
REGISTER_KERNEL(EXP, Exp)
REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_KERNEL(GREATER, Greater)
REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
Expand Down
Empty file.
114 changes: 74 additions & 40 deletions onert-micro/luci-interpreter/pal/cmsisnn/PALFullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,26 @@
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#ifndef LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H

#include "PALFullyConnectedCommon.h"

#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
#include <arm_nnfunctions.h>

namespace luci_interpreter_pal
{
template <typename T>
static inline void FullyConnected(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const T *input_data,
const tflite::RuntimeShape &filter_shape, const T *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, T *output_data)
{
{
// MARK: At this moment this operation doesn't support
assert(false && "FullyConnected NYI");
(void)params;
(void)input_shape;
(void)input_data;
(void)filter_shape;
(void)filter_data;
(void)bias_shape;
(void)bias_data;
(void)output_shape;
(void)output_data;
}
}

template <>
inline void
FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const int8_t *input_data,
const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, int8_t *output_data)
inline void FullyConnected<int8_t>(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int8_t *input_data,
const int32_t *filter_shape, const int8_t *filter_data,
const int32_t *bias_data, const int32_t *output_shape,
int8_t *output_data)
{
assert(output_shape.DimensionsCount() == 2);

const int batches = output_shape.Dims(0);
const int output_depth = output_shape.Dims(1);

const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
Expand Down Expand Up @@ -107,8 +81,68 @@ FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
auto res =
arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_MATH_SUCCESS);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

template <>
inline void FullyConnected(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int16_t *input_data, const int32_t *filter_shape,
const int8_t *filter_data, const int64_t *bias_data,
const int32_t *output_shape, int16_t *output_data)
{
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
fc_params.output_offset = params.output_offset;
fc_params.filter_offset = params.weights_offset;
fc_params.activation.min = params.quantized_activation_min;
fc_params.activation.max = params.quantized_activation_max;

cmsis_nn_per_tensor_quant_params quant_params;
quant_params.multiplier = params.output_multiplier;
quant_params.shift = params.output_shift;

cmsis_nn_dims input_dims;
input_dims.n = batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = accum_depth;

cmsis_nn_dims filter_dims;
filter_dims.n = accum_depth;
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_depth;

cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

cmsis_nn_dims output_dims;
output_dims.n = batches;
output_dims.h = 1;
output_dims.w = 1;
output_dims.c = output_depth;

int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
auto buffer = std::make_unique<int8_t[]>(buf_size);
assert(buffer != nullptr);

cmsis_nn_context ctx;
ctx.buf = buffer.get();
ctx.size = buf_size;

auto res =
arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#endif // LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H
2 changes: 1 addition & 1 deletion onert-micro/luci-interpreter/pal/mcu/PALFullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ inline void FullyConnected(const luci_interpreter_pal::FullyConnectedParams &, c
const int32_t *, int16_t *)
{
// MARK: At this moment this operation doesn't support
assert(false && "FullyConnected INT8 NYI");
assert(false && "FullyConnected INT16 NYI");
}

} // namespace luci_interpreter_pal
Expand Down
54 changes: 41 additions & 13 deletions onert-micro/luci-interpreter/src/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ void evalFloat(const circle::Tensor *input, const circle::Tensor *weights,
#ifndef DIS_QUANT
void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,
const circle::Tensor *bias, const circle::Tensor *output,
const circle::FullyConnectedOptions *options, BaseRuntimeGraph *runtime_graph)
const circle::FullyConnectedOptions *options, BaseRuntimeGraph *runtime_graph,
DataType type)
{
double real_multiplier = 0.0;
int output_shift;
Expand All @@ -80,7 +81,9 @@ void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,
&output_activation_max);

int32_t input_offset = -Tensor::zero_point(input);
int32_t filter_offset = -Tensor::zero_point(weights);
int32_t filter_offset = 0;
if (type == DataType::U8)
filter_offset = -Tensor::zero_point(weights);
int32_t output_offset = Tensor::zero_point(output);

luci_interpreter_pal::FullyConnectedParams op_params{};
Expand Down Expand Up @@ -112,11 +115,31 @@ void evalQuantized(const circle::Tensor *input, const circle::Tensor *weights,

int32_t output_shape[kMaxSmallSize];
kernels::getTensorDims(output, runtime_graph, output_shape);

luci_interpreter_pal::FullyConnected(
op_params, input_shape, kernels::getTensorData<uint8_t>(input_data), weights_shape,
kernels::getTensorData<uint8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<uint8_t>(output_data));
if (type == DataType::S8)
{
luci_interpreter_pal::FullyConnected<int8_t>(
op_params, input_shape, kernels::getTensorData<int8_t>(input_data), weights_shape,
kernels::getTensorData<int8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<int8_t>(output_data));
}
else if (type == DataType::U8)
{
luci_interpreter_pal::FullyConnected<uint8_t>(
op_params, input_shape, kernels::getTensorData<uint8_t>(input_data), weights_shape,
kernels::getTensorData<uint8_t>(weights_data), kernels::getTensorData<int32_t>(bias_data),
output_shape, kernels::getTensorData<uint8_t>(output_data));
}
else if (type == DataType::S16)
{
luci_interpreter_pal::FullyConnected(
op_params, input_shape, kernels::getTensorData<int16_t>(input_data), weights_shape,
kernels::getTensorData<int8_t>(weights_data), kernels::getTensorData<int64_t>(bias_data),
output_shape, kernels::getTensorData<int16_t>(output_data));
}
else
{
assert(false && "Unsupported quantize type");
}
}
#endif

Expand Down Expand Up @@ -160,9 +183,12 @@ void configure_kernel_CircleFullyConnected(const circle::Operator *cur_op,
}
else if (Tensor::element_type(weights) == DataType::S8)
{
LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8);
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::S8);
LUCI_INTERPRETER_CHECK(!bias || Tensor::element_type(bias) == DataType::S32)
LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == DataType::S8 ||
Tensor::element_type(input) == DataType::S16);
LUCI_INTERPRETER_CHECK(Tensor::element_type(output) == DataType::S8 ||
Tensor::element_type(output) == DataType::S16);
LUCI_INTERPRETER_CHECK(!bias || Tensor::element_type(bias) == DataType::S32 ||
Tensor::element_type(bias) == DataType::S64)
}
#endif // DIS_QUANT
else
Expand Down Expand Up @@ -210,12 +236,14 @@ void execute_kernel_CircleFullyConnected(const circle::Operator *cur_op,
assert(output != nullptr);

const auto *options = cur_op->builtin_options_as_FullyConnectedOptions();

switch (Tensor::element_type(input))
const auto input_type = Tensor::element_type(input);
switch (input_type)
{
#ifndef DIS_QUANT
case DataType::U8:
evalQuantized(input, weights, bias, output, options, runtime_graph);
case DataType::S8:
case DataType::S16:
evalQuantized(input, weights, bias, output, options, runtime_graph, input_type);
break;
#endif // DIS_QUANT
#ifndef DIS_FLOAT
Expand Down