Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions csrc/deepep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ constexpr int PADDING_SIZE = 3;
constexpr size_t HCOMM_NAME_LEN = 128;
constexpr uint32_t NO_SCALES = 0;
constexpr uint32_t DYNAMIC_SCALES = 2;
// In a shared header
constexpr int LOCAL_RANK_SIZE = 8;
constexpr int MAX_BATCH_SIZE = 4096;
constexpr int EXPERT_DATA_SIZE = 1 + 2 * MAX_BATCH_SIZE; // 8193

Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode,
std::string moe_all_to_all_group_name)
Expand Down Expand Up @@ -73,15 +77,44 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std:

const int num_tokens = new_topk_idx.size(0);
const int num_topk = new_topk_idx.size(1);
const int local_ranksize = LOCAL_RANK_SIZE;
auto server_num = num_ranks / local_ranksize;

auto device = new_topk_idx.device();
auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device));
auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device));
auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));

EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, num_tokens_per_rank,
num_tokens_per_expert, is_token_in_rank);

auto is_token_in_rank = at::zeros({num_tokens, num_ranks}, at::dtype(at::kInt).device(device));
const int notify_send_data_size =
num_experts * EXPERT_DATA_SIZE + server_num + MAX_BATCH_SIZE * (1 + 2 * server_num + num_topk);
/*
The notify send data is constructed by 8 parameters and the 8 parameters are ordered as follows:
1. the number of the tokens that every expert received from this NPU.
size:[numExpert]
2. The number of tokens received by each server from this NPU (deduplicated).
size:[serverNum]
3. The number of tokens sent from this NPU to each server (without deduplication).
size:[MAX_BS, serverNum]
4. The number of servers each token is sent to by this NPU.
size:[MAX_BS]
5. The order in which each token of this NPU is sent to various servers.
size:[MAX_BS, serverNum]
6. The order in which each token is sent to the expert.
size:[MAX_BS, numTopk]
7. The server offset of tokens received by each expert from this NPU.
size:[numExpert, MAX_BS]
8. The origin offset of the token received by each expert on the original NPU.
size:[numExpert, MAX_BS]
*/
auto notify_send_data = at::zeros({notify_send_data_size}, at::dtype(at::kInt).device(device));
notify_send_data
.index({at::indexing::Slice(num_experts + server_num + MAX_BATCH_SIZE * (server_num + 1),
num_experts + server_num + MAX_BATCH_SIZE * (server_num * 2 + 1))})
.fill_(-1);
// The order of each token sent to the server is set to -1.
EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize,
num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, notify_send_data);

this->notify_send_data = notify_send_data;
std::optional<torch::Tensor> num_tokens_per_rdma_rank = std::nullopt;
std::optional<EventHandle> output_event = std::nullopt;
auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool);
Expand Down
1 change: 1 addition & 0 deletions csrc/deepep/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Buffer {
at::Tensor ori_x;
at::Tensor new_topk_idx;
at::Tensor new_scales;
at::Tensor notify_send_data;

int64_t shared_expert_rank_num;
int64_t shared_expert_num = 1;
Expand Down
24 changes: 21 additions & 3 deletions csrc/deepep/ops/op_host/dispatch_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DispatchLayout : public OpDef
this->Attr("num_ranks").Int();
this->Attr("num_experts").Int();
this->Attr("num_topk").Int();
this->Attr("local_ranksize").Int();

this->Output("numTokensPerRank")
.ParamType(REQUIRED)
Expand All @@ -32,9 +33,14 @@ class DispatchLayout : public OpDef
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});
this->Output("notifySendData")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32})
.Format({ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND});

OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
OpAICoreConfig a3_config;
a3_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
Expand All @@ -44,7 +50,19 @@ class DispatchLayout : public OpDef
.ExtendCfgInfo("jitCompile.flag", "static_true")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");

this->AICore().AddConfig("ascend910_93", aicore_config);
OpAICoreConfig a2_config;
a2_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_false")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");

this->AICore().AddConfig("ascend910_93", a3_config);
this->AICore().AddConfig("ascend910b", a2_config);
}
};

Expand Down
50 changes: 45 additions & 5 deletions csrc/deepep/ops/op_host/dispatch_layout_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,24 @@ constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0;
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0;
constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1;
constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2;
constexpr uint32_t OUTPUT_NOTIFY_SEND_DATA_INDEX = 3;

constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0;
constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1;
constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2;
constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3;
constexpr uint32_t ATTR_LOCAL_RANKSIZE_INDEX = 4;
const int64_t MAX_COMM_WORLD_SIZE = 384;
const int64_t MAX_MOE_EXPERTS_NUM = 512;
const int64_t MAX_LOCAL_RANKSIZE = 8;

constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024;
constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024;

constexpr static int TILING_KEY_INT = 23;
constexpr static int TILING_KEY_A2_TYPE = 100;

constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t K_MAX = 16;
} // namespace
Expand All @@ -48,9 +55,24 @@ static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &
OP_LOGD(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks);
OP_LOGD(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts);
OP_LOGD(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk);
OP_LOGD(nodeName, "localRankSize is %u.", tilingData.dispatchLayoutInfo.localRankSize);
OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize);
}

static bool CheckIfA2Machine(gert::TilingContext *context)
{
fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
fe::PlatFormInfos &platformInfo = *platformInfoPtr;

std::string socVersion;
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);

if (socVersion == "Ascend910B") {
return true;
}
return false;
}

static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName,
DispatchLayoutTilingData &tilingData)
{
Expand All @@ -61,11 +83,14 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
auto numRanksPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_RANKS_INDEX));
auto numExpertsPtr = attrs->GetAttrPointer<int64_t>(ATTR_NUM_EXPERTS_INDEX);
auto numTopkPtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_NUM_TOPK_INDEX));
auto localRankSizePtr = attrs->GetAttrPointer<int64_t>(static_cast<int>(ATTR_LOCAL_RANKSIZE_INDEX));

OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED);
OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."),
return ge::GRAPH_FAILED);

OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE),
OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.",
Expand All @@ -80,10 +105,19 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
OP_LOGE(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr),
return ge::GRAPH_FAILED);

if (CheckIfA2Machine(context)) {
OP_TILING_CHECK(
(*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_LOCAL_RANKSIZE),
OP_LOGE(nodeName, "localRankSizePtr is invalid, only support (0, %ld], but got localRankSize=%ld.",
MAX_LOCAL_RANKSIZE, *localRankSizePtr),
return ge::GRAPH_FAILED);
}

tilingData.dispatchLayoutInfo.numTokens = static_cast<uint32_t>(*numTokensPtr);
tilingData.dispatchLayoutInfo.numRanks = static_cast<uint32_t>(*numRanksPtr);
tilingData.dispatchLayoutInfo.numExperts = static_cast<uint32_t>(*numExpertsPtr);
tilingData.dispatchLayoutInfo.numTopk = static_cast<uint32_t>(*numTopkPtr);
tilingData.dispatchLayoutInfo.localRankSize = static_cast<uint32_t>(*localRankSizePtr);

return ge::GRAPH_SUCCESS;
}
Expand All @@ -102,11 +136,13 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX);
auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX);
auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX);
auto notifySendData = context->GetOutputDesc(OUTPUT_NOTIFY_SEND_DATA_INDEX);

OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false);
OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false);
OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false);
OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false);
OP_TILING_CHECK(notifySendData == nullptr, OP_LOGE(nodeName, "notifySendData is null."), return false);

OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64),
OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.",
Expand All @@ -124,6 +160,10 @@ static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeNa
OP_LOGE(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.",
static_cast<ge::DataType>(isTokenInRank->GetDataType())),
return false);
OP_TILING_CHECK((notifySendData->GetDataType() != ge::DT_INT32),
OP_LOGE(nodeName, "notifySendData datatype is invalid, datatype should be int, but is %d.",
static_cast<ge::DataType>(notifySendData->GetDataType())),
return false);

return true;
}
Expand Down Expand Up @@ -169,11 +209,11 @@ static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context
OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS,
OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);

fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo();
fe::PlatFormInfos &platformInfo = *platformInfoPtr;

std::string socVersion;
(void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion);
int tilingKey = TILING_KEY_INT;
if (CheckIfA2Machine(context)) {
tilingKey = tilingKey + TILING_KEY_A2_TYPE;
}
context->SetTilingKey(tilingKey);

auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
uint32_t blockDim;
Expand Down
10 changes: 6 additions & 4 deletions csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ extern "C" {
#endif

aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks,
int64_t numExperts, int64_t numTopk, const aclTensor *numTokensPerRank,
const aclTensor *numTokensPerExpert, const aclTensor *isTokenInRank,
int64_t numExperts, int64_t numTopk, int64_t localRankSize,
const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert,
const aclTensor *isTokenInRank, const aclTensor *notifySendData,
uint64_t *workspaceSize, aclOpExecutor **executor)
{
return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, numTokensPerRank,
numTokensPerExpert, isTokenInRank, workspaceSize, executor);
return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, localRankSize,
numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData,
workspaceSize, executor);
}

aclnnStatus aclnnDispatchLayout(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
Expand Down
6 changes: 4 additions & 2 deletions csrc/deepep/ops/op_host/op_api/aclnn_dispatch_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@ extern "C" {
* numRanks : required
* numExperts : required
* numTopk : required
* localRankSize : required
* numTokensPerRank : required
* numTokensPerExpert : required
* isTokenInRank : required
* notifySendData : required
* workspaceSize : size of workspace(output).
* executor : executor context(output).
*/
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(
const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, int64_t numExperts, int64_t numTopk,
const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, const aclTensor *isTokenInRank,
uint64_t *workspaceSize, aclOpExecutor **executor);
int64_t localRankSize, const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert,
const aclTensor *isTokenInRank, const aclTensor *notifySendData, uint64_t *workspaceSize, aclOpExecutor **executor);

/* function: aclnnDispatchLayout
* workspace : workspace memory addr(input).
Expand Down
20 changes: 16 additions & 4 deletions csrc/deepep/ops/op_kernel/dispatch_layout.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
#include "kernel_operator.h"
#include "dispatch_layout.h"
#include "dispatch_layout_a2.h"
#include "dispatch_layout_tiling.h"

#define TILING_KEY_INT 23
#define TILING_KEY_A2_INT 123

extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank,
GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank,
GM_ADDR workspace, GM_ADDR tiling)
GM_ADDR notifySendData, GM_ADDR workspace, GM_ADDR tiling)
{
REGISTER_TILING_DEFAULT(DispatchLayoutTilingData);
GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling);

TPipe pipe;

DispatchLayout<int32_t> op;
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, workspace, &pipe, &tilingData);
op.Process();
if (TILING_KEY_IS(TILING_KEY_INT)) {
MoeDispatchLayout::DispatchLayout<int32_t> op;
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe,
&tilingData);
op.Process();
} else if (TILING_KEY_IS(TILING_KEY_A2_INT)) {
MoeDispatchLayoutA2::DispatchLayoutA2<int32_t> op;
op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, notifySendData, workspace, &pipe,
&tilingData);
op.Process();
}
}
10 changes: 6 additions & 4 deletions csrc/deepep/ops/op_kernel/dispatch_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include "sync_collectives.h"
#include "moe_distribute_base.h"
#include "dispatch_layout_tiling.h"

using namespace AscendC;
using namespace Moe;
namespace MoeDispatchLayout {

constexpr uint32_t UB_32_ALIGN = 32U;

Expand All @@ -23,14 +21,16 @@ __aicore__ inline void SyncFunc()
AscendC::WaitFlag<event>(eventID);
}

using namespace AscendC;
using namespace Moe;
template <typename T>
class DispatchLayout
{
public:
__aicore__ inline DispatchLayout(){};

__aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert,
GM_ADDR isTokenInRank, GM_ADDR workspace, TPipe *pipe,
GM_ADDR isTokenInRank, GM_ADDR notifySendData, GM_ADDR workspace, TPipe *pipe,
const DispatchLayoutTilingData *tilingData)
{
numTokens_ = tilingData->dispatchLayoutInfo.numTokens;
Expand All @@ -42,6 +42,7 @@ class DispatchLayout
coreIdx_ = GetBlockIdx();
uint32_t maxAivNum = GetBlockNum();
aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum;

if (coreIdx_ >= aivNum_) {
return;
}
Expand Down Expand Up @@ -157,5 +158,6 @@ class DispatchLayout
uint32_t numTokensPerExpert32AlignIntLen_{0};
uint32_t isTokenInRank32AlignIntLen_{0};
};
} // namespace MoeDispatchLayout

#endif // DISPATCH_LAYOUT_H
Loading
Loading