Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions csrc/deepep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,30 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
EP_HOST_ASSERT(low_latency_mode);
at::Tensor new_x = x;
this->new_topk_idx = topk_idx;
if (topk_idx.size(0) == 0) {
if (topk_idx.size(0) <= 2) {
this->is_padding = true;
this->ori_x = x.clone();
new_x = torch::ones({1, 7168}, x.options());
this->new_topk_idx = torch::arange(0, 8, topk_idx.options()).reshape({1, 8});
this->padding_cnt = 3 - topk_idx.size(0);
std::vector<at::Tensor> x_blocks;
std::vector<at::Tensor> topk_blocks;
if (topk_idx.size(0) != 0) {
x_blocks.emplace_back(x);
topk_blocks.emplace_back(topk_idx);
} else {
this->ori_x = x.clone();
}
for (int i = 0; i < this->padding_cnt; i++) {
at::Tensor tmp_x = torch::ones({1, 7168}, x.options());
at::Tensor tmp_topk = torch::arange(0, 8, topk_idx.options()).reshape({1, 8});
x_blocks.emplace_back(tmp_x);
topk_blocks.emplace_back(tmp_topk);
}
new_x = torch::cat(x_blocks, 0);
this->new_topk_idx = torch::cat(topk_blocks, 0);
}

auto num_tokens = static_cast<int>(new_x.size(0)), hidden = static_cast<int>(new_x.size(1));
auto num_scales = hidden / 128, num_topk = static_cast<int>(new_topk_idx.size(1));

auto num_local_experts = num_experts / (num_ranks - shared_expert_rank_num);
auto num_max_tokens = 0;
if (rank < shared_expert_rank_num) {
Expand All @@ -144,12 +159,13 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
} else { // moe expert
num_max_tokens = num_max_dispatch_tokens_per_rank * num_ranks * num_local_experts;
}
auto max_size = num_tokens * num_topk > num_max_tokens * 128 ? num_tokens * num_topk : num_max_tokens * 128;

// Allocate packed tensors
auto device = new_x.device();
auto packed_recv_x = at::empty({num_max_tokens, hidden}, new_x.options().dtype(use_fp8 ? at::kChar : at::kBFloat16));
auto packed_recv_x_scales = at::empty({num_max_tokens}, at::dtype(at::kFloat).device(device));
auto expandIdx = at::empty({num_tokens * num_topk}, at::dtype(at::kInt).device(device));
auto assist_info_for_combine = at::empty({max_size}, at::dtype(at::kInt).device(device));
auto packed_recv_count = at::empty({num_local_experts * num_ranks}, at::dtype(at::kInt).device(device));
auto tp_recv_count = at::empty({1}, at::dtype(at::kInt).device(device));
auto expertTokenNumsOut = at::empty({num_local_experts}, at::dtype(at::kLong).device(device));
Expand All @@ -163,6 +179,9 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
int64_t expert_shard_type = 0;
int64_t expert_token_nums_type = 1;
int64_t global_bs = num_max_dispatch_tokens_per_rank * num_ranks;
std::string comm_log = "0";
char *comm_log_ptr = const_cast<char *>(comm_log.c_str());


// get ep & tp name
char hcom_ep_name[128];
Expand All @@ -172,7 +191,7 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
HCCL_CHECK(HcclGetCommName(ep_comm, hcom_ep_name));
}

EXEC_NPU_CMD(aclnnMoeDistributeDispatch,
EXEC_NPU_CMD(aclnnMoeDistributeDispatchV2,
new_x,
new_topk_idx,
scales, // smooth scales,
Expand All @@ -191,9 +210,10 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
quant_mode,
global_bs, // global_bs
expert_token_nums_type, // expert_token_nums_type
comm_log_ptr,
packed_recv_x,
packed_recv_x_scales, // dynamicScalesOut
expandIdx,
assist_info_for_combine,
expertTokenNumsOut,
packed_recv_count,
tp_recv_count,
Expand All @@ -203,7 +223,7 @@ std::tuple<at::Tensor, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Te
std::optional<EventHandle> event;

// Return values
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, expandIdx, expertTokenNumsOut, event, std::function<void()>([]{})};
return {packed_recv_x, packed_recv_x_scales, packed_recv_count, assist_info_for_combine, expertTokenNumsOut, event, std::function<void()>([]{})};
}

int Buffer::get_rdma_rank() const {
Expand All @@ -219,8 +239,16 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
at::Tensor new_idx = topk_idx;
at::Tensor new_scales = topk_weights;
if (this->is_padding) {
std::vector<at::Tensor> scales_blocks;
if (this->padding_cnt != 3) {
scales_blocks.emplace_back(topk_weights);
}
for (int i = 0; i < this->padding_cnt; i++) {
at::Tensor tmp_scales = torch::zeros({1, 8}, topk_weights.options());
scales_blocks.emplace_back(tmp_scales);
}
new_idx = this->new_topk_idx;
this->new_scales = torch::zeros({1, 8}, topk_weights.options());
this->new_scales = torch::cat(scales_blocks, 0);
new_scales = this->new_scales;
}
// Tensor checks
Expand All @@ -238,7 +266,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
auto device = x.device();
at::Tensor expand_x = x;
at::Tensor expert_ids = new_idx;
at::Tensor expand_idx = src_info; // handle[0] = src_info
at::Tensor assist_info_for_combine = src_info; // handle[0] = src_info
at::Tensor ep_send_counts = ep_send_count;
at::Tensor expert_scales = new_scales;
at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device));
Expand All @@ -254,13 +282,16 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v

auto num_combined_tokens = static_cast<int>(new_scales.size(0));
auto hidden = static_cast<int>(x.size(1));
at::Tensor shared_expert_x{nullptr};
at::Tensor combined_x = at::empty({num_combined_tokens, hidden}, x.options());
std::optional<EventHandle> event;
std::string comm_log = "0";
char *comm_log_ptr = const_cast<char *>(comm_log.c_str());

EXEC_NPU_CMD(aclnnMoeDistributeCombine,
EXEC_NPU_CMD(aclnnMoeDistributeCombineV2,
expand_x,
expert_ids,
expand_idx,
assist_info_for_combine,
ep_send_counts,
expert_scales,
tp_send_counts,
Expand All @@ -269,6 +300,7 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
weight_scale,
group_list,
expand_scales,
shared_expert_x,
hcom_ep_name,
num_ranks,
rank,
Expand All @@ -283,9 +315,14 @@ std::tuple<at::Tensor, std::optional<EventHandle>, std::optional<std::function<v
out_dtype,
comm_quant_mode,
group_list_type,
comm_log_ptr,
combined_x);
if (this->is_padding) {
combined_x = this->ori_x;
if (this->padding_cnt == 3) {
combined_x = this->ori_x;
} else {
combined_x = combined_x.slice(0, 0, 3 - this->padding_cnt);
}
}
return {combined_x, event, std::function<void()>([]{})};
}
Expand Down
1 change: 1 addition & 0 deletions csrc/deepep/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct Buffer {

bool low_latency_mode = false;
bool is_padding = false;
int padding_cnt = 0;
at::Tensor ori_x;
at::Tensor new_topk_idx;
at::Tensor new_scales;
Expand Down
1 change: 1 addition & 0 deletions csrc/utils/whatsup/666
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
666
84 changes: 0 additions & 84 deletions tests/python/deepep/test_low_latency.py

This file was deleted.

File renamed without changes.
Loading