Skip to content

Commit

Permalink
修改rmsnorm的实现方式
Browse files Browse the repository at this point in the history
  • Loading branch information
shenshen.fu committed Mar 6, 2024
1 parent a5f1ebc commit aeec18e
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions source/layer/details/rms_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ void RMSNormLayer::set_weights(const std::vector<float>& weights) {

void RMSNormLayer::set_weights(const std::vector<std::shared_ptr<Tensor<float>>>& weights) {
CHECK(weights.size() == weights_.size());
for (uint32_t i = 0; i < weights.size(); ++i) {
if (this->weights_.at(i) != nullptr) {
CHECK(this->weights_.at(i)->rows() == weights.at(i)->rows());
CHECK(this->weights_.at(i)->cols() == weights.at(i)->cols());
CHECK(this->weights_.at(i)->channels() == weights.at(i)->channels());
}
}
this->weights_ = weights;
}

Expand Down Expand Up @@ -89,8 +82,10 @@ StatusCode RMSNormLayer::Forward(const std::vector<std::shared_ptr<Tensor<float>

const size_t size = input->size();
arma::fvec input_vec(input->raw_ptr(), size, false, true);
float mean_value = arma::mean(input_vec % input_vec);
float norm_value = 1.f / std::sqrt(mean_value + eps_);

const arma::fvec& input_pow_vec = arma::pow(input_vec, 2.f);
const float mean_value = arma::mean(input_pow_vec);
const float norm_value = 1.f / std::sqrt(mean_value + eps_);
arma::fvec output_vec(output->raw_ptr(), size, false, true);
output_vec = weight_vec % (norm_value * input_vec);
}
Expand Down

0 comments on commit aeec18e

Please sign in to comment.