Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochaoxing committed Jun 7, 2024
1 parent f45114f commit af6dbbe
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 56 deletions.
42 changes: 13 additions & 29 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,12 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,

void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const c10::optional<at::Tensor>& padding_mask = {},
const c10::optional<at::Tensor>& atten_mask = {},
const at::IntArrayRef& actual_seq_lengths = {},
int64_t num_heads = 1, double scale_value = 1.0,
int64_t pre_tokens = 2147473647,
int64_t next_tokens = 0,
const std::string& input_layout = "BSH",
int64_t num_key_value_heads = 0) {
callDiopi(diopiPromptFlashAttention, out, q, k, v, padding_mask, atten_mask,
actual_seq_lengths, num_heads, scale_value, pre_tokens,
next_tokens, input_layout.c_str(), num_key_value_heads);
const at::Tensor& atten_mask,
const at::IntArrayRef& actual_seq_lengths,
int64_t max_input_len, int64_t num_heads,
int64_t num_key_value_heads, int64_t dim) {
callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask,
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
}

void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
Expand Down Expand Up @@ -417,24 +412,13 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
}

void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
const c10::optional<at::Tensor>& padding_mask = {},
const c10::optional<at::Tensor>& atten_mask = {},
const at::IntArrayRef& actual_seq_lengths = {},
const c10::optional<at::Tensor>& block_table = {},
int64_t num_heads = 1, int64_t num_key_value_heads = 0,
double scale_value = 1.0, const std::string& input_layout = "BSH",
int64_t block_size = 0, int64_t inner_precise = 1,
const c10::optional<at::Tensor>& antiquant_scale = {}, const c10::optional<at::Tensor>& antiquant_offset = {},
const c10::optional<at::Tensor>& dequant_scale1 = {}, const c10::optional<at::Tensor>& quant_scale1 = {},
const c10::optional<at::Tensor>& dequant_scale2 = {}, const c10::optional<at::Tensor>& quant_scale2 = {},
const c10::optional<at::Tensor>& quant_offset2 = {}, const c10::optional<at::Tensor>& kv_padding_size = {}
) {
callDiopi(diopiPagedAttention, out, q, k, v, padding_mask, atten_mask, actual_seq_lengths,
antiquant_scale, antiquant_offset,
block_table,
dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, kv_padding_size,
num_heads, scale_value, input_layout.c_str(), num_key_value_heads, block_size, inner_precise
);
const at::IntArrayRef& actual_seq_lengths,
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
const at::Tensor& block_table,
int64_t block_size) {
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
numHeads, numKeyValueHeads, dim,
block_table, block_size);
}

void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
Expand Down
54 changes: 27 additions & 27 deletions deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,38 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
return out

def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
batch = b_start_loc.shape[0]
scale = 1 / math.sqrt(dim)
mask_key_str = str(batch) + ":" + str(max_input_len)
if mask_key_str not in mask_cache:
mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
mask = mask.repeat(batch, 1, 1)
mask = torch.logical_not(mask)
mask_cache[mask_key_str] = mask
print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
# batch = b_start_loc.shape[0]
# scale = 1 / math.sqrt(dim)
# mask_key_str = str(batch) + ":" + str(max_input_len)
# if mask_key_str not in mask_cache:
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
# mask = mask.repeat(batch, 1, 1)
# mask = torch.logical_not(mask)
# mask_cache[mask_key_str] = mask
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")

mask = mask_cache[mask_key_str]
ext.prompt_flash_attention(out, q, k, v,
None, mask, b_seq_len, head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
return out

context_attention_pack.context_attention_fwd = (
# flash_context_attention
fused_context_attention
)
# mask = mask_cache[mask_key_str]
# ext.prompt_flash_attention(out, q, k, v,
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
# return out

# context_attention_pack.context_attention_fwd = (
# # flash_context_attention
# fused_context_attention
# )
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention

def patch_paged_token_attention_inference():
def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
ext.paged_attention(out, q, k_cache, v_cache, None, None,
b_seq_len, block_table, q_head_num, kv_head_num,
1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
None, None, None, None, None, None, None, None
)
return out
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
# b_seq_len, block_table, q_head_num, kv_head_num,
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
# None, None, None, None, None, None, None, None
# )
# return out

token_attention_pack.paged_token_attention = (paged_token_attention)
token_attention_pack.paged_token_attention = ext.paged_attention


def patch_token_attention_inference():
Expand Down

0 comments on commit af6dbbe

Please sign in to comment.