Skip to content

Commit

Permalink
修改softmax函数的实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Dec 14, 2023
1 parent a425a0a commit 34e5331
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions source/layer/details/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,33 +102,33 @@ StatusCode SoftmaxLayer::Forward(const std::vector<std::shared_ptr<Tensor<float>
#pragma omp parallel for collapse(2)
for (uint32_t outer_size = 0; outer_size < outer_sizes; ++outer_size) {
for (uint32_t inner_size = 0; inner_size < inner_sizes; ++inner_size) {
float max_value = std::numeric_limits<float>::lowest();
// 迭代当前dim中的数据,并找到其中的最大值
float max_value = std::numeric_limits<float>::lowest();
uint32_t base_index = outer_size * axis_sizes * inner_sizes + inner_size;

std::vector<float> tmp_storage(axis_sizes);
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
uint32_t index = base_index + axis_size * inner_sizes;
float cur_value = input_values.at(index);
if (cur_value > max_value) {
max_value = cur_value;
}
tmp_storage.at(axis_size) = cur_value;
}

// 迭代当前dim中的数据,并进行求和
float sum_value = 0.f;
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
uint32_t index = base_index + axis_size * inner_sizes;
float cur_value = input_values.at(index);
float cur_value = tmp_storage.at(axis_size);
float exp_sub_value = fmath::exp(cur_value - max_value);

sum_value += exp_sub_value;
output_values.at(index) = exp_sub_value;
tmp_storage.at(axis_size) = exp_sub_value;
}

// 迭代当前dim中的数据,求exp(cur_value - max_value) / sum_value
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
uint32_t index = base_index + axis_size * inner_sizes;

float exp_sub_value = output_values.at(index);
float exp_sub_value = tmp_storage.at(axis_size);
output_values.at(index) = exp_sub_value / sum_value;
}
}
Expand Down

0 comments on commit 34e5331

Please sign in to comment.