Skip to content

Commit

Permalink
修改softmax函数的实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjhellofss committed Dec 16, 2023
1 parent 34e5331 commit ef145f4
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions source/layer/details/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ StatusCode SoftmaxLayer::Forward(const std::vector<std::shared_ptr<Tensor<float>
std::accumulate(raw_shapes.begin(), raw_shapes.begin() + dim, 1, std::multiplies());

// dim轴数据的数量
const uint32_t axis_sizes = raw_shapes.at(dim);
int32_t axis_sizes = static_cast<int32_t>(raw_shapes.at(dim));
CHECK_EQ(axis_sizes * outer_sizes * inner_sizes, input->size());

const auto& input_values = input->values(true);
std::vector<float> output_values(input_values.size());
std::vector<float> output_values = input->values(true);
#pragma omp parallel for collapse(2)
for (uint32_t outer_size = 0; outer_size < outer_sizes; ++outer_size) {
for (uint32_t inner_size = 0; inner_size < inner_sizes; ++inner_size) {
Expand All @@ -109,27 +108,59 @@ StatusCode SoftmaxLayer::Forward(const std::vector<std::shared_ptr<Tensor<float>
std::vector<float> tmp_storage(axis_sizes);
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
uint32_t index = base_index + axis_size * inner_sizes;
float cur_value = input_values.at(index);
float cur_value = output_values.at(index);
if (cur_value > max_value) {
max_value = cur_value;
}
tmp_storage.at(axis_size) = cur_value;
}

// 迭代当前dim中的数据,并进行求和
float sum_value = 0.f;
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
int32_t axis_size = 0;
#ifdef __AVX2__
int32_t packet_size = 8;
float* tmp_storage_ptr = tmp_storage.data();
__m256 sum_vec = _mm256_setzero_ps();
__m256 max_value256 = _mm256_set1_ps(max_value);
for (; axis_size <= axis_sizes - packet_size; axis_size += packet_size) {
__m256 p = _mm256_loadu_ps(tmp_storage_ptr);
__m256 exp_sub_value = fmath::exp_ps256(p - max_value256);
_mm256_storeu_ps(tmp_storage_ptr, exp_sub_value);
sum_vec = _mm256_add_ps(sum_vec, exp_sub_value);
tmp_storage_ptr += packet_size;
}
float result[8];
_mm256_storeu_ps(result, sum_vec);
for (int j = 0; j < packet_size; ++j) {
sum_value += result[j];
}
#endif
for (; axis_size < axis_sizes; ++axis_size) {
float cur_value = tmp_storage.at(axis_size);
float exp_sub_value = fmath::exp(cur_value - max_value);
sum_value += exp_sub_value;
tmp_storage.at(axis_size) = exp_sub_value;
}

#ifdef __AVX2__
tmp_storage_ptr = tmp_storage.data();
sum_vec = _mm256_set1_ps(sum_value);
for (axis_size = 0; axis_size <= axis_sizes - packet_size; axis_size += packet_size) {
__m256 p = _mm256_loadu_ps(tmp_storage_ptr);
__m256 div_value = _mm256_div_ps(p, sum_vec);
_mm256_storeu_ps(tmp_storage_ptr, div_value);
tmp_storage_ptr += packet_size;
}
#endif
for (; axis_size < axis_sizes; ++axis_size) {
tmp_storage.at(axis_size) = tmp_storage.at(axis_size) / sum_value;
}

// 迭代当前dim中的数据,求exp(cur_value - max_value) / sum_value
for (uint32_t axis_size = 0; axis_size < axis_sizes; ++axis_size) {
for (axis_size = 0; axis_size < axis_sizes; ++axis_size) {
uint32_t index = base_index + axis_size * inner_sizes;
float exp_sub_value = tmp_storage.at(axis_size);
output_values.at(index) = exp_sub_value / sum_value;
float div_value = tmp_storage.at(axis_size);
output_values.at(index) = div_value;
}
}
}
Expand Down

0 comments on commit ef145f4

Please sign in to comment.