Skip to content

Commit

Permalink
support lora
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Jan 27, 2024
1 parent da190ef commit 9088817
Show file tree
Hide file tree
Showing 19 changed files with 204 additions and 33 deletions.
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TurbomindModelConfig:
max_position_embeddings: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: int = 0
lora_policy: int = 0

@classmethod
def from_dict(cls, env, allow_none=False):
Expand Down
10 changes: 7 additions & 3 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,12 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)

context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
context_decoder_output_buf_ =
(T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
// double buffer for lora
context_decoder_output_buf_ = (T*)allocator_->reMalloc(
context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units * 2, false);
context_decoder_ids_buf_ =
(int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
lora_mask_buf_ = (int*)allocator_->reMalloc(lora_mask_buf_, sizeof(int) * max_context_token_num_, false);

tmp_k_cache_buf_ = (T*)allocator_->reMalloc(
tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
Expand Down Expand Up @@ -850,6 +852,7 @@ void LlamaBatch<T>::FreeBuffer()
allocator_->free((void**)&context_decoder_input_buf_);
allocator_->free((void**)&context_decoder_output_buf_);
allocator_->free((void**)&context_decoder_ids_buf_);
allocator_->free((void**)&lora_mask_buf_);

allocator_->free((void**)&tmp_k_cache_buf_);
allocator_->free((void**)&tmp_v_cache_buf_);
Expand Down Expand Up @@ -1586,7 +1589,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
max_context_cnts[p],
max_context_cnts[p],
h_input_length_buf_ + first,
sequences.data());
sequences.data(),
lora_mask_buf_);

if (iter == 0) {
// compute logits of inputs if requested
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class LlamaBatch {
T* decoder_output_buf_{};
int* sequence_lengths_{}; // current sequence length
int* init_ctx_lens_{};
int* lora_mask_buf_{}; // lora

float* logits_buf_{}; // combined logits
float* local_logits_buf_{}; // tensor parallel local logits
Expand Down
25 changes: 18 additions & 7 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
WeightType weight_type,
int group_size,
bool attn_bias,
int lora_policy,
size_t tensor_para_size,
size_t tensor_para_rank):
head_num_(head_num),
Expand All @@ -43,6 +44,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(size_t head_num,
inter_size_(inter_size),
weight_type_(weight_type),
attn_bias_(attn_bias),
lora_policy_(lora_policy),
tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank)
{
Expand Down Expand Up @@ -91,14 +93,17 @@ void freeWeights(LlamaDenseWeight<T>& weights)
}

template<typename T>
void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
void mallocWeights(LlamaDenseWeight<T>& weights, bool bias, int lora_policy)
{
if (bias) {
deviceMalloc((T**)&weights.bias, weights.output_dims);
}
const size_t bit_size = getBitSize(weights.type);
if (bit_size >= 16) { // fp16, fp32
deviceMalloc((T**)&weights.kernel, weights.input_dims * weights.output_dims);
if (lora_policy) {
deviceMalloc((T**)&weights.lora_kernel, weights.input_dims * weights.output_dims);
}
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
Expand Down Expand Up @@ -244,6 +249,12 @@ void loadWeights(LlamaDenseWeight<T>& w,
}
}
loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices);
if (w.lora_kernel) {
auto dot_pos = prefix.rfind(".");
auto lora_weight_file = prefix.substr(0, dot_pos) + ".lora" + prefix.substr(dot_pos) + ".weight";
TM_LOG_INFO("loading %s", lora_weight_file.c_str());
loadWeightFromBin((T*)w.lora_kernel, {dim0, dim1}, lora_weight_file, type, weight_slices);
}
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
Expand All @@ -265,19 +276,19 @@ void LlamaDecoderLayerWeight<T>::mallocWeights()
deviceMalloc((T**)&self_attn_norm_weights, hidden_units_);
deviceMalloc((T**)&ffn_norm_weights, hidden_units_);

turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_);
turbomind::mallocWeights(self_attn_weights.output, attn_bias_);
turbomind::mallocWeights(self_attn_weights.qkv, attn_bias_, lora_policy_);
turbomind::mallocWeights(self_attn_weights.output, attn_bias_, lora_policy_);
self_attn_weights.past_kv_scale = {1.f, 0.f, 1.f, 0.f};

if (weight_type_ == WeightType::kINT4) {
turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false);
turbomind::mallocWeights(ffn_weights.fused_gating_intermediate, false, lora_policy_);
}
else {
turbomind::mallocWeights(ffn_weights.gating, false);
turbomind::mallocWeights(ffn_weights.intermediate, false);
turbomind::mallocWeights(ffn_weights.gating, false, lora_policy_);
turbomind::mallocWeights(ffn_weights.intermediate, false, lora_policy_);
}

turbomind::mallocWeights(ffn_weights.output, false);
turbomind::mallocWeights(ffn_weights.output, false, lora_policy_);
}

template<typename T>
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct LlamaDecoderLayerWeight {
WeightType weight_type,
int group_size,
bool attn_bias,
int lora_policy,
size_t tensor_para_size,
size_t tensor_para_rank);
~LlamaDecoderLayerWeight();
Expand All @@ -60,6 +61,7 @@ struct LlamaDecoderLayerWeight {
WeightType weight_type_;
size_t bit_size_;
bool attn_bias_;
int lora_policy_;
size_t tensor_para_size_;
size_t tensor_para_rank_;
bool is_maintain_buffer_ = false;
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaDenseWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct LlamaDenseWeight {
size_t input_dims;
size_t output_dims;
void* kernel;
void* lora_kernel;
WeightType type;
T* bias;
T* scales_and_zeros;
Expand Down
15 changes: 12 additions & 3 deletions src/turbomind/models/llama/LlamaFfnLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
const T* ffn_input_data = input_tensors->at("ffn_input").getPtr<T>();
T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();

// lora
int* lora_mask = nullptr;
if (input_tensors->isExist("lora_mask")) {
lora_mask = input_tensors->at("lora_mask").getPtr<int>();
inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sizeof(T) * num_token * inter_size_ * 2, false);
gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sizeof(T) * num_token * inter_size_ * 2, false);
}

if (weights->fused_gating_intermediate.kernel) {
NvtxScope scope("fused_silu_ffn");
linear_.forward(
Expand All @@ -96,19 +104,20 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
else {
{ // w1(x)
NvtxScope scope("w1");
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear<T>::kGemm, lora_mask);
}
{ // w3(x)
NvtxScope scope("w3");
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
linear_.forward(
inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear<T>::kGemm, lora_mask);
}
// silu(w1(x)) * w3(x)
activation(num_token);
}

{ // w2(x)
NvtxScope scope("w2");
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output, LlamaLinear<T>::kGemm, lora_mask);
}

if (tensor_para_.world_size_ > 1) {
Expand Down
40 changes: 36 additions & 4 deletions src/turbomind/models/llama/LlamaLinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
Expand All @@ -25,14 +26,18 @@ class LlamaLinear {
{
}

void
forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
void forward(T* output_data,
const T* input_data,
int batch_size,
const LlamaDenseWeight<T>& weight,
Type type = kGemm,
int* lora_mask = nullptr)
{
switch (weight.type) {
case WeightType::kFP16:
case WeightType::kFP32:
case WeightType::kBF16:
forwardFp(output_data, input_data, batch_size, weight, type);
forwardFp(output_data, input_data, batch_size, weight, type, lora_mask);
break;
case WeightType::kINT4:
forwardInt4(output_data, input_data, batch_size, weight, type);
Expand All @@ -43,7 +48,12 @@ class LlamaLinear {
}

private:
void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
void forwardFp(T* output_data,
const T* input_data,
int batch_size,
const LlamaDenseWeight<T>& weight,
Type type,
int* lora_mask)
{
FT_CHECK(type == kGemm);
cublas_wrapper_->Gemm(CUBLAS_OP_N,
Expand All @@ -58,6 +68,28 @@ class LlamaLinear {
output_data,
weight.output_dims);
sync_check_cuda_error();

if (lora_mask && weight.lora_kernel) {
cublas_wrapper_->Gemm(CUBLAS_OP_N,
CUBLAS_OP_N,
weight.output_dims,
batch_size,
weight.input_dims,
(const T*)weight.lora_kernel,
weight.output_dims,
input_data,
weight.input_dims,
output_data + batch_size * weight.output_dims,
weight.output_dims);

invokeMaskAddTwoLinearOutput(output_data,
output_data + batch_size * weight.output_dims,
lora_mask,
batch_size,
weight.output_dims,
stream_);
sync_check_cuda_error();
}
}

void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Expand Down
49 changes: 44 additions & 5 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
int lora_policy,
cudaDeviceProp* cuda_device_prop):
head_num_(head_num),
size_per_head_(size_per_head),
Expand All @@ -84,6 +85,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
allocator_(allocator),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
cuda_device_prop_(cuda_device_prop),
lora_policy_(lora_policy),
debug_(isDebug()),
shared_state_(shared_state)

Expand Down Expand Up @@ -166,29 +168,54 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
}

template<typename T>
void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences)
void LlamaV2<T>::updateEmbedding(T* decoder_input,
const int bsz,
const int* h_input_length,
const Sequence** sequences,
int token_num,
int* lora_mask,
bool* have_embeddings)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

std::vector<int> mask(token_num);
int* mask_ptr = mask.data();
*have_embeddings = false;

for (int i = 0; i < bsz; i++) {
const auto& seq = *sequences[i];
const auto& embeddings = seq.input_embeddings;
const auto& ranges = seq.input_embedding_ranges;
for (int j = embeddings.size() - 1; j >= 0; j--) {
int begin = ranges[j].first;
int end = ranges[j].second;
if (seq.cache_len + h_input_length[i] - 1 < begin) {
continue;
}
if (end <= seq.cache_len) {
break;
}
int off_dst = std::max(0, begin - seq.cache_len);
int off_src = std::max(0, seq.cache_len - begin);
int off_dst = std::max(0, begin - seq.cache_len);
int off_src = std::max(0, seq.cache_len - begin);
// calculate union of [begin, end) and [seq.cache_len, seq.cache_len + h_input_length[i])
begin = std::max(begin, seq.cache_len);
end = std::min(end, seq.cache_len + h_input_length[i]);
size_t byte_size = (end - begin) * hidden_units_ * sizeof(T);
T* dst_ptr = decoder_input + off_dst * hidden_units_;
auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T);
cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_);
std::fill_n(mask_ptr + off_dst, (end - begin), 1);
*have_embeddings = true;
}
decoder_input += h_input_length[i] * hidden_units_;
mask_ptr += h_input_length[i];
}

if (lora_policy_ && *have_embeddings) {
cudaMemcpyAsync(lora_mask, mask.data(), sizeof(int) * token_num, cudaMemcpyDefault, stream_);
cudaStreamSynchronize(stream_);
}

sync_check_cuda_error();
}

Expand Down Expand Up @@ -216,7 +243,8 @@ void LlamaV2<T>::forwardUnified(T* out,
int pf_max_context_len,
int pf_session_len,
const int* h_input_length,
const Sequence** sequences)
const Sequence** sequences,
int* lora_mask)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

Expand All @@ -233,7 +261,14 @@ void LlamaV2<T>::forwardUnified(T* out,
hidden_units_,
stream_);

updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences);
bool have_embeddings = false;
updateEmbedding(decoder_input,
dc_batch_size + pf_batch_size,
h_input_length,
sequences,
token_num,
lora_mask,
&have_embeddings);

sync_check_cuda_error();

Expand Down Expand Up @@ -262,6 +297,10 @@ void LlamaV2<T>::forwardUnified(T* out,
{"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}},
{"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}};

if (lora_policy_ && have_embeddings && lora_mask) {
inputs.insert({"lora_mask", {MEMORY_GPU, TYPE_INT32, {token_num}, lora_mask}});
}

unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights);
}

Expand Down
Loading

0 comments on commit 9088817

Please sign in to comment.