diff --git a/onert-micro/onert-micro/include/execute/OMUtils.h b/onert-micro/onert-micro/include/execute/OMUtils.h index 177c61b2535..b45feb08953 100644 --- a/onert-micro/onert-micro/include/execute/OMUtils.h +++ b/onert-micro/onert-micro/include/execute/OMUtils.h @@ -23,6 +23,9 @@ #include "core/OMRuntimeShape.h" #include "core/OMKernelData.h" +#include "execute/OMKernelExecutionBuilder.h" +#include "execute/OMRuntimeKernel.h" + namespace onert_micro { namespace execute @@ -157,6 +160,9 @@ void calculateQuantParams(core::ArithmeticQuantParams ¶ms, const circle::Ten const circle::Tensor *input2, const circle::Tensor *output, circle::ActivationFunctionType act); +OMStatus SISOHeader(const OMExecuteArgs &execute_args, const circle::Tensor **input, + const circle::Tensor **output, uint8_t **input_data, uint8_t **output_data); + } // namespace execute } // namespace onert_micro diff --git a/onert-micro/onert-micro/src/execute/OMUtils.cpp b/onert-micro/onert-micro/src/execute/OMUtils.cpp index 23e463707b8..9bda002018c 100644 --- a/onert-micro/onert-micro/src/execute/OMUtils.cpp +++ b/onert-micro/onert-micro/src/execute/OMUtils.cpp @@ -156,6 +156,41 @@ void onert_micro::execute::readQuantParams(const circle::Tensor *tensor, long &z scale = tensor->quantization()->scale()->operator[](0); } +OMStatus onert_micro::execute::SISOHeader(const OMExecuteArgs &execute_args, + const circle::Tensor **input, + const circle::Tensor **output, uint8_t **input_data, + uint8_t **output_data) +{ + OMStatus status; + + core::OMRuntimeContext &runtime_context = execute_args.runtime_context; + core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; + uint16_t op_index = execute_args.kernel_index; + + { + OMRuntimeKernel runtime_kernel; + runtime_kernel.readKernel(op_index, runtime_context); + + *input = runtime_kernel.inputs[0]; + *output = runtime_kernel.outputs[0]; + + assert(*input != nullptr); + assert(*output != nullptr); + + status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); + if (status != Ok) + return status; + + *input_data = runtime_kernel.inputs_data[0]; + *output_data = runtime_kernel.outputs_data[0]; + } + + assert(*input_data != nullptr); + assert(*output_data != nullptr); + + return status; +} + void onert_micro::execute::calculateQuantParams(core::ArithmeticQuantParams ¶ms, const circle::Tensor *input1, const circle::Tensor *input2, diff --git a/onert-micro/onert-micro/src/execute/kernels/Cast.cpp b/onert-micro/onert-micro/src/execute/kernels/Cast.cpp index 5c9ef1ed279..becfa24e52b 100644 --- a/onert-micro/onert-micro/src/execute/kernels/Cast.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/Cast.cpp @@ -40,38 +40,15 @@ constexpr uint32_t outputTensorIdx = 0; // NOTE: doesnt currently support dynamic shapes OMStatus onert_micro::execute::execute_kernel_CircleCast(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *input = nullptr; const circle::Tensor *output = nullptr; uint8_t *input_data = nullptr; uint8_t *output_data = nullptr; - OMStatus status = Ok; - - { - OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - input = runtime_kernel.inputs[inputTensorIdx]; - output = runtime_kernel.outputs[outputTensorIdx]; - - assert(input != nullptr); - assert(output != nullptr); - - status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - if (status != Ok) - return status; - - input_data = runtime_kernel.inputs_data[inputTensorIdx]; - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - } + SISOHeader(execute_args, &input, &output, &input_data, &output_data); - assert(input_data != nullptr); - assert(output_data != nullptr); + OMStatus status; switch (input->type()) { diff --git a/onert-micro/onert-micro/src/execute/kernels/L2Normalize.cpp b/onert-micro/onert-micro/src/execute/kernels/L2Normalize.cpp index 3b7aac5c545..a0a46c53b75 100644 --- a/onert-micro/onert-micro/src/execute/kernels/L2Normalize.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/L2Normalize.cpp @@ -21,6 +21,7 @@ #include "execute/OMKernelExecutionBuilder.h" #include "execute/OMRuntimeKernel.h" +#include "execute/OMUtils.h" #include "PALL2Normalize.h" using namespace onert_micro; @@ -37,38 +38,15 @@ constexpr uint32_t outputTensorIdx = 0; // NOTE: doesnt currently support dynamic shapes OMStatus onert_micro::execute::execute_kernel_CircleL2Normalize(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *input = nullptr; const circle::Tensor *output = nullptr; uint8_t *input_data = nullptr; uint8_t *output_data = nullptr; - OMStatus status = Ok; - - { - OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - input = runtime_kernel.inputs[inputTensorIdx]; - output = runtime_kernel.outputs[outputTensorIdx]; - - assert(input != nullptr); - assert(output != nullptr); - - status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - if (status != Ok) - return status; - - input_data = runtime_kernel.inputs_data[inputTensorIdx]; - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - } + SISOHeader(execute_args, &input, &output, &input_data, &output_data); - assert(input_data != nullptr); - assert(output_data != nullptr); + OMStatus status; switch (input->type()) { diff --git a/onert-micro/onert-micro/src/execute/kernels/LogSoftmax.cpp b/onert-micro/onert-micro/src/execute/kernels/LogSoftmax.cpp index fbbf253d35e..decec3f3d04 100644 --- a/onert-micro/onert-micro/src/execute/kernels/LogSoftmax.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/LogSoftmax.cpp @@ -21,6 +21,7 @@ #include "execute/OMKernelExecutionBuilder.h" #include "execute/OMRuntimeKernel.h" +#include "execute/OMUtils.h" #include "PALLogSoftmax.h" using namespace onert_micro; @@ -37,38 +38,15 @@ constexpr uint32_t outputTensorIdx = 0; // NOTE: doesnt currently support dynamic shapes OMStatus onert_micro::execute::execute_kernel_CircleLogSoftmax(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *input = nullptr; const circle::Tensor *output = nullptr; uint8_t *input_data = nullptr; uint8_t *output_data = nullptr; - OMStatus status = Ok; - - { - OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - input = runtime_kernel.inputs[inputTensorIdx]; - output = runtime_kernel.outputs[outputTensorIdx]; - - assert(input != nullptr); - assert(output != nullptr); - - status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - if (status != Ok) - return status; - - input_data = runtime_kernel.inputs_data[inputTensorIdx]; - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - } + SISOHeader(execute_args, &input, &output, &input_data, &output_data); - assert(input_data != nullptr); - assert(output_data != nullptr); + OMStatus status; switch (input->type()) { diff --git a/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp b/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp index adf05d2b219..a5f93de6edb 100644 --- a/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/Logistic.cpp @@ -21,6 +21,8 @@ #include "execute/OMKernelExecutionBuilder.h" #include "execute/OMRuntimeKernel.h" +#include "execute/OMUtils.h" + #include "PALLogistic.h" using namespace onert_micro; @@ -37,38 +39,15 @@ constexpr uint32_t outputTensorIdx = 0; // NOTE: doesnt currently support dynamic shapes OMStatus onert_micro::execute::execute_kernel_CircleLogistic(const OMExecuteArgs &execute_args) { - core::OMRuntimeContext &runtime_context = execute_args.runtime_context; - core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage; - uint16_t op_index = execute_args.kernel_index; - const circle::Tensor *input = nullptr; const circle::Tensor *output = nullptr; uint8_t *input_data = nullptr; uint8_t *output_data = nullptr; - OMStatus status = Ok; - - { - OMRuntimeKernel runtime_kernel; - runtime_kernel.readKernel(op_index, runtime_context); - - input = runtime_kernel.inputs[inputTensorIdx]; - output = runtime_kernel.outputs[outputTensorIdx]; - - assert(input != nullptr); - assert(output != nullptr); - - status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context); - if (status != Ok) - return status; - - input_data = runtime_kernel.inputs_data[inputTensorIdx]; - output_data = runtime_kernel.outputs_data[outputTensorIdx]; - } + SISOHeader(execute_args, &input, &output, &input_data, &output_data); - assert(input_data != nullptr); - assert(output_data != nullptr); + OMStatus status; switch (input->type()) {