diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index c2b6039d6..a0f7490e0 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -1422,8 +1422,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh - + co * QK_ELTS_IN_16B + ci; + size_t offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + + tlength_circ * Dh + co * QK_ELTS_IN_16B + ci; if (!QUANT_POLICY) { *reinterpret_cast(¶ms.k_cache_per_sample[bi][offset]) = diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc index e8f77e1c7..881582ace 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc @@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* layer_offset, attention_mask, cu_seqlens, + input_tensors->at("context_lengths").getPtr(), batch_size, max_q_len, max_k_len, @@ -258,6 +259,7 @@ void LlamaContextAttentionLayer::fusedMultiHeadAttention(T** key_cache_ptr size_t cache_layer_offset, T* attention_mask, int* cu_seqlens, + int* context_lengths, int batch_size, int max_q_len, int max_k_len, @@ -274,13 +276,13 @@ void LlamaContextAttentionLayer::fusedMultiHeadAttention(T** key_cache_ptr int(size_per_head_), int(max_seq_len * size_per_head_), false, - int(cache_layer_offset), + cache_layer_offset, key_cache_ptrs}; Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_), int(size_per_head_), int(max_seq_len * size_per_head_), false, - int(cache_layer_offset), + cache_layer_offset, val_cache_ptrs}; Layout layout_o{ int(local_head_num_ * max_q_len * size_per_head_), @@ -298,6 +300,8 @@ void LlamaContextAttentionLayer::fusedMultiHeadAttention(T** key_cache_ptr qk_buf_float_, cu_seqlens, nullptr, + nullptr, + context_lengths, group_size, layout_q, layout_k, diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.h b/src/turbomind/models/llama/LlamaContextAttentionLayer.h index 235b575b8..f79eaa4ef 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.h +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.h @@ -72,6 +72,7 @@ class LlamaContextAttentionLayer { size_t cache_layer_offset, T* attention_mask, int* cu_seqlens, + int* context_lengths, int batch_size, int max_q_len, int max_k_len, diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc index 3caaf5906..103b32e88 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -130,7 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_embedding_base = rotary_embedding_base; + params.rotary_embedding_base = rotary_embedding_base; params.max_position_embeddings = max_position_embeddings; params.use_dynamic_ntk = use_dynamic_ntk; params.use_logn_attn = use_logn_attn; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index beaf3c3f6..9c48e4f81 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -93,7 +93,8 @@ LlamaV2::LlamaV2(size_t head_num, TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); - vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_; + vocab_size_padded_ = + (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_; size_t elem_bits = 0; if (quant_policy & QuantPolicy::kCacheKVInt8) { @@ -171,7 +172,7 @@ void LlamaV2::initialize(const LlamaAttentionParams& attn_params, dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, vocab_size_padded_, - 0, // end_id, deprecated + 0, // end_id, deprecated stream_, cublas_wrapper_, allocator_, diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 511cbe5bb..80e561442 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -95,8 +95,10 @@ void LlamaWeight::loadModel(std::string dir_path) loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); - loadWeightFromBin( - (T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type); + loadWeightFromBin((T*)post_decoder_embedding_kernel, + {hidden_units_ * vocab_size_padded_}, + dir_path + "output.weight", + model_file_type); for (unsigned layer = 0; layer < num_layer_; ++layer) { decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); diff --git a/src/turbomind/models/llama/flash_attention2/block_info.h b/src/turbomind/models/llama/flash_attention2/block_info.h index 310d1f22b..38b6aa258 100644 --- a/src/turbomind/models/llama/flash_attention2/block_info.h +++ b/src/turbomind/models/llama/flash_attention2/block_info.h @@ -15,10 +15,14 @@ struct BlockInfo { __device__ BlockInfo(const Params& params, const int bidb): sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), - actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : - params.cu_seqlens_q[bidb + 1] - sum_s_q), - actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : - params.cu_seqlens_k[bidb + 1] - sum_s_k) + actual_seqlen_q(params.actual_seqlen_q == nullptr ? + (!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : + params.cu_seqlens_q[bidb + 1] - sum_s_q) : + params.actual_seqlen_q[bidb]), + actual_seqlen_k(params.actual_seqlen_k == nullptr ? + (!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : + params.cu_seqlens_k[bidb + 1] - sum_s_k) : + params.actual_seqlen_k[bidb]) { } diff --git a/src/turbomind/models/llama/flash_attention2/flash.h b/src/turbomind/models/llama/flash_attention2/flash.h index 576cbc8d9..8a5a7c579 100644 --- a/src/turbomind/models/llama/flash_attention2/flash.h +++ b/src/turbomind/models/llama/flash_attention2/flash.h @@ -16,7 +16,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = uint32_t; + using index_t = size_t; // The QKV matrices. void* __restrict__ q_ptr; void* __restrict__ k_ptr; @@ -25,8 +25,8 @@ struct Qkv_params { // batched ptr inputs. void** __restrict__ k_batched_ptr = nullptr; void** __restrict__ v_batched_ptr = nullptr; - int k_batched_offset = 0; - int v_batched_offset = 0; + size_t k_batched_offset = 0; + size_t v_batched_offset = 0; // The stride between rows of the Q, K and V matrices. index_t q_batch_stride; @@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params { int* __restrict__ cu_seqlens_q; int* __restrict__ cu_seqlens_k; + // array of length b with actual length of each sequence + int* __restrict__ actual_seqlen_q; + int* __restrict__ actual_seqlen_k; + void* __restrict__ blockmask; bool is_bf16; diff --git a/src/turbomind/models/llama/flash_attention2/flash_api.cpp b/src/turbomind/models/llama/flash_attention2/flash_api.cpp index e2f12c723..55bc92c1f 100644 --- a/src/turbomind/models/llama/flash_attention2/flash_api.cpp +++ b/src/turbomind/models/llama/flash_attention2/flash_api.cpp @@ -121,6 +121,9 @@ class FlashAttentionOpImpl::impl { fwd_params.cu_seqlens_q = params.cu_seqlens_q; fwd_params.cu_seqlens_k = params.cu_seqlens_k; + fwd_params.actual_seqlen_q = params.actual_seqlen_q; + fwd_params.actual_seqlen_k = params.actual_seqlen_k; + fwd_params.blockmask = reinterpret_cast(params.mask); fwd_params.is_bf16 = false; diff --git a/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu b/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu index 29035421c..4fae69bd0 100644 --- a/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu +++ b/src/turbomind/models/llama/fused_multi_head_attention/llama_flash_attention_kernel.cu @@ -70,10 +70,10 @@ struct LlamaAttentionKernel: scalar_t** v_batch_seqs_ptr = nullptr; output_t** o_batch_seqs_ptr = nullptr; - int q_batch_seqs_offset = 0; - int k_batch_seqs_offset = 0; - int v_batch_seqs_offset = 0; - int o_batch_seqs_offset = 0; + size_t q_batch_seqs_offset = 0; + size_t k_batch_seqs_offset = 0; + size_t v_batch_seqs_offset = 0; + size_t o_batch_seqs_offset = 0; int32_t group_size = 1; @@ -81,7 +81,7 @@ struct LlamaAttentionKernel: template CUTLASS_DEVICE void - update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB) + update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, size_t batch_seq_offset, int batch_id, int strideB) { if (batch_seq_ptr != nullptr) data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset; diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 6bd4644f0..06cb24e04 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -80,12 +80,12 @@ void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st); template struct BaseAttentionLayout { - int stride_batch; - int stride_seq; - int stride_head; - bool use_seqlens = false; - int batch_seqs_offset = 0; - T** batch_seqs = nullptr; + int stride_batch; + int stride_seq; + int stride_head; + bool use_seqlens = false; + size_t batch_seqs_offset = 0; + T** batch_seqs = nullptr; }; template @@ -95,10 +95,12 @@ struct BaseAttentionParams { T* key; T* val; T* mask; - float* out_accum = nullptr; - int* cu_seqlens_q = nullptr; - int* cu_seqlens_k = nullptr; - size_t group_size = 1; + float* out_accum = nullptr; + int* cu_seqlens_q = nullptr; + int* cu_seqlens_k = nullptr; + int* actual_seqlen_q = nullptr; + int* actual_seqlen_k = nullptr; + size_t group_size = 1; BaseAttentionLayout layout_q; BaseAttentionLayout layout_k; BaseAttentionLayout layout_v; diff --git a/tests/csrc/unittests/test_context_attention_layer.cu b/tests/csrc/unittests/test_context_attention_layer.cu index 948cd88a6..87693de34 100644 --- a/tests/csrc/unittests/test_context_attention_layer.cu +++ b/tests/csrc/unittests/test_context_attention_layer.cu @@ -278,6 +278,8 @@ int main(int argc, const char* argv[]) // auto* input_lengths = (int*)allocator.malloc(sizeof(int) * batch_size, false); thrust::device_vector input_lengths(batch_size); thrust::host_vector input_lengths_host(batch_size); + thrust::device_vector kv_lengths(batch_size); + thrust::host_vector kv_lengths_host(batch_size); cudaRandomUniform(query_ptr, batch_size * num_heads * seq_len * size_per_head); cudaRandomUniform(key_ptr, batch_size * num_heads * key_len * size_per_head); @@ -285,13 +287,12 @@ int main(int argc, const char* argv[]) cudaRandomUniform(mask_ptr, batch_size * seq_len * key_len); // create random length for batch - std::uniform_int_distribution dist{seq_len / 2, seq_len}; - auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); }; - std::generate(begin(input_lengths_host), end(input_lengths_host), gen); - // for(int batch_id=0;batch_id dist{seq_len / 2, seq_len}; + auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); }; + std::generate(begin(input_lengths_host), end(input_lengths_host), gen); + thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin()); + } size_t h_token_num = 0; size_t* h_pinned_token_num; auto input_lengths_ptr = thrust::raw_pointer_cast(input_lengths.data()); @@ -306,10 +307,16 @@ int main(int argc, const char* argv[]) stream); cudaFreeHost((void*)h_pinned_token_num); - int* k_lens = (int*)allocator.malloc(batch_size * sizeof(int)); - deviceFill(k_lens, batch_size, key_len, stream); + { + std::uniform_int_distribution dist{seq_len, key_len}; + auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); }; + std::generate(begin(kv_lengths_host), end(kv_lengths_host), gen); + thrust::copy(kv_lengths_host.begin(), kv_lengths_host.end(), kv_lengths.begin()); + } + auto kv_lengths_ptr = thrust::raw_pointer_cast(kv_lengths.data()); + // deviceFill(kv_lengths_ptr, batch_size, key_len, stream); - invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, k_lens, seq_len, key_len, batch_size, stream); + invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, kv_lengths_ptr, seq_len, key_len, batch_size, stream); // deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream); // compute gt @@ -356,6 +363,8 @@ int main(int argc, const char* argv[]) accum_buf_ptr, cu_seqlens_ptr, nullptr, + nullptr, + kv_lengths_ptr, 1, layout_q, layout_k, @@ -367,10 +376,10 @@ int main(int argc, const char* argv[]) int num_rows = 8; // printf("query:\n"); // printMatrix(query_ptr, num_rows, 8, size_per_head, true); - printf("expect:\n"); - printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true); - printf("actual:\n"); - printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true); + // printf("expect:\n"); + // printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true); + // printf("actual:\n"); + // printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true); checkResult( "all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true);