diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/FloatRsqrtKernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/FloatRsqrtKernel.h new file mode 100644 index 00000000000..51fa2df144a --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/FloatRsqrtKernel.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H + +#include "TestDataRsqrtBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace rsqrt_float +{ +/* + * Rsqrt Kernel: + * + * Input(1, 3, 3, 2) + * | + * Rsqrt + * | + * Output(1, 3, 3, 2) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x14, 0x01, 0x00, 0x00, 0x30, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, + 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xd4, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data = { + 23.484858, 44.7245, 37.12629, 48.7498, 22.87489, 38.79412, 16.368074, 17.066277, 21.366201, + 22.673359, 42.429207, 45.798313, 29.01582, 32.38278, 24.38851, 32.807495, 24.529152, 41.338783}; + +const std::vector reference_output_data = { + 0.20635073, 0.14952964, 0.16411914, 0.14322327, 0.20908386, 0.16055249, + 0.2471731, 0.24206422, 0.21633977, 0.21001102, 0.15352091, 0.14776625, + 0.18564472, 0.17572881, 0.20249178, 0.17458762, 0.2019104, 0.1555325}; + +} // namespace rsqrt_float + +class TestDataFloatRsqrt : public TestDataRsqrtBase +{ +public: + TestDataFloatRsqrt() + { + _input_data = rsqrt_float::input_data; + _reference_output_data = rsqrt_float::reference_output_data; + _test_kernel_model_circle = rsqrt_float::test_kernel_model_circle; + } + + ~TestDataFloatRsqrt() override = default; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_FLOAT_RSQRT_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/NegRsqrtKernel.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/NegRsqrtKernel.h new file mode 100644 index 00000000000..0ee2eb57372 --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/NegRsqrtKernel.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H +#define LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H + +#include "luci_interpreter/test_models/TestDataBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ +namespace neg_input_output_type_mismatch_kernel +{ +/* + * Rsqrt Kernel with input output type mismatch: + * + * Input(1, 3, 3, 2) - Float32 + * | + * Rsqrt + * | + * Output(1, 3, 3, 2) - Int32 + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x24, 0x01, 0x00, 0x00, 0x40, 0x01, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xff, 0xff, 0xff, + 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x48, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x69, 0x66, 0x6d, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, + 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; +} // namespace neg_input_output_type_mismatch_kernel + +class NegTestDataInputOutputTypeMismatchRsqrtKernel : public NegTestDataBase +{ +public: + NegTestDataInputOutputTypeMismatchRsqrtKernel() + { + _test_kernel_model_circle = neg_input_output_type_mismatch_kernel::test_kernel_model_circle; + } + + ~NegTestDataInputOutputTypeMismatchRsqrtKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_NEG_RSQRT_KERNEL_H diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/TestDataRsqrtBase.h b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/TestDataRsqrtBase.h new file mode 100644 index 00000000000..d0a0cc5774f --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/test_models/rsqrt/TestDataRsqrtBase.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H +#define LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H + +#include "luci_interpreter/test_models/TestDataBase.h" + +namespace luci_interpreter +{ +namespace test_kernel +{ + +template class TestDataRsqrtBase : public TestDataBase +{ +public: + TestDataRsqrtBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_kernel +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_TEST_MODELS_RSQRT_KERNEL_BASE_H diff --git a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst index 930fa0d5c72..3208613e908 100644 --- a/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/cmsisnn/KernelsToBuild.lst @@ -40,4 +40,5 @@ REGISTER_KERNEL(TRANSPOSE, Transpose) REGISTER_KERNEL(SOFTMAX, Softmax) REGISTER_KERNEL(WHILE, While) REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear) +REGISTER_KERNEL(RSQRT, Rsqrt) REGISTER_KERNEL(NEG, Neg) diff --git a/onert-micro/luci-interpreter/pal/common/PALRsqrt.h b/onert-micro/luci-interpreter/pal/common/PALRsqrt.h new file mode 100644 index 00000000000..a8f4871adc3 --- /dev/null +++ b/onert-micro/luci-interpreter/pal/common/PALRsqrt.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LUCI_INTERPRETER_PAL_RSQRT_COMMON_H +#define LUCI_INTERPRETER_PAL_RSQRT_COMMON_H + +#include "PALUtils.h" +#include + +namespace luci_interpreter_pal +{ + +inline void Rsqrt(const int flat_size, const float *input_data, float *output_data) +{ + for (int i = 0; i < flat_size; ++i) + { + output_data[i] = 1.f / std::sqrt(input_data[i]); + } +} + +} // namespace luci_interpreter_pal + +#endif // LUCI_INTERPRETER_PAL_RSQRT_COMMON_H diff --git a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst index d06543c6988..355c2568461 100644 --- a/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst +++ b/onert-micro/luci-interpreter/pal/mcu/KernelsToBuild.lst @@ -48,4 +48,5 @@ REGISTER_KERNEL(SOFTMAX, Softmax) REGISTER_KERNEL(WHILE, While) REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM) REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear) +REGISTER_KERNEL(RSQRT, Rsqrt) REGISTER_KERNEL(NEG, Neg) diff --git a/onert-micro/luci-interpreter/src/kernels/Rsqrt.cpp b/onert-micro/luci-interpreter/src/kernels/Rsqrt.cpp index c45c3e4cac7..7df4f921183 100644 --- a/onert-micro/luci-interpreter/src/kernels/Rsqrt.cpp +++ b/onert-micro/luci-interpreter/src/kernels/Rsqrt.cpp @@ -14,53 +14,63 @@ * limitations under the License. */ -#include "kernels/Rsqrt.h" +#include "Builders.h" #include "kernels/Utils.h" +#include "SISOKernel.h" -#include +#include "PALRsqrt.h" namespace luci_interpreter { -namespace kernels +void configure_kernel_CircleRsqrt(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { + kernels::SISOKernel kernel(cur_op, runtime_graph); -Rsqrt::Rsqrt(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {} - -void Rsqrt::configure() -{ - if (input()->element_type() != output()->element_type()) - { - assert(false && "Input/output tensor data type mismatch."); - } - // TODO: enable it only if kernel with dynamic shapes - output()->resize(input()->shape()); + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input()) == + Tensor::element_type(kernel.output())); + LUCI_INTERPRETER_CHECK(Tensor::num_elements(kernel.input()) == + Tensor::num_elements(kernel.output())); + LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input()) == Tensor::num_dims(kernel.output())); } -void Rsqrt::execute() const +void execute_kernel_CircleRsqrt(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { - switch (input()->element_type()) + kernels::SISOKernel kernel(cur_op, runtime_graph); + + const auto *input_data = runtime_graph->getDataByTensor(kernel.input()); + assert(input_data); + + auto *output_data = runtime_graph->getDataByTensor(kernel.output()); + + bool is_inplace = runtime_graph->is_inplace_op(cur_op); + + switch (Tensor::element_type(kernel.input())) { +#ifndef DIS_FLOAT case DataType::FLOAT32: - evalFloat(); - break; + { + const float *input_data_float = kernels::getTensorData(input_data); + float *output_data_float = kernels::getTensorData(output_data); + if (is_inplace) + { + output_data_float = const_cast(input_data_float); + } + assert(output_data_float); + + const int flat_size = + kernels::getTensorRuntimeShape(kernel.input(), runtime_graph).flatSize(); + + luci_interpreter_pal::Rsqrt(flat_size, input_data_float, output_data_float); + break; + } +#endif // DIS_FLOAT default: - assert(false && "Unsupported type."); + assert(false && "Unsupported type"); } -} -void Rsqrt::evalFloat() const -{ - auto in = getTensorData(input()); - auto out = getTensorData(output()); - auto size = getTensorShape(input()).FlatSize(); - for (auto i = in; i != in + size; ++i) - { - *out = 1.f / std::sqrt(*i); - ++out; - } + if (is_inplace) + runtime_graph->makeInplaceOperation(kernel.input(), kernel.output()); } - -} // namespace kernels } // namespace luci_interpreter diff --git a/onert-micro/luci-interpreter/src/kernels/Rsqrt.h b/onert-micro/luci-interpreter/src/kernels/Rsqrt.h deleted file mode 100644 index adc5bcfa2cb..00000000000 --- a/onert-micro/luci-interpreter/src/kernels/Rsqrt.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LUCI_INTERPRETER_KERNELS_RSQRT_H -#define LUCI_INTERPRETER_KERNELS_RSQRT_H - -#include "core/Kernel.h" -#include "core/KernelParams.h" - -namespace luci_interpreter -{ -namespace kernels -{ - -class Rsqrt : public Kernel -{ -public: - Rsqrt(const Tensor *input, Tensor *output); - - const Tensor *input() const { return _inputs[0]; } - Tensor *output() const { return _outputs[0]; } - - void configure() override; - void execute() const override; - -private: - void evalFloat() const; -}; - -} // namespace kernels -} // namespace luci_interpreter - -#endif // LUCI_INTERPRETER_KERNELS_RSQRT_H diff --git a/onert-micro/luci-interpreter/src/kernels/Rsqrt.test.cpp b/onert-micro/luci-interpreter/src/kernels/Rsqrt.test.cpp index 3c649423281..84c08e6afbb 100644 --- a/onert-micro/luci-interpreter/src/kernels/Rsqrt.test.cpp +++ b/onert-micro/luci-interpreter/src/kernels/Rsqrt.test.cpp @@ -14,77 +14,73 @@ * limitations under the License. */ -#include "kernels/Rsqrt.h" #include "kernels/TestUtils.h" -#include "luci_interpreter/TestMemoryManager.h" +#include "luci_interpreter/test_models/rsqrt/FloatRsqrtKernel.h" +#include "luci_interpreter/test_models/rsqrt/NegRsqrtKernel.h" + +#include "loader/ModuleLoader.h" namespace luci_interpreter { -namespace kernels -{ namespace { using namespace testing; -void Check(std::initializer_list input_shape, std::initializer_list output_shape, - std::initializer_list input_data, std::initializer_list output_data) +class RsqrtTest : public ::testing::Test { - std::unique_ptr memory_manager = std::make_unique(); + // Do nothing +}; - Tensor input_tensor = - makeInputTensor(input_shape, input_data, memory_manager.get()); - Tensor output_tensor = makeOutputTensor(DataType::FLOAT32); +template std::vector checkRsqrtKernel(test_kernel::TestDataBase *test_data_base) +{ + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; - Rsqrt kernel(&input_tensor, &output_tensor); - kernel.configure(); - memory_manager->allocate_memory(output_tensor); - kernel.execute(); + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_base->get_model_ptr()); + ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input); - EXPECT_THAT(extractTensorData(output_tensor), FloatArrayNear(output_data)); - EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); -} + auto *main_runtime_graph = runtime_module.getMainGraph(); + assert(main_runtime_graph->getNumOfInputTensors() == 1); -TEST(RsqrtTest, SimpleRsqrt) -{ - Check( - /*input_shape=*/{1, 2, 4, 1}, /*output_shape=*/{1, 2, 4, 1}, - /*input_data=*/ - { - 5, 4, 8, 2, // - 6, 7.5, 9, 0.3, // - }, - /*output_data=*/ - { - 0.44721360, 0.5, 0.35355339, 0.70710678, // - 0.40824829, 0.36514837, 0.33333333, 1.8257419, // - }); -} + // Set input data + { + auto *input_tensor_data = reinterpret_cast(main_runtime_graph->configureGraphInput(0)); + std::copy(test_data_base->get_input_data_by_index(0).begin(), + test_data_base->get_input_data_by_index(0).end(), input_tensor_data); + } -TEST(RsqrtTest, Input_Output_Type_NEG) -{ - std::unique_ptr memory_manager = std::make_unique(); + runtime_module.execute(); - Tensor input_tensor = makeInputTensor({1}, {1.f}, memory_manager.get()); - Tensor output_tensor = makeOutputTensor(DataType::S32); + assert(main_runtime_graph->getNumOfOutputTensors() == 1); - Rsqrt kernel(&input_tensor, &output_tensor); - EXPECT_ANY_THROW(kernel.configure()); + T *output_data = reinterpret_cast(main_runtime_graph->getOutputDataByIndex(0)); + const size_t num_elements = (main_runtime_graph->getOutputDataSizeByIndex(0) / sizeof(T)); + std::vector output_data_vector(output_data, output_data + num_elements); + return output_data_vector; } -TEST(RsqrtTest, Invalid_Input_Type_NEG) +TEST_F(RsqrtTest, Float_P) { - std::unique_ptr memory_manager = std::make_unique(); - - Tensor input_tensor = makeInputTensor({1}, {1}, memory_manager.get()); - Tensor output_tensor = makeOutputTensor(DataType::S64); + test_kernel::TestDataFloatRsqrt test_data_kernel; + std::vector output_data_vector = checkRsqrtKernel(&test_data_kernel); + EXPECT_THAT(output_data_vector, kernels::testing::FloatArrayNear( + test_data_kernel.get_output_data_by_index(0), 0.0001f)); +} - Rsqrt kernel(&input_tensor, &output_tensor); - kernel.configure(); - memory_manager->allocate_memory(output_tensor); - EXPECT_ANY_THROW(kernel.execute()); +TEST_F(RsqrtTest, Input_output_type_mismatch_NEG) +{ + test_kernel::NegTestDataInputOutputTypeMismatchRsqrtKernel test_data_kernel; + MemoryManager memory_manager{}; + RuntimeModule runtime_module{}; + bool dealloc_input = true; + // Load model with single op + auto *model_data_raw = reinterpret_cast(test_data_kernel.get_model_ptr()); + EXPECT_DEATH(ModuleLoader::load(&runtime_module, &memory_manager, model_data_raw, dealloc_input), + ""); } } // namespace -} // namespace kernels } // namespace luci_interpreter