Skip to content

Commit

Permalink
Disable attention mask when it is not needed (#813)
Browse files Browse the repository at this point in the history
* disable attention mask when not needed

* fix for sm<80 and float data type
  • Loading branch information
lzhangzz authored Dec 11, 2023
1 parent d5a8946 commit b8354da
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i
session_len_ = max_session_len;
}

FT_CHECK(max_context_token_num_ >= session_len_);

for (auto& s : states_) {
s.requests.resize(max_batch_size_);
s.sequences.resize(max_batch_size_);
Expand Down
28 changes: 18 additions & 10 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ void UnifiedDecoder<T>::allocateBuffer(size_t num_token, size_t pf_batch_size, s
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

if (pf_batch_size) {
attention_mask_ =
(T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false);
if (need_causal_mask_) {
attention_mask_ = (T*)allocator_->reMalloc(
attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false);
}
else {
// just to avoid nullptr
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T), false);
}
padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false);
cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false);
Expand Down Expand Up @@ -162,14 +168,16 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con

FT_CHECK(tmp_token_num == token_num - dc_batch_size);

invokeCreateCausalMasks(attention_mask_,
input_length + pf_offset,
context_length + pf_offset,
pf_max_q_len,
pf_max_k_len,
pf_batch_size,
stream_);
sync_check_cuda_error();
if (need_causal_mask_) {
invokeCreateCausalMasks(attention_mask_,
input_length + pf_offset,
context_length + pf_offset,
pf_max_q_len,
pf_max_k_len,
pf_batch_size,
stream_);
sync_check_cuda_error();
}
}

/////////////////////////////////////////////
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/models/llama/unified_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/unified_attention_layer.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -46,6 +47,8 @@ class UnifiedDecoder {

const DataType dtype_;

bool need_causal_mask_{false};

using WeightType = LlamaDecoderLayerWeight<T>;

void forwardSelfAttn(T* attn_io,
Expand Down Expand Up @@ -88,6 +91,14 @@ class UnifiedDecoder {
tensor_para_(tensor_para),
dtype_(getTensorType<T>())
{
#ifdef _MSC_VER
// Both unfused MHA and flash attention 1 need causal mask
need_causal_mask_ = true;
#endif
// attention mask is not used for FA-1 (which requires sm80+ and half/bf16 data type)
if (!use_fmha || (getSMVersion() < 80 || sizeof(T) != 2)) {
need_causal_mask_ = true;
}
initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
}

Expand Down

0 comments on commit b8354da

Please sign in to comment.