From 4821f87eee7323c32e085d70a395c198bfe8df88 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Tue, 26 Sep 2023 15:07:30 +0300 Subject: [PATCH] [onert-micro] Add float SelectV2 kernels This commit adds float SelectV2 kernels for onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../select_v2/FloatSelectV2Kernel.h | 99 ++++++++++++ .../test_models/select_v2/NegSelectV2Kernel.h | 93 +++++++++++ .../select_v2/TestDataSelectV2Base.h | 66 ++++++++ .../pal/cmsisnn/KernelsToBuild.lst | 1 + .../luci-interpreter/pal/common/PALSelectV2.h | 53 +++++++ .../pal/mcu/KernelsToBuild.lst | 1 + .../luci-interpreter/src/kernels/SelectV2.cpp | 149 ++++++++++++++++++ .../src/kernels/SelectV2.test.cpp | 97 ++++++++++++ 8 files changed, 559 insertions(+) create mode 100644 onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/FloatSelectV2Kernel.h create mode 100644 onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/NegSelectV2Kernel.h create mode 100644 onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/TestDataSelectV2Base.h create mode 100644 onert-micro/luci-interpreter/pal/common/PALSelectV2.h create mode 100644 onert-micro/luci-interpreter/src/kernels/SelectV2.cpp create mode 100644 onert-micro/luci-interpreter/src/kernels/SelectV2.test.cpp diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/FloatSelectV2Kernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/FloatSelectV2Kernel.h new file mode 100644 index 00000000000..fc73dbd0f2a --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/FloatSelectV2Kernel.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_FLOAT_SELECT_V2_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_FLOAT_SELECT_V2_KERNEL_H + +#include "TestDataSelectV2Base.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace select_v2_float +{ +/* + * SelectV2 Kernel: + * + * InputCond(1, 3) X(1, 3) Y(1, 3) + * | | | + * SelectV2 + * | + * Output(1, 3) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x98, 0x01, 0x00, 0x00, 0xb4, 0x01, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x84, 0xff, 0xff, 0xff, 0x88, 0xff, 0xff, 0xff, 0x8c, 0xff, 0xff, 0xff, + 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x62, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xff, 0xff, 0xff, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x6f, 0x66, 0x6d, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xdc, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x11, 0x00, 0x00, 0x00, + 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, + 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data_1 = {true, false, false}; + +const std::vector input_data_2 = {1.1, 1.2, 1.3}; + +const std::vector input_data_3 = {2.1, 2.2, 2.3}; + +const std::vector reference_output_data = {1.1, 2.2, 2.3}; + +} // namespace select_v2_float + +class TestDataFloatSelectV2 : public TestDataSelectV2Base +{ +public: + TestDataFloatSelectV2() + { + _input_data_1 = select_v2_float::input_data_1; + _input_data_2 = select_v2_float::input_data_2; + _input_data_3 = select_v2_float::input_data_3; + _reference_output_data = select_v2_float::reference_output_data; + _test_kernel_model_circle = select_v2_float::test_kernel_model_circle; + } + + ~TestDataFloatSelectV2() override = default; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_FLOAT_SELECT_V2_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/NegSelectV2Kernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/NegSelectV2Kernel.h new file mode 100644 index 00000000000..0a96f9ad6bd --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/NegSelectV2Kernel.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_NEG_SELECT_V2_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_NEG_SELECT_V2_KERNEL_H + +#include "TestDataSelectV2Base.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace neg_select_v2_input_type_mismatch +{ + +/* + * SelectV2 Kernel with input type mismatch (input_x_type should be equal to input_y_type): + * + * Input_conv(1, 3) - Bool input_x(1, 3) - Int32 input_y(1, 3)- Float32 + * \ | / + * \ | / + * SelectV2 + * | + * Output(1, 4, 4, 1) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x84, 0xff, 0xff, 0xff, 0x88, 0xff, 0xff, 0xff, 0x8c, 0xff, 0xff, 0xff, + 0x90, 0xff, 0xff, 0xff, 0x94, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x62, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xdc, 0xff, 0xff, 0xff, + 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x6f, 0x66, 0x6d, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x65, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xd8, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x06, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6e, 0x64, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, + 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, + 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; +} // namespace neg_select_v2_input_type_mismatch + +class NegTestDataInputMismatchSelectV2Kernel : public NegTestDataBase +{ +public: + NegTestDataInputMismatchSelectV2Kernel() + { + _test_kernel_model_circle = neg_select_v2_input_type_mismatch::test_kernel_model_circle; + } + + ~NegTestDataInputMismatchSelectV2Kernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_NEG_TRANSPOSE_CONV_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/TestDataSelectV2Base.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/TestDataSelectV2Base.h new file mode 100644 index 00000000000..b0d951ae70e --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/select_v2/TestDataSelectV2Base.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_SELECT_V2_KERNEL_BASE_H +#define LUCI_INTERPRETER_TEST_MODELS_SELECT_V2_KERNEL_BASE_H + +#include "luci_interpreter/test_models/TestDataBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ + +template class TestDataSelectV2Base : public TestDataBase +{ +public: + TestDataSelectV2Base() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 1: + return _input_data_2; + case 2: + return _input_data_3; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_cond_input() { return _input_data_1; } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data_1; + std::vector _input_data_2; + std::vector _input_data_3; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_SELECT_V2_KERNEL_BASE_H diff --git a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst index 68dac6a16c5..ad138f67b50 100644 --- a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst @@ -39,6 +39,7 @@ REGISTER_KERNEL(SPLIT_V, SplitV) REGISTER_KERNEL(TANH, Tanh) REGISTER_KERNEL(TRANSPOSE, Transpose) REGISTER_KERNEL(SOFTMAX, Softmax) +REGISTER_KERNEL(SELECT_V2, SelectV2) REGISTER_KERNEL(WHILE, While) REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear) REGISTER_KERNEL(NEG, Neg) diff --git a/onert-micro/luci-interpreter/pal/common/PALSelectV2.h b/onert-micro/luci-interpreter/pal/common/PALSelectV2.h new file mode 100644 index 00000000000..52302b71357 --- /dev/null +++ b/onert-micro/luci-interpreter/pal/common/PALSelectV2.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_PAL_SELECT_V2_COMMON_H +#define LUCI_INTERPRETER_PAL_SELECT_V2_COMMON_H + +#include "PALUtils.h" +#include "ProcessBroadcastShapes.h" + +namespace luci_interpreter_pal +{ + +template +void Select(const luci_interpreter::RuntimeShape &input_condition_shape, + const D *input_condition_data, const luci_interpreter::RuntimeShape &input_x_shape, + const T *input_x_data, const luci_interpreter::RuntimeShape &input_y_shape, + const T *input_y_data, const luci_interpreter::RuntimeShape &output_shape, + T *output_data) +{ + int64_t flatsize; + // Allow select operator executions on mixed scalar tensors and one element + // tensors. + if (input_condition_shape.flatSize() == 1 && input_x_shape.flatSize() == 1 && + input_y_shape.flatSize() == 1 && output_shape.flatSize() == 1) + { + flatsize = 1; + } + else + { + flatsize = input_condition_shape.flatSize(); + } + for (int64_t i = 0; i < flatsize; ++i) + { + output_data[i] = input_condition_data[i] ? input_x_data[i] : input_y_data[i]; + } +} + +} // namespace luci_interpreter_pal + +#endif // LUCI_INTERPRETER_PAL_SELECT_V2_COMMON_H diff --git a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst index 076c6795751..2aa26f268df 100644 --- a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst @@ -46,6 +46,7 @@ REGISTER_KERNEL(TANH, Tanh) REGISTER_KERNEL(TRANSPOSE, Transpose) REGISTER_KERNEL(TRANSPOSE_CONV, TransposeConv) REGISTER_KERNEL(SOFTMAX, Softmax) +REGISTER_KERNEL(SELECT_V2, SelectV2) REGISTER_KERNEL(WHILE, While) REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM) REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear) diff --git a/onert-micro/luci-interpreter/src/kernels/SelectV2.cpp b/onert-micro/luci-interpreter/src/kernels/SelectV2.cpp new file mode 100644 index 00000000000..0864cf2c959 --- /dev/null +++ b/onert-micro/luci-interpreter/src/kernels/SelectV2.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2023 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Builders.h" +#include "kernels/Utils.h" + +#include "PALSelectV2.h" + +namespace luci_interpreter +{ + +namespace +{ + +constexpr int kInputTensorCondition = 0; +constexpr int kInputTensorX = 1; +constexpr int kInputTensorY = 2; +constexpr int kOutputTensor = 0; + +template +void CallSelect(const circle::Tensor *input_condition, const circle::Tensor *input_x, + const circle::Tensor *input_y, const circle::Tensor *output, bool need_broadcast, + RuntimeGraph *runtime_graph) +{ + using Func = decltype(luci_interpreter_pal::Select) *; + Func select_func; + if (need_broadcast) + { + assert(false && "Broadcast not supported now"); + } + else + { + select_func = luci_interpreter_pal::Select; + } + + select_func(kernels::getTensorRuntimeShape(input_condition, runtime_graph), + kernels::getTensorData(runtime_graph->getDataByTensor(input_condition)), + kernels::getTensorRuntimeShape(input_x, runtime_graph), + kernels::getTensorData(runtime_graph->getDataByTensor(input_x)), + kernels::getTensorRuntimeShape(input_y, runtime_graph), + kernels::getTensorData(runtime_graph->getDataByTensor(input_y)), + kernels::getTensorRuntimeShape(output, runtime_graph), + kernels::getTensorData(runtime_graph->getDataByTensor(output))); +} + +} // namespace + +void configure_kernel_CircleSelectV2(const circle::Operator *cur_op, + BaseRuntimeGraph *runtime_graph) +{ + const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition); + const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX); + const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY); + const auto output_index = cur_op->outputs()->operator[](kOutputTensor); + + assert(input_cond_index != -1); + assert(input_x_index != -1); + assert(input_y_index != -1); + assert(output_index != -1); + + const auto input_cond = runtime_graph->getCircleTensorByIndex(input_cond_index); + const auto input_x = runtime_graph->getCircleTensorByIndex(input_x_index); + const auto input_y = runtime_graph->getCircleTensorByIndex(input_y_index); + const auto output = runtime_graph->getCircleTensorByIndex(output_index); + + assert(input_cond != nullptr); + assert(input_x != nullptr); + assert(input_y != nullptr); + + // Input condition should be bool + LUCI_INTERPRETER_CHECK(Tensor::element_type(input_cond) == DataType::BOOL); + + // X, Y and Output should be the same type + LUCI_INTERPRETER_CHECK(Tensor::element_type(input_x) == Tensor::element_type(input_y)); + LUCI_INTERPRETER_CHECK(Tensor::element_type(input_x) == Tensor::element_type(output)); + + bool possible_mixed_scaler = + Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 && + Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1; + + bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) && + Tensor::num_elements(input_x) == Tensor::num_elements(input_y); + + // Broadcast not supported now + if (not same_shape and not possible_mixed_scaler) + { + LUCI_INTERPRETER_CHECK(false); + } +} + +void execute_kernel_CircleSelectV2(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) +{ + const auto input_cond_index = cur_op->inputs()->operator[](kInputTensorCondition); + const auto input_x_index = cur_op->inputs()->operator[](kInputTensorX); + const auto input_y_index = cur_op->inputs()->operator[](kInputTensorY); + const auto output_index = cur_op->outputs()->operator[](kOutputTensor); + + assert(input_cond_index != -1); + assert(input_x_index != -1); + assert(input_y_index != -1); + assert(output_index != -1); + + const auto input_cond = runtime_graph->getCircleTensorByIndex(input_cond_index); + const auto input_x = runtime_graph->getCircleTensorByIndex(input_x_index); + const auto input_y = runtime_graph->getCircleTensorByIndex(input_y_index); + const auto output = runtime_graph->getCircleTensorByIndex(output_index); + + assert(input_cond != nullptr); + assert(input_x != nullptr); + assert(input_y != nullptr); + + bool possible_mixed_scaler = + Tensor::num_elements(input_cond) == 1 && Tensor::num_elements(input_x) == 1 && + Tensor::num_elements(input_y) == 1 && Tensor::num_elements(output) == 1; + + bool same_shape = Tensor::num_elements(input_cond) == Tensor::num_elements(input_x) && + Tensor::num_elements(input_x) == Tensor::num_elements(input_y); + bool is_broadcast = false; + if (not possible_mixed_scaler and not same_shape) + is_broadcast = true; + + const auto type = Tensor::element_type(input_x); + switch (type) + { +#ifndef DIS_FLOAT + case DataType::FLOAT32: + CallSelect(input_cond, input_x, input_y, output, is_broadcast, runtime_graph); + break; +#endif // DIS_FLOAT + default: + assert(false && "Unsupported type."); + } +} + +} // namespace luci_interpreter diff --git a/onert-micro/luci-interpreter/src/kernels/SelectV2.test.cpp b/onert-micro/luci-interpreter/src/kernels/SelectV2.test.cpp new file mode 100644 index 00000000000..851cf5c6072 --- /dev/null +++ b/onert-micro/luci-interpreter/src/kernels/SelectV2.test.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernels/TestUtils.h" +#include "luci_interpreter/test_models/select_v2/FloatSelectV2Kernel.h" +#include "luci_interpreter/test_models/select_v2/NegSelectV2Kernel.h" + +#include "loader/ModuleLoader.h" + +namespace luci_interpreter +{ +namespace +{ + +using namespace testing; + +class SelectV2Test : public ::testing::Test +{ + // Do nothing +}; + +template +std::vector checkSelectV2Kernel(test_kernel::TestDataSelectV2Base *test_data_base) +{ + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; + + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_base->get_model_ptr()); + ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input); + + auto *main_runtime_graph = runtime_module.getMainGraph(); + assert(main_runtime_graph->getNumOfInputTensors() == 3); + + // set input data + { + auto *input_tensor_data_1 = + reinterpret_cast(main_runtime_graph->configureGraphInput(0)); + std::copy(test_data_base->get_cond_input().begin(), test_data_base->get_cond_input().end(), + input_tensor_data_1); + + auto *input_tensor_data_2 = reinterpret_cast(main_runtime_graph->configureGraphInput(1)); + std::copy(test_data_base->get_input_data_by_index(1).begin(), + test_data_base->get_input_data_by_index(1).end(), input_tensor_data_2); + + auto *input_tensor_data_3 = reinterpret_cast(main_runtime_graph->configureGraphInput(2)); + std::copy(test_data_base->get_input_data_by_index(2).begin(), + test_data_base->get_input_data_by_index(2).end(), input_tensor_data_3); + } + + runtime_module.execute(); + + assert(main_runtime_graph->getNumOfOutputTensors() == 1); + + T *output_data = reinterpret_cast(main_runtime_graph->getOutputDataByIndex(0)); + const size_t num_elements = (main_runtime_graph->getOutputDataSizeByIndex(0) / sizeof(T)); + std::vector output_data_vector(output_data, output_data + num_elements); + return output_data_vector; +} + +TEST_F(SelectV2Test, Float_P) +{ + test_kernel::TestDataFloatSelectV2 test_data_kernel; + std::vector output_data_vector = checkSelectV2Kernel(&test_data_kernel); + EXPECT_THAT(output_data_vector, kernels::testing::FloatArrayNear( + test_data_kernel.get_output_data_by_index(0), 0.0001f)); +} + +TEST_F(SelectV2Test, Input_type_mismatch_NEG) +{ + test_kernel::NegTestDataInputMismatchSelectV2Kernel test_data_kernel; + + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_kernel.get_model_ptr()); + EXPECT_DEATH(ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input), + ""); +} + +} // namespace +} // namespace luci_interpreter