Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support some ops for lightllm and lmdeploy #115

Open
wants to merge 1 commit into
base: llama2_infer_910b
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions csrc/diopi_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@ struct IsOptionalArithmetic<c10::optional<T>> : std::is_arithmetic<T> {};

} // namespace type_traits

inline void checkTensorOnDevice(const at::Tensor& tensor) {
//if (tensor.device().type() == at::DeviceType::CPU) {
// DIPU_LOGE("This op only runs on Device");
// throw std::runtime_error("This op only runs on Device");
//}
}
inline void checkTensorOnDevice(const at::Tensor& tensor) {}

inline void checkTensorOnDevice(const c10::optional<at::Tensor>& tensor) {
if (tensor) {
Expand Down
102 changes: 32 additions & 70 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
b_start_loc, b_seq_len, max_input_len, other_kv_index);
}

// void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const at::Tensor& b_loc,
// const at::Tensor& b_start_loc,
// const at::Tensor& b_seq_len,
// int max_input_len, int other_kv_index) {
// callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
// b_seq_len, max_input_len, other_kv_index);
// }

// void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const at::Tensor& b_loc,
// const at::Tensor& b_start_loc,
// const at::Tensor& b_seq_len,
// int max_input_len, int other_kv_index) {
// callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
// b_seq_len, max_input_len, other_kv_index);
// }

// void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, at::Tensor& out,
// const int head, const char* layout,
// const c10::optional<at::Tensor>& padding_mask = {},
// const c10::optional<at::Tensor>& atten_mask = {},
// const OptionalIntArray& actual_seq_lengths = {},
// int64_t num_heads = 1, double scale_value = 1.0,
// const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
// callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
// actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
// }

void extPromptFlashAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const at::Tensor& atten_mask,
const at::IntArrayRef& actual_seq_lengths,
int64_t max_input_len, int64_t num_heads,
int64_t max_input_len, int64_t num_heads,
int64_t num_key_value_heads, int64_t dim) {
callDiopi(diopiPromptFlashAttention, out, q, k, v, atten_mask,
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
actual_seq_lengths, max_input_len, num_heads, num_key_value_heads,
dim);
}

void extContextAttentionInference(const at::Tensor& q, const at::Tensor& k,
Expand All @@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
}

void extApplyPenaltyV2(at::Tensor& logits, const at::Tensor& presence_penalty,
const at::Tensor& frequency_penalty,
const at::Tensor& repetition_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts) {
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty,
p_token_ids, p_token_counts);
const at::Tensor& frequency_penalty,
const at::Tensor& repetition_penalty,
const at::Tensor& p_token_ids,
const at::Tensor& p_token_counts) {
callDiopi(diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty,
repetition_penalty, p_token_ids, p_token_counts);
}

void extPagedAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
const at::IntArrayRef& actual_seq_lengths,
int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
const at::Tensor& block_table,
int64_t block_size) {
callDiopi(diopiPagedAttention, out, q, k, v, actual_seq_lengths,
numHeads, numKeyValueHeads, dim,
block_table, block_size);
void extPagedAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const c10::optional<at::Tensor>& atten_mask = {},
const at::IntArrayRef& actual_seq_lengths = {},
int64_t num_heads = 1, int64_t num_kv_heads = 1,
int64_t dim = 1,
const c10::optional<at::Tensor>& block_table = {},
int64_t block_size = 1) {
callDiopi(diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
num_heads, num_kv_heads, dim, block_table, block_size);
}

void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
void extRotaryEmbeddingV2(at::Tensor& query, at::Tensor& key,
const at::Tensor& cos, const at::Tensor& sin,
int64_t dim) {
callDiopi(diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
}

void extMatmulAllReduce(at::Tensor& out, const at::Tensor& x1,
const at::Tensor& x2, const c10::optional<at::Tensor>& bias,
const at::Tensor& x2,
const c10::optional<at::Tensor>& bias,
const char* group, const char* reduce_op,
int64_t comm_turn, int64_t stream_mode) {
callDiopi(diopiMatmulAllReduce, out, x1, x2,
bias, group, reduce_op, comm_turn, stream_mode);
callDiopi(diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op,
comm_turn, stream_mode);
}

// 判断是否有对应的 diopi 实现:
Expand Down Expand Up @@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("token_softmax_reducev_inference", &extTokenSoftmaxReduceVInference,
"deeplink ext_token_softmax_reducev_inference");
}
// if (&diopiTokenDecodeAttentionInference != nullptr) {
// m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
// "deeplink token_decode_attention_inference");
// }
// if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
// m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
// "deeplink token_decode_attention_inference");
// }
// if (&diopiIncreFlashAttention != nullptr) {
// m.def("incre_flash_attention", &extIncreFlashAttention,
// "deeplink incre_flash_attention");
// }
if (&diopiPromptFlashAttention != nullptr) {
m.def("prompt_flash_attention", &extPromptFlashAttention,
"deeplink ext_prompt_flash_attention");
Expand Down Expand Up @@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"deeplink ext_paged_attention");
}
if (&diopiRotaryEmbeddingV2 != nullptr) {
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2, "deeplink extRotaryEmbeddingV2");
m.def("rotary_embedding_v2", &extRotaryEmbeddingV2,
"deeplink extRotaryEmbeddingV2");
}
if (&diopiMatmulAllReduce != nullptr) {
m.def("matmul_all_reduce", &extMatmulAllReduce,
"deeplink ext_matmul_all_reduce",
py::arg("out"), py::arg("x1"),
py::arg("x2"), py::arg("bias"),
py::arg("group"), py::arg("reduce_op") = "sum",
py::arg("comm_turn") = 0, py::arg("stream_mode") = 1);
"deeplink ext_matmul_all_reduce", py::arg("out"), py::arg("x1"),
py::arg("x2"), py::arg("bias"), py::arg("group"),
py::arg("reduce_op") = "sum", py::arg("comm_turn") = 0,
py::arg("stream_mode") = 1);
}
}

Expand Down
90 changes: 56 additions & 34 deletions deeplink_ext/patch_lightllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def patch_apply_penalty():
apply_penalty_pack.apply_penalty_v2 = ext.apply_penalty_v2

def patch_context_attention_inference():
def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len):
def flash_context_attention(
q, k, v, out, b_start_loc, b_seq_len, max_input_len
):
batch, head, dim = b_start_loc.shape[0], q.shape[1], q.shape[2]
numKeyValueHeads = k.shape[1]
assert k.shape[1] == v.shape[1]
Expand All @@ -62,52 +64,72 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)

single_out = out[start:end, :].view(1, single_seq_len, -1)
if single_seq_len not in mask_cache:
mask = torch.tril(torch.ones(single_seq_len, single_seq_len, dtype=torch.bool), diagonal=0).cuda()
mask = torch.tril(
torch.ones(
single_seq_len, single_seq_len, dtype=torch.bool
),
diagonal=0,
).cuda()
mask = mask.repeat(1, 1, 1)
mask = torch.logical_not(mask)
mask_cache[single_seq_len] = mask
print(f"cache mask in context attention, seqLen:{single_seq_len}")
print(
f"cache mask in context attention, seqLen:{single_seq_len}"
)
mask = mask_cache[single_seq_len]
ext.prompt_flash_attention(single_out, single_q, single_k, single_v, None, mask, [], head, scale, 2147473647, 0, "BSH", numKeyValueHeads)
ext.prompt_flash_attention(
single_out,
single_q,
single_k,
single_v,
None,
mask,
[],
head,
scale,
2147473647,
0,
"BSH",
numKeyValueHeads,
)
return out

# def fused_context_attention(out, q, k, v, mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim):
# batch = b_start_loc.shape[0]
# scale = 1 / math.sqrt(dim)
# mask_key_str = str(batch) + ":" + str(max_input_len)
# if mask_key_str not in mask_cache:
# mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
# mask = mask.repeat(batch, 1, 1)
# mask = torch.logical_not(mask)
# mask_cache[mask_key_str] = mask
# print(f"cache mask in context attention, batch:seqLen={mask_key_str}")

# mask = mask_cache[mask_key_str]
# ext.prompt_flash_attention(out, q, k, v,
# mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
# return out

# context_attention_pack.context_attention_fwd = (
# # flash_context_attention
# fused_context_attention
# )
context_attention_pack.prompt_flash_attention = ext.prompt_flash_attention

def patch_paged_token_attention_inference():
# def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
# ext.paged_attention(out, q, k_cache, v_cache, None, None,
# b_seq_len, block_table, q_head_num, kv_head_num,
# 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
# None, None, None, None, None, None, None, None
# )
# return out

token_attention_pack.paged_token_attention = ext.paged_attention
def paged_token_attention(
out,
q,
k_cache,
v_cache,
b_seq_len,
q_head_num,
kv_head_num,
head_dim,
block_table,
block_size,
):
ext.paged_attention(
out,
q,
k_cache,
v_cache,
None,
b_seq_len,
q_head_num,
kv_head_num,
head_dim,
block_table,
block_size,
)

token_attention_pack.paged_token_attention = paged_token_attention

def patch_token_attention_inference():
token_attention_pack.token_att_fwd = ext.token_attention_inference
token_attention_pack.token_decode_attention_fwd = ext.token_decode_attention_inference_batch_one#ext.token_decode_attention_inference
token_attention_pack.token_decode_attention_fwd = (
ext.token_decode_attention_inference_batch_one
) # ext.token_decode_attention_inference

def patch_token_softmax_reducev_inference():
token_attention_softmax_reducev_pack.token_softmax_reducev_fwd = (
Expand Down
Loading