Skip to content

Commit

Permalink
[onert-micro] Add float Rsqrt kernel (#11576)
Browse files Browse the repository at this point in the history
This commit adds float Rsqrt kernel

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored Sep 26, 2023
1 parent c3fb7e8 commit 6b32a35
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H
#define LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H

#include "TestDataRsqrtBase.h"

namespace luci_interpreter
{
namespace test_kernel
{
namespace rsqrt_float
{
/*
* Rsqrt Kernel:
*
* Input(1, 3, 3, 2)
* |
* Rsqrt
* |
* Output(1, 3, 3, 2)
*/
const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x00, 0x00, 0x14, 0x01, 0x00, 0x00, 0x30, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff,
0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d,
0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};

const std::vector<float> input_data = {
23.484858, 44.7245, 37.12629, 48.7498, 22.87489, 38.79412, 16.368074, 17.066277, 21.366201,
22.673359, 42.429207, 45.798313, 29.01582, 32.38278, 24.38851, 32.807495, 24.529152, 41.338783};

const std::vector<float> reference_output_data = {
0.20635073, 0.14952964, 0.16411914, 0.14322327, 0.20908386, 0.16055249,
0.2471731, 0.24206422, 0.21633977, 0.21001102, 0.15352091, 0.14776625,
0.18564472, 0.17572881, 0.20249178, 0.17458762, 0.2019104, 0.1555325};

} // namespace rsqrt_float

class TestDataFloatRsqrt : public TestDataRsqrtBase<float>
{
public:
TestDataFloatRsqrt()
{
_input_data = rsqrt_float::input_data;
_reference_output_data = rsqrt_float::reference_output_data;
_test_kernel_model_circle = rsqrt_float::test_kernel_model_circle;
}

~TestDataFloatRsqrt() override = default;
};

} // namespace test_kernel
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H
#define LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H

#include "luci_interpreter/test_models/TestDataBase.h"

namespace luci_interpreter
{
namespace test_kernel
{
namespace neg_input_output_type_mismatch_kernel
{
/*
* Rsqrt Kernel with input output type mismatch:
*
* Input(1, 3, 3, 2) - Float32
* |
* Rsqrt
* |
* Output(1, 3, 3, 2) - Int32
*/
const unsigned char test_kernel_model_circle[] = {
0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x00, 0x00, 0x24, 0x01, 0x00, 0x00, 0x40, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff,
0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00,
0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d,
0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00};
} // namespace neg_input_output_type_mismatch_kernel

class NegTestDataInputOutputTypeMismatchRsqrtKernel : public NegTestDataBase
{
public:
NegTestDataInputOutputTypeMismatchRsqrtKernel()
{
_test_kernel_model_circle = neg_input_output_type_mismatch_kernel::test_kernel_model_circle;
}

~NegTestDataInputOutputTypeMismatchRsqrtKernel() override = default;

const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; }

protected:
const unsigned char *_test_kernel_model_circle;
};

} // namespace test_kernel
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H
#define LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H

#include "luci_interpreter/test_models/TestDataBase.h"

namespace luci_interpreter
{
namespace test_kernel
{

template <typename T> class TestDataRsqrtBase : public TestDataBase<T>
{
public:
TestDataRsqrtBase() = default;

const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; }

const std::vector<T> &get_input_data_by_index(int i) override final
{
switch (i)
{
case 0:
return _input_data;
default:
assert(false && "Wrong input index");
}
}

const std::vector<T> &get_output_data_by_index(int i) override final
{
assert(i == 0);
return _reference_output_data;
}

protected:
std::vector<T> _input_data;
std::vector<T> _reference_output_data;
const unsigned char *_test_kernel_model_circle;
};

} // namespace test_kernel
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ REGISTER_KERNEL(TRANSPOSE, Transpose)
REGISTER_KERNEL(SOFTMAX, Softmax)
REGISTER_KERNEL(WHILE, While)
REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)
REGISTER_KERNEL(RSQRT, Rsqrt)
REGISTER_KERNEL(NEG, Neg)
37 changes: 37 additions & 0 deletions onert-micro/luci-interpreter/pal/common/PALRsqrt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LUCI_INTERPRETER_PAL_RSQRT_COMMON_H
#define LUCI_INTERPRETER_PAL_RSQRT_COMMON_H

#include "PALUtils.h"
#include <cmath>

namespace luci_interpreter_pal
{

inline void Rsqrt(const int flat_size, const float *input_data, float *output_data)
{
for (int i = 0; i < flat_size; ++i)
{
output_data[i] = 1.f / std::sqrt(input_data[i]);
}
}

} // namespace luci_interpreter_pal

#endif // LUCI_INTERPRETER_PAL_RSQRT_COMMON_H
1 change: 1 addition & 0 deletions onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ REGISTER_KERNEL(SOFTMAX, Softmax)
REGISTER_KERNEL(WHILE, While)
REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)
REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)
REGISTER_KERNEL(RSQRT, Rsqrt)
REGISTER_KERNEL(NEG, Neg)
72 changes: 41 additions & 31 deletions onert-micro/luci-interpreter/src/kernels/Rsqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,63 @@
* limitations under the License.
*/

#include "kernels/Rsqrt.h"
#include "Builders.h"
#include "kernels/Utils.h"
#include "SISOKernel.h"

#include <cmath>
#include "PALRsqrt.h"

namespace luci_interpreter
{

namespace kernels
void configure_kernel_CircleRsqrt(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
kernels::SISOKernel kernel(cur_op, runtime_graph);

Rsqrt::Rsqrt(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {}

void Rsqrt::configure()
{
if (input()->element_type() != output()->element_type())
{
assert(false && "Input/output tensor data type mismatch.");
}
// TODO: enable it only if kernel with dynamic shapes
output()->resize(input()->shape());
LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) ==
Tensor::element_type(kernel.output()));
LUCI_INTERPRETER_CHECK(Tensor::num_elements(kernel.input()) ==
Tensor::num_elements(kernel.output()));
LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input()) == Tensor::num_dims(kernel.output()));
}

void Rsqrt::execute() const
void execute_kernel_CircleRsqrt(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
switch (input()->element_type())
kernels::SISOKernel kernel(cur_op, runtime_graph);

const auto *input_data = runtime_graph->getDataByTensor(kernel.input());
assert(input_data);

auto *output_data = runtime_graph->getDataByTensor(kernel.output());

bool is_inplace = runtime_graph->is_inplace_op(cur_op);

switch (Tensor::element_type(kernel.input()))
{
#ifndef DIS_FLOAT
case DataType::FLOAT32:
evalFloat();
break;
{
const float *input_data_float = kernels::getTensorData<float>(input_data);
float *output_data_float = kernels::getTensorData<float>(output_data);
if (is_inplace)
{
output_data_float = const_cast<float *>(input_data_float);
}

assert(output_data_float);

const int flat_size =
kernels::getTensorRuntimeShape(kernel.input(), runtime_graph).flatSize();

luci_interpreter_pal::Rsqrt(flat_size, input_data_float, output_data_float);
break;
}
#endif // DIS_FLOAT
default:
assert(false && "Unsupported type.");
assert(false && "Unsupported type");
}
}

void Rsqrt::evalFloat() const
{
auto in = getTensorData<float>(input());
auto out = getTensorData<float>(output());
auto size = getTensorShape(input()).FlatSize();
for (auto i = in; i != in + size; ++i)
{
*out = 1.f / std::sqrt(*i);
++out;
}
if (is_inplace)
runtime_graph->makeInplaceOperation(kernel.input(), kernel.output());
}

} // namespace kernels
} // namespace luci_interpreter
Loading

0 comments on commit 6b32a35

Please sign in to comment.