Skip to content

Commit

Permalink
[onert-micro] Add cmsis-nn FullyConnected kernel
Browse files Browse the repository at this point in the history
This commit adds cmsis-nn FullyConnected kernel

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Sep 20, 2023
1 parent e35823f commit 42cba5c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ REGISTER_KERNEL(CONV_2D, Conv2D)
REGISTER_KERNEL(LOGISTIC, Logistic)
REGISTER_KERNEL(GATHER, Gather)
REGISTER_KERNEL(EXP, Exp)
REGISTER_KERNEL(FULLY_CONNECTED, FullyConnected)
REGISTER_KERNEL(GREATER, Greater)
REGISTER_KERNEL(GREATER_EQUAL, GreaterEqual)
REGISTER_KERNEL(EXPAND_DIMS, ExpandDims)
Expand Down
Empty file.
114 changes: 74 additions & 40 deletions onert-micro/luci-interpreter/pal/cmsisnn/PALFullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,26 @@
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#ifndef LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H
#define LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H

#include "PALFullyConnectedCommon.h"

#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
#include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
#include <arm_nnfunctions.h>

namespace luci_interpreter_pal
{
template <typename T>
static inline void FullyConnected(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const T *input_data,
const tflite::RuntimeShape &filter_shape, const T *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, T *output_data)
{
{
// MARK: At this moment this operation doesn't support
assert(false && "FullyConnected NYI");
(void)params;
(void)input_shape;
(void)input_data;
(void)filter_shape;
(void)filter_data;
(void)bias_shape;
(void)bias_data;
(void)output_shape;
(void)output_data;
}
}

template <>
inline void
FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
const tflite::RuntimeShape &input_shape, const int8_t *input_data,
const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
const tflite::RuntimeShape &output_shape, int8_t *output_data)
inline void FullyConnected<int8_t>(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int8_t *input_data,
const int32_t *filter_shape, const int8_t *filter_data,
const int32_t *bias_data, const int32_t *output_shape,
int8_t *output_data)
{
assert(output_shape.DimensionsCount() == 2);

const int batches = output_shape.Dims(0);
const int output_depth = output_shape.Dims(1);

const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
Expand Down Expand Up @@ -107,8 +81,68 @@ FullyConnected<int8_t>(const tflite::FullyConnectedParams &params,
auto res =
arm_fully_connected_s8(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_MATH_SUCCESS);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

template <>
inline void FullyConnected(const luci_interpreter_pal::FullyConnectedParams &params,
const int32_t *, const int16_t *input_data, const int32_t *filter_shape,
const int8_t *filter_data, const int64_t *bias_data,
const int32_t *output_shape, int16_t *output_data)
{
const int batches = output_shape[0];
const int output_depth = output_shape[1];
const int accum_depth = filter_shape[1];

cmsis_nn_fc_params fc_params;
fc_params.input_offset = params.input_offset;
fc_params.output_offset = params.output_offset;
fc_params.filter_offset = params.weights_offset;
fc_params.activation.min = params.quantized_activation_min;
fc_params.activation.max = params.quantized_activation_max;

cmsis_nn_per_tensor_quant_params quant_params;
quant_params.multiplier = params.output_multiplier;
quant_params.shift = params.output_shift;

cmsis_nn_dims input_dims;
input_dims.n = batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = accum_depth;

cmsis_nn_dims filter_dims;
filter_dims.n = accum_depth;
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_depth;

cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

cmsis_nn_dims output_dims;
output_dims.n = batches;
output_dims.h = 1;
output_dims.w = 1;
output_dims.c = output_depth;

int32_t buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
auto buffer = std::make_unique<int8_t[]>(buf_size);
assert(buffer != nullptr);

cmsis_nn_context ctx;
ctx.buf = buffer.get();
ctx.size = buf_size;

auto res =
arm_fully_connected_s16(&ctx, &fc_params, &quant_params, &input_dims, input_data, &filter_dims,
filter_data, &bias_dims, bias_data, &output_dims, output_data);
assert(res == ARM_CMSIS_NN_SUCCESS);
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_FULLYCONNECTED_H
#endif // LUCI_INTERPRETER_PAL_FULLY_CONNECTED_H

0 comments on commit 42cba5c

Please sign in to comment.