diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 186f792..d664259 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -375,17 +375,12 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics, void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - const c10::optional& padding_mask = {}, - const c10::optional& atten_mask = {}, - const at::IntArrayRef& actual_seq_lengths = {}, - int64_t num_heads = 1, double scale_value = 1.0, - int64_t pre_tokens = 2147473647, - int64_t next_tokens = 0, - const std::string& input_layout = "BSH", - int64_t num_key_value_heads = 0) { - callDiopi(diopiPromptFlashAttention, out, q, k, v, padding_mask, atten_mask, - actual_seq_lengths, num_heads, scale_value, pre_tokens, - next_tokens, input_layout.c_str(), num_key_value_heads); + const at::Tensor& atten_mask, + const at::IntArrayRef& actual_seq_lengths, + int64_t max_input_len, int64_t num_heads, + int64_t num_key_value_heads, int64_t dim) { + callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask, + actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim); } void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k, @@ -417,24 +412,13 @@ void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty, } void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - const c10::optional& padding_mask = {}, - const c10::optional& atten_mask = {}, - const at::IntArrayRef& actual_seq_lengths = {}, - const c10::optional& block_table = {}, - int64_t num_heads = 1, int64_t num_key_value_heads = 0, - double scale_value = 1.0, const std::string& input_layout = "BSH", - int64_t block_size = 0, int64_t inner_precise = 1, - const c10::optional& antiquant_scale = {}, const c10::optional& antiquant_offset = {}, - const c10::optional& dequant_scale1 = {}, const c10::optional& quant_scale1 = {}, - const c10::optional& dequant_scale2 = {}, const c10::optional& quant_scale2 = {}, - const c10::optional& quant_offset2 = {}, const c10::optional& kv_padding_size = {} - ) { - callDiopi(diopiPagedAttention, out, q, k, v, padding_mask, atten_mask, actual_seq_lengths, - antiquant_scale, antiquant_offset, - block_table, - dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, kv_padding_size, - num_heads, scale_value, input_layout.c_str(), num_key_value_heads, block_size, inner_precise - ); + const at::IntArrayRef& actual_seq_lengths, + int64_t numHeads, int64_t numKeyValueHeads, int64_t dim, + const at::Tensor& block_table, + int64_t block_size) { + callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths, + numHeads, numKeyValueHeads, dim, + block_table, block_size); } void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) { diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index 5d343c5..1121796 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -71,38 +71,38 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len) ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads) return out - def fused_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len, head, numKeyValueHeads, dim): - batch = b_start_loc.shape[0] - scale = 1 / math.sqrt(dim) - mask_key_str = str(batch) + ":" + str(max_input_len) - if mask_key_str not in mask_cache: - mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda() - mask = mask.repeat(batch, 1, 1) - mask = torch.logical_not(mask) - mask_cache[mask_key_str] = mask - print(f"cache mask in context attention, batch:seqLen={mask_key_str}") + # def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim): + # batch = b_start_loc.shape[0] + # scale = 1 / math.sqrt(dim) + # mask_key_str = str(batch) + ":" + str(max_input_len) + # if mask_key_str not in mask_cache: + # mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda() + # mask = mask.repeat(batch, 1, 1) + # mask = torch.logical_not(mask) + # mask_cache[mask_key_str] = mask + # print(f"cache mask in context attention, batch:seqLen={mask_key_str}") - mask = mask_cache[mask_key_str] - ext.prompt_flash_attention(out, q, k, v, - None, mask, b_seq_len, head, scale, 2147473647, 0, "BSH", numKeyValueHeads) - return out - - context_attention_pack.context_attention_fwd = ( - # flash_context_attention - fused_context_attention - ) + # mask = mask_cache[mask_key_str] + # ext.prompt_flash_attention(out, q, k, v, + # mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim) + # return out + # context_attention_pack.context_attention_fwd = ( + # # flash_context_attention + # fused_context_attention + # ) + context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention def patch_paged_token_attention_inference(): - def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size): - ext.paged_attention(out, q, k_cache, v_cache, None, None, - b_seq_len, block_table, q_head_num, kv_head_num, - 1.0 / math.sqrt(head_dim), "BSH", block_size, 0, - None, None, None, None, None, None, None, None - ) - return out + # def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size): + # ext.paged_attention(out, q, k_cache, v_cache, None, None, + # b_seq_len, block_table, q_head_num, kv_head_num, + # 1.0 / math.sqrt(head_dim), "BSH", block_size, 0, + # None, None, None, None, None, None, None, None + # ) + # return out - token_attention_pack.paged_token_attention = (paged_token_attention) + token_attention_pack.paged_token_attention = ext.paged_attention def patch_token_attention_inference():