From 908881734b1e08f7f466f8b80d697f00d8395800 Mon Sep 17 00:00:00 2001 From: irexyc Date: Sat, 27 Jan 2024 14:43:19 +0000 Subject: [PATCH] support lora --- .../turbomind/deploy/target_model/base.py | 1 + src/turbomind/models/llama/LlamaBatch.cc | 10 ++-- src/turbomind/models/llama/LlamaBatch.h | 1 + .../models/llama/LlamaDecoderLayerWeight.cc | 25 +++++++--- .../models/llama/LlamaDecoderLayerWeight.h | 2 + src/turbomind/models/llama/LlamaDenseWeight.h | 1 + src/turbomind/models/llama/LlamaFfnLayer.cc | 15 ++++-- src/turbomind/models/llama/LlamaLinear.h | 40 +++++++++++++-- src/turbomind/models/llama/LlamaV2.cc | 49 +++++++++++++++++-- src/turbomind/models/llama/LlamaV2.h | 13 ++++- src/turbomind/models/llama/LlamaWeight.cc | 4 +- src/turbomind/models/llama/LlamaWeight.h | 1 + src/turbomind/models/llama/SequenceManager.cc | 3 +- .../models/llama/llama_decoder_kernels.cu | 43 ++++++++++++++++ .../models/llama/llama_decoder_kernels.h | 4 ++ .../models/llama/unified_attention_layer.cc | 16 ++++-- src/turbomind/models/llama/unified_decoder.cc | 5 +- .../triton_backend/llama/LlamaTritonModel.cc | 3 ++ .../triton_backend/llama/LlamaTritonModel.h | 1 + 19 files changed, 204 insertions(+), 33 deletions(-) diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 4f4ac3e4c..9c600b4ff 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -66,6 +66,7 @@ class TurbomindModelConfig: max_position_embeddings: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: int = 0 + lora_policy: int = 0 @classmethod def from_dict(cls, env, allow_none=False): diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index bdcd5b851..89e0b45bf 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -722,10 +722,12 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len) context_decoder_input_buf_ = (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); - context_decoder_output_buf_ = - (T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); + // double buffer for lora + context_decoder_output_buf_ = (T*)allocator_->reMalloc( + context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units * 2, false); context_decoder_ids_buf_ = (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false); + lora_mask_buf_ = (int*)allocator_->reMalloc(lora_mask_buf_, sizeof(int) * max_context_token_num_, false); tmp_k_cache_buf_ = (T*)allocator_->reMalloc( tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false); @@ -850,6 +852,7 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&context_decoder_input_buf_); allocator_->free((void**)&context_decoder_output_buf_); allocator_->free((void**)&context_decoder_ids_buf_); + allocator_->free((void**)&lora_mask_buf_); allocator_->free((void**)&tmp_k_cache_buf_); allocator_->free((void**)&tmp_v_cache_buf_); @@ -1586,7 +1589,8 @@ bool LlamaBatch::Forward(GenerationState& g, int iter) max_context_cnts[p], max_context_cnts[p], h_input_length_buf_ + first, - sequences.data()); + sequences.data(), + lora_mask_buf_); if (iter == 0) { // compute logits of inputs if requested diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 9af3b7522..01caaefb3 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -225,6 +225,7 @@ class LlamaBatch { T* decoder_output_buf_{}; int* sequence_lengths_{}; // current sequence length int* init_ctx_lens_{}; + int* lora_mask_buf_{}; // lora float* logits_buf_{}; // combined logits float* local_logits_buf_{}; // tensor parallel local logits diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 34c0abf86..2f6c964ea 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -34,6 +34,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, WeightType weight_type, int group_size, bool attn_bias, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank): head_num_(head_num), @@ -43,6 +44,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, inter_size_(inter_size), weight_type_(weight_type), attn_bias_(attn_bias), + lora_policy_(lora_policy), tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank) { @@ -91,7 +93,7 @@ void freeWeights(LlamaDenseWeight& weights) } template -void mallocWeights(LlamaDenseWeight& weights, bool bias) +void mallocWeights(LlamaDenseWeight& weights, bool bias, int lora_policy) { if (bias) { deviceMalloc((T**)&weights.bias, weights.output_dims); @@ -99,6 +101,9 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias) const size_t bit_size = getBitSize(weights.type); if (bit_size >= 16) { // fp16, fp32 deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims); + if (lora_policy) { + deviceMalloc((T**)&weights.lora_kernel, weights.input_dims * weights.output_dims); + } } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; @@ -244,6 +249,12 @@ void loadWeights(LlamaDenseWeight& w, } } loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices); + if (w.lora_kernel) { + auto dot_pos = prefix.rfind("."); + auto lora_weight_file = prefix.substr(0, dot_pos) + ".lora" + prefix.substr(dot_pos) + ".weight"; + TM_LOG_INFO("loading %s", lora_weight_file.c_str()); + loadWeightFromBin((T*)w.lora_kernel, {dim0, dim1}, lora_weight_file, type, weight_slices); + } } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; @@ -265,19 +276,19 @@ void LlamaDecoderLayerWeight::mallocWeights() deviceMalloc((T**)&self_attn_norm_weights, hidden_units_); deviceMalloc((T**)&ffn_norm_weights, hidden_units_); - turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_); - turbomind::mallocWeights(self_attn_weights.output, attn_bias_); + turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_, lora_policy_); + turbomind::mallocWeights(self_attn_weights.output, attn_bias_, lora_policy_); self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f}; if (weight_type_ == WeightType::kINT4) { - turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false); + turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false, lora_policy_); } else { - turbomind::mallocWeights(ffn_weights.gating, false); - turbomind::mallocWeights(ffn_weights.intermediate, false); + turbomind::mallocWeights(ffn_weights.gating, false, lora_policy_); + turbomind::mallocWeights(ffn_weights.intermediate, false, lora_policy_); } - turbomind::mallocWeights(ffn_weights.output, false); + turbomind::mallocWeights(ffn_weights.output, false, lora_policy_); } template diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 169a3aa9e..0c36b7f60 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -36,6 +36,7 @@ struct LlamaDecoderLayerWeight { WeightType weight_type, int group_size, bool attn_bias, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank); ~LlamaDecoderLayerWeight(); @@ -60,6 +61,7 @@ struct LlamaDecoderLayerWeight { WeightType weight_type_; size_t bit_size_; bool attn_bias_; + int lora_policy_; size_t tensor_para_size_; size_t tensor_para_rank_; bool is_maintain_buffer_ = false; diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 369f26c73..05408f40f 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -59,6 +59,7 @@ struct LlamaDenseWeight { size_t input_dims; size_t output_dims; void* kernel; + void* lora_kernel; WeightType type; T* bias; T* scales_and_zeros; diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index 42575af66..fb55abb71 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -88,6 +88,14 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); + // lora + int* lora_mask = nullptr; + if (input_tensors->isExist("lora_mask")) { + lora_mask = input_tensors->at("lora_mask").getPtr(); + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * num_token * inter_size_ * 2, false); + gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * num_token * inter_size_ * 2, false); + } + if (weights->fused_gating_intermediate.kernel) { NvtxScope scope("fused_silu_ffn"); linear_.forward( @@ -96,11 +104,12 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, else { { // w1(x) NvtxScope scope("w1"); - linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating); + linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear::kGemm, lora_mask); } { // w3(x) NvtxScope scope("w3"); - linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate); + linear_.forward( + inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear::kGemm, lora_mask); } // silu(w1(x)) * w3(x) activation(num_token); @@ -108,7 +117,7 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, { // w2(x) NvtxScope scope("w2"); - linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output); + linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output, LlamaLinear::kGemm, lora_mask); } if (tensor_para_.world_size_ > 1) { diff --git a/src/turbomind/models/llama/LlamaLinear.h b/src/turbomind/models/llama/LlamaLinear.h index a3717b2a9..02ac80d9d 100644 --- a/src/turbomind/models/llama/LlamaLinear.h +++ b/src/turbomind/models/llama/LlamaLinear.h @@ -4,6 +4,7 @@ #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" +#include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" @@ -25,14 +26,18 @@ class LlamaLinear { { } - void - forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type = kGemm) + void forward(T* output_data, + const T* input_data, + int batch_size, + const LlamaDenseWeight& weight, + Type type = kGemm, + int* lora_mask = nullptr) { switch (weight.type) { case WeightType::kFP16: case WeightType::kFP32: case WeightType::kBF16: - forwardFp(output_data, input_data, batch_size, weight, type); + forwardFp(output_data, input_data, batch_size, weight, type, lora_mask); break; case WeightType::kINT4: forwardInt4(output_data, input_data, batch_size, weight, type); @@ -43,7 +48,12 @@ class LlamaLinear { } private: - void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) + void forwardFp(T* output_data, + const T* input_data, + int batch_size, + const LlamaDenseWeight& weight, + Type type, + int* lora_mask) { FT_CHECK(type == kGemm); cublas_wrapper_->Gemm(CUBLAS_OP_N, @@ -58,6 +68,28 @@ class LlamaLinear { output_data, weight.output_dims); sync_check_cuda_error(); + + if (lora_mask && weight.lora_kernel) { + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + weight.output_dims, + batch_size, + weight.input_dims, + (const T*)weight.lora_kernel, + weight.output_dims, + input_data, + weight.input_dims, + output_data + batch_size * weight.output_dims, + weight.output_dims); + + invokeMaskAddTwoLinearOutput(output_data, + output_data + batch_size * weight.output_dims, + lora_mask, + batch_size, + weight.output_dims, + stream_); + sync_check_cuda_error(); + } } void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 87772c3e3..4b1767ab8 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -63,6 +63,7 @@ LlamaV2::LlamaV2(size_t head_num, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, + int lora_policy, cudaDeviceProp* cuda_device_prop): head_num_(head_num), size_per_head_(size_per_head), @@ -84,6 +85,7 @@ LlamaV2::LlamaV2(size_t head_num, allocator_(allocator), is_free_buffer_after_forward_(is_free_buffer_after_forward), cuda_device_prop_(cuda_device_prop), + lora_policy_(lora_policy), debug_(isDebug()), shared_state_(shared_state) @@ -166,10 +168,20 @@ void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba } template -void LlamaV2::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences) +void LlamaV2::updateEmbedding(T* decoder_input, + const int bsz, + const int* h_input_length, + const Sequence** sequences, + int token_num, + int* lora_mask, + bool* have_embeddings) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); + std::vector mask(token_num); + int* mask_ptr = mask.data(); + *have_embeddings = false; + for (int i = 0; i < bsz; i++) { const auto& seq = *sequences[i]; const auto& embeddings = seq.input_embeddings; @@ -177,18 +189,33 @@ void LlamaV2::updateEmbedding(T* decoder_input, const int bsz, const int* h_i for (int j = embeddings.size() - 1; j >= 0; j--) { int begin = ranges[j].first; int end = ranges[j].second; + if (seq.cache_len + h_input_length[i] - 1 < begin) { + continue; + } if (end <= seq.cache_len) { break; } - int off_dst = std::max(0, begin - seq.cache_len); - int off_src = std::max(0, seq.cache_len - begin); + int off_dst = std::max(0, begin - seq.cache_len); + int off_src = std::max(0, seq.cache_len - begin); + // calculate union of [begin, end) and [seq.cache_len, seq.cache_len + h_input_length[i]) + begin = std::max(begin, seq.cache_len); + end = std::min(end, seq.cache_len + h_input_length[i]); size_t byte_size = (end - begin) * hidden_units_ * sizeof(T); T* dst_ptr = decoder_input + off_dst * hidden_units_; auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T); cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_); + std::fill_n(mask_ptr + off_dst, (end - begin), 1); + *have_embeddings = true; } decoder_input += h_input_length[i] * hidden_units_; + mask_ptr += h_input_length[i]; + } + + if (lora_policy_ && *have_embeddings) { + cudaMemcpyAsync(lora_mask, mask.data(), sizeof(int) * token_num, cudaMemcpyDefault, stream_); + cudaStreamSynchronize(stream_); } + sync_check_cuda_error(); } @@ -216,7 +243,8 @@ void LlamaV2::forwardUnified(T* out, int pf_max_context_len, int pf_session_len, const int* h_input_length, - const Sequence** sequences) + const Sequence** sequences, + int* lora_mask) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -233,7 +261,14 @@ void LlamaV2::forwardUnified(T* out, hidden_units_, stream_); - updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences); + bool have_embeddings = false; + updateEmbedding(decoder_input, + dc_batch_size + pf_batch_size, + h_input_length, + sequences, + token_num, + lora_mask, + &have_embeddings); sync_check_cuda_error(); @@ -262,6 +297,10 @@ void LlamaV2::forwardUnified(T* out, {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}}, {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}}; + if (lora_policy_ && have_embeddings && lora_mask) { + inputs.insert({"lora_mask", {MEMORY_GPU, TYPE_INT32, {token_num}, lora_mask}}); + } + unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights); } diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 551b7cb12..6354223a9 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -73,6 +73,7 @@ class LlamaV2 { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, + int lora_policy, cudaDeviceProp* cuda_device_prop); struct Control { @@ -107,7 +108,13 @@ class LlamaV2 { void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); - void updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences); + void updateEmbedding(T* decoder_input, + const int bsz, + const int* h_input_length, + const Sequence** sequences, + int token_num, + int* lora_mask, + bool* have_embeddings); void forwardUnified(T* out, T* decoder_output, @@ -132,7 +139,8 @@ class LlamaV2 { int pf_max_context_len, int pf_session_len, const int* h_input_length, - const Sequence** sequences); + const Sequence** sequences, + int* lora_mask); void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); @@ -163,6 +171,7 @@ class LlamaV2 { const size_t vocab_size_; size_t vocab_size_padded_; float rmsnorm_eps_ = 1e-6f; + const int lora_policy_{}; const LlamaAttentionParams attn_params_; diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 6e62eaf42..eff8fa822 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -32,6 +32,7 @@ LlamaWeight::LlamaWeight(size_t head_num, bool attn_bias, WeightType weight_type, int group_size, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank): hidden_units_(head_num * size_per_head), @@ -56,6 +57,7 @@ LlamaWeight::LlamaWeight(size_t head_num, weight_type_, group_size, attn_bias, + lora_policy, tensor_para_size_, tensor_para_rank_)); } @@ -90,7 +92,7 @@ template void LlamaWeight::loadModel(std::string dir_path) { FtCudaDataType model_file_type = FtCudaDataType::FP16; - if(weight_type_ == WeightType::kBF16){ + if (weight_type_ == WeightType::kBF16) { model_file_type = FtCudaDataType::BF16; } dir_path += '/'; diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index a896a87a0..abbb91f24 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -37,6 +37,7 @@ struct LlamaWeight { bool attn_bias, WeightType weight_type, int group_size, + int lora_policy, size_t tensor_para_size, size_t tensor_para_rank); diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 9765b6e02..dc34a0562 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -22,8 +22,7 @@ SequenceManager::SequenceManager(size_t layer_num, size_t elem_bits, int rank, IAllocator* allocator): - block_seq_len_(block_seq_len), - rank_(rank) + block_seq_len_(block_seq_len), rank_(rank) { constexpr int kBitsPerByte = 8; diff --git a/src/turbomind/models/llama/llama_decoder_kernels.cu b/src/turbomind/models/llama/llama_decoder_kernels.cu index 6bdfa2c5e..6945d6160 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.cu +++ b/src/turbomind/models/llama/llama_decoder_kernels.cu @@ -188,11 +188,54 @@ void invokeFusedAddBiasResidualRMSNorm( residual, in_out, bias, scale, eps, batch_size, n_dims); } +template +__global__ void +maskAddTwoLinearOutput(T* __restrict__ output1, T* __restrict__ output2, const int* __restrict__ mask, int dim) +{ + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + + const auto batch_idx = block.group_index().x; + if (!mask[batch_idx]) { + return; + } + + uint4* __restrict__ out1_ptr = reinterpret_cast(output1 + batch_idx * dim); + uint4* __restrict__ out2_ptr = reinterpret_cast(output2 + batch_idx * dim); + + res_norm_t ops; + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + float thread_sum{}; + for (auto i = block.thread_rank(); i < dim / PACK_DIM; i += block.size()) { + auto o1 = out1_ptr[i]; + auto o2 = out2_ptr[i]; + uint4 b = uint4{}; + o1 = ops.addvec(o1, o2, b, thread_sum); + out1_ptr[i] = o1; + } +} + +template +void invokeMaskAddTwoLinearOutput(T* output1, T* output2, const int* mask, int batch_size, int dim, cudaStream_t stream) +{ + constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); + FT_CHECK(dim % PACK_DIM == 0); + const int n_pack = dim / PACK_DIM; + const int n_iter = ((n_pack + 1023) / 1024); // iterations when block size == 1024 + int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect + n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size + maskAddTwoLinearOutput<<>>(output1, output2, mask, dim); +} + +template void invokeMaskAddTwoLinearOutput(float*, float*, const int*, int, int, cudaStream_t); +template void invokeMaskAddTwoLinearOutput(half*, half*, const int*, int, int, cudaStream_t); + template void invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); #ifdef ENABLE_BF16 template void invokeFusedAddBiasResidualRMSNorm( __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +template void invokeMaskAddTwoLinearOutput(__nv_bfloat16*, __nv_bfloat16*, const int*, int, int, cudaStream_t); #endif } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_decoder_kernels.h b/src/turbomind/models/llama/llama_decoder_kernels.h index ade0dc053..5e4593cc9 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.h +++ b/src/turbomind/models/llama/llama_decoder_kernels.h @@ -8,4 +8,8 @@ template void invokeFusedAddBiasResidualRMSNorm( T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream); +template +void invokeMaskAddTwoLinearOutput( + T* output1, T* output2, const int* mask, int batch_size, int dim, cudaStream_t stream); + } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index aeb8c5db4..839360cb6 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -49,8 +49,9 @@ void UnifiedAttentionLayer::allocateBuffer(size_t num_token, const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; - // no padding - qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, false); + // no padding, double buffer for lora + qkv_buf_ = + (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_ * 2, false); // qkv_buf_3_ padding is removed qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, false); @@ -179,6 +180,11 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa T* attention_out = outputs->getPtr("hidden_features"); + int* lora_mask = nullptr; + if (inputs->isExist("lora_mask")) { + lora_mask = inputs->at("lora_mask").getPtr(); + } + ///////////////////////////////////////////// /// allocate buffers allocateBuffer(num_token, // @@ -194,7 +200,7 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa ////////////////////////////////////////////// /// qkv gemm // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] - linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv); + linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv, LlamaLinear::kGemm, lora_mask); if (pf_batch_size) { const int offset = dc_batch_size; @@ -240,7 +246,7 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa ////////////////////////////////////////////// /// output gemm -> - linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output); + linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output, LlamaLinear::kGemm, lora_mask); if (tensor_para_.world_size_ > 1) { NcclGuard nccl_guard(tensor_para_, stream_); @@ -628,6 +634,6 @@ template class UnifiedAttentionLayer; template class UnifiedAttentionLayer; #ifdef ENABLE_BF16 template class UnifiedAttentionLayer<__nv_bfloat16>; -#endif // ENABLE_BF16 +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 8617738d2..bccea1fdd 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -219,6 +219,9 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con /// feed-forward network TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; + if (inputs->isExist("lora_mask")) { + ffn_inputs.insert({"lora_mask", inputs->at("lora_mask")}); + } ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &weights->at(layer)->ffn_weights); const bool is_last_layer = layer == num_layer_ - 1; @@ -263,6 +266,6 @@ template class UnifiedDecoder; template class UnifiedDecoder; #ifdef ENABLE_BF16 template class UnifiedDecoder<__nv_bfloat16>; -#endif // ENABLE_BF16 +#endif // ENABLE_BF16 } // namespace turbomind diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 77f6b1983..e45e9fe67 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -183,6 +183,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); group_size_ = reader.GetInteger("llama", "group_size", 0); + lora_policy_ = reader.GetInteger("llama", "lora_policy", 0); // rotary embedding parameters attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); @@ -308,6 +309,7 @@ std::unique_ptr> LlamaTritonModel::createSh cublas_wrapper.get(), allocator.get(), false, // is_free_buffer_after_forward, + lora_policy_, cuda_device_prop_ptr.get()); return std::make_unique>( @@ -367,6 +369,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) attn_bias_, weight_type_, group_size_, + lora_policy_, tensor_para_size_, tensor_para_rank); // model inited with model_dir diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index ff086a909..c057eb621 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -101,6 +101,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { bool attn_bias_; int quant_policy_; int group_size_; + int lora_policy_; // shared weights for each device std::vector>> shared_weights_;