Skip to content

Commit

Permalink
add attenMask in paged_attention.
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen committed Jun 24, 2024
1 parent af6dbbe commit a8a02e0
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 110 deletions.
7 changes: 1 addition & 6 deletions csrc/diopi_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ struct IsOptionalArithmetic<c10::optional<T>> : std::is_arithmetic<T> {};

} // namespace type_traits

inline void checkTensorOnDevice(const at::Tensor& tensor) {
//if (tensor.device().type() == at::DeviceType::CPU) {
// DIPU_LOGE("This op only runs on Device");
// throw std::runtime_error("This op only runs on Device");
//}
}
inline void checkTensorOnDevice(const at::Tensor& tensor) {}

inline void checkTensorOnDevice(const c10::optional<at::Tensor>& tensor) {
if (tensor) {
Expand Down
102 changes: 32 additions & 70 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
b_start_loc, b_seq_len, max_input_len, other_kv_index);
}

// void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const at::Tensor& b_loc,
// const at::Tensor& b_start_loc,
// const at::Tensor& b_seq_len,
// int max_input_len, int other_kv_index) {
// callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
// b_seq_len, max_input_len, other_kv_index);
// }

// void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const at::Tensor& b_loc,
// const at::Tensor& b_start_loc,
// const at::Tensor& b_seq_len,
// int max_input_len, int other_kv_index) {
// callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
// b_seq_len, max_input_len, other_kv_index);
// }

// void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const int head, const char* layout,
// const c10::optional<at::Tensor>& padding_mask = {},
// const c10::optional<at::Tensor>& atten_mask = {},
// const OptionalIntArray& actual_seq_lengths = {},
// int64_t num_heads = 1, double scale_value = 1.0,
// const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
// callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
// actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
// }

void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const at::Tensor& atten_mask,
const at::IntArrayRef& actual_seq_lengths,
int64_t max_input_len, int64_t num_heads,
int64_t max_input_len, int64_t num_heads,
int64_t num_key_value_heads, int64_t dim) {
callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask,
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads,
dim);
}

void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
Expand All @@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
}

void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
const at::Tensor& frequency_penalty,
const at::Tensor& repetition_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts) {
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty,
p_token_ids, p_token_counts);
const at::Tensor& frequency_penalty,
const at::Tensor& repetition_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts) {
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty,
repetition_penalty, p_token_ids, p_token_counts);
}

void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
const at::IntArrayRef& actual_seq_lengths,
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
const at::Tensor& block_table,
int64_t block_size) {
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
numHeads, numKeyValueHeads, dim,
block_table, block_size);
void extPagedAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const c10::optional<at::Tensor>& atten_mask = {},
const at::IntArrayRef& actual_seq_lengths = {},
int64_t num_heads = 1, int64_t num_kv_heads = 1,
int64_t dim = 1,
const c10::optional<at::Tensor>& block_table = {},
int64_t block_size = 1) {
callDiopi(diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
num_heads, num_kv_heads, dim, block_table, block_size);
}

void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key,
const at::Tensor& cos, const at::Tensor& sin,
int64_t dim) {
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
}

void extMatmulAllReduce(at::Tensor& out, const at::Tensor& x1,
const at::Tensor& x2, const c10::optional<at::Tensor>& bias,
const at::Tensor& x2,
const c10::optional<at::Tensor>& bias,
const char* group, const char* reduce_op,
int64_t comm_turn, int64_t stream_mode) {
callDiopi(diopiMatmulAllReduce, out, x1, x2,
bias, group, reduce_op, comm_turn, stream_mode);
callDiopi(diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op,
comm_turn, stream_mode);
}

// 判断是否有对应的 diopi 实现:
Expand Down Expand Up @@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference,
"deeplink ext_token_softmax_reducev_inference");
}
// if (&diopiTokenDecodeAttentionInference != nullptr) {
// m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
// "deeplink token_decode_attention_inference");
// }
// if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
// m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
// "deeplink token_decode_attention_inference");
// }
// if (&diopiIncreFlashAttention != nullptr) {
// m.def("incre_flash_attention", &extIncreFlashAttention,
// "deeplink incre_flash_attention");
// }
if (&diopiPromptFlashAttention != nullptr) {
m.def("prompt_flash_attention", &extPromptFlashAttention,
"deeplink ext_prompt_flash_attention");
Expand Down Expand Up @@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"deeplink ext_paged_attention");
}
if (&diopiRotaryEmbeddingV2 != nullptr) {
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, "deeplink extRotaryEmbeddingV2");
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2,
"deeplink extRotaryEmbeddingV2");
}
if (&diopiMatmulAllReduce != nullptr) {
m.def("matmul_all_reduce", &extMatmulAllReduce,
"deeplink ext_matmul_all_reduce",
py::arg("out"), py::arg("x1"),
py::arg("x2"), py::arg("bias"),
py::arg("group"), py::arg("reduce_op") = "sum",
py::arg("comm_turn") = 0, py::arg("stream_mode") = 1);
"deeplink ext_matmul_all_reduce", py::arg("out"), py::arg("x1"),
py::arg("x2"), py::arg("bias"), py::arg("group"),
py::arg("reduce_op") = "sum", py::arg("comm_turn") = 0,
py::arg("stream_mode") = 1);
}
}

Expand Down
90 changes: 56 additions & 34 deletions deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def patch_apply_penalty():
apply_penalty_pack.apply_penalty_v2 = ext.apply_penalty_v2

def patch_context_attention_inference():
def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
def flash_context_attention(
q, k, v, out, b_start_loc, b_seq_len, max_input_len
):
batch, head, dim = b_start_loc.shape[0], q.shape[1], q.shape[2]
numKeyValueHeads = k.shape[1]
assert k.shape[1] == v.shape[1]
Expand All @@ -62,52 +64,72 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)

single_out = out[start:end, :].view(1, single_seq_len, -1)
if single_seq_len not in mask_cache:
mask = torch.tril(torch.ones(single_seq_len, single_seq_len, dtype=torch.bool), diagonal=0).cuda()
mask = torch.tril(
torch.ones(
single_seq_len, single_seq_len, dtype=torch.bool
),
diagonal=0,
).cuda()
mask = mask.repeat(1, 1, 1)
mask = torch.logical_not(mask)
mask_cache[single_seq_len] = mask
print(f"cache mask in context attention, seqLen:{single_seq_len}")
print(
f"cache mask in context attention, seqLen:{single_seq_len}"
)
mask = mask_cache[single_seq_len]
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
ext.prompt_flash_attention(
single_out,
single_q,
single_k,
single_v,
None,
mask,
[],
head,
scale,
2147473647,
0,
"BSH",
numKeyValueHeads,
)
return out

# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
# batch = b_start_loc.shape[0]
# scale = 1 / math.sqrt(dim)
# mask_key_str = str(batch) + ":" + str(max_input_len)
# if mask_key_str not in mask_cache:
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
# mask = mask.repeat(batch, 1, 1)
# mask = torch.logical_not(mask)
# mask_cache[mask_key_str] = mask
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")

# mask = mask_cache[mask_key_str]
# ext.prompt_flash_attention(out, q, k, v,
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
# return out

# context_attention_pack.context_attention_fwd = (
# # flash_context_attention
# fused_context_attention
# )
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention

def patch_paged_token_attention_inference():
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
# b_seq_len, block_table, q_head_num, kv_head_num,
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
# None, None, None, None, None, None, None, None
# )
# return out

token_attention_pack.paged_token_attention = ext.paged_attention
def paged_token_attention(
out,
q,
k_cache,
v_cache,
b_seq_len,
q_head_num,
kv_head_num,
head_dim,
block_table,
block_size,
):
ext.paged_attention(
out,
q,
k_cache,
v_cache,
None,
b_seq_len,
q_head_num,
kv_head_num,
head_dim,
block_table,
block_size,
)

token_attention_pack.paged_token_attention = paged_token_attention

def patch_token_attention_inference():
token_attention_pack.token_att_fwd = ext.token_attention_inference
token_attention_pack.token_decode_attention_fwd = ext.token_decode_attention_inference_batch_one#ext.token_decode_attention_inference
token_attention_pack.token_decode_attention_fwd = (
ext.token_decode_attention_inference_batch_one
) # ext.token_decode_attention_inference

def patch_token_softmax_reducev_inference():
token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = (
Expand Down

0 comments on commit a8a02e0

Please sign in to comment.