From 34e5331ace2950d6b6b6025236b711c04b16c8ed Mon Sep 17 00:00:00 2001 From: zjhellofss Date: Thu, 14 Dec 2023 23:20:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9softmax=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/layer/details/softmax.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/source/layer/details/softmax.cpp b/source/layer/details/softmax.cpp index 8c177116..a0f3cf95 100644 --- a/source/layer/details/softmax.cpp +++ b/source/layer/details/softmax.cpp @@ -102,33 +102,33 @@ StatusCode SoftmaxLayer::Forward(const std::vector #pragma omp parallel for collapse(2) for (uint32_t outer_size = 0; outer_size < outer_sizes; ++outer_size) { for (uint32_t inner_size = 0; inner_size < inner_sizes; ++inner_size) { - float max_value = std::numeric_limits::lowest(); // 迭代当前dim中的数据,并找到其中的最大值 + float max_value = std::numeric_limits::lowest(); uint32_t base_index = outer_size * axis_sizes * inner_sizes + inner_size; + + std::vector tmp_storage(axis_sizes); for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) { uint32_t index = base_index + axis_size * inner_sizes; float cur_value = input_values.at(index); if (cur_value > max_value) { max_value = cur_value; } + tmp_storage.at(axis_size) = cur_value; } // 迭代当前dim中的数据,并进行求和 float sum_value = 0.f; for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) { - uint32_t index = base_index + axis_size * inner_sizes; - float cur_value = input_values.at(index); + float cur_value = tmp_storage.at(axis_size); float exp_sub_value = fmath::exp(cur_value - max_value); - sum_value += exp_sub_value; - output_values.at(index) = exp_sub_value; + tmp_storage.at(axis_size) = exp_sub_value; } // 迭代当前dim中的数据,求exp(cur_value - max_value) / sum_value for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) { uint32_t index = base_index + axis_size * inner_sizes; - - float exp_sub_value = output_values.at(index); + float exp_sub_value = tmp_storage.at(axis_size); output_values.at(index) = exp_sub_value / sum_value; } }