diff --git a/csrc/diopi_helper.h b/csrc/diopi_helper.h index 063f072..b5e9c87 100644 --- a/csrc/diopi_helper.h +++ b/csrc/diopi_helper.h @@ -35,12 +35,7 @@ struct IsOptionalArithmetic> : std::is_arithmetic {}; } // namespace type_traits -inline void checkTensorOnDevice(const at::Tensor& tensor) { - //if (tensor.device().type() == at::DeviceType::CPU) { - // DIPU_LOGE("This op only runs on Device"); - // throw std::runtime_error("This op only runs on Device"); - //} -} +inline void checkTensorOnDevice(const at::Tensor& tensor) {} inline void checkTensorOnDevice(const c10::optional& tensor) { if (tensor) { diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index d664259..742b221 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics, b_start_loc, b_seq_len, max_input_len, other_kv_index); } -// void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k, -// const at::Tensor& v, at::Tensor& out, -// const at::Tensor& b_loc, -// const at::Tensor& b_start_loc, -// const at::Tensor& b_seq_len, -// int max_input_len, int other_kv_index) { -// callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc, -// b_seq_len, max_input_len, other_kv_index); -// } - -// void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k, -// const at::Tensor& v, at::Tensor& out, -// const at::Tensor& b_loc, -// const at::Tensor& b_start_loc, -// const at::Tensor& b_seq_len, -// int max_input_len, int other_kv_index) { -// callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc, -// b_seq_len, max_input_len, other_kv_index); -// } - -// void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k, -// const at::Tensor& v, at::Tensor& out, -// const int head, const char* layout, -// const c10::optional& padding_mask = {}, -// const c10::optional& atten_mask = {}, -// const OptionalIntArray& actual_seq_lengths = {}, -// int64_t num_heads = 1, double scale_value = 1.0, -// const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) { -// callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask, -// actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads); -// } - void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, const at::Tensor& atten_mask, const at::IntArrayRef& actual_seq_lengths, - int64_t max_input_len, int64_t num_heads, + int64_t max_input_len, int64_t num_heads, int64_t num_key_value_heads, int64_t dim) { callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask, - actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim); + actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, + dim); } void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k, @@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty, } void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty, - const at::Tensor& frequency_penalty, - const at::Tensor& repetition_penalty, - const at::Tensor& p_token_ids, - const at::Tensor& p_token_counts) { - callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty, - p_token_ids, p_token_counts); + const at::Tensor& frequency_penalty, + const at::Tensor& repetition_penalty, + const at::Tensor& p_token_ids, + const at::Tensor& p_token_counts) { + callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, + repetition_penalty, p_token_ids, p_token_counts); } -void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - const at::IntArrayRef& actual_seq_lengths, - int64_t numHeads, int64_t numKeyValueHeads, int64_t dim, - const at::Tensor& block_table, - int64_t block_size) { - callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths, - numHeads, numKeyValueHeads, dim, - block_table, block_size); +void extPagedAttention(at::Tensor& out, const at::Tensor& q, + const at::Tensor& k, const at::Tensor& v, + const c10::optional& atten_mask = {}, + const at::IntArrayRef& actual_seq_lengths = {}, + int64_t numHeads = 1, int64_t numKeyValueHeads = 1, + int64_t dim = 1, + const c10::optional& block_table = {}, + int64_t block_size = 1) { + callDiopi(diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths, + numHeads, numKeyValueHeads, dim, block_table, block_size); } -void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) { +void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, + const at::Tensor& cos, const at::Tensor& sin, + int64_t dim) { callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim); } void extMatmulAllReduce(at::Tensor& out, const at::Tensor& x1, - const at::Tensor& x2, const c10::optional& bias, + const at::Tensor& x2, + const c10::optional& bias, const char* group, const char* reduce_op, int64_t comm_turn, int64_t stream_mode) { - callDiopi(diopiMatmulAllReduce, out, x1, x2, - bias, group, reduce_op, comm_turn, stream_mode); + callDiopi(diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op, + comm_turn, stream_mode); } // 判断是否有对应的 diopi 实现: @@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference, "deeplink ext_token_softmax_reducev_inference"); } - // if (&diopiTokenDecodeAttentionInference != nullptr) { - // m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference, - // "deeplink token_decode_attention_inference"); - // } - // if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) { - // m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne, - // "deeplink token_decode_attention_inference"); - // } - // if (&diopiIncreFlashAttention != nullptr) { - // m.def("incre_flash_attention", &extIncreFlashAttention, - // "deeplink incre_flash_attention"); - // } if (&diopiPromptFlashAttention != nullptr) { m.def("prompt_flash_attention", &extPromptFlashAttention, "deeplink ext_prompt_flash_attention"); @@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "deeplink ext_paged_attention"); } if (&diopiRotaryEmbeddingV2 != nullptr) { - m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, "deeplink extRotaryEmbeddingV2"); + m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, + "deeplink extRotaryEmbeddingV2"); } if (&diopiMatmulAllReduce != nullptr) { m.def("matmul_all_reduce", &extMatmulAllReduce, - "deeplink ext_matmul_all_reduce", - py::arg("out"), py::arg("x1"), - py::arg("x2"), py::arg("bias"), - py::arg("group"), py::arg("reduce_op") = "sum", - py::arg("comm_turn") = 0, py::arg("stream_mode") = 1); + "deeplink ext_matmul_all_reduce", py::arg("out"), py::arg("x1"), + py::arg("x2"), py::arg("bias"), py::arg("group"), + py::arg("reduce_op") = "sum", py::arg("comm_turn") = 0, + py::arg("stream_mode") = 1); } } diff --git a/deeplink_ext/patch_lightllm.py b/deeplink_ext/patch_lightllm.py index 1121796..2a5c096 100644 --- a/deeplink_ext/patch_lightllm.py +++ b/deeplink_ext/patch_lightllm.py @@ -46,7 +46,9 @@ def patch_apply_penalty(): apply_penalty_pack.apply_penalty_v2 = ext.apply_penalty_v2 def patch_context_attention_inference(): - def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len): + def flash_context_attention( + q, k, v, out, b_start_loc, b_seq_len, max_input_len + ): batch, head, dim = b_start_loc.shape[0], q.shape[1], q.shape[2] numKeyValueHeads = k.shape[1] assert k.shape[1] == v.shape[1] @@ -62,52 +64,72 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len) single_out = out[start:end, :].view(1, single_seq_len, -1) if single_seq_len not in mask_cache: - mask = torch.tril(torch.ones(single_seq_len, single_seq_len, dtype=torch.bool), diagonal=0).cuda() + mask = torch.tril( + torch.ones( + single_seq_len, single_seq_len, dtype=torch.bool + ), + diagonal=0, + ).cuda() mask = mask.repeat(1, 1, 1) mask = torch.logical_not(mask) mask_cache[single_seq_len] = mask - print(f"cache mask in context attention, seqLen:{single_seq_len}") + print( + f"cache mask in context attention, seqLen:{single_seq_len}" + ) mask = mask_cache[single_seq_len] - ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads) + ext.prompt_flash_attention( + single_out, + single_q, + single_k, + single_v, + None, + mask, + [], + head, + scale, + 2147473647, + 0, + "BSH", + numKeyValueHeads, + ) return out - # def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim): - # batch = b_start_loc.shape[0] - # scale = 1 / math.sqrt(dim) - # mask_key_str = str(batch) + ":" + str(max_input_len) - # if mask_key_str not in mask_cache: - # mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda() - # mask = mask.repeat(batch, 1, 1) - # mask = torch.logical_not(mask) - # mask_cache[mask_key_str] = mask - # print(f"cache mask in context attention, batch:seqLen={mask_key_str}") - - # mask = mask_cache[mask_key_str] - # ext.prompt_flash_attention(out, q, k, v, - # mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim) - # return out - - # context_attention_pack.context_attention_fwd = ( - # # flash_context_attention - # fused_context_attention - # ) context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention def patch_paged_token_attention_inference(): - # def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size): - # ext.paged_attention(out, q, k_cache, v_cache, None, None, - # b_seq_len, block_table, q_head_num, kv_head_num, - # 1.0 / math.sqrt(head_dim), "BSH", block_size, 0, - # None, None, None, None, None, None, None, None - # ) - # return out - - token_attention_pack.paged_token_attention = ext.paged_attention + def paged_token_attention( + out, + q, + k_cache, + v_cache, + b_seq_len, + q_head_num, + kv_head_num, + head_dim, + block_table, + block_size, + ): + ext.paged_attention( + out, + q, + k_cache, + v_cache, + None, + b_seq_len, + q_head_num, + kv_head_num, + head_dim, + block_table, + block_size, + ) + token_attention_pack.paged_token_attention = paged_token_attention def patch_token_attention_inference(): token_attention_pack.token_att_fwd = ext.token_attention_inference - token_attention_pack.token_decode_attention_fwd = ext.token_decode_attention_inference_batch_one#ext.token_decode_attention_inference + token_attention_pack.token_decode_attention_fwd = ( + ext.token_decode_attention_inference_batch_one + ) # ext.token_decode_attention_inference def patch_token_softmax_reducev_inference(): token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = (