Skip to content

Commit

Permalink
miss to read moe_ffn weights from converted tm model (#2698)
Browse files Browse the repository at this point in the history
* miss to read moe_ffn weights

* fix linting

* fix linting

* fix linting
  • Loading branch information
lvhan028 authored Nov 4, 2024
1 parent e557f05 commit 5f577c2
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,43 @@ void loadWeights(
}
}

template<typename T>
void loadWeights(LlamaDenseWeight<T>& w, std::string prefix, FtCudaDataType model_file_type)
{
auto weight_file = prefix + ".weight";
auto qweight_file = prefix + ".qweight";

if (!std::filesystem::exists(weight_file) && !std::filesystem::exists(qweight_file)) {
TM_LOG_ERROR("%s and %s does not exist", weight_file.c_str(), qweight_file.c_str());
FT_CHECK(false);
}

size_t dim0 = w.input_dims;
size_t dim1 = w.output_dims;
const auto type = model_file_type;

if (w.bias) {
loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type);
}
const size_t bit_size = getBitSize(w.type);
if (bit_size >= 16) { // fp16, fp32
loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type);
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;

FT_CHECK(dim1 % factor == 0);

std::vector<size_t> w_shape{dim0, dim1 / factor * sizeof(uint32_t)};
loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8);

const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1;

loadWeightFromBin((half*)w.scales, {group_count, dim1}, prefix + ".scales", type);
loadWeightFromBin((half*)w.zeros, {group_count, dim1}, prefix + ".zeros", type);
}
}

template<typename T>
void LlamaDecoderLayerWeight<T>::mallocWeights()
{
Expand Down Expand Up @@ -357,10 +394,22 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type, tensor_para_size_);

loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_);

loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_);
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_);
if (moe_weights.experts.empty()) {
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_);
loadWeights(
ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_);
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_);
}
else {
loadWeights(moe_weights.gate, dir_path + ".moe_ffn.gate", type);
for (size_t i = 0; i < moe_weights.experts.size(); ++i) {
std::string weight_name = dir_path + ".moe_ffn.experts." + std::to_string(i);
loadWeights(moe_weights.experts[i].gating, weight_name + ".w1", tensor_para_rank_, type, tensor_para_size_);
loadWeights(
moe_weights.experts[i].intermediate, weight_name + ".w3", tensor_para_rank_, type, tensor_para_size_);
loadWeights(moe_weights.experts[i].output, weight_name + ".w2", tensor_para_rank_, type, tensor_para_size_);
}
}
}

template<typename T>
Expand Down

0 comments on commit 5f577c2

Please sign in to comment.