diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 7ed657a9b..2d68ef353 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -303,6 +303,43 @@ void loadWeights( } } +template +void loadWeights(LlamaDenseWeight& w, std::string prefix, FtCudaDataType model_file_type) +{ + auto weight_file = prefix + ".weight"; + auto qweight_file = prefix + ".qweight"; + + if (!std::filesystem::exists(weight_file) && !std::filesystem::exists(qweight_file)) { + TM_LOG_ERROR("%s and %s does not exist", weight_file.c_str(), qweight_file.c_str()); + FT_CHECK(false); + } + + size_t dim0 = w.input_dims; + size_t dim1 = w.output_dims; + const auto type = model_file_type; + + if (w.bias) { + loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type); + } + const size_t bit_size = getBitSize(w.type); + if (bit_size >= 16) { // fp16, fp32 + loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type); + } + else { // int8, int4 + const int factor = sizeof(float) * 8 / bit_size; + + FT_CHECK(dim1 % factor == 0); + + std::vector w_shape{dim0, dim1 / factor * sizeof(uint32_t)}; + loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8); + + const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1; + + loadWeightFromBin((half*)w.scales, {group_count, dim1}, prefix + ".scales", type); + loadWeightFromBin((half*)w.zeros, {group_count, dim1}, prefix + ".zeros", type); + } +} + template void LlamaDecoderLayerWeight::mallocWeights() { @@ -357,10 +394,22 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type, tensor_para_size_); loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_); - - loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_); - loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_); - loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_); + if (moe_weights.experts.empty()) { + loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_); + loadWeights( + ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_); + loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_); + } + else { + loadWeights(moe_weights.gate, dir_path + ".moe_ffn.gate", type); + for (size_t i = 0; i < moe_weights.experts.size(); ++i) { + std::string weight_name = dir_path + ".moe_ffn.experts." + std::to_string(i); + loadWeights(moe_weights.experts[i].gating, weight_name + ".w1", tensor_para_rank_, type, tensor_para_size_); + loadWeights( + moe_weights.experts[i].intermediate, weight_name + ".w3", tensor_para_rank_, type, tensor_para_size_); + loadWeights(moe_weights.experts[i].output, weight_name + ".w2", tensor_para_rank_, type, tensor_para_size_); + } + } } template