diff --git a/onert-micro/onert-micro/include/execute/OMUtils.h b/onert-micro/onert-micro/include/execute/OMUtils.h index 109dbea23fa..1365bfd39cb 100644 --- a/onert-micro/onert-micro/include/execute/OMUtils.h +++ b/onert-micro/onert-micro/include/execute/OMUtils.h @@ -167,6 +167,17 @@ OMStatus TISOHeader(const OMExecuteArgs &execute_args, const circle::Tensor **in const circle::Tensor **input2, const circle::Tensor **output, OMRuntimeKernel *runtime_kernel); +inline int calculateInputRadius(int input_integer_bits, int input_left_shift, int total_signed_bits) +{ + const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * + (1LL << (total_signed_bits - input_integer_bits)) / + (1LL << input_left_shift); + // Tighten bound using floor. Suppose that we could use the exact value. + // After scaling the difference, the result would be at the maximum. Thus we + // must ensure that our value has lower magnitude. + return static_cast(std::floor(max_input_rescaled)); +} + } // namespace execute } // namespace onert_micro diff --git a/onert-micro/onert-micro/src/execute/kernels/Softmax.cpp b/onert-micro/onert-micro/src/execute/kernels/Softmax.cpp index 6b868be20b8..78b750eb703 100644 --- a/onert-micro/onert-micro/src/execute/kernels/Softmax.cpp +++ b/onert-micro/onert-micro/src/execute/kernels/Softmax.cpp @@ -23,6 +23,8 @@ #include "PALSoftmax.h" +#include "execute/OMUtils.h" + using namespace onert_micro; using namespace onert_micro::execute; @@ -32,6 +34,18 @@ namespace constexpr uint32_t inputTensorIdx = 0; constexpr uint32_t outputTensorIdx = 0; +static const int kScaledDiffIntegerBits = 5; +void preprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, + int32_t *quantized_multiplier, int *left_shift) +{ + const double max_real_multiplier = (1LL << 31) - 1.0; + const double input_beta_real_multiplier = + std::min(beta * input_scale * (1 << (31 - input_integer_bits)), max_real_multiplier); + + onert_micro::execute::quantizeMultiplier(input_beta_real_multiplier, quantized_multiplier, + left_shift); +} + } // namespace // NOTE: doesnt currently support dynamic shapes @@ -126,6 +140,14 @@ OMStatus onert_micro::execute::execute_kernel_CircleSoftmax(const OMExecuteArgs params.output_zp = output->quantization()->zero_point()->operator[](0); params.input_zp = input->quantization()->zero_point()->operator[](0); + int left_shift = 0; + preprocessSoftmaxScaling(static_cast(params.beta), + static_cast(params.input_scale), kScaledDiffIntegerBits, + ¶ms.input_multiplier, &left_shift); + params.input_left_shift = left_shift; + params.diff_min = -1.0 * onert_micro::execute::calculateInputRadius( + kScaledDiffIntegerBits, params.input_left_shift, 31); + status = pal::Softmax(params, core::utils::castInputData(input_data), core::utils::castOutputData(output_data)); }