Skip to content

Commit

Permalink
[Fix] Support actual seqlen in flash-attention2 (#418)
Browse files Browse the repository at this point in the history
* support actual seqlen

* fix lint

* update variable types

* lint

* update type

* fix lint

---------
  • Loading branch information
q.yao authored Sep 18, 2023
1 parent 3a7880a commit abe9f7b
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1422,8 +1422,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {

int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;
size_t offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;

if (!QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
layer_offset,
attention_mask,
cu_seqlens,
input_tensors->at("context_lengths").getPtr<int>(),
batch_size,
max_q_len,
max_k_len,
Expand Down Expand Up @@ -258,6 +259,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
Expand All @@ -274,13 +276,13 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
key_cache_ptrs};
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
val_cache_ptrs};
Layout layout_o{
int(local_head_num_ * max_q_len * size_per_head_),
Expand All @@ -298,6 +300,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
qk_buf_float_,
cu_seqlens,
nullptr,
nullptr,
context_lengths,
group_size,
layout_q,
layout_k,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class LlamaContextAttentionLayer {
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,

params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_embedding_base = rotary_embedding_base;
params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);

vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
vocab_size_padded_ =
(vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;

size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
Expand Down Expand Up @@ -171,7 +172,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,

dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_padded_,
0, // end_id, deprecated
0, // end_id, deprecated
stream_,
cublas_wrapper_,
allocator_,
Expand Down
6 changes: 4 additions & 2 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path)

loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);

loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
loadWeightFromBin((T*)post_decoder_embedding_kernel,
{hidden_units_ * vocab_size_padded_},
dir_path + "output.weight",
model_file_type);

for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
Expand Down
12 changes: 8 additions & 4 deletions src/turbomind/models/llama/flash_attention2/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ struct BlockInfo {
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
actual_seqlen_q(params.actual_seqlen_q == nullptr ?
(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q) :
params.actual_seqlen_q[bidb]),
actual_seqlen_k(params.actual_seqlen_k == nullptr ?
(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k) :
params.actual_seqlen_k[bidb])
{
}

Expand Down
10 changes: 7 additions & 3 deletions src/turbomind/models/llama/flash_attention2/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
using index_t = size_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
Expand All @@ -25,8 +25,8 @@ struct Qkv_params {
// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;
size_t k_batched_offset = 0;
size_t v_batched_offset = 0;

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
Expand Down Expand Up @@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params {
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;

// array of length b with actual length of each sequence
int* __restrict__ actual_seqlen_q;
int* __restrict__ actual_seqlen_k;

void* __restrict__ blockmask;

bool is_bf16;
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/models/llama/flash_attention2/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {
fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;

fwd_params.actual_seqlen_q = params.actual_seqlen_q;
fwd_params.actual_seqlen_k = params.actual_seqlen_k;

fwd_params.blockmask = reinterpret_cast<void*>(params.mask);

fwd_params.is_bf16 = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ struct LlamaAttentionKernel:
scalar_t** v_batch_seqs_ptr = nullptr;
output_t** o_batch_seqs_ptr = nullptr;

int q_batch_seqs_offset = 0;
int k_batch_seqs_offset = 0;
int v_batch_seqs_offset = 0;
int o_batch_seqs_offset = 0;
size_t q_batch_seqs_offset = 0;
size_t k_batch_seqs_offset = 0;
size_t v_batch_seqs_offset = 0;
size_t o_batch_seqs_offset = 0;

int32_t group_size = 1;

float scale;

template<typename ptr_t>
CUTLASS_DEVICE void
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB)
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, size_t batch_seq_offset, int batch_id, int strideB)
{
if (batch_seq_ptr != nullptr)
data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset;
Expand Down
22 changes: 12 additions & 10 deletions src/turbomind/models/llama/llama_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st);

template<typename T>
struct BaseAttentionLayout {
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
int batch_seqs_offset = 0;
T** batch_seqs = nullptr;
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
size_t batch_seqs_offset = 0;
T** batch_seqs = nullptr;
};

template<typename T>
Expand All @@ -95,10 +95,12 @@ struct BaseAttentionParams {
T* key;
T* val;
T* mask;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
int* actual_seqlen_q = nullptr;
int* actual_seqlen_k = nullptr;
size_t group_size = 1;
BaseAttentionLayout<T> layout_q;
BaseAttentionLayout<T> layout_k;
BaseAttentionLayout<T> layout_v;
Expand Down
37 changes: 23 additions & 14 deletions tests/csrc/unittests/test_context_attention_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,20 +278,21 @@ int main(int argc, const char* argv[])
// auto* input_lengths = (int*)allocator.malloc(sizeof(int) * batch_size, false);
thrust::device_vector<int> input_lengths(batch_size);
thrust::host_vector<int> input_lengths_host(batch_size);
thrust::device_vector<int> kv_lengths(batch_size);
thrust::host_vector<int> kv_lengths_host(batch_size);

cudaRandomUniform<scalar_t>(query_ptr, batch_size * num_heads * seq_len * size_per_head);
cudaRandomUniform<scalar_t>(key_ptr, batch_size * num_heads * key_len * size_per_head);
cudaRandomUniform<scalar_t>(val_ptr, batch_size * num_heads * key_len * size_per_head);
cudaRandomUniform<scalar_t>(mask_ptr, batch_size * seq_len * key_len);

// create random length for batch
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
// for(int batch_id=0;batch_id<batch_size;++batch_id){
// input_lengths_host[batch_id] = seq_len;
// }
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
{
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
}
size_t h_token_num = 0;
size_t* h_pinned_token_num;
auto input_lengths_ptr = thrust::raw_pointer_cast(input_lengths.data());
Expand All @@ -306,10 +307,16 @@ int main(int argc, const char* argv[])
stream);
cudaFreeHost((void*)h_pinned_token_num);

int* k_lens = (int*)allocator.malloc(batch_size * sizeof(int));
deviceFill(k_lens, batch_size, key_len, stream);
{
std::uniform_int_distribution<int> dist{seq_len, key_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(kv_lengths_host), end(kv_lengths_host), gen);
thrust::copy(kv_lengths_host.begin(), kv_lengths_host.end(), kv_lengths.begin());
}
auto kv_lengths_ptr = thrust::raw_pointer_cast(kv_lengths.data());
// deviceFill(kv_lengths_ptr, batch_size, key_len, stream);

invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, k_lens, seq_len, key_len, batch_size, stream);
invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, kv_lengths_ptr, seq_len, key_len, batch_size, stream);
// deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream);

// compute gt
Expand Down Expand Up @@ -356,6 +363,8 @@ int main(int argc, const char* argv[])
accum_buf_ptr,
cu_seqlens_ptr,
nullptr,
nullptr,
kv_lengths_ptr,
1,
layout_q,
layout_k,
Expand All @@ -367,10 +376,10 @@ int main(int argc, const char* argv[])
int num_rows = 8;
// printf("query:\n");
// printMatrix(query_ptr, num_rows, 8, size_per_head, true);
printf("expect:\n");
printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
printf("actual:\n");
printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
// printf("expect:\n");
// printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
// printf("actual:\n");
// printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
checkResult(
"all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true);

Expand Down

0 comments on commit abe9f7b

Please sign in to comment.