From f2b6e69bf75f7600f607a0d95bda2d83dc14eddb Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 26 Aug 2024 12:51:42 +0300 Subject: [PATCH] [onert-micro] Add BroadcastTo op This pr adds BroadcastTo operation to onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev + +namespace onert_micro +{ +namespace execute +{ +namespace pal +{ + +template +void BroadcastImpl(const NdArrayDesc &input_desc, const uint8_t *input_data, + const NdArrayDesc &output_desc, uint8_t *output_data, int indexes[N], int dim, + const int last_broadcasting_dim, const uint32_t type_size) +{ + // Copy data from input to output. + if (dim == last_broadcasting_dim) + { + int copy_size = output_desc.strides[dim] * type_size; + const uint8_t *data_src = input_data + subscriptToIndex(input_desc, indexes) * type_size; + uint8_t *data_dst = output_data + subscriptToIndex(output_desc, indexes) * type_size; + for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) + { + memcpy(data_dst, data_src, copy_size); + } + return; + } + + // Recursive call to find the next broadcasting. + for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim]; ++indexes[dim]) + { + BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, dim + 1, + last_broadcasting_dim, type_size); + } + + // Duplicate data in output tensor. + indexes[dim] = 0; + if (input_desc.extents[dim] != output_desc.extents[dim]) + { + int copy_size = output_desc.strides[dim] * type_size; + uint8_t *data_src = output_data + subscriptToIndex(output_desc, indexes) * type_size; + uint8_t *data_dst = data_src + copy_size; + for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) + { + memcpy(data_dst, data_src, copy_size); + } + } +} + +template +inline OMStatus BroadcastTo(const core::OMRuntimeShape &unextended_input_shape, + const uint8_t *input_data, + const core::OMRuntimeShape &unextended_output_shape, + uint8_t *output_data, core::OMDataType data_type) +{ + NdArrayDesc input_desc; + NdArrayDesc output_desc; + copyDimsToDesc(core::OMRuntimeShape::extendedShape(N, unextended_input_shape), &input_desc); + copyDimsToDesc(core::OMRuntimeShape::extendedShape(N, unextended_output_shape), &output_desc); + + // Get the last dimension has broadcasting. At this dimension, the data is + // copied from input tensor to output tensor. + int last_broadcast_dim = -1; + for (int i = N - 1; i >= 0; --i) + { + if (input_desc.extents[i] != output_desc.extents[i]) + { + last_broadcast_dim = i; + break; + } + } + + // If non-broadcasting, just copy data from input to output tensor. + if (last_broadcast_dim == -1) + { + memcpy(output_data, input_data, unextended_input_shape.flatSize() * sizeof(data_type)); + return Ok; + } + + // Broadcasting using memcpy. + int indexes[N] = {0}; + BroadcastImpl(input_desc, input_data, output_desc, output_data, indexes, 0, last_broadcast_dim, + core::getOMDataTypeSize(data_type)); + + return Ok; +} + +} // namespace pal +} // namespace execute +} // namespace onert_micro + +#endif // ONERT_MICRO_EXECUTE_PAL_BROADCAST_TO_COMMON_H diff --git a/onert-micro/onert-micro/include/pal/mcu/CustomKernelsToBuild.lst b/onert-micro/onert-micro/include/pal/mcu/CustomKernelsToBuild.lst index e69de29bb2d..884f8dbdb86 100644 --- a/onert-micro/onert-micro/include/pal/mcu/CustomKernelsToBuild.lst +++ b/onert-micro/onert-micro/include/pal/mcu/CustomKernelsToBuild.lst @@ -0,0 +1 @@ +REGISTER_CUSTOM_KERNEL(BROADCAST_TO, "BroadcastTo") diff --git a/onert-micro/onert-micro/include/test_models/broadcast_to/FloatBroadcastToKernel.h b/onert-micro/onert-micro/include/test_models/broadcast_to/FloatBroadcastToKernel.h new file mode 100644 index 00000000000..7edc7242afb --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/broadcast_to/FloatBroadcastToKernel.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_FLOAT_BROADCAST_TO_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_FLOAT_BROADCAST_TO_KERNEL_H + +#include "TestDataBroadcastToBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace broadcast_to_float +{ +/* + * BroadcastTo Kernel: + * + * Input(2, 3) + * | + * BroadcastTo + * | + * Output(1, 2, 3) + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x98, 0x01, 0x00, 0x00, 0xc8, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xf8, 0xff, 0xff, 0xff, 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x54, 0x69, 0x64, 0x78, 0x00, 0x02, + 0x08, 0x07, 0x02, 0x01, 0x02, 0x00, 0x02, 0x04, 0x04, 0x04, 0x24, 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x7c, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x9c, 0xff, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x62, 0x63, 0x5f, 0x6f, 0x66, 0x6d, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x62, 0x63, 0x5f, 0x73, 0x68, 0x61, 0x70, 0x65, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x62, 0x63, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, + 0x0b, 0x00, 0x00, 0x00, 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, 0x74, 0x54, 0x6f, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, 0x74, 0x65, 0x32, 0x63, + 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; + +const std::vector input_data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + +const std::vector reference_output_data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + +} // namespace broadcast_to_float + +class TestDataFloatBroadcastTo : public TestDataBroadcastToBase +{ +public: + TestDataFloatBroadcastTo() + { + _input_data = broadcast_to_float::input_data; + _reference_output_data = broadcast_to_float::reference_output_data; + _test_kernel_model_circle = broadcast_to_float::test_kernel_model_circle; + } + + ~TestDataFloatBroadcastTo() override = default; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_FLOAT_BROADCAST_TO_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/broadcast_to/NegBroadcastToKernel.h b/onert-micro/onert-micro/include/test_models/broadcast_to/NegBroadcastToKernel.h new file mode 100644 index 00000000000..c98d5ba46a5 --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/broadcast_to/NegBroadcastToKernel.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_NEG_BROADCAST_TO_KERNEL_H +#define ONERT_MICRO_TEST_MODELS_NEG_BROADCAST_TO_KERNEL_H + +#include "test_models/TestDataBase.h" + +namespace onert_micro +{ +namespace test_model +{ +namespace neg_input_output_type_mismatch_broadcast_to_kernel +{ +/* + * BroadcastTo Kernel with input output type mismatch: + * + * Input(2, 3) - Float32 + * | + * BroadcastTo + * | + * Output(1, 2, 3) - Int32 + */ +const unsigned char test_kernel_model_circle[] = { + 0x18, 0x00, 0x00, 0x00, 0x43, 0x49, 0x52, 0x30, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x9c, 0x01, 0x00, 0x00, 0xcc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0xf8, 0xff, 0xff, 0xff, 0xfc, 0xff, 0xff, 0xff, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, + 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x54, 0x69, 0x64, 0x78, 0x00, 0x02, + 0x08, 0x07, 0x02, 0x01, 0x02, 0x00, 0x02, 0x04, 0x04, 0x04, 0x24, 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xd0, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + 0x10, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x62, 0x63, 0x5f, 0x6f, 0x66, 0x6d, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x14, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x62, 0x63, 0x5f, 0x73, 0x68, 0x61, 0x70, 0x65, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x62, 0x63, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x20, 0x0b, 0x00, 0x00, 0x00, 0x42, 0x72, 0x6f, 0x61, 0x64, 0x63, 0x61, 0x73, + 0x74, 0x54, 0x6f, 0x00, 0x11, 0x00, 0x00, 0x00, 0x4f, 0x4e, 0x45, 0x2d, 0x74, 0x66, 0x6c, 0x69, + 0x74, 0x65, 0x32, 0x63, 0x69, 0x72, 0x63, 0x6c, 0x65, 0x00, 0x00, 0x00}; +} // namespace neg_input_output_type_mismatch_broadcast_to_kernel + +class NegTestDataInputOutputTypeMismatchBroadcastToKernel : public NegTestDataBase +{ +public: + NegTestDataInputOutputTypeMismatchBroadcastToKernel() + { + _test_kernel_model_circle = + neg_input_output_type_mismatch_broadcast_to_kernel::test_kernel_model_circle; + } + + ~NegTestDataInputOutputTypeMismatchBroadcastToKernel() override = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + +protected: + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_NEG_BROADCAST_TO_KERNEL_H diff --git a/onert-micro/onert-micro/include/test_models/broadcast_to/TestDataBroadcastToBase.h b/onert-micro/onert-micro/include/test_models/broadcast_to/TestDataBroadcastToBase.h new file mode 100644 index 00000000000..37630c3340b --- /dev/null +++ b/onert-micro/onert-micro/include/test_models/broadcast_to/TestDataBroadcastToBase.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_TEST_MODELS_BROADCAST_TO_KERNEL_BASE_H +#define ONERT_MICRO_TEST_MODELS_BROADCAST_TO_KERNEL_BASE_H + +#include "test_models/TestDataBase.h" +#include + +namespace onert_micro +{ +namespace test_model +{ + +template class TestDataBroadcastToBase : public TestDataBase +{ +public: + TestDataBroadcastToBase() = default; + + const unsigned char *get_model_ptr() override final { return _test_kernel_model_circle; } + + const std::vector &get_input_data_by_index(int i) override final + { + switch (i) + { + case 0: + return _input_data; + default: + assert(false && "Wrong input index"); + } + } + + const std::vector &get_output_data_by_index(int i) override final + { + assert(i == 0); + return _reference_output_data; + } + +protected: + std::vector _input_data; + std::vector _reference_output_data; + const unsigned char *_test_kernel_model_circle; +}; + +} // namespace test_model +} // namespace onert_micro + +#endif // ONERT_MICRO_TEST_MODELS_BROADCAST_TO_KERNEL_BASE_H diff --git a/onert-micro/onert-micro/src/execute/CMakeLists.txt b/onert-micro/onert-micro/src/execute/CMakeLists.txt index 621b54ab0d6..1368c8fee43 100644 --- a/onert-micro/onert-micro/src/execute/CMakeLists.txt +++ b/onert-micro/onert-micro/src/execute/CMakeLists.txt @@ -32,7 +32,7 @@ endmacro(REGISTER_KERNEL) # To add REGISTER_KERNEL list include(${KERNEL_REGISTER_FILE}) -macro(REGISTER_CUSTOM_KERNEL NODE) +macro(REGISTER_CUSTOM_KERNEL OPERATOR, NODE) list(APPEND SOURCES "kernels/${NODE}.cpp") endmacro(REGISTER_CUSTOM_KERNEL) @@ -61,7 +61,7 @@ endmacro(REGISTER_KERNEL) include(${KERNEL_REGISTER_FILE}) -macro(REGISTER_CUSTOM_KERNEL NODE) +macro(REGISTER_CUSTOM_KERNEL OPERATOR, NODE) list(APPEND TEST_SOURCES "kernels/tests/${NODE}.test.cpp") endmacro(REGISTER_CUSTOM_KERNEL) diff --git a/onert-micro/onert-micro/src/execute/kernels/BroadcastTo.cpp b/onert-micro/onert-micro/src/execute/kernels/BroadcastTo.cpp new file mode 100644 index 00000000000..439a64eb76f --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/BroadcastTo.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OMStatus.h" + +#include "core/OMUtils.h" +#include "core/OMKernelData.h" +#include "core/OMDataType.h" + +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMUtils.h" +#include "execute/OMRuntimeKernel.h" + +#include "PALBroadcastTo.h" + +using namespace onert_micro; +using namespace onert_micro::execute; + +namespace +{ + +constexpr int kMaxDims = 5; + +constexpr uint32_t input1TensorIdx = 0; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +// NOTE: doesnt currently support dynamic shapes +// Note: ignore second input due to doesnt support dynamic shape +OMStatus onert_micro::execute::execute_kernel_CircleBROADCAST_TO(const OMExecuteArgs &execute_args) +{ + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + const circle::Tensor *output; + const circle::Tensor *input1; + + uint8_t *output_data; + uint8_t *input_data; + + // Read kernel + execute::OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + output = runtime_kernel.outputs[outputTensorIdx]; + assert(output != nullptr); + + input1 = runtime_kernel.inputs[input1TensorIdx]; + assert(input1 != nullptr); + + runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + + output_data = runtime_kernel.outputs_data[outputTensorIdx]; + assert(output_data != nullptr); + + input_data = runtime_kernel.inputs_data[input1TensorIdx]; + assert(input_data != nullptr); + + OMStatus status; + const core::OMRuntimeShape input1_shape(input1); + const core::OMRuntimeShape output_shape(output); + + status = pal::BroadcastTo(input1_shape, const_cast(input_data), + output_shape, output_data, core::OMDataType(input1->type())); + + return status; +} diff --git a/onert-micro/onert-micro/src/execute/kernels/tests/BroadcastTo.test.cpp b/onert-micro/onert-micro/src/execute/kernels/tests/BroadcastTo.test.cpp new file mode 100644 index 00000000000..c9c61094a2b --- /dev/null +++ b/onert-micro/onert-micro/src/execute/kernels/tests/BroadcastTo.test.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "execute/OMTestUtils.h" +#include "test_models/broadcast_to/FloatBroadcastToKernel.h" +#include "test_models/broadcast_to/NegBroadcastToKernel.h" + +namespace onert_micro +{ +namespace execute +{ +namespace testing +{ + +using namespace testing; + +class BroadcastToTest : public ::testing::Test +{ + // Do nothing +}; + +TEST_F(BroadcastToTest, Float_P) +{ + onert_micro::test_model::TestDataFloatBroadcastTo test_data_kernel; + std::vector output_data_vector = + onert_micro::execute::testing::checkKernel(1, &test_data_kernel); + EXPECT_THAT(output_data_vector, test_data_kernel.get_output_data_by_index(0)); +} + +TEST_F(BroadcastToTest, Input_output_type_mismatch_NEG) +{ + onert_micro::test_model::NegTestDataInputOutputTypeMismatchBroadcastToKernel test_data_kernel; + + EXPECT_DEATH(checkNEGSISOKernel(&test_data_kernel), ""); +} + +} // namespace testing +} // namespace execute +} // namespace onert_micro diff --git a/onert-micro/onert-micro/src/import/CMakeLists.txt b/onert-micro/onert-micro/src/import/CMakeLists.txt index e865ada32ff..ac9835d382b 100644 --- a/onert-micro/onert-micro/src/import/CMakeLists.txt +++ b/onert-micro/onert-micro/src/import/CMakeLists.txt @@ -21,7 +21,7 @@ endmacro(REGISTER_KERNEL) # To add REGISTER_KERNEL list include(${KERNEL_REGISTER_FILE}) -macro(REGISTER_CUSTOM_KERNEL NODE) +macro(REGISTER_CUSTOM_KERNEL OPERATOR, NODE) list(APPEND SOURCES "kernels/${NODE}.cpp") endmacro(REGISTER_CUSTOM_KERNEL) diff --git a/onert-micro/onert-micro/src/import/kernels/BroadcastTo.cpp b/onert-micro/onert-micro/src/import/kernels/BroadcastTo.cpp new file mode 100644 index 00000000000..5f439c5ca26 --- /dev/null +++ b/onert-micro/onert-micro/src/import/kernels/BroadcastTo.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "import/OMKernelConfigureBuilder.h" +#include "core/OMUtils.h" +#include "OMStatus.h" +#include "execute/OMRuntimeKernel.h" + +using namespace onert_micro; +using namespace onert_micro::core; + +namespace +{ + +constexpr uint32_t input1TensorIdx = 0; +constexpr uint32_t input2TensorIdx = 1; +constexpr uint32_t outputTensorIdx = 0; + +} // namespace + +OMStatus +onert_micro::import::configure_kernel_CircleBROADCAST_TO(const OMConfigureArgs &config_args) +{ + OMRuntimeContext &runtime_context = config_args.runtime_context; + uint16_t op_index = config_args.kernel_index; + + onert_micro::execute::OMRuntimeKernel runtime_kernel; + + OMStatus status = runtime_kernel.readKernel(op_index, runtime_context); + if (status != Ok) + return status; + + const circle::Tensor *input1 = runtime_kernel.inputs[input1TensorIdx]; + const circle::Tensor *input2 = runtime_kernel.inputs[input2TensorIdx]; + const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx]; + + assert(input1 != nullptr); + assert(input2 != nullptr); + assert(output != nullptr); + + status = utils::checkCondition(input1->type() == output->type()); + if (status != Ok) + return status; + + status = utils::checkCondition(input2->type() == circle::TensorType_INT32); + if (status != Ok) + return status; + + if (input1->type() != circle::TensorType_INT8 and input1->type() != circle::TensorType_INT16) + return status; + +#ifndef DIS_QUANT + + // Check quantization params + if (input1->quantization() == nullptr or output->quantization() == nullptr) + { + return NoQuantization; + } + + if (input1->quantization()->scale() == nullptr or + input1->quantization()->zero_point() == nullptr or + input1->quantization()->scale()->size() != 1 or + input1->quantization()->zero_point()->size() != 1) + { + return NoQuantization; + } + + if (output->quantization()->scale() == nullptr or + output->quantization()->zero_point() == nullptr or + output->quantization()->scale()->size() != 1 or + output->quantization()->zero_point()->size() != 1) + { + return NoQuantization; + } + +#endif // DIS_QUANT + + return status; + + return Ok; +}