Skip to content

Commit

Permalink
Fix crash when session len is too large
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 16, 2023
1 parent 2dec28a commit 2a33625
Show file tree
Hide file tree
Showing 23 changed files with 104 additions and 72 deletions.
15 changes: 10 additions & 5 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_prompt(self, prompt, sequence_start=True):

@abstractmethod
def decorate_prompt(self, prompt, sequence_start):
pass
return prompt

@staticmethod
def _translate_messages(messages: List):
Expand Down Expand Up @@ -176,8 +176,8 @@ class InternLMChat7B(BaseModel):
def __init__(self,
system='',
user='<|User|>',
eoh='<eoh>',
eoa='<eoa>',
eoh='',
eoa='',
assistant='<|Bot|>',
**kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -231,19 +231,22 @@ def messages2prompt(self, messages, sequence_start=True):
@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [103027, 103028]
return [103028]


@MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B):
"""Chat template and generation parameters of InternLM-Chat-7B-8K."""

def __init__(self, session_len=8192, **kwargs):
def __init__(self, session_len=8192, repetition_penalty=1.02, **kwargs):
super(InternLMChat7B8K, self).__init__(**kwargs)
self.session_len = session_len
self.repetition_penalty = repetition_penalty


@MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel):
"""Generation parameters of Baichuan-7B base model."""

def __init__(self, repetition_penalty=1.1, **kwargs):
super().__init__(**kwargs)
Expand All @@ -252,6 +255,8 @@ def __init__(self, repetition_penalty=1.1, **kwargs):

@MODELS.register_module(name='baichuan2-7b')
class Baichuan2_7B(BaseModel):
"""Chat template and generation parameters of Baichuan2-7B-Base and
Baichuan2-7B-Chat models."""

def __init__(self,
temperature=0.3,
Expand Down
7 changes: 4 additions & 3 deletions lmdeploy/turbomind/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class HuggingFaceTokenizer:

def __init__(self, model_dir: str):
from transformers import (AutoTokenizer, CodeLlamaTokenizerFast,
LlamaTokenizerFast)
LlamaTokenizer, LlamaTokenizerFast)
model_file = osp.join(model_dir, 'tokenizer.model')
backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
Expand All @@ -121,8 +121,9 @@ def __init__(self, model_dir: str):
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
self.need_padding = isinstance(self.model, LlamaTokenizerFast) \
or isinstance(self.model, CodeLlamaTokenizerFast)
self.need_padding = type(self.model) in [
LlamaTokenizer, LlamaTokenizerFast, CodeLlamaTokenizerFast
]
self._no_prefix_space_tokens = None
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {

int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
size_t offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;

if (!QUANT_POLICY) {
Expand Down
12 changes: 6 additions & 6 deletions src/turbomind/models/llama/Barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class Barrier {
FT_CHECK(count == 1);
}

Barrier(const Barrier&) = delete;
Barrier& operator=(const Barrier&) = delete;
Barrier(Barrier&&) noexcept = delete;
Barrier(const Barrier&) = delete;
Barrier& operator=(const Barrier&) = delete;
Barrier(Barrier&&) noexcept = delete;
Barrier& operator=(Barrier&&) noexcept = delete;

void wait() {}
Expand All @@ -39,9 +39,9 @@ class Barrier {
pthread_barrier_init(&barrier_, nullptr, count);
}

Barrier(const Barrier&) = delete;
Barrier& operator=(const Barrier&) = delete;
Barrier(Barrier&&) noexcept = delete;
Barrier(const Barrier&) = delete;
Barrier& operator=(const Barrier&) = delete;
Barrier(Barrier&&) noexcept = delete;
Barrier& operator=(Barrier&&) noexcept = delete;

void wait()
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,8 +899,9 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_

if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
layer_offset,
attention_mask,
cu_seqlens,
input_tensors->at("context_lengths").getPtr<int>(),
batch_size,
max_q_len,
max_k_len,
Expand Down Expand Up @@ -258,6 +259,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
Expand All @@ -274,13 +276,13 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
key_cache_ptrs};
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
val_cache_ptrs};
Layout layout_o{
int(local_head_num_ * max_q_len * size_per_head_),
Expand All @@ -298,6 +300,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
qk_buf_float_,
cu_seqlens,
nullptr,
nullptr,
context_lengths,
group_size,
layout_q,
layout_k,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class LlamaContextAttentionLayer {
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaDecoderLayerWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct LlamaDecoderLayerWeight {
size_t tensor_para_size,
size_t tensor_para_rank);
~LlamaDecoderLayerWeight();
LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete;
LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete;
LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete;

void loadModel(std::string dir_path, FtCudaDataType model_file_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,

params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_embedding_base = rotary_embedding_base;
params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
Expand Down
3 changes: 1 addition & 2 deletions src/turbomind/models/llama/LlamaDenseWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

namespace turbomind {

enum class WeightType : int
{
enum class WeightType : int {
kFP32,
kFP16,
kFP8, // not supported yet
Expand Down
3 changes: 1 addition & 2 deletions src/turbomind/models/llama/LlamaLinear.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ namespace turbomind {
template<typename T>
class LlamaLinear {
public:
enum Type
{
enum Type {
kGemm,
kFusedSiluFfn
};
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);

vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
vocab_size_padded_ =
(vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;

size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
Expand Down Expand Up @@ -171,7 +172,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,

dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_padded_,
0, // end_id, deprecated
0, // end_id, deprecated
stream_,
cublas_wrapper_,
allocator_,
Expand Down
6 changes: 4 additions & 2 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path)

loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);

loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
loadWeightFromBin((T*)post_decoder_embedding_kernel,
{hidden_units_ * vocab_size_padded_},
dir_path + "output.weight",
model_file_type);

for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct LlamaWeight {

~LlamaWeight();

LlamaWeight(const LlamaWeight& other) = delete;
LlamaWeight(const LlamaWeight& other) = delete;
LlamaWeight& operator=(const LlamaWeight& other) = delete;

void loadModel(std::string dir_path);
Expand Down
3 changes: 1 addition & 2 deletions src/turbomind/models/llama/Request.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ struct Request {
using Callback = std::function<void(std::unordered_map<std::string, Tensor>*)>;
Callback stream_cb;

enum
{
enum {
kInvalid = 1,
kConflict = 2,
kBusy = 3,
Expand Down
12 changes: 8 additions & 4 deletions src/turbomind/models/llama/flash_attention2/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ struct BlockInfo {
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
actual_seqlen_q(params.actual_seqlen_q == nullptr ?
(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q) :
params.actual_seqlen_q[bidb]),
actual_seqlen_k(params.actual_seqlen_k == nullptr ?
(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k) :
params.actual_seqlen_k[bidb])
{
}

Expand Down
10 changes: 7 additions & 3 deletions src/turbomind/models/llama/flash_attention2/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
using index_t = size_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
Expand All @@ -25,8 +25,8 @@ struct Qkv_params {
// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;
size_t k_batched_offset = 0;
size_t v_batched_offset = 0;

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
Expand Down Expand Up @@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params {
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;

// array of length b with actual length of each sequence
int* __restrict__ actual_seqlen_q;
int* __restrict__ actual_seqlen_k;

void* __restrict__ blockmask;

bool is_bf16;
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/models/llama/flash_attention2/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {
fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;

fwd_params.actual_seqlen_q = params.actual_seqlen_q;
fwd_params.actual_seqlen_k = params.actual_seqlen_k;

fwd_params.blockmask = reinterpret_cast<void*>(params.mask);

fwd_params.is_bf16 = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ struct LlamaAttentionKernel:
scalar_t** v_batch_seqs_ptr = nullptr;
output_t** o_batch_seqs_ptr = nullptr;

int q_batch_seqs_offset = 0;
int k_batch_seqs_offset = 0;
int v_batch_seqs_offset = 0;
int o_batch_seqs_offset = 0;
size_t q_batch_seqs_offset = 0;
size_t k_batch_seqs_offset = 0;
size_t v_batch_seqs_offset = 0;
size_t o_batch_seqs_offset = 0;

int32_t group_size = 1;

float scale;

template<typename ptr_t>
CUTLASS_DEVICE void
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB)
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, size_t batch_seq_offset, int batch_id, int strideB)
{
if (batch_seq_ptr != nullptr)
data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset;
Expand Down
16 changes: 9 additions & 7 deletions src/turbomind/models/llama/llama_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ struct BaseAttentionLayout {
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
int batch_seqs_offset = 0;
T** batch_seqs = nullptr;
bool use_seqlens = false;
size_t batch_seqs_offset = 0;
T** batch_seqs = nullptr;
};

template<typename T>
Expand All @@ -95,10 +95,12 @@ struct BaseAttentionParams {
T* key;
T* val;
T* mask;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
int* actual_seqlen_q = nullptr;
int* actual_seqlen_k = nullptr;
size_t group_size = 1;
BaseAttentionLayout<T> layout_q;
BaseAttentionLayout<T> layout_k;
BaseAttentionLayout<T> layout_v;
Expand Down
Loading

0 comments on commit 2a33625

Please sign in to comment.